Commit 0df135ea authored by StyleZhang's avatar StyleZhang

Merge branch 'main' into feat/header-ssr

parents 4af95d70 ecd6cbae
...@@ -17,9 +17,15 @@ A single API encompassing plugin capabilities, context enhancement, and more, sa ...@@ -17,9 +17,15 @@ A single API encompassing plugin capabilities, context enhancement, and more, sa
Visual data analysis, log review, and annotation for applications Visual data analysis, log review, and annotation for applications
Dify is compatible with Langchain, meaning we'll gradually support multiple LLMs, currently supported: Dify is compatible with Langchain, meaning we'll gradually support multiple LLMs, currently supported:
- GPT 3 (text-davinci-003) * **OpenAI** :GPT4、GPT3.5-turbo、GPT3.5-turbo-16k、text-davinci-003
- GPT 3.5 Turbo(ChatGPT)
- GPT-4 * **Azure OpenAI**
* **Antropic**:Claude2、Claude-instant
> We've got 1000 free trial credits available for all cloud service users to try out the Claude model.Visit [Dify.ai](https://dify.ai) and
try it now.
* **hugging face Hub**:Coming soon.
## Use Cloud Services ## Use Cloud Services
......
...@@ -17,11 +17,16 @@ ...@@ -17,11 +17,16 @@
- 一套 API 即可包含插件、上下文增强等能力,替你省下了后端代码的编写工作 - 一套 API 即可包含插件、上下文增强等能力,替你省下了后端代码的编写工作
- 可视化的对应用进行数据分析,查阅日志或进行标注 - 可视化的对应用进行数据分析,查阅日志或进行标注
Dify 兼容 Langchain,这意味着我们将逐步支持多种 LLMs ,目前已支持 Dify 兼容 Langchain,这意味着我们将逐步支持多种 LLMs ,目前支持的模型供应商
- GPT 3 (text-davinci-003) * **OpenAI**:GPT4、GPT3.5-turbo、GPT3.5-turbo-16k、text-davinci-003
- GPT 3.5 Turbo(ChatGPT)
- GPT-4 * **Azure OpenAI Service**
* **Anthropic**:Claude2、Claude-instant
> 我们为所有注册云端版的用户免费提供了 1000 次 Claude 模型的消息调用额度,登录 [dify.ai](https://cloud.dify.ai) 即可使用。
* **Hugging Face Hub**(即将推出)
## 使用云服务 ## 使用云服务
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
import os import os
from datetime import datetime from datetime import datetime
from werkzeug.exceptions import Forbidden
if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true': if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true':
from gevent import monkey from gevent import monkey
monkey.patch_all() monkey.patch_all()
...@@ -27,7 +29,7 @@ from events import event_handlers ...@@ -27,7 +29,7 @@ from events import event_handlers
import core import core
from config import Config, CloudEditionConfig from config import Config, CloudEditionConfig
from commands import register_commands from commands import register_commands
from models.account import TenantAccountJoin from models.account import TenantAccountJoin, AccountStatus
from models.model import Account, EndUser, App from models.model import Account, EndUser, App
import warnings import warnings
...@@ -101,6 +103,9 @@ def load_user(user_id): ...@@ -101,6 +103,9 @@ def load_user(user_id):
account = db.session.query(Account).filter(Account.id == account_id).first() account = db.session.query(Account).filter(Account.id == account_id).first()
if account: if account:
if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value:
raise Forbidden('Account is banned or closed.')
workspace_id = session.get('workspace_id') workspace_id = session.get('workspace_id')
if workspace_id: if workspace_id:
tenant_account_join = db.session.query(TenantAccountJoin).filter( tenant_account_join = db.session.query(TenantAccountJoin).filter(
......
...@@ -18,7 +18,8 @@ from models.model import Account ...@@ -18,7 +18,8 @@ from models.model import Account
import secrets import secrets
import base64 import base64
from models.provider import Provider from models.provider import Provider, ProviderName
from services.provider_service import ProviderService
@click.command('reset-password', help='Reset the account password.') @click.command('reset-password', help='Reset the account password.')
...@@ -193,9 +194,40 @@ def recreate_all_dataset_indexes(): ...@@ -193,9 +194,40 @@ def recreate_all_dataset_indexes():
click.echo(click.style('Congratulations! Recreate {} dataset indexes.'.format(recreate_count), fg='green')) click.echo(click.style('Congratulations! Recreate {} dataset indexes.'.format(recreate_count), fg='green'))
@click.command('sync-anthropic-hosted-providers', help='Sync anthropic hosted providers.')
def sync_anthropic_hosted_providers():
click.echo(click.style('Start sync anthropic hosted providers.', fg='green'))
count = 0
page = 1
while True:
try:
tenants = db.session.query(Tenant).order_by(Tenant.created_at.desc()).paginate(page=page, per_page=50)
except NotFound:
break
page += 1
for tenant in tenants:
try:
click.echo('Syncing tenant anthropic hosted provider: {}'.format(tenant.id))
ProviderService.create_system_provider(
tenant,
ProviderName.ANTHROPIC.value,
current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'],
True
)
count += 1
except Exception as e:
click.echo(click.style('Sync tenant anthropic hosted provider error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
continue
click.echo(click.style('Congratulations! Synced {} anthropic hosted providers.'.format(count), fg='green'))
def register_commands(app): def register_commands(app):
app.cli.add_command(reset_password) app.cli.add_command(reset_password)
app.cli.add_command(reset_email) app.cli.add_command(reset_email)
app.cli.add_command(generate_invitation_codes) app.cli.add_command(generate_invitation_codes)
app.cli.add_command(reset_encrypt_key_pair) app.cli.add_command(reset_encrypt_key_pair)
app.cli.add_command(recreate_all_dataset_indexes) app.cli.add_command(recreate_all_dataset_indexes)
app.cli.add_command(sync_anthropic_hosted_providers)
...@@ -50,7 +50,10 @@ DEFAULTS = { ...@@ -50,7 +50,10 @@ DEFAULTS = {
'PDF_PREVIEW': 'True', 'PDF_PREVIEW': 'True',
'LOG_LEVEL': 'INFO', 'LOG_LEVEL': 'INFO',
'DISABLE_PROVIDER_CONFIG_VALIDATION': 'False', 'DISABLE_PROVIDER_CONFIG_VALIDATION': 'False',
'DEFAULT_LLM_PROVIDER': 'openai' 'DEFAULT_LLM_PROVIDER': 'openai',
'OPENAI_HOSTED_QUOTA_LIMIT': 200,
'ANTHROPIC_HOSTED_QUOTA_LIMIT': 1000,
'TENANT_DOCUMENT_COUNT': 100
} }
...@@ -86,7 +89,7 @@ class Config: ...@@ -86,7 +89,7 @@ class Config:
self.CONSOLE_URL = get_env('CONSOLE_URL') self.CONSOLE_URL = get_env('CONSOLE_URL')
self.API_URL = get_env('API_URL') self.API_URL = get_env('API_URL')
self.APP_URL = get_env('APP_URL') self.APP_URL = get_env('APP_URL')
self.CURRENT_VERSION = "0.3.8" self.CURRENT_VERSION = "0.3.9"
self.COMMIT_SHA = get_env('COMMIT_SHA') self.COMMIT_SHA = get_env('COMMIT_SHA')
self.EDITION = "SELF_HOSTED" self.EDITION = "SELF_HOSTED"
self.DEPLOY_ENV = get_env('DEPLOY_ENV') self.DEPLOY_ENV = get_env('DEPLOY_ENV')
...@@ -191,6 +194,10 @@ class Config: ...@@ -191,6 +194,10 @@ class Config:
# hosted provider credentials # hosted provider credentials
self.OPENAI_API_KEY = get_env('OPENAI_API_KEY') self.OPENAI_API_KEY = get_env('OPENAI_API_KEY')
self.ANTHROPIC_API_KEY = get_env('ANTHROPIC_API_KEY')
self.OPENAI_HOSTED_QUOTA_LIMIT = get_env('OPENAI_HOSTED_QUOTA_LIMIT')
self.ANTHROPIC_HOSTED_QUOTA_LIMIT = get_env('ANTHROPIC_HOSTED_QUOTA_LIMIT')
# By default it is False # By default it is False
# You could disable it for compatibility with certain OpenAPI providers # You could disable it for compatibility with certain OpenAPI providers
...@@ -207,6 +214,8 @@ class Config: ...@@ -207,6 +214,8 @@ class Config:
self.NOTION_INTERNAL_SECRET = get_env('NOTION_INTERNAL_SECRET') self.NOTION_INTERNAL_SECRET = get_env('NOTION_INTERNAL_SECRET')
self.NOTION_INTEGRATION_TOKEN = get_env('NOTION_INTEGRATION_TOKEN') self.NOTION_INTEGRATION_TOKEN = get_env('NOTION_INTEGRATION_TOKEN')
self.TENANT_DOCUMENT_COUNT = get_env('TENANT_DOCUMENT_COUNT')
class CloudEditionConfig(Config): class CloudEditionConfig(Config):
......
...@@ -50,8 +50,8 @@ class ChatMessageAudioApi(Resource): ...@@ -50,8 +50,8 @@ class ChatMessageAudioApi(Resource):
raise UnsupportedAudioTypeError() raise UnsupportedAudioTypeError()
except ProviderNotSupportSpeechToTextServiceError: except ProviderNotSupportSpeechToTextServiceError:
raise ProviderNotSupportSpeechToTextError() raise ProviderNotSupportSpeechToTextError()
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError() raise ProviderNotInitializeError(ex.description)
except QuotaExceededError: except QuotaExceededError:
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
......
...@@ -63,8 +63,8 @@ class CompletionMessageApi(Resource): ...@@ -63,8 +63,8 @@ class CompletionMessageApi(Resource):
except services.errors.app_model_config.AppModelConfigBrokenError: except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.") logging.exception("App model config broken.")
raise AppUnavailableError() raise AppUnavailableError()
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError() raise ProviderNotInitializeError(ex.description)
except QuotaExceededError: except QuotaExceededError:
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
...@@ -133,8 +133,8 @@ class ChatMessageApi(Resource): ...@@ -133,8 +133,8 @@ class ChatMessageApi(Resource):
except services.errors.app_model_config.AppModelConfigBrokenError: except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.") logging.exception("App model config broken.")
raise AppUnavailableError() raise AppUnavailableError()
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError() raise ProviderNotInitializeError(ex.description)
except QuotaExceededError: except QuotaExceededError:
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
...@@ -164,8 +164,8 @@ def compact_response(response: Union[dict | Generator]) -> Response: ...@@ -164,8 +164,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
except services.errors.app_model_config.AppModelConfigBrokenError: except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.") logging.exception("App model config broken.")
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
except QuotaExceededError: except QuotaExceededError:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
......
...@@ -16,7 +16,7 @@ class ProviderNotInitializeError(BaseHTTPException): ...@@ -16,7 +16,7 @@ class ProviderNotInitializeError(BaseHTTPException):
class ProviderQuotaExceededError(BaseHTTPException): class ProviderQuotaExceededError(BaseHTTPException):
error_code = 'provider_quota_exceeded' error_code = 'provider_quota_exceeded'
description = "Your quota for Dify Hosted OpenAI has been exhausted. " \ description = "Your quota for Dify Hosted Model Provider has been exhausted. " \
"Please go to Settings -> Model Provider to complete your own provider credentials." "Please go to Settings -> Model Provider to complete your own provider credentials."
code = 400 code = 400
......
...@@ -27,8 +27,8 @@ class IntroductionGenerateApi(Resource): ...@@ -27,8 +27,8 @@ class IntroductionGenerateApi(Resource):
account.current_tenant_id, account.current_tenant_id,
args['prompt_template'] args['prompt_template']
) )
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError() raise ProviderNotInitializeError(ex.description)
except QuotaExceededError: except QuotaExceededError:
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
...@@ -58,8 +58,8 @@ class RuleGenerateApi(Resource): ...@@ -58,8 +58,8 @@ class RuleGenerateApi(Resource):
args['audiences'], args['audiences'],
args['hoping_to_solve'] args['hoping_to_solve']
) )
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError() raise ProviderNotInitializeError(ex.description)
except QuotaExceededError: except QuotaExceededError:
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
......
...@@ -269,8 +269,8 @@ class MessageMoreLikeThisApi(Resource): ...@@ -269,8 +269,8 @@ class MessageMoreLikeThisApi(Resource):
raise NotFound("Message Not Exists.") raise NotFound("Message Not Exists.")
except MoreLikeThisDisabledError: except MoreLikeThisDisabledError:
raise AppMoreLikeThisDisabledError() raise AppMoreLikeThisDisabledError()
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError() raise ProviderNotInitializeError(ex.description)
except QuotaExceededError: except QuotaExceededError:
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
...@@ -297,8 +297,8 @@ def compact_response(response: Union[dict | Generator]) -> Response: ...@@ -297,8 +297,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n"
except MoreLikeThisDisabledError: except MoreLikeThisDisabledError:
yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n"
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
except QuotaExceededError: except QuotaExceededError:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
...@@ -339,8 +339,8 @@ class MessageSuggestedQuestionApi(Resource): ...@@ -339,8 +339,8 @@ class MessageSuggestedQuestionApi(Resource):
raise NotFound("Message not found") raise NotFound("Message not found")
except ConversationNotExistsError: except ConversationNotExistsError:
raise NotFound("Conversation not found") raise NotFound("Conversation not found")
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError() raise ProviderNotInitializeError(ex.description)
except QuotaExceededError: except QuotaExceededError:
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
......
...@@ -279,8 +279,8 @@ class DatasetDocumentListApi(Resource): ...@@ -279,8 +279,8 @@ class DatasetDocumentListApi(Resource):
try: try:
documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user) documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user)
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError() raise ProviderNotInitializeError(ex.description)
except QuotaExceededError: except QuotaExceededError:
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
...@@ -324,8 +324,8 @@ class DatasetInitApi(Resource): ...@@ -324,8 +324,8 @@ class DatasetInitApi(Resource):
document_data=args, document_data=args,
account=current_user account=current_user
) )
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError() raise ProviderNotInitializeError(ex.description)
except QuotaExceededError: except QuotaExceededError:
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
......
...@@ -95,8 +95,8 @@ class HitTestingApi(Resource): ...@@ -95,8 +95,8 @@ class HitTestingApi(Resource):
return {"query": response['query'], 'records': marshal(response['records'], hit_testing_record_fields)} return {"query": response['query'], 'records': marshal(response['records'], hit_testing_record_fields)}
except services.errors.index.IndexNotInitializedError: except services.errors.index.IndexNotInitializedError:
raise DatasetNotInitializedError() raise DatasetNotInitializedError()
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError() raise ProviderNotInitializeError(ex.description)
except QuotaExceededError: except QuotaExceededError:
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
......
...@@ -47,8 +47,8 @@ class ChatAudioApi(InstalledAppResource): ...@@ -47,8 +47,8 @@ class ChatAudioApi(InstalledAppResource):
raise UnsupportedAudioTypeError() raise UnsupportedAudioTypeError()
except ProviderNotSupportSpeechToTextServiceError: except ProviderNotSupportSpeechToTextServiceError:
raise ProviderNotSupportSpeechToTextError() raise ProviderNotSupportSpeechToTextError()
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError() raise ProviderNotInitializeError(ex.description)
except QuotaExceededError: except QuotaExceededError:
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
......
...@@ -54,8 +54,8 @@ class CompletionApi(InstalledAppResource): ...@@ -54,8 +54,8 @@ class CompletionApi(InstalledAppResource):
except services.errors.app_model_config.AppModelConfigBrokenError: except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.") logging.exception("App model config broken.")
raise AppUnavailableError() raise AppUnavailableError()
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError() raise ProviderNotInitializeError(ex.description)
except QuotaExceededError: except QuotaExceededError:
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
...@@ -113,8 +113,8 @@ class ChatApi(InstalledAppResource): ...@@ -113,8 +113,8 @@ class ChatApi(InstalledAppResource):
except services.errors.app_model_config.AppModelConfigBrokenError: except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.") logging.exception("App model config broken.")
raise AppUnavailableError() raise AppUnavailableError()
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError() raise ProviderNotInitializeError(ex.description)
except QuotaExceededError: except QuotaExceededError:
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
...@@ -155,8 +155,8 @@ def compact_response(response: Union[dict | Generator]) -> Response: ...@@ -155,8 +155,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
except services.errors.app_model_config.AppModelConfigBrokenError: except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.") logging.exception("App model config broken.")
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
except QuotaExceededError: except QuotaExceededError:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
......
...@@ -107,8 +107,8 @@ class MessageMoreLikeThisApi(InstalledAppResource): ...@@ -107,8 +107,8 @@ class MessageMoreLikeThisApi(InstalledAppResource):
raise NotFound("Message Not Exists.") raise NotFound("Message Not Exists.")
except MoreLikeThisDisabledError: except MoreLikeThisDisabledError:
raise AppMoreLikeThisDisabledError() raise AppMoreLikeThisDisabledError()
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError() raise ProviderNotInitializeError(ex.description)
except QuotaExceededError: except QuotaExceededError:
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
...@@ -135,8 +135,8 @@ def compact_response(response: Union[dict | Generator]) -> Response: ...@@ -135,8 +135,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n"
except MoreLikeThisDisabledError: except MoreLikeThisDisabledError:
yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n"
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
except QuotaExceededError: except QuotaExceededError:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
...@@ -174,8 +174,8 @@ class MessageSuggestedQuestionApi(InstalledAppResource): ...@@ -174,8 +174,8 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
raise NotFound("Conversation not found") raise NotFound("Conversation not found")
except SuggestedQuestionsAfterAnswerDisabledError: except SuggestedQuestionsAfterAnswerDisabledError:
raise AppSuggestedQuestionsAfterAnswerDisabledError() raise AppSuggestedQuestionsAfterAnswerDisabledError()
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError() raise ProviderNotInitializeError(ex.description)
except QuotaExceededError: except QuotaExceededError:
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
......
...@@ -3,6 +3,7 @@ import base64 ...@@ -3,6 +3,7 @@ import base64
import json import json
import logging import logging
from flask import current_app
from flask_login import login_required, current_user from flask_login import login_required, current_user
from flask_restful import Resource, reqparse, abort from flask_restful import Resource, reqparse, abort
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
...@@ -34,7 +35,7 @@ class ProviderListApi(Resource): ...@@ -34,7 +35,7 @@ class ProviderListApi(Resource):
plaintext, the rest is replaced by * and the last two bits are displayed in plaintext plaintext, the rest is replaced by * and the last two bits are displayed in plaintext
""" """
ProviderService.init_supported_provider(current_user.current_tenant, "cloud") ProviderService.init_supported_provider(current_user.current_tenant)
providers = Provider.query.filter_by(tenant_id=tenant_id).all() providers = Provider.query.filter_by(tenant_id=tenant_id).all()
provider_list = [ provider_list = [
...@@ -50,7 +51,8 @@ class ProviderListApi(Resource): ...@@ -50,7 +51,8 @@ class ProviderListApi(Resource):
'quota_used': p.quota_used 'quota_used': p.quota_used
} if p.provider_type == ProviderType.SYSTEM.value else {}), } if p.provider_type == ProviderType.SYSTEM.value else {}),
'token': ProviderService.get_obfuscated_api_key(current_user.current_tenant, 'token': ProviderService.get_obfuscated_api_key(current_user.current_tenant,
ProviderName(p.provider_name)) ProviderName(p.provider_name), only_custom=True)
if p.provider_type == ProviderType.CUSTOM.value else None
} }
for p in providers for p in providers
] ]
...@@ -121,9 +123,10 @@ class ProviderTokenApi(Resource): ...@@ -121,9 +123,10 @@ class ProviderTokenApi(Resource):
is_valid=token_is_valid) is_valid=token_is_valid)
db.session.add(provider_model) db.session.add(provider_model)
if provider_model.is_valid: if provider in [ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value] and provider_model.is_valid:
other_providers = db.session.query(Provider).filter( other_providers = db.session.query(Provider).filter(
Provider.tenant_id == tenant.id, Provider.tenant_id == tenant.id,
Provider.provider_name.in_([ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value]),
Provider.provider_name != provider, Provider.provider_name != provider,
Provider.provider_type == ProviderType.CUSTOM.value Provider.provider_type == ProviderType.CUSTOM.value
).all() ).all()
...@@ -133,7 +136,7 @@ class ProviderTokenApi(Resource): ...@@ -133,7 +136,7 @@ class ProviderTokenApi(Resource):
db.session.commit() db.session.commit()
if provider in [ProviderName.ANTHROPIC.value, ProviderName.AZURE_OPENAI.value, ProviderName.COHERE.value, if provider in [ProviderName.AZURE_OPENAI.value, ProviderName.COHERE.value,
ProviderName.HUGGINGFACEHUB.value]: ProviderName.HUGGINGFACEHUB.value]:
return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'}, 201 return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'}, 201
...@@ -157,7 +160,7 @@ class ProviderTokenValidateApi(Resource): ...@@ -157,7 +160,7 @@ class ProviderTokenValidateApi(Resource):
args = parser.parse_args() args = parser.parse_args()
# todo: remove this when the provider is supported # todo: remove this when the provider is supported
if provider in [ProviderName.ANTHROPIC.value, ProviderName.COHERE.value, if provider in [ProviderName.COHERE.value,
ProviderName.HUGGINGFACEHUB.value]: ProviderName.HUGGINGFACEHUB.value]:
return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'} return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'}
...@@ -203,7 +206,19 @@ class ProviderSystemApi(Resource): ...@@ -203,7 +206,19 @@ class ProviderSystemApi(Resource):
provider_model.is_valid = args['is_enabled'] provider_model.is_valid = args['is_enabled']
db.session.commit() db.session.commit()
elif not provider_model: elif not provider_model:
ProviderService.create_system_provider(tenant, provider, args['is_enabled']) if provider == ProviderName.OPENAI.value:
quota_limit = current_app.config['OPENAI_HOSTED_QUOTA_LIMIT']
elif provider == ProviderName.ANTHROPIC.value:
quota_limit = current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT']
else:
quota_limit = 0
ProviderService.create_system_provider(
tenant,
provider,
quota_limit,
args['is_enabled']
)
else: else:
abort(403) abort(403)
......
...@@ -43,8 +43,8 @@ class AudioApi(AppApiResource): ...@@ -43,8 +43,8 @@ class AudioApi(AppApiResource):
raise UnsupportedAudioTypeError() raise UnsupportedAudioTypeError()
except ProviderNotSupportSpeechToTextServiceError: except ProviderNotSupportSpeechToTextServiceError:
raise ProviderNotSupportSpeechToTextError() raise ProviderNotSupportSpeechToTextError()
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError() raise ProviderNotInitializeError(ex.description)
except QuotaExceededError: except QuotaExceededError:
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
......
...@@ -54,8 +54,8 @@ class CompletionApi(AppApiResource): ...@@ -54,8 +54,8 @@ class CompletionApi(AppApiResource):
except services.errors.app_model_config.AppModelConfigBrokenError: except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.") logging.exception("App model config broken.")
raise AppUnavailableError() raise AppUnavailableError()
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError() raise ProviderNotInitializeError(ex.description)
except QuotaExceededError: except QuotaExceededError:
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
...@@ -115,8 +115,8 @@ class ChatApi(AppApiResource): ...@@ -115,8 +115,8 @@ class ChatApi(AppApiResource):
except services.errors.app_model_config.AppModelConfigBrokenError: except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.") logging.exception("App model config broken.")
raise AppUnavailableError() raise AppUnavailableError()
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError() raise ProviderNotInitializeError(ex.description)
except QuotaExceededError: except QuotaExceededError:
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
...@@ -156,8 +156,8 @@ def compact_response(response: Union[dict | Generator]) -> Response: ...@@ -156,8 +156,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
except services.errors.app_model_config.AppModelConfigBrokenError: except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.") logging.exception("App model config broken.")
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
except QuotaExceededError: except QuotaExceededError:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
......
...@@ -85,8 +85,8 @@ class DocumentListApi(DatasetApiResource): ...@@ -85,8 +85,8 @@ class DocumentListApi(DatasetApiResource):
dataset_process_rule=dataset.latest_process_rule, dataset_process_rule=dataset.latest_process_rule,
created_from='api' created_from='api'
) )
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError() raise ProviderNotInitializeError(ex.description)
document = documents[0] document = documents[0]
if doc_type and doc_metadata: if doc_type and doc_metadata:
metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type] metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type]
......
...@@ -45,8 +45,8 @@ class AudioApi(WebApiResource): ...@@ -45,8 +45,8 @@ class AudioApi(WebApiResource):
raise UnsupportedAudioTypeError() raise UnsupportedAudioTypeError()
except ProviderNotSupportSpeechToTextServiceError: except ProviderNotSupportSpeechToTextServiceError:
raise ProviderNotSupportSpeechToTextError() raise ProviderNotSupportSpeechToTextError()
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError() raise ProviderNotInitializeError(ex.description)
except QuotaExceededError: except QuotaExceededError:
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
......
...@@ -52,8 +52,8 @@ class CompletionApi(WebApiResource): ...@@ -52,8 +52,8 @@ class CompletionApi(WebApiResource):
except services.errors.app_model_config.AppModelConfigBrokenError: except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.") logging.exception("App model config broken.")
raise AppUnavailableError() raise AppUnavailableError()
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError() raise ProviderNotInitializeError(ex.description)
except QuotaExceededError: except QuotaExceededError:
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
...@@ -109,8 +109,8 @@ class ChatApi(WebApiResource): ...@@ -109,8 +109,8 @@ class ChatApi(WebApiResource):
except services.errors.app_model_config.AppModelConfigBrokenError: except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.") logging.exception("App model config broken.")
raise AppUnavailableError() raise AppUnavailableError()
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError() raise ProviderNotInitializeError(ex.description)
except QuotaExceededError: except QuotaExceededError:
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
...@@ -150,8 +150,8 @@ def compact_response(response: Union[dict | Generator]) -> Response: ...@@ -150,8 +150,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
except services.errors.app_model_config.AppModelConfigBrokenError: except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.") logging.exception("App model config broken.")
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
except QuotaExceededError: except QuotaExceededError:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
......
...@@ -101,8 +101,8 @@ class MessageMoreLikeThisApi(WebApiResource): ...@@ -101,8 +101,8 @@ class MessageMoreLikeThisApi(WebApiResource):
raise NotFound("Message Not Exists.") raise NotFound("Message Not Exists.")
except MoreLikeThisDisabledError: except MoreLikeThisDisabledError:
raise AppMoreLikeThisDisabledError() raise AppMoreLikeThisDisabledError()
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError() raise ProviderNotInitializeError(ex.description)
except QuotaExceededError: except QuotaExceededError:
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
...@@ -129,8 +129,8 @@ def compact_response(response: Union[dict | Generator]) -> Response: ...@@ -129,8 +129,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n"
except MoreLikeThisDisabledError: except MoreLikeThisDisabledError:
yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n"
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
except QuotaExceededError: except QuotaExceededError:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
...@@ -167,8 +167,8 @@ class MessageSuggestedQuestionApi(WebApiResource): ...@@ -167,8 +167,8 @@ class MessageSuggestedQuestionApi(WebApiResource):
raise NotFound("Conversation not found") raise NotFound("Conversation not found")
except SuggestedQuestionsAfterAnswerDisabledError: except SuggestedQuestionsAfterAnswerDisabledError:
raise AppSuggestedQuestionsAfterAnswerDisabledError() raise AppSuggestedQuestionsAfterAnswerDisabledError()
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError() raise ProviderNotInitializeError(ex.description)
except QuotaExceededError: except QuotaExceededError:
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
......
...@@ -13,8 +13,13 @@ class HostedOpenAICredential(BaseModel): ...@@ -13,8 +13,13 @@ class HostedOpenAICredential(BaseModel):
api_key: str api_key: str
class HostedAnthropicCredential(BaseModel):
api_key: str
class HostedLLMCredentials(BaseModel): class HostedLLMCredentials(BaseModel):
openai: Optional[HostedOpenAICredential] = None openai: Optional[HostedOpenAICredential] = None
anthropic: Optional[HostedAnthropicCredential] = None
hosted_llm_credentials = HostedLLMCredentials() hosted_llm_credentials = HostedLLMCredentials()
...@@ -26,3 +31,6 @@ def init_app(app: Flask): ...@@ -26,3 +31,6 @@ def init_app(app: Flask):
if app.config.get("OPENAI_API_KEY"): if app.config.get("OPENAI_API_KEY"):
hosted_llm_credentials.openai = HostedOpenAICredential(api_key=app.config.get("OPENAI_API_KEY")) hosted_llm_credentials.openai = HostedOpenAICredential(api_key=app.config.get("OPENAI_API_KEY"))
if app.config.get("ANTHROPIC_API_KEY"):
hosted_llm_credentials.anthropic = HostedAnthropicCredential(api_key=app.config.get("ANTHROPIC_API_KEY"))
...@@ -48,7 +48,7 @@ class LLMCallbackHandler(BaseCallbackHandler): ...@@ -48,7 +48,7 @@ class LLMCallbackHandler(BaseCallbackHandler):
}) })
self.llm_message.prompt = real_prompts self.llm_message.prompt = real_prompts
self.llm_message.prompt_tokens = self.llm.get_messages_tokens(messages[0]) self.llm_message.prompt_tokens = self.llm.get_num_tokens_from_messages(messages[0])
def on_llm_start( def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
......
...@@ -118,6 +118,7 @@ class Completion: ...@@ -118,6 +118,7 @@ class Completion:
prompt, stop_words = cls.get_main_llm_prompt( prompt, stop_words = cls.get_main_llm_prompt(
mode=mode, mode=mode,
llm=final_llm, llm=final_llm,
model=app_model_config.model_dict,
pre_prompt=app_model_config.pre_prompt, pre_prompt=app_model_config.pre_prompt,
query=query, query=query,
inputs=inputs, inputs=inputs,
...@@ -129,6 +130,7 @@ class Completion: ...@@ -129,6 +130,7 @@ class Completion:
cls.recale_llm_max_tokens( cls.recale_llm_max_tokens(
final_llm=final_llm, final_llm=final_llm,
model=app_model_config.model_dict,
prompt=prompt, prompt=prompt,
mode=mode mode=mode
) )
...@@ -138,7 +140,8 @@ class Completion: ...@@ -138,7 +140,8 @@ class Completion:
return response return response
@classmethod @classmethod
def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, pre_prompt: str, query: str, inputs: dict, def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, model: dict,
pre_prompt: str, query: str, inputs: dict,
chain_output: Optional[str], chain_output: Optional[str],
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \ memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
Tuple[Union[str | List[BaseMessage]], Optional[List[str]]]: Tuple[Union[str | List[BaseMessage]], Optional[List[str]]]:
...@@ -151,10 +154,11 @@ class Completion: ...@@ -151,10 +154,11 @@ class Completion:
if mode == 'completion': if mode == 'completion':
prompt_template = JinjaPromptTemplate.from_template( prompt_template = JinjaPromptTemplate.from_template(
template=("""Use the following CONTEXT as your learned knowledge: template=("""Use the following context as your learned knowledge, inside <context></context> XML tags.
[CONTEXT]
<context>
{{context}} {{context}}
[END CONTEXT] </context>
When answer to user: When answer to user:
- If you don't know, just say that you don't know. - If you don't know, just say that you don't know.
...@@ -204,10 +208,11 @@ And answer according to the language of the user's question. ...@@ -204,10 +208,11 @@ And answer according to the language of the user's question.
if chain_output: if chain_output:
human_inputs['context'] = chain_output human_inputs['context'] = chain_output
human_message_prompt += """Use the following CONTEXT as your learned knowledge. human_message_prompt += """Use the following context as your learned knowledge, inside <context></context> XML tags.
[CONTEXT]
<context>
{{context}} {{context}}
[END CONTEXT] </context>
When answer to user: When answer to user:
- If you don't know, just say that you don't know. - If you don't know, just say that you don't know.
...@@ -219,7 +224,7 @@ And answer according to the language of the user's question. ...@@ -219,7 +224,7 @@ And answer according to the language of the user's question.
if pre_prompt: if pre_prompt:
human_message_prompt += pre_prompt human_message_prompt += pre_prompt
query_prompt = "\nHuman: {{query}}\nAI: " query_prompt = "\n\nHuman: {{query}}\n\nAssistant: "
if memory: if memory:
# append chat histories # append chat histories
...@@ -228,9 +233,11 @@ And answer according to the language of the user's question. ...@@ -228,9 +233,11 @@ And answer according to the language of the user's question.
inputs=human_inputs inputs=human_inputs
) )
curr_message_tokens = memory.llm.get_messages_tokens([tmp_human_message]) curr_message_tokens = memory.llm.get_num_tokens_from_messages([tmp_human_message])
rest_tokens = llm_constant.max_context_token_length[memory.llm.model_name] \ model_name = model['name']
- memory.llm.max_tokens - curr_message_tokens max_tokens = model.get("completion_params").get('max_tokens')
rest_tokens = llm_constant.max_context_token_length[model_name] \
- max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0) rest_tokens = max(rest_tokens, 0)
histories = cls.get_history_messages_from_memory(memory, rest_tokens) histories = cls.get_history_messages_from_memory(memory, rest_tokens)
...@@ -241,7 +248,10 @@ And answer according to the language of the user's question. ...@@ -241,7 +248,10 @@ And answer according to the language of the user's question.
# if histories_param not in human_inputs: # if histories_param not in human_inputs:
# human_inputs[histories_param] = '{{' + histories_param + '}}' # human_inputs[histories_param] = '{{' + histories_param + '}}'
human_message_prompt += "\n\n" + histories human_message_prompt += "\n\n" if human_message_prompt else ""
human_message_prompt += "Here is the chat histories between human and assistant, " \
"inside <histories></histories> XML tags.\n\n<histories>"
human_message_prompt += histories + "</histories>"
human_message_prompt += query_prompt human_message_prompt += query_prompt
...@@ -307,13 +317,15 @@ And answer according to the language of the user's question. ...@@ -307,13 +317,15 @@ And answer according to the language of the user's question.
model=app_model_config.model_dict model=app_model_config.model_dict
) )
model_limited_tokens = llm_constant.max_context_token_length[llm.model_name] model_name = app_model_config.model_dict.get("name")
max_tokens = llm.max_tokens model_limited_tokens = llm_constant.max_context_token_length[model_name]
max_tokens = app_model_config.model_dict.get("completion_params").get('max_tokens')
# get prompt without memory and context # get prompt without memory and context
prompt, _ = cls.get_main_llm_prompt( prompt, _ = cls.get_main_llm_prompt(
mode=mode, mode=mode,
llm=llm, llm=llm,
model=app_model_config.model_dict,
pre_prompt=app_model_config.pre_prompt, pre_prompt=app_model_config.pre_prompt,
query=query, query=query,
inputs=inputs, inputs=inputs,
...@@ -332,16 +344,17 @@ And answer according to the language of the user's question. ...@@ -332,16 +344,17 @@ And answer according to the language of the user's question.
return rest_tokens return rest_tokens
@classmethod @classmethod
def recale_llm_max_tokens(cls, final_llm: Union[StreamableOpenAI, StreamableChatOpenAI], def recale_llm_max_tokens(cls, final_llm: BaseLanguageModel, model: dict,
prompt: Union[str, List[BaseMessage]], mode: str): prompt: Union[str, List[BaseMessage]], mode: str):
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit # recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
model_limited_tokens = llm_constant.max_context_token_length[final_llm.model_name] model_name = model.get("name")
max_tokens = final_llm.max_tokens model_limited_tokens = llm_constant.max_context_token_length[model_name]
max_tokens = model.get("completion_params").get('max_tokens')
if mode == 'completion' and isinstance(final_llm, BaseLLM): if mode == 'completion' and isinstance(final_llm, BaseLLM):
prompt_tokens = final_llm.get_num_tokens(prompt) prompt_tokens = final_llm.get_num_tokens(prompt)
else: else:
prompt_tokens = final_llm.get_messages_tokens(prompt) prompt_tokens = final_llm.get_num_tokens_from_messages(prompt)
if prompt_tokens + max_tokens > model_limited_tokens: if prompt_tokens + max_tokens > model_limited_tokens:
max_tokens = max(model_limited_tokens - prompt_tokens, 16) max_tokens = max(model_limited_tokens - prompt_tokens, 16)
...@@ -350,9 +363,10 @@ And answer according to the language of the user's question. ...@@ -350,9 +363,10 @@ And answer according to the language of the user's question.
@classmethod @classmethod
def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str, def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str,
app_model_config: AppModelConfig, user: Account, streaming: bool): app_model_config: AppModelConfig, user: Account, streaming: bool):
llm: StreamableOpenAI = LLMBuilder.to_llm(
llm = LLMBuilder.to_llm_from_model(
tenant_id=app.tenant_id, tenant_id=app.tenant_id,
model_name='gpt-3.5-turbo', model=app_model_config.model_dict,
streaming=streaming streaming=streaming
) )
...@@ -360,6 +374,7 @@ And answer according to the language of the user's question. ...@@ -360,6 +374,7 @@ And answer according to the language of the user's question.
original_prompt, _ = cls.get_main_llm_prompt( original_prompt, _ = cls.get_main_llm_prompt(
mode="completion", mode="completion",
llm=llm, llm=llm,
model=app_model_config.model_dict,
pre_prompt=pre_prompt, pre_prompt=pre_prompt,
query=message.query, query=message.query,
inputs=message.inputs, inputs=message.inputs,
...@@ -390,6 +405,7 @@ And answer according to the language of the user's question. ...@@ -390,6 +405,7 @@ And answer according to the language of the user's question.
cls.recale_llm_max_tokens( cls.recale_llm_max_tokens(
final_llm=llm, final_llm=llm,
model=app_model_config.model_dict,
prompt=prompt, prompt=prompt,
mode='completion' mode='completion'
) )
......
from _decimal import Decimal from _decimal import Decimal
models = { models = {
'claude-instant-1': 'anthropic', # 100,000 tokens
'claude-2': 'anthropic', # 100,000 tokens
'gpt-4': 'openai', # 8,192 tokens 'gpt-4': 'openai', # 8,192 tokens
'gpt-4-32k': 'openai', # 32,768 tokens 'gpt-4-32k': 'openai', # 32,768 tokens
'gpt-3.5-turbo': 'openai', # 4,096 tokens 'gpt-3.5-turbo': 'openai', # 4,096 tokens
...@@ -10,10 +12,13 @@ models = { ...@@ -10,10 +12,13 @@ models = {
'text-curie-001': 'openai', # 2,049 tokens 'text-curie-001': 'openai', # 2,049 tokens
'text-babbage-001': 'openai', # 2,049 tokens 'text-babbage-001': 'openai', # 2,049 tokens
'text-ada-001': 'openai', # 2,049 tokens 'text-ada-001': 'openai', # 2,049 tokens
'text-embedding-ada-002': 'openai' # 8191 tokens, 1536 dimensions 'text-embedding-ada-002': 'openai', # 8191 tokens, 1536 dimensions
'whisper-1': 'openai'
} }
max_context_token_length = { max_context_token_length = {
'claude-instant-1': 100000,
'claude-2': 100000,
'gpt-4': 8192, 'gpt-4': 8192,
'gpt-4-32k': 32768, 'gpt-4-32k': 32768,
'gpt-3.5-turbo': 4096, 'gpt-3.5-turbo': 4096,
...@@ -23,17 +28,21 @@ max_context_token_length = { ...@@ -23,17 +28,21 @@ max_context_token_length = {
'text-curie-001': 2049, 'text-curie-001': 2049,
'text-babbage-001': 2049, 'text-babbage-001': 2049,
'text-ada-001': 2049, 'text-ada-001': 2049,
'text-embedding-ada-002': 8191 'text-embedding-ada-002': 8191,
} }
models_by_mode = { models_by_mode = {
'chat': [ 'chat': [
'claude-instant-1', # 100,000 tokens
'claude-2', # 100,000 tokens
'gpt-4', # 8,192 tokens 'gpt-4', # 8,192 tokens
'gpt-4-32k', # 32,768 tokens 'gpt-4-32k', # 32,768 tokens
'gpt-3.5-turbo', # 4,096 tokens 'gpt-3.5-turbo', # 4,096 tokens
'gpt-3.5-turbo-16k', # 16,384 tokens 'gpt-3.5-turbo-16k', # 16,384 tokens
], ],
'completion': [ 'completion': [
'claude-instant-1', # 100,000 tokens
'claude-2', # 100,000 tokens
'gpt-4', # 8,192 tokens 'gpt-4', # 8,192 tokens
'gpt-4-32k', # 32,768 tokens 'gpt-4-32k', # 32,768 tokens
'gpt-3.5-turbo', # 4,096 tokens 'gpt-3.5-turbo', # 4,096 tokens
...@@ -52,6 +61,14 @@ models_by_mode = { ...@@ -52,6 +61,14 @@ models_by_mode = {
model_currency = 'USD' model_currency = 'USD'
model_prices = { model_prices = {
'claude-instant-1': {
'prompt': Decimal('0.00163'),
'completion': Decimal('0.00551'),
},
'claude-2': {
'prompt': Decimal('0.01102'),
'completion': Decimal('0.03268'),
},
'gpt-4': { 'gpt-4': {
'prompt': Decimal('0.03'), 'prompt': Decimal('0.03'),
'completion': Decimal('0.06'), 'completion': Decimal('0.06'),
......
...@@ -56,7 +56,7 @@ class ConversationMessageTask: ...@@ -56,7 +56,7 @@ class ConversationMessageTask:
) )
def init(self): def init(self):
provider_name = LLMBuilder.get_default_provider(self.app.tenant_id) provider_name = LLMBuilder.get_default_provider(self.app.tenant_id, self.model_name)
self.model_dict['provider'] = provider_name self.model_dict['provider'] = provider_name
override_model_configs = None override_model_configs = None
...@@ -89,7 +89,7 @@ class ConversationMessageTask: ...@@ -89,7 +89,7 @@ class ConversationMessageTask:
system_message = PromptBuilder.to_system_message(self.app_model_config.pre_prompt, self.inputs) system_message = PromptBuilder.to_system_message(self.app_model_config.pre_prompt, self.inputs)
system_instruction = system_message.content system_instruction = system_message.content
llm = LLMBuilder.to_llm(self.tenant_id, self.model_name) llm = LLMBuilder.to_llm(self.tenant_id, self.model_name)
system_instruction_tokens = llm.get_messages_tokens([system_message]) system_instruction_tokens = llm.get_num_tokens_from_messages([system_message])
if not self.conversation: if not self.conversation:
self.is_new_conversation = True self.is_new_conversation = True
...@@ -185,6 +185,7 @@ class ConversationMessageTask: ...@@ -185,6 +185,7 @@ class ConversationMessageTask:
if provider and provider.provider_type == ProviderType.SYSTEM.value: if provider and provider.provider_type == ProviderType.SYSTEM.value:
db.session.query(Provider).filter( db.session.query(Provider).filter(
Provider.tenant_id == self.app.tenant_id, Provider.tenant_id == self.app.tenant_id,
Provider.provider_name == provider.provider_name,
Provider.quota_limit > Provider.quota_used Provider.quota_limit > Provider.quota_used
).update({'quota_used': Provider.quota_used + 1}) ).update({'quota_used': Provider.quota_used + 1})
......
...@@ -4,6 +4,7 @@ from typing import List ...@@ -4,6 +4,7 @@ from typing import List
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
from extensions.ext_database import db from extensions.ext_database import db
from libs import helper from libs import helper
from models.dataset import Embedding from models.dataset import Embedding
...@@ -49,6 +50,7 @@ class CacheEmbedding(Embeddings): ...@@ -49,6 +50,7 @@ class CacheEmbedding(Embeddings):
text_embeddings.extend(embedding_results) text_embeddings.extend(embedding_results)
return text_embeddings return text_embeddings
@handle_openai_exceptions
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> List[float]:
"""Embed query text.""" """Embed query text."""
# use doc embedding cache or store if not exists # use doc embedding cache or store if not exists
......
...@@ -23,6 +23,10 @@ class LLMGenerator: ...@@ -23,6 +23,10 @@ class LLMGenerator:
@classmethod @classmethod
def generate_conversation_name(cls, tenant_id: str, query, answer): def generate_conversation_name(cls, tenant_id: str, query, answer):
prompt = CONVERSATION_TITLE_PROMPT prompt = CONVERSATION_TITLE_PROMPT
if len(query) > 2000:
query = query[:300] + "...[TRUNCATED]..." + query[-300:]
prompt = prompt.format(query=query) prompt = prompt.format(query=query)
llm: StreamableOpenAI = LLMBuilder.to_llm( llm: StreamableOpenAI = LLMBuilder.to_llm(
tenant_id=tenant_id, tenant_id=tenant_id,
...@@ -52,7 +56,17 @@ class LLMGenerator: ...@@ -52,7 +56,17 @@ class LLMGenerator:
if not message.answer: if not message.answer:
continue continue
message_qa_text = "Human:" + message.query + "\nAI:" + message.answer + "\n" if len(message.query) > 2000:
query = message.query[:300] + "...[TRUNCATED]..." + message.query[-300:]
else:
query = message.query
if len(message.answer) > 2000:
answer = message.answer[:300] + "...[TRUNCATED]..." + message.answer[-300:]
else:
answer = message.answer
message_qa_text = "\n\nHuman:" + query + "\n\nAssistant:" + answer
if rest_tokens - TokenCalculator.get_num_tokens(model, context + message_qa_text) > 0: if rest_tokens - TokenCalculator.get_num_tokens(model, context + message_qa_text) > 0:
context += message_qa_text context += message_qa_text
......
...@@ -17,7 +17,7 @@ class IndexBuilder: ...@@ -17,7 +17,7 @@ class IndexBuilder:
model_credentials = LLMBuilder.get_model_credentials( model_credentials = LLMBuilder.get_model_credentials(
tenant_id=dataset.tenant_id, tenant_id=dataset.tenant_id,
model_provider=LLMBuilder.get_default_provider(dataset.tenant_id), model_provider=LLMBuilder.get_default_provider(dataset.tenant_id, 'text-embedding-ada-002'),
model_name='text-embedding-ada-002' model_name='text-embedding-ada-002'
) )
......
...@@ -40,6 +40,9 @@ class ProviderTokenNotInitError(Exception): ...@@ -40,6 +40,9 @@ class ProviderTokenNotInitError(Exception):
""" """
description = "Provider Token Not Init" description = "Provider Token Not Init"
def __init__(self, *args, **kwargs):
self.description = args[0] if args else self.description
class QuotaExceededError(Exception): class QuotaExceededError(Exception):
""" """
......
...@@ -8,9 +8,10 @@ from core.llm.provider.base import BaseProvider ...@@ -8,9 +8,10 @@ from core.llm.provider.base import BaseProvider
from core.llm.provider.llm_provider_service import LLMProviderService from core.llm.provider.llm_provider_service import LLMProviderService
from core.llm.streamable_azure_chat_open_ai import StreamableAzureChatOpenAI from core.llm.streamable_azure_chat_open_ai import StreamableAzureChatOpenAI
from core.llm.streamable_azure_open_ai import StreamableAzureOpenAI from core.llm.streamable_azure_open_ai import StreamableAzureOpenAI
from core.llm.streamable_chat_anthropic import StreamableChatAnthropic
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
from core.llm.streamable_open_ai import StreamableOpenAI from core.llm.streamable_open_ai import StreamableOpenAI
from models.provider import ProviderType from models.provider import ProviderType, ProviderName
class LLMBuilder: class LLMBuilder:
...@@ -32,43 +33,43 @@ class LLMBuilder: ...@@ -32,43 +33,43 @@ class LLMBuilder:
@classmethod @classmethod
def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI]: def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
provider = cls.get_default_provider(tenant_id) provider = cls.get_default_provider(tenant_id, model_name)
model_credentials = cls.get_model_credentials(tenant_id, provider, model_name) model_credentials = cls.get_model_credentials(tenant_id, provider, model_name)
llm_cls = None
mode = cls.get_mode_by_model(model_name) mode = cls.get_mode_by_model(model_name)
if mode == 'chat': if mode == 'chat':
if provider == 'openai': if provider == ProviderName.OPENAI.value:
llm_cls = StreamableChatOpenAI llm_cls = StreamableChatOpenAI
else: elif provider == ProviderName.AZURE_OPENAI.value:
llm_cls = StreamableAzureChatOpenAI llm_cls = StreamableAzureChatOpenAI
elif provider == ProviderName.ANTHROPIC.value:
llm_cls = StreamableChatAnthropic
elif mode == 'completion': elif mode == 'completion':
if provider == 'openai': if provider == ProviderName.OPENAI.value:
llm_cls = StreamableOpenAI llm_cls = StreamableOpenAI
else: elif provider == ProviderName.AZURE_OPENAI.value:
llm_cls = StreamableAzureOpenAI llm_cls = StreamableAzureOpenAI
else:
raise ValueError(f"model name {model_name} is not supported.")
if not llm_cls:
raise ValueError(f"model name {model_name} is not supported.")
model_kwargs = { model_kwargs = {
'model_name': model_name,
'temperature': kwargs.get('temperature', 0),
'max_tokens': kwargs.get('max_tokens', 256),
'top_p': kwargs.get('top_p', 1), 'top_p': kwargs.get('top_p', 1),
'frequency_penalty': kwargs.get('frequency_penalty', 0), 'frequency_penalty': kwargs.get('frequency_penalty', 0),
'presence_penalty': kwargs.get('presence_penalty', 0), 'presence_penalty': kwargs.get('presence_penalty', 0),
'callbacks': kwargs.get('callbacks', None),
'streaming': kwargs.get('streaming', False),
} }
model_extras_kwargs = model_kwargs if mode == 'completion' else {'model_kwargs': model_kwargs} model_kwargs.update(model_credentials)
model_kwargs = llm_cls.get_kwargs_from_model_params(model_kwargs)
return llm_cls( return llm_cls(**model_kwargs)
model_name=model_name,
temperature=kwargs.get('temperature', 0),
max_tokens=kwargs.get('max_tokens', 256),
**model_extras_kwargs,
callbacks=kwargs.get('callbacks', None),
streaming=kwargs.get('streaming', False),
# request_timeout=None
**model_credentials
)
@classmethod @classmethod
def to_llm_from_model(cls, tenant_id: str, model: dict, streaming: bool = False, def to_llm_from_model(cls, tenant_id: str, model: dict, streaming: bool = False,
...@@ -118,14 +119,29 @@ class LLMBuilder: ...@@ -118,14 +119,29 @@ class LLMBuilder:
return provider_service.get_credentials(model_name) return provider_service.get_credentials(model_name)
@classmethod @classmethod
def get_default_provider(cls, tenant_id: str) -> str: def get_default_provider(cls, tenant_id: str, model_name: str) -> str:
provider = BaseProvider.get_valid_provider(tenant_id) provider_name = llm_constant.models[model_name]
if not provider:
raise ProviderTokenNotInitError() if provider_name == 'openai':
# get the default provider (openai / azure_openai) for the tenant
if provider.provider_type == ProviderType.SYSTEM.value: openai_provider = BaseProvider.get_valid_provider(tenant_id, ProviderName.OPENAI.value)
provider_name = 'openai' azure_openai_provider = BaseProvider.get_valid_provider(tenant_id, ProviderName.AZURE_OPENAI.value)
else:
provider_name = provider.provider_name provider = None
if openai_provider:
provider = openai_provider
elif azure_openai_provider:
provider = azure_openai_provider
if not provider:
raise ProviderTokenNotInitError(
f"No valid {provider_name} model provider credentials found. "
f"Please go to Settings -> Model Provider to complete your provider credentials."
)
if provider.provider_type == ProviderType.SYSTEM.value:
provider_name = 'openai'
else:
provider_name = provider.provider_name
return provider_name return provider_name
from typing import Optional import json
import logging
from typing import Optional, Union
import anthropic
from langchain.chat_models import ChatAnthropic
from langchain.schema import HumanMessage
from core import hosted_llm_credentials
from core.llm.error import ProviderTokenNotInitError
from core.llm.provider.base import BaseProvider from core.llm.provider.base import BaseProvider
from models.provider import ProviderName from core.llm.provider.errors import ValidateFailedError
from models.provider import ProviderName, ProviderType
class AnthropicProvider(BaseProvider): class AnthropicProvider(BaseProvider):
def get_models(self, model_id: Optional[str] = None) -> list[dict]: def get_models(self, model_id: Optional[str] = None) -> list[dict]:
credentials = self.get_credentials(model_id) return [
# todo {
return [] 'id': 'claude-instant-1',
'name': 'claude-instant-1',
},
{
'id': 'claude-2',
'name': 'claude-2',
},
]
def get_credentials(self, model_id: Optional[str] = None) -> dict: def get_credentials(self, model_id: Optional[str] = None) -> dict:
return self.get_provider_api_key(model_id=model_id)
def get_provider_name(self):
return ProviderName.ANTHROPIC
def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
""" """
Returns the API credentials for Azure OpenAI as a dictionary, for the given tenant_id. Returns the provider configs.
The dictionary contains keys: azure_api_type, azure_api_version, azure_api_base, and azure_api_key.
""" """
return { try:
'anthropic_api_key': self.get_provider_api_key(model_id=model_id) config = self.get_provider_api_key(only_custom=only_custom)
} except:
config = {
'anthropic_api_key': ''
}
def get_provider_name(self): if obfuscated:
return ProviderName.ANTHROPIC if not config.get('anthropic_api_key'):
\ No newline at end of file config = {
'anthropic_api_key': ''
}
config['anthropic_api_key'] = self.obfuscated_token(config.get('anthropic_api_key'))
return config
return config
def get_encrypted_token(self, config: Union[dict | str]):
"""
Returns the encrypted token.
"""
return json.dumps({
'anthropic_api_key': self.encrypt_token(config['anthropic_api_key'])
})
def get_decrypted_token(self, token: str):
"""
Returns the decrypted token.
"""
config = json.loads(token)
config['anthropic_api_key'] = self.decrypt_token(config['anthropic_api_key'])
return config
def get_token_type(self):
return dict
def config_validate(self, config: Union[dict | str]):
"""
Validates the given config.
"""
# check OpenAI / Azure OpenAI credential is valid
openai_provider = BaseProvider.get_valid_provider(self.tenant_id, ProviderName.OPENAI.value)
azure_openai_provider = BaseProvider.get_valid_provider(self.tenant_id, ProviderName.AZURE_OPENAI.value)
provider = None
if openai_provider:
provider = openai_provider
elif azure_openai_provider:
provider = azure_openai_provider
if not provider:
raise ValidateFailedError(f"OpenAI or Azure OpenAI provider must be configured first.")
if provider.provider_type == ProviderType.SYSTEM.value:
quota_used = provider.quota_used if provider.quota_used is not None else 0
quota_limit = provider.quota_limit if provider.quota_limit is not None else 0
if quota_used >= quota_limit:
raise ValidateFailedError(f"Your quota for Dify Hosted OpenAI has been exhausted, "
f"please configure OpenAI or Azure OpenAI provider first.")
try:
if not isinstance(config, dict):
raise ValueError('Config must be a object.')
if 'anthropic_api_key' not in config:
raise ValueError('anthropic_api_key must be provided.')
chat_llm = ChatAnthropic(
model='claude-instant-1',
anthropic_api_key=config['anthropic_api_key'],
max_tokens_to_sample=10,
temperature=0,
default_request_timeout=60
)
messages = [
HumanMessage(
content="ping"
)
]
chat_llm(messages)
except anthropic.APIConnectionError as ex:
raise ValidateFailedError(f"Anthropic: Connection error, cause: {ex.__cause__}")
except (anthropic.APIStatusError, anthropic.RateLimitError) as ex:
raise ValidateFailedError(f"Anthropic: Error code: {ex.status_code} - "
f"{ex.body['error']['type']}: {ex.body['error']['message']}")
except Exception as ex:
logging.exception('Anthropic config validation failed')
raise ex
def get_hosted_credentials(self) -> Union[str | dict]:
if not hosted_llm_credentials.anthropic or not hosted_llm_credentials.anthropic.api_key:
raise ProviderTokenNotInitError(
f"No valid {self.get_provider_name().value} model provider credentials found. "
f"Please go to Settings -> Model Provider to complete your provider credentials."
)
return {'anthropic_api_key': hosted_llm_credentials.anthropic.api_key}
...@@ -52,12 +52,12 @@ class AzureProvider(BaseProvider): ...@@ -52,12 +52,12 @@ class AzureProvider(BaseProvider):
def get_provider_name(self): def get_provider_name(self):
return ProviderName.AZURE_OPENAI return ProviderName.AZURE_OPENAI
def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]: def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
""" """
Returns the provider configs. Returns the provider configs.
""" """
try: try:
config = self.get_provider_api_key() config = self.get_provider_api_key(only_custom=only_custom)
except: except:
config = { config = {
'openai_api_type': 'azure', 'openai_api_type': 'azure',
...@@ -81,7 +81,6 @@ class AzureProvider(BaseProvider): ...@@ -81,7 +81,6 @@ class AzureProvider(BaseProvider):
return config return config
def get_token_type(self): def get_token_type(self):
# TODO: change to dict when implemented
return dict return dict
def config_validate(self, config: Union[dict | str]): def config_validate(self, config: Union[dict | str]):
......
...@@ -2,7 +2,7 @@ import base64 ...@@ -2,7 +2,7 @@ import base64
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional, Union from typing import Optional, Union
from core import hosted_llm_credentials from core.constant import llm_constant
from core.llm.error import QuotaExceededError, ModelCurrentlyNotSupportError, ProviderTokenNotInitError from core.llm.error import QuotaExceededError, ModelCurrentlyNotSupportError, ProviderTokenNotInitError
from extensions.ext_database import db from extensions.ext_database import db
from libs import rsa from libs import rsa
...@@ -14,15 +14,18 @@ class BaseProvider(ABC): ...@@ -14,15 +14,18 @@ class BaseProvider(ABC):
def __init__(self, tenant_id: str): def __init__(self, tenant_id: str):
self.tenant_id = tenant_id self.tenant_id = tenant_id
def get_provider_api_key(self, model_id: Optional[str] = None, prefer_custom: bool = True) -> Union[str | dict]: def get_provider_api_key(self, model_id: Optional[str] = None, only_custom: bool = False) -> Union[str | dict]:
""" """
Returns the decrypted API key for the given tenant_id and provider_name. Returns the decrypted API key for the given tenant_id and provider_name.
If the provider is of type SYSTEM and the quota is exceeded, raises a QuotaExceededError. If the provider is of type SYSTEM and the quota is exceeded, raises a QuotaExceededError.
If the provider is not found or not valid, raises a ProviderTokenNotInitError. If the provider is not found or not valid, raises a ProviderTokenNotInitError.
""" """
provider = self.get_provider(prefer_custom) provider = self.get_provider(only_custom)
if not provider: if not provider:
raise ProviderTokenNotInitError() raise ProviderTokenNotInitError(
f"No valid {llm_constant.models[model_id]} model provider credentials found. "
f"Please go to Settings -> Model Provider to complete your provider credentials."
)
if provider.provider_type == ProviderType.SYSTEM.value: if provider.provider_type == ProviderType.SYSTEM.value:
quota_used = provider.quota_used if provider.quota_used is not None else 0 quota_used = provider.quota_used if provider.quota_used is not None else 0
...@@ -38,18 +41,19 @@ class BaseProvider(ABC): ...@@ -38,18 +41,19 @@ class BaseProvider(ABC):
else: else:
return self.get_decrypted_token(provider.encrypted_config) return self.get_decrypted_token(provider.encrypted_config)
def get_provider(self, prefer_custom: bool) -> Optional[Provider]: def get_provider(self, only_custom: bool = False) -> Optional[Provider]:
""" """
Returns the Provider instance for the given tenant_id and provider_name. Returns the Provider instance for the given tenant_id and provider_name.
If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag. If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag.
""" """
return BaseProvider.get_valid_provider(self.tenant_id, self.get_provider_name().value, prefer_custom) return BaseProvider.get_valid_provider(self.tenant_id, self.get_provider_name().value, only_custom)
@classmethod @classmethod
def get_valid_provider(cls, tenant_id: str, provider_name: str = None, prefer_custom: bool = False) -> Optional[Provider]: def get_valid_provider(cls, tenant_id: str, provider_name: str = None, only_custom: bool = False) -> Optional[
Provider]:
""" """
Returns the Provider instance for the given tenant_id and provider_name. Returns the Provider instance for the given tenant_id and provider_name.
If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag. If both CUSTOM and System providers exist.
""" """
query = db.session.query(Provider).filter( query = db.session.query(Provider).filter(
Provider.tenant_id == tenant_id Provider.tenant_id == tenant_id
...@@ -58,39 +62,31 @@ class BaseProvider(ABC): ...@@ -58,39 +62,31 @@ class BaseProvider(ABC):
if provider_name: if provider_name:
query = query.filter(Provider.provider_name == provider_name) query = query.filter(Provider.provider_name == provider_name)
providers = query.order_by(Provider.provider_type.desc() if prefer_custom else Provider.provider_type).all() if only_custom:
query = query.filter(Provider.provider_type == ProviderType.CUSTOM.value)
custom_provider = None providers = query.order_by(Provider.provider_type.asc()).all()
system_provider = None
for provider in providers: for provider in providers:
if provider.provider_type == ProviderType.CUSTOM.value and provider.is_valid and provider.encrypted_config: if provider.provider_type == ProviderType.CUSTOM.value and provider.is_valid and provider.encrypted_config:
custom_provider = provider return provider
elif provider.provider_type == ProviderType.SYSTEM.value and provider.is_valid: elif provider.provider_type == ProviderType.SYSTEM.value and provider.is_valid:
system_provider = provider return provider
if custom_provider:
return custom_provider
elif system_provider:
return system_provider
else:
return None
def get_hosted_credentials(self) -> str: return None
if self.get_provider_name() != ProviderName.OPENAI:
raise ProviderTokenNotInitError()
if not hosted_llm_credentials.openai or not hosted_llm_credentials.openai.api_key: def get_hosted_credentials(self) -> Union[str | dict]:
raise ProviderTokenNotInitError() raise ProviderTokenNotInitError(
f"No valid {self.get_provider_name().value} model provider credentials found. "
return hosted_llm_credentials.openai.api_key f"Please go to Settings -> Model Provider to complete your provider credentials."
)
def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]: def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
""" """
Returns the provider configs. Returns the provider configs.
""" """
try: try:
config = self.get_provider_api_key() config = self.get_provider_api_key(only_custom=only_custom)
except: except:
config = '' config = ''
......
...@@ -31,11 +31,11 @@ class LLMProviderService: ...@@ -31,11 +31,11 @@ class LLMProviderService:
def get_credentials(self, model_id: Optional[str] = None) -> dict: def get_credentials(self, model_id: Optional[str] = None) -> dict:
return self.provider.get_credentials(model_id) return self.provider.get_credentials(model_id)
def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]: def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
return self.provider.get_provider_configs(obfuscated) return self.provider.get_provider_configs(obfuscated=obfuscated, only_custom=only_custom)
def get_provider_db_record(self, prefer_custom: bool = False) -> Optional[Provider]: def get_provider_db_record(self) -> Optional[Provider]:
return self.provider.get_provider(prefer_custom) return self.provider.get_provider()
def config_validate(self, config: Union[dict | str]): def config_validate(self, config: Union[dict | str]):
""" """
......
...@@ -4,6 +4,8 @@ from typing import Optional, Union ...@@ -4,6 +4,8 @@ from typing import Optional, Union
import openai import openai
from openai.error import AuthenticationError, OpenAIError from openai.error import AuthenticationError, OpenAIError
from core import hosted_llm_credentials
from core.llm.error import ProviderTokenNotInitError
from core.llm.moderation import Moderation from core.llm.moderation import Moderation
from core.llm.provider.base import BaseProvider from core.llm.provider.base import BaseProvider
from core.llm.provider.errors import ValidateFailedError from core.llm.provider.errors import ValidateFailedError
...@@ -42,3 +44,12 @@ class OpenAIProvider(BaseProvider): ...@@ -42,3 +44,12 @@ class OpenAIProvider(BaseProvider):
except Exception as ex: except Exception as ex:
logging.exception('OpenAI config validation failed') logging.exception('OpenAI config validation failed')
raise ex raise ex
def get_hosted_credentials(self) -> Union[str | dict]:
if not hosted_llm_credentials.openai or not hosted_llm_credentials.openai.api_key:
raise ProviderTokenNotInitError(
f"No valid {self.get_provider_name().value} model provider credentials found. "
f"Please go to Settings -> Model Provider to complete your provider credentials."
)
return hosted_llm_credentials.openai.api_key
from langchain.callbacks.manager import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun, Callbacks from langchain.callbacks.manager import Callbacks
from langchain.schema import BaseMessage, ChatResult, LLMResult from langchain.schema import BaseMessage, LLMResult
from langchain.chat_models import AzureChatOpenAI from langchain.chat_models import AzureChatOpenAI
from typing import Optional, List, Dict, Any from typing import Optional, List, Dict, Any
from pydantic import root_validator from pydantic import root_validator
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
class StreamableAzureChatOpenAI(AzureChatOpenAI): class StreamableAzureChatOpenAI(AzureChatOpenAI):
...@@ -46,30 +46,7 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI): ...@@ -46,30 +46,7 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI):
"organization": self.openai_organization if self.openai_organization else None, "organization": self.openai_organization if self.openai_organization else None,
} }
def get_messages_tokens(self, messages: List[BaseMessage]) -> int: @handle_openai_exceptions
"""Get the number of tokens in a list of messages.
Args:
messages: The messages to count the tokens of.
Returns:
The number of tokens in the messages.
"""
tokens_per_message = 5
tokens_per_request = 3
message_tokens = tokens_per_request
message_strs = ''
for message in messages:
message_strs += message.content
message_tokens += tokens_per_message
# calc once
message_tokens += self.get_num_tokens(message_strs)
return message_tokens
@handle_llm_exceptions
def generate( def generate(
self, self,
messages: List[List[BaseMessage]], messages: List[List[BaseMessage]],
...@@ -79,12 +56,18 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI): ...@@ -79,12 +56,18 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI):
) -> LLMResult: ) -> LLMResult:
return super().generate(messages, stop, callbacks, **kwargs) return super().generate(messages, stop, callbacks, **kwargs)
@handle_llm_exceptions_async @classmethod
async def agenerate( def get_kwargs_from_model_params(cls, params: dict):
self, model_kwargs = {
messages: List[List[BaseMessage]], 'top_p': params.get('top_p', 1),
stop: Optional[List[str]] = None, 'frequency_penalty': params.get('frequency_penalty', 0),
callbacks: Callbacks = None, 'presence_penalty': params.get('presence_penalty', 0),
**kwargs: Any, }
) -> LLMResult:
return await super().agenerate(messages, stop, callbacks, **kwargs) del params['top_p']
del params['frequency_penalty']
del params['presence_penalty']
params['model_kwargs'] = model_kwargs
return params
...@@ -5,7 +5,7 @@ from typing import Optional, List, Dict, Mapping, Any ...@@ -5,7 +5,7 @@ from typing import Optional, List, Dict, Mapping, Any
from pydantic import root_validator from pydantic import root_validator
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
class StreamableAzureOpenAI(AzureOpenAI): class StreamableAzureOpenAI(AzureOpenAI):
...@@ -50,7 +50,7 @@ class StreamableAzureOpenAI(AzureOpenAI): ...@@ -50,7 +50,7 @@ class StreamableAzureOpenAI(AzureOpenAI):
"organization": self.openai_organization if self.openai_organization else None, "organization": self.openai_organization if self.openai_organization else None,
}} }}
@handle_llm_exceptions @handle_openai_exceptions
def generate( def generate(
self, self,
prompts: List[str], prompts: List[str],
...@@ -60,12 +60,6 @@ class StreamableAzureOpenAI(AzureOpenAI): ...@@ -60,12 +60,6 @@ class StreamableAzureOpenAI(AzureOpenAI):
) -> LLMResult: ) -> LLMResult:
return super().generate(prompts, stop, callbacks, **kwargs) return super().generate(prompts, stop, callbacks, **kwargs)
@handle_llm_exceptions_async @classmethod
async def agenerate( def get_kwargs_from_model_params(cls, params: dict):
self, return params
prompts: List[str],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
return await super().agenerate(prompts, stop, callbacks, **kwargs)
from typing import List, Optional, Any, Dict
from langchain.callbacks.manager import Callbacks
from langchain.chat_models import ChatAnthropic
from langchain.schema import BaseMessage, LLMResult
from core.llm.wrappers.anthropic_wrapper import handle_anthropic_exceptions
class StreamableChatAnthropic(ChatAnthropic):
"""
Wrapper around Anthropic's large language model.
"""
@handle_anthropic_exceptions
def generate(
self,
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> LLMResult:
return super().generate(messages, stop, callbacks, tags=tags, metadata=metadata, **kwargs)
@classmethod
def get_kwargs_from_model_params(cls, params: dict):
params['model'] = params.get('model_name')
del params['model_name']
params['max_tokens_to_sample'] = params.get('max_tokens')
del params['max_tokens']
del params['frequency_penalty']
del params['presence_penalty']
return params
...@@ -7,7 +7,7 @@ from typing import Optional, List, Dict, Any ...@@ -7,7 +7,7 @@ from typing import Optional, List, Dict, Any
from pydantic import root_validator from pydantic import root_validator
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
class StreamableChatOpenAI(ChatOpenAI): class StreamableChatOpenAI(ChatOpenAI):
...@@ -48,30 +48,7 @@ class StreamableChatOpenAI(ChatOpenAI): ...@@ -48,30 +48,7 @@ class StreamableChatOpenAI(ChatOpenAI):
"organization": self.openai_organization if self.openai_organization else None, "organization": self.openai_organization if self.openai_organization else None,
} }
def get_messages_tokens(self, messages: List[BaseMessage]) -> int: @handle_openai_exceptions
"""Get the number of tokens in a list of messages.
Args:
messages: The messages to count the tokens of.
Returns:
The number of tokens in the messages.
"""
tokens_per_message = 5
tokens_per_request = 3
message_tokens = tokens_per_request
message_strs = ''
for message in messages:
message_strs += message.content
message_tokens += tokens_per_message
# calc once
message_tokens += self.get_num_tokens(message_strs)
return message_tokens
@handle_llm_exceptions
def generate( def generate(
self, self,
messages: List[List[BaseMessage]], messages: List[List[BaseMessage]],
...@@ -81,12 +58,18 @@ class StreamableChatOpenAI(ChatOpenAI): ...@@ -81,12 +58,18 @@ class StreamableChatOpenAI(ChatOpenAI):
) -> LLMResult: ) -> LLMResult:
return super().generate(messages, stop, callbacks, **kwargs) return super().generate(messages, stop, callbacks, **kwargs)
@handle_llm_exceptions_async @classmethod
async def agenerate( def get_kwargs_from_model_params(cls, params: dict):
self, model_kwargs = {
messages: List[List[BaseMessage]], 'top_p': params.get('top_p', 1),
stop: Optional[List[str]] = None, 'frequency_penalty': params.get('frequency_penalty', 0),
callbacks: Callbacks = None, 'presence_penalty': params.get('presence_penalty', 0),
**kwargs: Any, }
) -> LLMResult:
return await super().agenerate(messages, stop, callbacks, **kwargs) del params['top_p']
del params['frequency_penalty']
del params['presence_penalty']
params['model_kwargs'] = model_kwargs
return params
...@@ -6,7 +6,7 @@ from typing import Optional, List, Dict, Any, Mapping ...@@ -6,7 +6,7 @@ from typing import Optional, List, Dict, Any, Mapping
from langchain import OpenAI from langchain import OpenAI
from pydantic import root_validator from pydantic import root_validator
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
class StreamableOpenAI(OpenAI): class StreamableOpenAI(OpenAI):
...@@ -49,7 +49,7 @@ class StreamableOpenAI(OpenAI): ...@@ -49,7 +49,7 @@ class StreamableOpenAI(OpenAI):
"organization": self.openai_organization if self.openai_organization else None, "organization": self.openai_organization if self.openai_organization else None,
}} }}
@handle_llm_exceptions @handle_openai_exceptions
def generate( def generate(
self, self,
prompts: List[str], prompts: List[str],
...@@ -59,12 +59,6 @@ class StreamableOpenAI(OpenAI): ...@@ -59,12 +59,6 @@ class StreamableOpenAI(OpenAI):
) -> LLMResult: ) -> LLMResult:
return super().generate(prompts, stop, callbacks, **kwargs) return super().generate(prompts, stop, callbacks, **kwargs)
@handle_llm_exceptions_async @classmethod
async def agenerate( def get_kwargs_from_model_params(cls, params: dict):
self, return params
prompts: List[str],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
return await super().agenerate(prompts, stop, callbacks, **kwargs)
import openai import openai
from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
from models.provider import ProviderName from models.provider import ProviderName
from core.llm.error_handle_wraps import handle_llm_exceptions
from core.llm.provider.base import BaseProvider from core.llm.provider.base import BaseProvider
...@@ -13,7 +14,7 @@ class Whisper: ...@@ -13,7 +14,7 @@ class Whisper:
self.client = openai.Audio self.client = openai.Audio
self.credentials = provider.get_credentials() self.credentials = provider.get_credentials()
@handle_llm_exceptions @handle_openai_exceptions
def transcribe(self, file): def transcribe(self, file):
return self.client.transcribe( return self.client.transcribe(
model='whisper-1', model='whisper-1',
......
import logging
from functools import wraps
import anthropic
from core.llm.error import LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, \
LLMBadRequestError
def handle_anthropic_exceptions(func):
@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except anthropic.APIConnectionError as e:
logging.exception("Failed to connect to Anthropic API.")
raise LLMAPIConnectionError(f"Anthropic: The server could not be reached, cause: {e.__cause__}")
except anthropic.RateLimitError:
raise LLMRateLimitError("Anthropic: A 429 status code was received; we should back off a bit.")
except anthropic.AuthenticationError as e:
raise LLMAuthorizationError(f"Anthropic: {e.message}")
except anthropic.BadRequestError as e:
raise LLMBadRequestError(f"Anthropic: {e.message}")
except anthropic.APIStatusError as e:
raise LLMAPIUnavailableError(f"Anthropic: code: {e.status_code}, cause: {e.message}")
return wrapper
...@@ -7,7 +7,7 @@ from core.llm.error import LLMAPIConnectionError, LLMAPIUnavailableError, LLMRat ...@@ -7,7 +7,7 @@ from core.llm.error import LLMAPIConnectionError, LLMAPIUnavailableError, LLMRat
LLMBadRequestError LLMBadRequestError
def handle_llm_exceptions(func): def handle_openai_exceptions(func):
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
try: try:
...@@ -29,27 +29,3 @@ def handle_llm_exceptions(func): ...@@ -29,27 +29,3 @@ def handle_llm_exceptions(func):
raise LLMBadRequestError(e.__class__.__name__ + ":" + str(e)) raise LLMBadRequestError(e.__class__.__name__ + ":" + str(e))
return wrapper return wrapper
def handle_llm_exceptions_async(func):
@wraps(func)
async def wrapper(*args, **kwargs):
try:
return await func(*args, **kwargs)
except openai.error.InvalidRequestError as e:
logging.exception("Invalid request to OpenAI API.")
raise LLMBadRequestError(str(e))
except openai.error.APIConnectionError as e:
logging.exception("Failed to connect to OpenAI API.")
raise LLMAPIConnectionError(e.__class__.__name__ + ":" + str(e))
except (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout) as e:
logging.exception("OpenAI service unavailable.")
raise LLMAPIUnavailableError(e.__class__.__name__ + ":" + str(e))
except openai.error.RateLimitError as e:
raise LLMRateLimitError(str(e))
except openai.error.AuthenticationError as e:
raise LLMAuthorizationError(str(e))
except openai.error.OpenAIError as e:
raise LLMBadRequestError(e.__class__.__name__ + ":" + str(e))
return wrapper
from typing import Any, List, Dict, Union from typing import Any, List, Dict, Union
from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import get_buffer_string, BaseMessage, HumanMessage, AIMessage from langchain.schema import get_buffer_string, BaseMessage, HumanMessage, AIMessage, BaseLanguageModel
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
from core.llm.streamable_open_ai import StreamableOpenAI from core.llm.streamable_open_ai import StreamableOpenAI
...@@ -12,8 +12,8 @@ from models.model import Conversation, Message ...@@ -12,8 +12,8 @@ from models.model import Conversation, Message
class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory): class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
conversation: Conversation conversation: Conversation
human_prefix: str = "Human" human_prefix: str = "Human"
ai_prefix: str = "AI" ai_prefix: str = "Assistant"
llm: Union[StreamableChatOpenAI | StreamableOpenAI] llm: BaseLanguageModel
memory_key: str = "chat_history" memory_key: str = "chat_history"
max_token_limit: int = 2000 max_token_limit: int = 2000
message_limit: int = 10 message_limit: int = 10
...@@ -38,12 +38,12 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory): ...@@ -38,12 +38,12 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
return chat_messages return chat_messages
# prune the chat message if it exceeds the max token limit # prune the chat message if it exceeds the max token limit
curr_buffer_length = self.llm.get_messages_tokens(chat_messages) curr_buffer_length = self.llm.get_num_tokens_from_messages(chat_messages)
if curr_buffer_length > self.max_token_limit: if curr_buffer_length > self.max_token_limit:
pruned_memory = [] pruned_memory = []
while curr_buffer_length > self.max_token_limit and chat_messages: while curr_buffer_length > self.max_token_limit and chat_messages:
pruned_memory.append(chat_messages.pop(0)) pruned_memory.append(chat_messages.pop(0))
curr_buffer_length = self.llm.get_messages_tokens(chat_messages) curr_buffer_length = self.llm.get_num_tokens_from_messages(chat_messages)
return chat_messages return chat_messages
......
...@@ -30,7 +30,7 @@ class DatasetTool(BaseTool): ...@@ -30,7 +30,7 @@ class DatasetTool(BaseTool):
else: else:
model_credentials = LLMBuilder.get_model_credentials( model_credentials = LLMBuilder.get_model_credentials(
tenant_id=self.dataset.tenant_id, tenant_id=self.dataset.tenant_id,
model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id), model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id, 'text-embedding-ada-002'),
model_name='text-embedding-ada-002' model_name='text-embedding-ada-002'
) )
...@@ -60,7 +60,7 @@ class DatasetTool(BaseTool): ...@@ -60,7 +60,7 @@ class DatasetTool(BaseTool):
async def _arun(self, tool_input: str) -> str: async def _arun(self, tool_input: str) -> str:
model_credentials = LLMBuilder.get_model_credentials( model_credentials = LLMBuilder.get_model_credentials(
tenant_id=self.dataset.tenant_id, tenant_id=self.dataset.tenant_id,
model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id), model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id, 'text-embedding-ada-002'),
model_name='text-embedding-ada-002' model_name='text-embedding-ada-002'
) )
......
from flask import current_app
from events.tenant_event import tenant_was_updated from events.tenant_event import tenant_was_updated
from models.provider import ProviderName
from services.provider_service import ProviderService from services.provider_service import ProviderService
...@@ -6,4 +9,16 @@ from services.provider_service import ProviderService ...@@ -6,4 +9,16 @@ from services.provider_service import ProviderService
def handle(sender, **kwargs): def handle(sender, **kwargs):
tenant = sender tenant = sender
if tenant.status == 'normal': if tenant.status == 'normal':
ProviderService.create_system_provider(tenant) ProviderService.create_system_provider(
tenant,
ProviderName.OPENAI.value,
current_app.config['OPENAI_HOSTED_QUOTA_LIMIT'],
True
)
ProviderService.create_system_provider(
tenant,
ProviderName.ANTHROPIC.value,
current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'],
True
)
from flask import current_app
from events.tenant_event import tenant_was_created from events.tenant_event import tenant_was_created
from models.provider import ProviderName
from services.provider_service import ProviderService from services.provider_service import ProviderService
...@@ -6,4 +9,16 @@ from services.provider_service import ProviderService ...@@ -6,4 +9,16 @@ from services.provider_service import ProviderService
def handle(sender, **kwargs): def handle(sender, **kwargs):
tenant = sender tenant = sender
if tenant.status == 'normal': if tenant.status == 'normal':
ProviderService.create_system_provider(tenant) ProviderService.create_system_provider(
tenant,
ProviderName.OPENAI.value,
current_app.config['OPENAI_HOSTED_QUOTA_LIMIT'],
True
)
ProviderService.create_system_provider(
tenant,
ProviderName.ANTHROPIC.value,
current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'],
True
)
...@@ -10,7 +10,7 @@ flask-session2==1.3.1 ...@@ -10,7 +10,7 @@ flask-session2==1.3.1
flask-cors==3.0.10 flask-cors==3.0.10
gunicorn~=20.1.0 gunicorn~=20.1.0
gevent~=22.10.2 gevent~=22.10.2
langchain==0.0.209 langchain==0.0.230
openai~=0.27.5 openai~=0.27.5
psycopg2-binary~=2.9.6 psycopg2-binary~=2.9.6
pycryptodome==3.17 pycryptodome==3.17
...@@ -35,3 +35,4 @@ docx2txt==0.8 ...@@ -35,3 +35,4 @@ docx2txt==0.8
pypdfium2==4.16.0 pypdfium2==4.16.0
resend~=0.5.1 resend~=0.5.1
pyjwt~=2.6.0 pyjwt~=2.6.0
anthropic~=0.3.4
...@@ -6,6 +6,30 @@ from models.account import Account ...@@ -6,6 +6,30 @@ from models.account import Account
from services.dataset_service import DatasetService from services.dataset_service import DatasetService
from core.llm.llm_builder import LLMBuilder from core.llm.llm_builder import LLMBuilder
MODEL_PROVIDERS = [
'openai',
'anthropic',
]
MODELS_BY_APP_MODE = {
'chat': [
'claude-instant-1',
'claude-2',
'gpt-4',
'gpt-4-32k',
'gpt-3.5-turbo',
'gpt-3.5-turbo-16k',
],
'completion': [
'claude-instant-1',
'claude-2',
'gpt-4',
'gpt-4-32k',
'gpt-3.5-turbo',
'gpt-3.5-turbo-16k',
'text-davinci-003',
]
}
class AppModelConfigService: class AppModelConfigService:
@staticmethod @staticmethod
...@@ -125,7 +149,7 @@ class AppModelConfigService: ...@@ -125,7 +149,7 @@ class AppModelConfigService:
if not isinstance(config["speech_to_text"]["enabled"], bool): if not isinstance(config["speech_to_text"]["enabled"], bool):
raise ValueError("enabled in speech_to_text must be of boolean type") raise ValueError("enabled in speech_to_text must be of boolean type")
provider_name = LLMBuilder.get_default_provider(account.current_tenant_id) provider_name = LLMBuilder.get_default_provider(account.current_tenant_id, 'whisper-1')
if config["speech_to_text"]["enabled"] and provider_name != 'openai': if config["speech_to_text"]["enabled"] and provider_name != 'openai':
raise ValueError("provider not support speech to text") raise ValueError("provider not support speech to text")
...@@ -153,14 +177,14 @@ class AppModelConfigService: ...@@ -153,14 +177,14 @@ class AppModelConfigService:
raise ValueError("model must be of object type") raise ValueError("model must be of object type")
# model.provider # model.provider
if 'provider' not in config["model"] or config["model"]["provider"] != "openai": if 'provider' not in config["model"] or config["model"]["provider"] not in MODEL_PROVIDERS:
raise ValueError("model.provider must be 'openai'") raise ValueError(f"model.provider is required and must be in {str(MODEL_PROVIDERS)}")
# model.name # model.name
if 'name' not in config["model"]: if 'name' not in config["model"]:
raise ValueError("model.name is required") raise ValueError("model.name is required")
if config["model"]["name"] not in llm_constant.models_by_mode[mode]: if config["model"]["name"] not in MODELS_BY_APP_MODE[mode]:
raise ValueError("model.name must be in the specified model list") raise ValueError("model.name must be in the specified model list")
# model.completion_params # model.completion_params
......
...@@ -27,7 +27,7 @@ class AudioService: ...@@ -27,7 +27,7 @@ class AudioService:
message = f"Audio size larger than {FILE_SIZE} mb" message = f"Audio size larger than {FILE_SIZE} mb"
raise AudioTooLargeServiceError(message) raise AudioTooLargeServiceError(message)
provider_name = LLMBuilder.get_default_provider(tenant_id) provider_name = LLMBuilder.get_default_provider(tenant_id, 'whisper-1')
if provider_name != ProviderName.OPENAI.value: if provider_name != ProviderName.OPENAI.value:
raise ProviderNotSupportSpeechToTextServiceError() raise ProviderNotSupportSpeechToTextServiceError()
...@@ -37,8 +37,3 @@ class AudioService: ...@@ -37,8 +37,3 @@ class AudioService:
buffer.name = 'temp.mp3' buffer.name = 'temp.mp3'
return Whisper(provider_service.provider).transcribe(buffer) return Whisper(provider_service.provider).transcribe(buffer)
\ No newline at end of file
...@@ -4,6 +4,9 @@ import datetime ...@@ -4,6 +4,9 @@ import datetime
import time import time
import random import random
from typing import Optional, List from typing import Optional, List
from flask import current_app
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from flask_login import current_user from flask_login import current_user
...@@ -374,6 +377,12 @@ class DocumentService: ...@@ -374,6 +377,12 @@ class DocumentService:
def save_document_with_dataset_id(dataset: Dataset, document_data: dict, def save_document_with_dataset_id(dataset: Dataset, document_data: dict,
account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None, account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None,
created_from: str = 'web'): created_from: str = 'web'):
# check document limit
if current_app.config['EDITION'] == 'CLOUD':
documents_count = DocumentService.get_tenant_documents_count()
tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
if documents_count > tenant_document_count:
raise ValueError(f"over document limit {tenant_document_count}.")
# if dataset is empty, update dataset data_source_type # if dataset is empty, update dataset data_source_type
if not dataset.data_source_type: if not dataset.data_source_type:
dataset.data_source_type = document_data["data_source"]["type"] dataset.data_source_type = document_data["data_source"]["type"]
...@@ -521,6 +530,14 @@ class DocumentService: ...@@ -521,6 +530,14 @@ class DocumentService:
) )
return document return document
@staticmethod
def get_tenant_documents_count():
documents_count = Document.query.filter(Document.completed_at.isnot(None),
Document.enabled == True,
Document.archived == False,
Document.tenant_id == current_user.current_tenant_id).count()
return documents_count
@staticmethod @staticmethod
def update_document_with_dataset_id(dataset: Dataset, document_data: dict, def update_document_with_dataset_id(dataset: Dataset, document_data: dict,
account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None, account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None,
...@@ -616,6 +633,12 @@ class DocumentService: ...@@ -616,6 +633,12 @@ class DocumentService:
@staticmethod @staticmethod
def save_document_without_dataset_id(tenant_id: str, document_data: dict, account: Account): def save_document_without_dataset_id(tenant_id: str, document_data: dict, account: Account):
# check document limit
if current_app.config['EDITION'] == 'CLOUD':
documents_count = DocumentService.get_tenant_documents_count()
tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
if documents_count > tenant_document_count:
raise ValueError(f"over document limit {tenant_document_count}.")
# save dataset # save dataset
dataset = Dataset( dataset = Dataset(
tenant_id=tenant_id, tenant_id=tenant_id,
......
...@@ -31,7 +31,7 @@ class HitTestingService: ...@@ -31,7 +31,7 @@ class HitTestingService:
model_credentials = LLMBuilder.get_model_credentials( model_credentials = LLMBuilder.get_model_credentials(
tenant_id=dataset.tenant_id, tenant_id=dataset.tenant_id,
model_provider=LLMBuilder.get_default_provider(dataset.tenant_id), model_provider=LLMBuilder.get_default_provider(dataset.tenant_id, 'text-embedding-ada-002'),
model_name='text-embedding-ada-002' model_name='text-embedding-ada-002'
) )
......
...@@ -10,50 +10,40 @@ from models.provider import * ...@@ -10,50 +10,40 @@ from models.provider import *
class ProviderService: class ProviderService:
@staticmethod @staticmethod
def init_supported_provider(tenant, edition): def init_supported_provider(tenant):
"""Initialize the model provider, check whether the supported provider has a record""" """Initialize the model provider, check whether the supported provider has a record"""
providers = Provider.query.filter_by(tenant_id=tenant.id).all() need_init_provider_names = [ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value, ProviderName.ANTHROPIC.value]
openai_provider_exists = False providers = db.session.query(Provider).filter(
azure_openai_provider_exists = False Provider.tenant_id == tenant.id,
Provider.provider_type == ProviderType.CUSTOM.value,
# TODO: The cloud version needs to construct the data of the SYSTEM type Provider.provider_name.in_(need_init_provider_names)
).all()
exists_provider_names = []
for provider in providers: for provider in providers:
if provider.provider_name == ProviderName.OPENAI.value and provider.provider_type == ProviderType.CUSTOM.value: exists_provider_names.append(provider.provider_name)
openai_provider_exists = True
if provider.provider_name == ProviderName.AZURE_OPENAI.value and provider.provider_type == ProviderType.CUSTOM.value:
azure_openai_provider_exists = True
# Initialize the model provider, check whether the supported provider has a record not_exists_provider_names = list(set(need_init_provider_names) - set(exists_provider_names))
# Create default providers if they don't exist if not_exists_provider_names:
if not openai_provider_exists: # Initialize the model provider, check whether the supported provider has a record
openai_provider = Provider( for provider_name in not_exists_provider_names:
tenant_id=tenant.id, provider = Provider(
provider_name=ProviderName.OPENAI.value, tenant_id=tenant.id,
provider_type=ProviderType.CUSTOM.value, provider_name=provider_name,
is_valid=False provider_type=ProviderType.CUSTOM.value,
) is_valid=False
db.session.add(openai_provider) )
db.session.add(provider)
if not azure_openai_provider_exists:
azure_openai_provider = Provider(
tenant_id=tenant.id,
provider_name=ProviderName.AZURE_OPENAI.value,
provider_type=ProviderType.CUSTOM.value,
is_valid=False
)
db.session.add(azure_openai_provider)
if not openai_provider_exists or not azure_openai_provider_exists:
db.session.commit() db.session.commit()
@staticmethod @staticmethod
def get_obfuscated_api_key(tenant, provider_name: ProviderName): def get_obfuscated_api_key(tenant, provider_name: ProviderName, only_custom: bool = False):
llm_provider_service = LLMProviderService(tenant.id, provider_name.value) llm_provider_service = LLMProviderService(tenant.id, provider_name.value)
return llm_provider_service.get_provider_configs(obfuscated=True) return llm_provider_service.get_provider_configs(obfuscated=True, only_custom=only_custom)
@staticmethod @staticmethod
def get_token_type(tenant, provider_name: ProviderName): def get_token_type(tenant, provider_name: ProviderName):
...@@ -73,7 +63,7 @@ class ProviderService: ...@@ -73,7 +63,7 @@ class ProviderService:
return llm_provider_service.get_encrypted_token(configs) return llm_provider_service.get_encrypted_token(configs)
@staticmethod @staticmethod
def create_system_provider(tenant: Tenant, provider_name: str = ProviderName.OPENAI.value, def create_system_provider(tenant: Tenant, provider_name: str = ProviderName.OPENAI.value, quota_limit: int = 200,
is_valid: bool = True): is_valid: bool = True):
if current_app.config['EDITION'] != 'CLOUD': if current_app.config['EDITION'] != 'CLOUD':
return return
...@@ -90,7 +80,7 @@ class ProviderService: ...@@ -90,7 +80,7 @@ class ProviderService:
provider_name=provider_name, provider_name=provider_name,
provider_type=ProviderType.SYSTEM.value, provider_type=ProviderType.SYSTEM.value,
quota_type=ProviderQuotaType.TRIAL.value, quota_type=ProviderQuotaType.TRIAL.value,
quota_limit=200, quota_limit=quota_limit,
encrypted_config='', encrypted_config='',
is_valid=is_valid, is_valid=is_valid,
) )
......
from extensions.ext_database import db from extensions.ext_database import db
from models.account import Tenant from models.account import Tenant
from models.provider import Provider, ProviderType from models.provider import Provider, ProviderType, ProviderName
class WorkspaceService: class WorkspaceService:
...@@ -33,7 +33,7 @@ class WorkspaceService: ...@@ -33,7 +33,7 @@ class WorkspaceService:
if provider.is_valid and provider.encrypted_config: if provider.is_valid and provider.encrypted_config:
custom_provider = provider custom_provider = provider
elif provider.provider_type == ProviderType.SYSTEM.value: elif provider.provider_type == ProviderType.SYSTEM.value:
if provider.is_valid: if provider.provider_name == ProviderName.OPENAI.value and provider.is_valid:
system_provider = provider system_provider = provider
if system_provider and not custom_provider: if system_provider and not custom_provider:
......
...@@ -2,7 +2,7 @@ version: '3.1' ...@@ -2,7 +2,7 @@ version: '3.1'
services: services:
# API service # API service
api: api:
image: langgenius/dify-api:0.3.8 image: langgenius/dify-api:0.3.9
restart: always restart: always
environment: environment:
# Startup mode, 'api' starts the API server. # Startup mode, 'api' starts the API server.
...@@ -124,7 +124,7 @@ services: ...@@ -124,7 +124,7 @@ services:
# worker service # worker service
# The Celery worker for processing the queue. # The Celery worker for processing the queue.
worker: worker:
image: langgenius/dify-api:0.3.8 image: langgenius/dify-api:0.3.9
restart: always restart: always
environment: environment:
# Startup mode, 'worker' starts the Celery worker for processing the queue. # Startup mode, 'worker' starts the Celery worker for processing the queue.
...@@ -176,7 +176,7 @@ services: ...@@ -176,7 +176,7 @@ services:
# Frontend web application. # Frontend web application.
web: web:
image: langgenius/dify-web:0.3.8 image: langgenius/dify-web:0.3.9
restart: always restart: always
environment: environment:
EDITION: SELF_HOSTED EDITION: SELF_HOSTED
......
...@@ -65,6 +65,7 @@ export type IChatProps = { ...@@ -65,6 +65,7 @@ export type IChatProps = {
isShowSuggestion?: boolean isShowSuggestion?: boolean
suggestionList?: string[] suggestionList?: string[]
isShowSpeechToText?: boolean isShowSpeechToText?: boolean
answerIconClassName?: string
} }
export type MessageMore = { export type MessageMore = {
...@@ -174,10 +175,11 @@ type IAnswerProps = { ...@@ -174,10 +175,11 @@ type IAnswerProps = {
onSubmitAnnotation?: SubmitAnnotationFunc onSubmitAnnotation?: SubmitAnnotationFunc
displayScene: DisplayScene displayScene: DisplayScene
isResponsing?: boolean isResponsing?: boolean
answerIconClassName?: string
} }
// The component needs to maintain its own state to control whether to display input component // The component needs to maintain its own state to control whether to display input component
const Answer: FC<IAnswerProps> = ({ item, feedbackDisabled = false, isHideFeedbackEdit = false, onFeedback, onSubmitAnnotation, displayScene = 'web', isResponsing }) => { const Answer: FC<IAnswerProps> = ({ item, feedbackDisabled = false, isHideFeedbackEdit = false, onFeedback, onSubmitAnnotation, displayScene = 'web', isResponsing, answerIconClassName }) => {
const { id, content, more, feedback, adminFeedback, annotation: initAnnotation } = item const { id, content, more, feedback, adminFeedback, annotation: initAnnotation } = item
const [showEdit, setShowEdit] = useState(false) const [showEdit, setShowEdit] = useState(false)
const [loading, setLoading] = useState(false) const [loading, setLoading] = useState(false)
...@@ -292,7 +294,7 @@ const Answer: FC<IAnswerProps> = ({ item, feedbackDisabled = false, isHideFeedba ...@@ -292,7 +294,7 @@ const Answer: FC<IAnswerProps> = ({ item, feedbackDisabled = false, isHideFeedba
return ( return (
<div key={id}> <div key={id}>
<div className='flex items-start'> <div className='flex items-start'>
<div className={`${s.answerIcon} w-10 h-10 shrink-0`}> <div className={`${s.answerIcon} ${answerIconClassName} w-10 h-10 shrink-0`}>
{isResponsing {isResponsing
&& <div className={s.typeingIcon}> && <div className={s.typeingIcon}>
<LoadingAnim type='avatar' /> <LoadingAnim type='avatar' />
...@@ -428,6 +430,7 @@ const Chat: FC<IChatProps> = ({ ...@@ -428,6 +430,7 @@ const Chat: FC<IChatProps> = ({
isShowSuggestion, isShowSuggestion,
suggestionList, suggestionList,
isShowSpeechToText, isShowSpeechToText,
answerIconClassName,
}) => { }) => {
const { t } = useTranslation() const { t } = useTranslation()
const { notify } = useContext(ToastContext) const { notify } = useContext(ToastContext)
...@@ -520,6 +523,7 @@ const Chat: FC<IChatProps> = ({ ...@@ -520,6 +523,7 @@ const Chat: FC<IChatProps> = ({
onSubmitAnnotation={onSubmitAnnotation} onSubmitAnnotation={onSubmitAnnotation}
displayScene={displayScene ?? 'web'} displayScene={displayScene ?? 'web'}
isResponsing={isResponsing && isLast} isResponsing={isResponsing && isLast}
answerIconClassName={answerIconClassName}
/> />
} }
return <Question key={item.id} id={item.id} content={item.content} more={item.more} useCurrentUserAvatar={useCurrentUserAvatar} /> return <Question key={item.id} id={item.id} content={item.content} more={item.more} useCurrentUserAvatar={useCurrentUserAvatar} />
......
...@@ -372,7 +372,7 @@ const Debug: FC<IDebug> = ({ ...@@ -372,7 +372,7 @@ const Debug: FC<IDebug> = ({
{/* Chat */} {/* Chat */}
{mode === AppType.chat && ( {mode === AppType.chat && (
<div className="mt-[34px] h-full flex flex-col"> <div className="mt-[34px] h-full flex flex-col">
<div className={cn(doShowSuggestion ? 'pb-[140px]' : (isResponsing ? 'pb-[113px]' : 'pb-[66px]'), 'relative mt-1.5 grow h-[200px] overflow-hidden')}> <div className={cn(doShowSuggestion ? 'pb-[140px]' : (isResponsing ? 'pb-[113px]' : 'pb-[76px]'), 'relative mt-1.5 grow h-[200px] overflow-hidden')}>
<div className="h-full overflow-y-auto overflow-x-hidden" ref={chatListDomRef}> <div className="h-full overflow-y-auto overflow-x-hidden" ref={chatListDomRef}>
<Chat <Chat
chatList={chatList} chatList={chatList}
......
...@@ -16,6 +16,7 @@ import ConfigModel from '@/app/components/app/configuration/config-model' ...@@ -16,6 +16,7 @@ import ConfigModel from '@/app/components/app/configuration/config-model'
import Config from '@/app/components/app/configuration/config' import Config from '@/app/components/app/configuration/config'
import Debug from '@/app/components/app/configuration/debug' import Debug from '@/app/components/app/configuration/debug'
import Confirm from '@/app/components/base/confirm' import Confirm from '@/app/components/base/confirm'
import { ProviderType } from '@/types/app'
import type { AppDetailResponse } from '@/models/app' import type { AppDetailResponse } from '@/models/app'
import { ToastContext } from '@/app/components/base/toast' import { ToastContext } from '@/app/components/base/toast'
import { fetchTenantInfo } from '@/service/common' import { fetchTenantInfo } from '@/service/common'
...@@ -67,7 +68,7 @@ const Configuration: FC = () => { ...@@ -67,7 +68,7 @@ const Configuration: FC = () => {
frequency_penalty: 1, // -2-2 frequency_penalty: 1, // -2-2
}) })
const [modelConfig, doSetModelConfig] = useState<ModelConfig>({ const [modelConfig, doSetModelConfig] = useState<ModelConfig>({
provider: 'openai', provider: ProviderType.openai,
model_id: 'gpt-3.5-turbo', model_id: 'gpt-3.5-turbo',
configs: { configs: {
prompt_template: '', prompt_template: '',
...@@ -84,8 +85,9 @@ const Configuration: FC = () => { ...@@ -84,8 +85,9 @@ const Configuration: FC = () => {
doSetModelConfig(newModelConfig) doSetModelConfig(newModelConfig)
} }
const setModelId = (modelId: string) => { const setModelId = (modelId: string, provider: ProviderType) => {
const newModelConfig = produce(modelConfig, (draft: any) => { const newModelConfig = produce(modelConfig, (draft: any) => {
draft.provider = provider
draft.model_id = modelId draft.model_id = modelId
}) })
setModelConfig(newModelConfig) setModelConfig(newModelConfig)
......
...@@ -184,7 +184,11 @@ const GenerationItem: FC<IGenerationItemProps> = ({ ...@@ -184,7 +184,11 @@ const GenerationItem: FC<IGenerationItemProps> = ({
{taskId} {taskId}
</div>) </div>)
} }
<Markdown content={content} /> <div className='flex'>
<div className='grow w-0'>
<Markdown content={content} />
</div>
</div>
{messageId && ( {messageId && (
<div className='flex items-center justify-between mt-3'> <div className='flex items-center justify-between mt-3'>
<div className='flex items-center'> <div className='flex items-center'>
......
...@@ -19,6 +19,7 @@ const AutoHeightTextarea = forwardRef( ...@@ -19,6 +19,7 @@ const AutoHeightTextarea = forwardRef(
{ value, onChange, placeholder, className, minHeight = 36, maxHeight = 96, autoFocus, controlFocus, onKeyDown, onKeyUp }: IProps, { value, onChange, placeholder, className, minHeight = 36, maxHeight = 96, autoFocus, controlFocus, onKeyDown, onKeyUp }: IProps,
outerRef: any, outerRef: any,
) => { ) => {
// eslint-disable-next-line react-hooks/rules-of-hooks
const ref = outerRef || useRef<HTMLTextAreaElement>(null) const ref = outerRef || useRef<HTMLTextAreaElement>(null)
const doFocus = () => { const doFocus = () => {
...@@ -54,13 +55,20 @@ const AutoHeightTextarea = forwardRef( ...@@ -54,13 +55,20 @@ const AutoHeightTextarea = forwardRef(
return ( return (
<div className='relative'> <div className='relative'>
<div className={cn(className, 'invisible whitespace-pre-wrap break-all overflow-y-auto')} style={{ minHeight, maxHeight }}> <div className={cn(className, 'invisible whitespace-pre-wrap break-all overflow-y-auto')} style={{
minHeight,
maxHeight,
paddingRight: (value && value.trim().length > 10000) ? 140 : 130,
}}>
{!value ? placeholder : value.replace(/\n$/, '\n ')} {!value ? placeholder : value.replace(/\n$/, '\n ')}
</div> </div>
<textarea <textarea
ref={ref} ref={ref}
autoFocus={autoFocus} autoFocus={autoFocus}
className={cn(className, 'absolute inset-0 resize-none overflow-hidden')} className={cn(className, 'absolute inset-0 resize-none overflow-auto')}
style={{
paddingRight: (value && value.trim().length > 10000) ? 140 : 130,
}}
placeholder={placeholder} placeholder={placeholder}
onChange={onChange} onChange={onChange}
onKeyDown={onKeyDown} onKeyDown={onKeyDown}
......
'use client'
class StorageMock {
data: Record<string, string>
constructor() {
this.data = {} as Record<string, string>
}
setItem(name: string, value: string) {
this.data[name] = value
}
getItem(name: string) {
return this.data[name] || null
}
removeItem(name: string) {
delete this.data[name]
}
clear() {
this.data = {}
}
}
let localStorage, sessionStorage
try {
localStorage = globalThis.localStorage
sessionStorage = globalThis.sessionStorage
}
catch (e) {
localStorage = new StorageMock()
sessionStorage = new StorageMock()
}
Object.defineProperty(globalThis, 'localStorage', {
value: localStorage,
})
Object.defineProperty(globalThis, 'sessionStorage', {
value: sessionStorage,
})
const BrowerInitor = ({
children,
}: { children: React.ReactElement }) => {
return children
}
export default BrowerInitor
...@@ -13,6 +13,7 @@ import { useAppContext } from '@/context/app-context' ...@@ -13,6 +13,7 @@ import { useAppContext } from '@/context/app-context'
import { ToastContext } from '@/app/components/base/toast' import { ToastContext } from '@/app/components/base/toast'
import AppIcon from '@/app/components/base/app-icon' import AppIcon from '@/app/components/base/app-icon'
import Avatar from '@/app/components/base/avatar' import Avatar from '@/app/components/base/avatar'
import { IS_CE_EDITION } from '@/config'
const titleClassName = ` const titleClassName = `
text-sm font-medium text-gray-900 text-sm font-medium text-gray-900
...@@ -136,11 +137,13 @@ export default function AccountPage() { ...@@ -136,11 +137,13 @@ export default function AccountPage() {
<div className={titleClassName}>{t('common.account.email')}</div> <div className={titleClassName}>{t('common.account.email')}</div>
<div className={classNames(inputClassName, 'cursor-pointer')}>{userProfile.email}</div> <div className={classNames(inputClassName, 'cursor-pointer')}>{userProfile.email}</div>
</div> </div>
<div className='mb-8'> {IS_CE_EDITION && (
<div className='mb-1 text-sm font-medium text-gray-900'>{t('common.account.password')}</div> <div className='mb-8'>
<div className='mb-2 text-xs text-gray-500'>{t('common.account.passwordTip')}</div> <div className='mb-1 text-sm font-medium text-gray-900'>{t('common.account.password')}</div>
<Button className='font-medium !text-gray-700 !px-3 !py-[7px] !text-[13px]' onClick={() => setEditPasswordModalVisible(true)}>{userProfile.is_password_set ? t('common.account.resetPassword') : t('common.account.setPassword')}</Button> <div className='mb-2 text-xs text-gray-500'>{t('common.account.passwordTip')}</div>
</div> <Button className='font-medium !text-gray-700 !px-3 !py-[7px] !text-[13px]' onClick={() => setEditPasswordModalVisible(true)}>{userProfile.is_password_set ? t('common.account.resetPassword') : t('common.account.setPassword')}</Button>
</div>
)}
{!!apps.length && ( {!!apps.length && (
<> <>
<div className='mb-6 border-[0.5px] border-gray-100' /> <div className='mb-6 border-[0.5px] border-gray-100' />
......
...@@ -5,6 +5,8 @@ import InvitationLink from './invitation-link' ...@@ -5,6 +5,8 @@ import InvitationLink from './invitation-link'
import s from './index.module.css' import s from './index.module.css'
import Modal from '@/app/components/base/modal' import Modal from '@/app/components/base/modal'
import Button from '@/app/components/base/button' import Button from '@/app/components/base/button'
import { IS_CE_EDITION } from '@/config'
type IInvitedModalProps = { type IInvitedModalProps = {
invitationLink: string invitationLink: string
onCancel: () => void onCancel: () => void
...@@ -29,11 +31,18 @@ const InvitedModal = ({ ...@@ -29,11 +31,18 @@ const InvitedModal = ({
<XMarkIcon className='w-4 h-4 cursor-pointer' onClick={onCancel} /> <XMarkIcon className='w-4 h-4 cursor-pointer' onClick={onCancel} />
</div> </div>
<div className='mb-1 text-xl font-semibold text-gray-900'>{t('common.members.invitationSent')}</div> <div className='mb-1 text-xl font-semibold text-gray-900'>{t('common.members.invitationSent')}</div>
<div className='mb-5 text-sm text-gray-500'>{t('common.members.invitationSentTip')}</div> {!IS_CE_EDITION && (
<div className='mb-9'> <div className='mb-10 text-sm text-gray-500'>{t('common.members.invitationSentTip')}</div>
<div className='py-2 text-sm font-Medium text-gray-900'>{t('common.members.invitationLink')}</div> )}
<InvitationLink value={invitationLink} /> {IS_CE_EDITION && (
</div> <>
<div className='mb-5 text-sm text-gray-500'>{t('common.members.invitationSentTip')}</div>
<div className='mb-9'>
<div className='py-2 text-sm font-Medium text-gray-900'>{t('common.members.invitationLink')}</div>
<InvitationLink value={invitationLink} />
</div>
</>
)}
<div className='flex justify-end'> <div className='flex justify-end'>
<Button <Button
className='w-[96px] text-sm font-medium' className='w-[96px] text-sm font-medium'
......
.icon {
width: 24px;
height: 24px;
margin-right: 12px;
background: url(../../../assets/anthropic.svg) center center no-repeat;
background-size: contain;
}
.bar {
background: linear-gradient(90deg, rgba(41, 112, 255, 0.9) 0%, rgba(21, 94, 239, 0.9) 100%);
}
.bar-error {
background: linear-gradient(90deg, rgba(240, 68, 56, 0.72) 0%, rgba(217, 45, 32, 0.9) 100%);
}
.bar-item {
width: 10%;
border-right: 1px solid rgba(255, 255, 255, 0.5);
}
.bar-item:last-of-type {
border-right: 0;
}
\ No newline at end of file
import { useTranslation } from 'react-i18next'
import cn from 'classnames'
import s from './index.module.css'
import type { ProviderHosted } from '@/models/common'
type AnthropicHostedProviderProps = {
provider: ProviderHosted
}
const AnthropicHostedProvider = ({
provider,
}: AnthropicHostedProviderProps) => {
const { t } = useTranslation()
const exhausted = provider.quota_used > provider.quota_limit
return (
<div className={`
border-[0.5px] border-gray-200 rounded-xl
${exhausted ? 'bg-[#FFFBFA]' : 'bg-gray-50'}
`}>
<div className='pt-4 px-4 pb-3'>
<div className='flex items-center mb-3'>
<div className={s.icon} />
<div className='grow text-sm font-medium text-gray-800'>
{t('common.provider.anthropicHosted.anthropicHosted')}
</div>
<div className={`
px-2 h-[22px] flex items-center rounded-md border
text-xs font-semibold
${exhausted ? 'border-[#D92D20] text-[#D92D20]' : 'border-primary-600 text-primary-600'}
`}>
{exhausted ? t('common.provider.anthropicHosted.exhausted') : t('common.provider.anthropicHosted.onTrial')}
</div>
</div>
<div className='text-[13px] text-gray-500'>{t('common.provider.anthropicHosted.desc')}</div>
</div>
<div className='flex items-center h-[42px] px-4 border-t-[0.5px] border-t-[rgba(0, 0, 0, 0.05)]'>
<div className='text-[13px] text-gray-700'>{t('common.provider.anthropicHosted.callTimes')}</div>
<div className='relative grow h-2 flex bg-gray-200 rounded-md mx-2 overflow-hidden'>
<div
className={cn(s.bar, exhausted && s['bar-error'], 'absolute top-0 left-0 right-0 bottom-0')}
style={{ width: `${(provider.quota_used / provider.quota_limit * 100).toFixed(2)}%` }}
/>
{Array(10).fill(0).map((i, k) => (
<div key={k} className={s['bar-item']} />
))}
</div>
<div className={`
text-[13px] font-medium ${exhausted ? 'text-[#D92D20]' : 'text-gray-700'}
`}>{provider.quota_used}/{provider.quota_limit}</div>
</div>
{
exhausted && (
<div className='
px-4 py-3 leading-[18px] flex items-center text-[13px] text-gray-700 font-medium
bg-[#FFFAEB] border-t border-t-[rgba(0, 0, 0, 0.05)] rounded-b-xl
'>
{t('common.provider.anthropicHosted.usedUp')}
</div>
)
}
</div>
)
}
export default AnthropicHostedProvider
import { useEffect, useState } from 'react'
import { useTranslation } from 'react-i18next'
import Link from 'next/link'
import { ArrowTopRightOnSquareIcon } from '@heroicons/react/24/outline'
import ProviderInput from '../provider-input'
import type { ValidatedStatusState } from '../provider-input/useValidateToken'
import useValidateToken, { ValidatedStatus } from '../provider-input/useValidateToken'
import {
ValidatedErrorIcon,
ValidatedErrorOnOpenaiTip,
ValidatedSuccessIcon,
ValidatingTip,
} from '../provider-input/Validate'
import type { Provider, ProviderAnthropicToken } from '@/models/common'
type AnthropicProviderProps = {
provider: Provider
onValidatedStatus: (status?: ValidatedStatusState) => void
onTokenChange: (token: ProviderAnthropicToken) => void
}
const AnthropicProvider = ({
provider,
onValidatedStatus,
onTokenChange,
}: AnthropicProviderProps) => {
const { t } = useTranslation()
const [token, setToken] = useState<ProviderAnthropicToken>((provider.token as ProviderAnthropicToken) || { anthropic_api_key: '' })
const [validating, validatedStatus, setValidatedStatus, validate] = useValidateToken(provider.provider_name)
const handleFocus = () => {
if (token.anthropic_api_key === (provider.token as ProviderAnthropicToken).anthropic_api_key) {
setToken({ anthropic_api_key: '' })
onTokenChange({ anthropic_api_key: '' })
setValidatedStatus({})
}
}
const handleChange = (v: string) => {
const apiKey = { anthropic_api_key: v }
setToken(apiKey)
onTokenChange(apiKey)
validate(apiKey, {
beforeValidating: () => {
if (!v) {
setValidatedStatus({})
return false
}
return true
},
})
}
useEffect(() => {
if (typeof onValidatedStatus === 'function')
onValidatedStatus(validatedStatus)
}, [validatedStatus])
const getValidatedIcon = () => {
if (validatedStatus?.status === ValidatedStatus.Error || validatedStatus.status === ValidatedStatus.Exceed)
return <ValidatedErrorIcon />
if (validatedStatus.status === ValidatedStatus.Success)
return <ValidatedSuccessIcon />
}
const getValidatedTip = () => {
if (validating)
return <ValidatingTip />
if (validatedStatus?.status === ValidatedStatus.Error)
return <ValidatedErrorOnOpenaiTip errorMessage={validatedStatus.message ?? ''} />
}
return (
<div className='px-4 pt-3 pb-4'>
<ProviderInput
value={token.anthropic_api_key}
name={t('common.provider.apiKey')}
placeholder={t('common.provider.enterYourKey')}
onChange={handleChange}
onFocus={handleFocus}
validatedIcon={getValidatedIcon()}
validatedTip={getValidatedTip()}
/>
<Link className="inline-flex items-center mt-3 text-xs font-normal cursor-pointer text-primary-600 w-fit" href="https://docs.anthropic.com/claude/reference/getting-started-with-the-api" target={'_blank'}>
{t('common.provider.anthropic.keyFrom')}
<ArrowTopRightOnSquareIcon className='w-3 h-3 ml-1 text-primary-600' aria-hidden="true" />
</Link>
</div>
)
}
export default AnthropicProvider
...@@ -5,6 +5,7 @@ import { useTranslation } from 'react-i18next' ...@@ -5,6 +5,7 @@ import { useTranslation } from 'react-i18next'
import Link from 'next/link' import Link from 'next/link'
import ProviderItem from './provider-item' import ProviderItem from './provider-item'
import OpenaiHostedProvider from './openai-hosted-provider' import OpenaiHostedProvider from './openai-hosted-provider'
import AnthropicHostedProvider from './anthropic-hosted-provider'
import type { ProviderHosted } from '@/models/common' import type { ProviderHosted } from '@/models/common'
import { fetchProviders } from '@/service/common' import { fetchProviders } from '@/service/common'
import { IS_CE_EDITION } from '@/config' import { IS_CE_EDITION } from '@/config'
...@@ -18,6 +19,10 @@ const providersMap: { [k: string]: any } = { ...@@ -18,6 +19,10 @@ const providersMap: { [k: string]: any } = {
icon: 'azure', icon: 'azure',
name: 'Azure OpenAI Service', name: 'Azure OpenAI Service',
}, },
'anthropic-custom': {
icon: 'anthropic',
name: 'Anthropic',
},
} }
// const providersList = [ // const providersList = [
...@@ -65,6 +70,8 @@ const ProviderPage = () => { ...@@ -65,6 +70,8 @@ const ProviderPage = () => {
} }
}) })
const providerHosted = data?.filter(provider => provider.provider_name === 'openai' && provider.provider_type === 'system')?.[0] const providerHosted = data?.filter(provider => provider.provider_name === 'openai' && provider.provider_type === 'system')?.[0]
const anthropicHosted = data?.filter(provider => provider.provider_name === 'anthropic' && provider.provider_type === 'system')?.[0]
const providedOpenaiProvider = data?.find(provider => provider.is_enabled && (provider.provider_name === 'openai' || provider.provider_name === 'azure_openai'))
return ( return (
<div className='pb-7'> <div className='pb-7'>
...@@ -78,6 +85,16 @@ const ProviderPage = () => { ...@@ -78,6 +85,16 @@ const ProviderPage = () => {
</> </>
) )
} }
{
anthropicHosted && !IS_CE_EDITION && (
<>
<div>
<AnthropicHostedProvider provider={anthropicHosted as ProviderHosted} />
</div>
<div className='my-5 w-full h-0 border-[0.5px] border-gray-100' />
</>
)
}
<div> <div>
{ {
providers?.map(providerItem => ( providers?.map(providerItem => (
...@@ -89,11 +106,12 @@ const ProviderPage = () => { ...@@ -89,11 +106,12 @@ const ProviderPage = () => {
activeId={activeProviderId} activeId={activeProviderId}
onActive={aid => setActiveProviderId(aid)} onActive={aid => setActiveProviderId(aid)}
onSave={() => mutate()} onSave={() => mutate()}
providedOpenaiProvider={providedOpenaiProvider}
/> />
)) ))
} }
</div> </div>
<div className='absolute bottom-0 w-full h-[42px] flex items-center bg-white text-xs text-gray-500'> <div className='fixed bottom-0 w-[472px] h-[42px] flex items-center bg-white text-xs text-gray-500'>
<LockClosedIcon className='w-3 h-3 mr-1' /> <LockClosedIcon className='w-3 h-3 mr-1' />
{t('common.provider.encrypted.front')} {t('common.provider.encrypted.front')}
<Link <Link
......
...@@ -5,14 +5,20 @@ import { useTranslation } from 'react-i18next' ...@@ -5,14 +5,20 @@ import { useTranslation } from 'react-i18next'
import Indicator from '../../../indicator' import Indicator from '../../../indicator'
import OpenaiProvider from '../openai-provider' import OpenaiProvider from '../openai-provider'
import AzureProvider from '../azure-provider' import AzureProvider from '../azure-provider'
import AnthropicProvider from '../anthropic-provider'
import type { ValidatedStatusState } from '../provider-input/useValidateToken' import type { ValidatedStatusState } from '../provider-input/useValidateToken'
import { ValidatedStatus } from '../provider-input/useValidateToken' import { ValidatedStatus } from '../provider-input/useValidateToken'
import s from './index.module.css' import s from './index.module.css'
import type { Provider, ProviderAzureToken } from '@/models/common' import type { Provider, ProviderAnthropicToken, ProviderAzureToken } from '@/models/common'
import { ProviderName } from '@/models/common' import { ProviderName } from '@/models/common'
import { updateProviderAIKey } from '@/service/common' import { updateProviderAIKey } from '@/service/common'
import { ToastContext } from '@/app/components/base/toast' import { ToastContext } from '@/app/components/base/toast'
import Tooltip from '@/app/components/base/tooltip'
const providerNameMap: Record<string, string> = {
openai: 'OpenAI',
azure_openai: 'Azure OpenAI Service',
}
type IProviderItemProps = { type IProviderItemProps = {
icon: string icon: string
name: string name: string
...@@ -20,6 +26,7 @@ type IProviderItemProps = { ...@@ -20,6 +26,7 @@ type IProviderItemProps = {
activeId: string activeId: string
onActive: (v: string) => void onActive: (v: string) => void
onSave: () => void onSave: () => void
providedOpenaiProvider?: Provider
} }
const ProviderItem = ({ const ProviderItem = ({
activeId, activeId,
...@@ -28,15 +35,18 @@ const ProviderItem = ({ ...@@ -28,15 +35,18 @@ const ProviderItem = ({
provider, provider,
onActive, onActive,
onSave, onSave,
providedOpenaiProvider,
}: IProviderItemProps) => { }: IProviderItemProps) => {
const { t } = useTranslation() const { t } = useTranslation()
const [validatedStatus, setValidatedStatus] = useState<ValidatedStatusState>() const [validatedStatus, setValidatedStatus] = useState<ValidatedStatusState>()
const [loading, setLoading] = useState(false) const [loading, setLoading] = useState(false)
const { notify } = useContext(ToastContext) const { notify } = useContext(ToastContext)
const [token, setToken] = useState<ProviderAzureToken | string>( const [token, setToken] = useState<ProviderAzureToken | string | ProviderAnthropicToken>(
provider.provider_name === 'azure_openai' provider.provider_name === 'azure_openai'
? { openai_api_base: '', openai_api_key: '' } ? { openai_api_base: '', openai_api_key: '' }
: '', : provider.provider_name === 'anthropic'
? { anthropic_api_key: '' }
: '',
) )
const id = `${provider.provider_name}-${provider.provider_type}` const id = `${provider.provider_name}-${provider.provider_type}`
const isOpen = id === activeId const isOpen = id === activeId
...@@ -54,6 +64,8 @@ const ProviderItem = ({ ...@@ -54,6 +64,8 @@ const ProviderItem = ({
} }
if (provider.provider_name === ProviderName.OPENAI) if (provider.provider_name === ProviderName.OPENAI)
return provider.token return provider.token
if (provider.provider_name === ProviderName.ANTHROPIC)
return provider.token?.anthropic_api_key
} }
const handleUpdateToken = async () => { const handleUpdateToken = async () => {
if (loading) if (loading)
...@@ -81,7 +93,7 @@ const ProviderItem = ({ ...@@ -81,7 +93,7 @@ const ProviderItem = ({
<div className={cn(s[`icon-${icon}`], 'mr-3 w-6 h-6 rounded-md')} /> <div className={cn(s[`icon-${icon}`], 'mr-3 w-6 h-6 rounded-md')} />
<div className='grow text-sm font-medium text-gray-800'>{name}</div> <div className='grow text-sm font-medium text-gray-800'>{name}</div>
{ {
providerTokenHasSetted() && !comingSoon && !isOpen && ( providerTokenHasSetted() && !comingSoon && !isOpen && provider.provider_name !== ProviderName.ANTHROPIC && (
<div className='flex items-center mr-4'> <div className='flex items-center mr-4'>
{!isValid && <div className='text-xs text-[#D92D20]'>{t('common.provider.invalidApiKey')}</div>} {!isValid && <div className='text-xs text-[#D92D20]'>{t('common.provider.invalidApiKey')}</div>}
<Indicator color={!isValid ? 'red' : 'green'} className='ml-2' /> <Indicator color={!isValid ? 'red' : 'green'} className='ml-2' />
...@@ -89,7 +101,27 @@ const ProviderItem = ({ ...@@ -89,7 +101,27 @@ const ProviderItem = ({
) )
} }
{ {
!comingSoon && !isOpen && ( (providerTokenHasSetted() && !comingSoon && !isOpen && provider.provider_name === ProviderName.ANTHROPIC) && (
<div className='flex items-center mr-4'>
{
providedOpenaiProvider?.is_valid
? !isValid
? <div className='text-xs text-[#D92D20]'>{t('common.provider.invalidApiKey')}</div>
: null
: <div className='text-xs text-[#DC6803]'>{t('common.provider.anthropic.notEnabled')}</div>
}
<Indicator color={
providedOpenaiProvider?.is_valid
? isValid
? 'green'
: 'red'
: 'yellow'
} className='ml-2' />
</div>
)
}
{
!comingSoon && !isOpen && provider.provider_name !== ProviderName.ANTHROPIC && (
<div className=' <div className='
px-3 h-[28px] bg-white border border-gray-200 rounded-md cursor-pointer px-3 h-[28px] bg-white border border-gray-200 rounded-md cursor-pointer
text-xs font-medium text-gray-700 flex items-center text-xs font-medium text-gray-700 flex items-center
...@@ -98,6 +130,34 @@ const ProviderItem = ({ ...@@ -98,6 +130,34 @@ const ProviderItem = ({
</div> </div>
) )
} }
{
(!comingSoon && !isOpen && provider.provider_name === ProviderName.ANTHROPIC)
? providedOpenaiProvider?.is_enabled
? (
<div className='
px-3 h-[28px] bg-white border border-gray-200 rounded-md cursor-pointer
text-xs font-medium text-gray-700 flex items-center
' onClick={() => providedOpenaiProvider.is_valid && onActive(id)}>
{providerTokenHasSetted() ? t('common.provider.editKey') : t('common.provider.addKey')}
</div>
)
: (
<Tooltip
htmlContent={<div className='w-[320px]'>
{t('common.provider.anthropic.enableTip')}
</div>}
position='bottom'
selector='anthropic-provider-enable-top-tooltip'>
<div className='
px-3 h-[28px] bg-white border border-gray-200 rounded-md cursor-not-allowed
text-xs font-medium text-gray-700 flex items-center opacity-50
'>
{t('common.provider.addKey')}
</div>
</Tooltip>
)
: null
}
{ {
comingSoon && !isOpen && ( comingSoon && !isOpen && (
<div className=' <div className='
...@@ -147,6 +207,29 @@ const ProviderItem = ({ ...@@ -147,6 +207,29 @@ const ProviderItem = ({
/> />
) )
} }
{
provider.provider_name === ProviderName.ANTHROPIC && isOpen && (
<AnthropicProvider
provider={provider}
onValidatedStatus={v => setValidatedStatus(v)}
onTokenChange={v => setToken(v)}
/>
)
}
{
provider.provider_name === ProviderName.ANTHROPIC && !isOpen && providerTokenHasSetted() && providedOpenaiProvider?.is_valid && (
<div className='px-4 py-3 text-[13px] font-medium text-gray-700'>
{t('common.provider.anthropic.using')} {providerNameMap[providedOpenaiProvider.provider_name as string]}
</div>
)
}
{
provider.provider_name === ProviderName.ANTHROPIC && !isOpen && providerTokenHasSetted() && !providedOpenaiProvider?.is_valid && (
<div className='px-4 py-3 bg-[#FFFAEB] text-[13px] font-medium text-gray-700'>
{t('common.provider.anthropic.enableTip')}
</div>
)
}
</div> </div>
) )
} }
......
...@@ -620,7 +620,7 @@ const Main: FC<IMainProps> = ({ ...@@ -620,7 +620,7 @@ const Main: FC<IMainProps> = ({
{ {
hasSetInputs && ( hasSetInputs && (
<div className={cn(doShowSuggestion ? 'pb-[140px]' : (isResponsing ? 'pb-[113px]' : 'pb-[66px]'), 'relative grow h-[200px] pc:w-[794px] max-w-full mobile:w-full mx-auto mb-3.5 overflow-hidden')}> <div className={cn(doShowSuggestion ? 'pb-[140px]' : (isResponsing ? 'pb-[113px]' : 'pb-[76px]'), 'relative grow h-[200px] pc:w-[794px] max-w-full mobile:w-full mx-auto mb-3.5 overflow-hidden')}>
<div className='h-full overflow-y-auto' ref={chatListDomRef}> <div className='h-full overflow-y-auto' ref={chatListDomRef}>
<Chat <Chat
chatList={chatList} chatList={chatList}
......
This diff is collapsed.
This diff is collapsed.
...@@ -552,6 +552,10 @@ const Main: FC<IMainProps> = ({ ...@@ -552,6 +552,10 @@ const Main: FC<IMainProps> = ({
) )
} }
const difyIcon = (
<div className={s.difyHeader}></div>
)
if (appUnavailable) if (appUnavailable)
return <AppUnavailable isUnknwonReason={isUnknwonReason} /> return <AppUnavailable isUnknwonReason={isUnknwonReason} />
...@@ -562,7 +566,8 @@ const Main: FC<IMainProps> = ({ ...@@ -562,7 +566,8 @@ const Main: FC<IMainProps> = ({
<div> <div>
<Header <Header
title={siteInfo.title} title={siteInfo.title}
icon={siteInfo.icon || ''} icon=''
customerIcon={difyIcon}
icon_background={siteInfo.icon_background} icon_background={siteInfo.icon_background}
isEmbedScene={true} isEmbedScene={true}
isMobile={isMobile} isMobile={isMobile}
...@@ -604,7 +609,7 @@ const Main: FC<IMainProps> = ({ ...@@ -604,7 +609,7 @@ const Main: FC<IMainProps> = ({
{ {
hasSetInputs && ( hasSetInputs && (
<div className={cn(doShowSuggestion ? 'pb-[140px]' : (isResponsing ? 'pb-[113px]' : 'pb-[66px]'), 'relative grow h-[200px] pc:w-[794px] max-w-full mobile:w-full mx-auto mb-3.5 overflow-hidden')}> <div className={cn(doShowSuggestion ? 'pb-[140px]' : (isResponsing ? 'pb-[113px]' : 'pb-[76px]'), 'relative grow h-[200px] pc:w-[794px] max-w-full mobile:w-full mx-auto mb-3.5 overflow-hidden')}>
<div className='h-full overflow-y-auto' ref={chatListDomRef}> <div className='h-full overflow-y-auto' ref={chatListDomRef}>
<Chat <Chat
chatList={chatList} chatList={chatList}
...@@ -624,6 +629,7 @@ const Main: FC<IMainProps> = ({ ...@@ -624,6 +629,7 @@ const Main: FC<IMainProps> = ({
suggestionList={suggestQuestions} suggestionList={suggestQuestions}
displayScene='web' displayScene='web'
isShowSpeechToText={speechToTextConfig?.enabled} isShowSpeechToText={speechToTextConfig?.enabled}
answerIconClassName={s.difyIcon}
/> />
</div> </div>
</div>) </div>)
......
.installedApp { .installedApp {
height: calc(100vh - 74px); height: calc(100vh - 74px);
}
.difyIcon {
background-image: url(./icons/dify.svg);
}
.difyHeader {
width: 24px;
height: 24px;
background: url(./icons/dify-header.svg) center center no-repeat;
background-size: contain;
} }
\ No newline at end of file
...@@ -307,7 +307,7 @@ const Welcome: FC<IWelcomeProps> = ({ ...@@ -307,7 +307,7 @@ const Welcome: FC<IWelcomeProps> = ({
} }
return ( return (
<div className='relative mobile:min-h-[48px] tablet:min-h-[64px]'> <div className='relative tablet:min-h-[64px]'>
{/* {hasSetInputs && renderHeader()} */} {/* {hasSetInputs && renderHeader()} */}
<div className='mx-auto pc:w-[794px] max-w-full mobile:w-full px-3.5'> <div className='mx-auto pc:w-[794px] max-w-full mobile:w-full px-3.5'>
{/* Has't set inputs */} {/* Has't set inputs */}
......
...@@ -3,6 +3,7 @@ import React from 'react' ...@@ -3,6 +3,7 @@ import React from 'react'
import AppIcon from '@/app/components/base/app-icon' import AppIcon from '@/app/components/base/app-icon'
export type IHeaderProps = { export type IHeaderProps = {
title: string title: string
customerIcon?: React.ReactNode
icon: string icon: string
icon_background: string icon_background: string
isMobile?: boolean isMobile?: boolean
...@@ -11,6 +12,7 @@ export type IHeaderProps = { ...@@ -11,6 +12,7 @@ export type IHeaderProps = {
const Header: FC<IHeaderProps> = ({ const Header: FC<IHeaderProps> = ({
title, title,
isMobile, isMobile,
customerIcon,
icon, icon,
icon_background, icon_background,
isEmbedScene = false, isEmbedScene = false,
...@@ -25,7 +27,7 @@ const Header: FC<IHeaderProps> = ({ ...@@ -25,7 +27,7 @@ const Header: FC<IHeaderProps> = ({
> >
<div></div> <div></div>
<div className="flex items-center space-x-2"> <div className="flex items-center space-x-2">
<AppIcon size="small" icon={icon} background={icon_background} /> {customerIcon || <AppIcon size="small" icon={icon} background={icon_background} />}
<div <div
className={`text-sm text-gray-800 font-bold ${ className={`text-sm text-gray-800 font-bold ${
isEmbedScene ? 'text-white' : '' isEmbedScene ? 'text-white' : ''
......
import I18nServer from './components/i18n-server' import I18nServer from './components/i18n-server'
import BrowerInitor from './components/browser-initor'
import SentryInitor from './components/sentry-initor' import SentryInitor from './components/sentry-initor'
import { getLocaleOnServer } from '@/i18n/server' import { getLocaleOnServer } from '@/i18n/server'
...@@ -25,10 +26,12 @@ const LocaleLayout = ({ ...@@ -25,10 +26,12 @@ const LocaleLayout = ({
data-public-edition={process.env.NEXT_PUBLIC_EDITION} data-public-edition={process.env.NEXT_PUBLIC_EDITION}
data-public-sentry-dsn={process.env.NEXT_PUBLIC_SENTRY_DSN} data-public-sentry-dsn={process.env.NEXT_PUBLIC_SENTRY_DSN}
> >
<SentryInitor> <BrowerInitor>
{/* @ts-expect-error Async Server Component */} <SentryInitor>
<I18nServer locale={locale}>{children}</I18nServer> {/* @ts-expect-error Async Server Component */}
</SentryInitor> <I18nServer locale={locale}>{children}</I18nServer>
</SentryInitor>
</BrowerInitor>
</body> </body>
</html> </html>
) )
......
...@@ -54,7 +54,7 @@ const translation = { ...@@ -54,7 +54,7 @@ const translation = {
maxTokenTip: maxTokenTip:
'Max tokens depending on the model. Prompt and completion share this limit. One token is roughly 1 English character.', 'Max tokens depending on the model. Prompt and completion share this limit. One token is roughly 1 English character.',
maxTokenSettingTip: 'Your max token setting is high, potentially limiting space for prompts, queries, and data. Consider setting it below 2/3.', maxTokenSettingTip: 'Your max token setting is high, potentially limiting space for prompts, queries, and data. Consider setting it below 2/3.',
setToCurrentModelMaxTokenTip: 'Max token is updated to the maximum token of the current model 4,000.', setToCurrentModelMaxTokenTip: 'Max token is updated to the maximum token of the current model {{maxToken}}.',
}, },
tone: { tone: {
Creative: 'Creative', Creative: 'Creative',
...@@ -180,6 +180,22 @@ const translation = { ...@@ -180,6 +180,22 @@ const translation = {
useYourModel: 'Currently using own Model Provider.', useYourModel: 'Currently using own Model Provider.',
close: 'Close', close: 'Close',
}, },
anthropicHosted: {
anthropicHosted: 'Anthropic Claude',
onTrial: 'ON TRIAL',
exhausted: 'QUOTA EXHAUSTED',
desc: 'Powerful model, which excels at a wide range of tasks from sophisticated dialogue and creative content generation to detailed instruction.',
callTimes: 'Call times',
usedUp: 'Trial quota used up. Add own Model Provider.',
useYourModel: 'Currently using own Model Provider.',
close: 'Close',
},
anthropic: {
using: 'The embedding capability is using',
enableTip: 'To enable the Anthropic model, you need to bind to OpenAI or Azure OpenAI Service first.',
notEnabled: 'Not enabled',
keyFrom: 'Get your API key from Anthropic',
},
encrypted: { encrypted: {
front: 'Your API KEY will be encrypted and stored using', front: 'Your API KEY will be encrypted and stored using',
back: ' technology.', back: ' technology.',
......
...@@ -54,7 +54,7 @@ const translation = { ...@@ -54,7 +54,7 @@ const translation = {
maxTokenTip: maxTokenTip:
'生成的最大令牌数取决于模型。提示和完成共享令牌数限制。一个令牌约等于 1 个英文或 半个中文字符。', '生成的最大令牌数取决于模型。提示和完成共享令牌数限制。一个令牌约等于 1 个英文或 半个中文字符。',
maxTokenSettingTip: '您设置的最大 tokens 数较大,可能会导致 prompt、用户问题、数据集内容没有 token 空间进行处理,建议设置到 2/3 以下。', maxTokenSettingTip: '您设置的最大 tokens 数较大,可能会导致 prompt、用户问题、数据集内容没有 token 空间进行处理,建议设置到 2/3 以下。',
setToCurrentModelMaxTokenTip: '最大令牌数更新为当前模型最大的令牌数 4,000。', setToCurrentModelMaxTokenTip: '最大令牌数更新为当前模型最大的令牌数 {{maxToken}}。',
}, },
tone: { tone: {
Creative: '创意', Creative: '创意',
...@@ -180,6 +180,22 @@ const translation = { ...@@ -180,6 +180,22 @@ const translation = {
useYourModel: '当前正在使用你自己的模型供应商。', useYourModel: '当前正在使用你自己的模型供应商。',
close: '关闭', close: '关闭',
}, },
anthropicHosted: {
anthropicHosted: 'Anthropic Claude',
onTrial: '体验',
exhausted: '超出限额',
desc: '功能强大的模型,擅长执行从复杂对话和创意内容生成到详细指导的各种任务。',
callTimes: '调用次数',
usedUp: '试用额度已用完,请在下方添加自己的模型供应商',
useYourModel: '当前正在使用你自己的模型供应商。',
close: '关闭',
},
anthropic: {
using: '嵌入能力正在使用',
enableTip: '要启用 Anthropic 模型,您需要先绑定 OpenAI 或 Azure OpenAI 服务。',
notEnabled: '未启用',
keyFrom: '从 Anthropic 获取您的 API 密钥',
},
encrypted: { encrypted: {
front: '密钥将使用 ', front: '密钥将使用 ',
back: ' 技术进行加密和存储。', back: ' 技术进行加密和存储。',
......
...@@ -59,14 +59,19 @@ export type Member = Pick<UserProfileResponse, 'id' | 'name' | 'email' | 'last_l ...@@ -59,14 +59,19 @@ export type Member = Pick<UserProfileResponse, 'id' | 'name' | 'email' | 'last_l
export enum ProviderName { export enum ProviderName {
OPENAI = 'openai', OPENAI = 'openai',
AZURE_OPENAI = 'azure_openai', AZURE_OPENAI = 'azure_openai',
ANTHROPIC = 'anthropic',
} }
export type ProviderAzureToken = { export type ProviderAzureToken = {
openai_api_base?: string openai_api_base?: string
openai_api_key?: string openai_api_key?: string
} }
export type ProviderAnthropicToken = {
anthropic_api_key?: string
}
export type ProviderTokenType = { export type ProviderTokenType = {
[ProviderName.OPENAI]: string [ProviderName.OPENAI]: string
[ProviderName.AZURE_OPENAI]: ProviderAzureToken [ProviderName.AZURE_OPENAI]: ProviderAzureToken
[ProviderName.ANTHROPIC]: ProviderAnthropicToken
} }
export type Provider = { export type Provider = {
[Name in ProviderName]: { [Name in ProviderName]: {
......
{ {
"name": "dify-web", "name": "dify-web",
"version": "0.3.8", "version": "0.3.9",
"private": true, "private": true,
"scripts": { "scripts": {
"dev": "next dev", "dev": "next dev",
......
...@@ -54,7 +54,7 @@ async function embedChatbot () { ...@@ -54,7 +54,7 @@ async function embedChatbot () {
iframe.title = "dify chatbot bubble window" iframe.title = "dify chatbot bubble window"
iframe.id = 'dify-chatbot-bubble-window' iframe.id = 'dify-chatbot-bubble-window'
iframe.src = `https://${isDev ? 'dev.' : ''}udify.app/chatbot/${difyChatbotConfig.token}`; iframe.src = `https://${isDev ? 'dev.' : ''}udify.app/chatbot/${difyChatbotConfig.token}`;
iframe.style.cssText = 'border: none; position: fixed; flex-direction: column; justify-content: space-between; box-shadow: rgba(150, 150, 150, 0.2) 0px 10px 30px 0px, rgba(150, 150, 150, 0.2) 0px 0px 0px 1px; bottom: 5rem; right: 1rem; width: 24rem; height: 40rem; border-radius: 0.75rem; display: flex; z-index: 2147483647; overflow: hidden; left: unset;' iframe.style.cssText = 'border: none; position: fixed; flex-direction: column; justify-content: space-between; box-shadow: rgba(150, 150, 150, 0.2) 0px 10px 30px 0px, rgba(150, 150, 150, 0.2) 0px 0px 0px 1px; bottom: 5rem; right: 1rem; width: 24rem; height: 40rem; border-radius: 0.75rem; display: flex; z-index: 2147483647; overflow: hidden; left: unset; background-color: #F3F4F6;'
document.body.appendChild(iframe); document.body.appendChild(iframe);
} }
......
...@@ -27,4 +27,4 @@ async function embedChatbot(){const t=window.difyChatbotConfig;if(t&&t.token){co ...@@ -27,4 +27,4 @@ async function embedChatbot(){const t=window.difyChatbotConfig;if(t&&t.token){co
stroke-linecap="round" stroke-linecap="round"
stroke-linejoin="round" stroke-linejoin="round"
/> />
</svg>`;if(!document.getElementById("dify-chatbot-bubble-button")){var e=document.createElement("div");e.id="dify-chatbot-bubble-button",e.style.cssText="position: fixed; bottom: 1rem; right: 1rem; width: 50px; height: 50px; border-radius: 25px; background-color: #155EEF; box-shadow: rgba(0, 0, 0, 0.2) 0px 4px 8px 0px; cursor: pointer; z-index: 2147483647; transition: all 0.2s ease-in-out 0s; left: unset; transform: scale(1); :hover {transform: scale(1.1);}";const d=document.createElement("div");d.style.cssText="display: flex; align-items: center; justify-content: center; width: 100%; height: 100%; z-index: 2147483647;",d.innerHTML=n,e.appendChild(d),document.body.appendChild(e),e.addEventListener("click",function(){var e=document.getElementById("dify-chatbot-bubble-window");e?"none"===e.style.display?(e.style.display="block",d.innerHTML=i):(e.style.display="none",d.innerHTML=n):((e=document.createElement("iframe")).allow="fullscreen;microphone",e.title="dify chatbot bubble window",e.id="dify-chatbot-bubble-window",e.src=`https://${o?"dev.":""}udify.app/chatbot/`+t.token,e.style.cssText="border: none; position: fixed; flex-direction: column; justify-content: space-between; box-shadow: rgba(150, 150, 150, 0.2) 0px 10px 30px 0px, rgba(150, 150, 150, 0.2) 0px 0px 0px 1px; bottom: 5rem; right: 1rem; width: 24rem; height: 40rem; border-radius: 0.75rem; display: flex; z-index: 2147483647; overflow: hidden; left: unset;",document.body.appendChild(e),d.innerHTML=i)})}}else console.error("difyChatbotConfig is empty or token is not provided")}document.body.onload=embedChatbot; </svg>`;if(!document.getElementById("dify-chatbot-bubble-button")){var e=document.createElement("div");e.id="dify-chatbot-bubble-button",e.style.cssText="position: fixed; bottom: 1rem; right: 1rem; width: 50px; height: 50px; border-radius: 25px; background-color: #155EEF; box-shadow: rgba(0, 0, 0, 0.2) 0px 4px 8px 0px; cursor: pointer; z-index: 2147483647; transition: all 0.2s ease-in-out 0s; left: unset; transform: scale(1); :hover {transform: scale(1.1);}";const d=document.createElement("div");d.style.cssText="display: flex; align-items: center; justify-content: center; width: 100%; height: 100%; z-index: 2147483647;",d.innerHTML=n,e.appendChild(d),document.body.appendChild(e),e.addEventListener("click",function(){var e=document.getElementById("dify-chatbot-bubble-window");e?"none"===e.style.display?(e.style.display="block",d.innerHTML=i):(e.style.display="none",d.innerHTML=n):((e=document.createElement("iframe")).allow="fullscreen;microphone",e.title="dify chatbot bubble window",e.id="dify-chatbot-bubble-window",e.src=`https://${o?"dev.":""}udify.app/chatbot/`+t.token,e.style.cssText="border: none; position: fixed; flex-direction: column; justify-content: space-between; box-shadow: rgba(150, 150, 150, 0.2) 0px 10px 30px 0px, rgba(150, 150, 150, 0.2) 0px 0px 0px 1px; bottom: 5rem; right: 1rem; width: 24rem; height: 40rem; border-radius: 0.75rem; display: flex; z-index: 2147483647; overflow: hidden; left: unset; background-color: #F3F4F6;",document.body.appendChild(e),d.innerHTML=i)})}}else console.error("difyChatbotConfig is empty or token is not provided")}document.body.onload=embedChatbot;
\ No newline at end of file \ No newline at end of file
...@@ -3,7 +3,7 @@ import { del, get, patch, post, put } from './base' ...@@ -3,7 +3,7 @@ import { del, get, patch, post, put } from './base'
import type { import type {
AccountIntegrate, CommonResponse, DataSourceNotion, AccountIntegrate, CommonResponse, DataSourceNotion,
IWorkspace, LangGeniusVersionResponse, Member, IWorkspace, LangGeniusVersionResponse, Member,
OauthResponse, Provider, ProviderAzureToken, TenantInfoResponse, OauthResponse, Provider, ProviderAnthropicToken, ProviderAzureToken, TenantInfoResponse,
UserProfileOriginResponse, UserProfileOriginResponse,
} from '@/models/common' } from '@/models/common'
import type { import type {
...@@ -58,7 +58,7 @@ export const fetchProviders: Fetcher<Provider[] | null, { url: string; params: R ...@@ -58,7 +58,7 @@ export const fetchProviders: Fetcher<Provider[] | null, { url: string; params: R
export const validateProviderKey: Fetcher<ValidateOpenAIKeyResponse, { url: string; body: { token: string } }> = ({ url, body }) => { export const validateProviderKey: Fetcher<ValidateOpenAIKeyResponse, { url: string; body: { token: string } }> = ({ url, body }) => {
return post(url, { body }) as Promise<ValidateOpenAIKeyResponse> return post(url, { body }) as Promise<ValidateOpenAIKeyResponse>
} }
export const updateProviderAIKey: Fetcher<UpdateOpenAIKeyResponse, { url: string; body: { token: string | ProviderAzureToken } }> = ({ url, body }) => { export const updateProviderAIKey: Fetcher<UpdateOpenAIKeyResponse, { url: string; body: { token: string | ProviderAzureToken | ProviderAnthropicToken } }> = ({ url, body }) => {
return post(url, { body }) as Promise<UpdateOpenAIKeyResponse> return post(url, { body }) as Promise<UpdateOpenAIKeyResponse>
} }
......
export enum ProviderType {
openai = 'openai',
anthropic = 'anthropic',
}
export enum AppType { export enum AppType {
'chat' = 'chat', 'chat' = 'chat',
'completion' = 'completion', 'completion' = 'completion',
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment