Unverified Commit e8210ef7 authored by Yeuoly's avatar Yeuoly

Merge branch 'main' into feat/agent-image

parents c5d2181b 3b357f51
...@@ -13,6 +13,16 @@ class NotSetupError(BaseHTTPException): ...@@ -13,6 +13,16 @@ class NotSetupError(BaseHTTPException):
"Please proceed with the initialization and installation process first." "Please proceed with the initialization and installation process first."
code = 401 code = 401
class NotInitValidateError(BaseHTTPException):
error_code = 'not_init_validated'
description = "Init validation has not been completed yet. " \
"Please proceed with the init validation process first."
code = 401
class InitValidateFailedError(BaseHTTPException):
error_code = 'init_validate_failed'
description = "Init validation failed. Please check the password and try again."
code = 401
class AccountNotLinkTenantError(BaseHTTPException): class AccountNotLinkTenantError(BaseHTTPException):
error_code = 'account_not_link_tenant' error_code = 'account_not_link_tenant'
......
import os
from flask import current_app, session
from flask_restful import Resource, reqparse
from libs.helper import str_len
from models.model import DifySetup
from services.account_service import TenantService
from . import api
from .error import AlreadySetupError, InitValidateFailedError
from .wraps import only_edition_self_hosted
class InitValidateAPI(Resource):
def get(self):
init_status = get_init_validate_status()
if init_status:
return { 'status': 'finished' }
return {'status': 'not_started' }
@only_edition_self_hosted
def post(self):
# is tenant created
tenant_count = TenantService.get_tenant_count()
if tenant_count > 0:
raise AlreadySetupError()
parser = reqparse.RequestParser()
parser.add_argument('password', type=str_len(30),
required=True, location='json')
input_password = parser.parse_args()['password']
if input_password != os.environ.get('INIT_PASSWORD'):
session['is_init_validated'] = False
raise InitValidateFailedError()
session['is_init_validated'] = True
return {'result': 'success'}, 201
def get_init_validate_status():
if current_app.config['EDITION'] == 'SELF_HOSTED':
if os.environ.get('INIT_PASSWORD'):
return session.get('is_init_validated') or DifySetup.query.first()
return True
api.add_resource(InitValidateAPI, '/init')
...@@ -10,7 +10,8 @@ from models.model import DifySetup ...@@ -10,7 +10,8 @@ from models.model import DifySetup
from services.account_service import AccountService, RegisterService, TenantService from services.account_service import AccountService, RegisterService, TenantService
from . import api from . import api
from .error import AlreadySetupError, NotSetupError from .error import AlreadySetupError, NotSetupError, NotInitValidateError
from .init_validate import get_init_validate_status
from .wraps import only_edition_self_hosted from .wraps import only_edition_self_hosted
...@@ -24,7 +25,7 @@ class SetupApi(Resource): ...@@ -24,7 +25,7 @@ class SetupApi(Resource):
'step': 'finished', 'step': 'finished',
'setup_at': setup_status.setup_at.isoformat() 'setup_at': setup_status.setup_at.isoformat()
} }
return {'step': 'not_start'} return {'step': 'not_started'}
return {'step': 'finished'} return {'step': 'finished'}
@only_edition_self_hosted @only_edition_self_hosted
...@@ -37,6 +38,9 @@ class SetupApi(Resource): ...@@ -37,6 +38,9 @@ class SetupApi(Resource):
tenant_count = TenantService.get_tenant_count() tenant_count = TenantService.get_tenant_count()
if tenant_count > 0: if tenant_count > 0:
raise AlreadySetupError() raise AlreadySetupError()
if not get_init_validate_status():
raise NotInitValidateError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('email', type=email, parser.add_argument('email', type=email,
...@@ -71,7 +75,10 @@ def setup_required(view): ...@@ -71,7 +75,10 @@ def setup_required(view):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args, **kwargs):
# check setup # check setup
if not get_setup_status(): if not get_init_validate_status():
raise NotInitValidateError()
elif not get_setup_status():
raise NotSetupError() raise NotSetupError()
return view(*args, **kwargs) return view(*args, **kwargs)
......
...@@ -199,7 +199,7 @@ class AssistantApplicationRunner(AppRunner): ...@@ -199,7 +199,7 @@ class AssistantApplicationRunner(AppRunner):
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
if set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL]).intersection(model_schema.features): if set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL]).intersection(model_schema.features or []):
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
# start agent runner # start agent runner
......
...@@ -97,7 +97,6 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): ...@@ -97,7 +97,6 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
tool_input='', tool_input='',
messages_ids=message_file_ids messages_ids=message_file_ids
) )
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
# recale llm max tokens # recale llm max tokens
self.recale_llm_max_tokens(self.model_config, prompt_messages) self.recale_llm_max_tokens(self.model_config, prompt_messages)
...@@ -124,7 +123,11 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): ...@@ -124,7 +123,11 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
current_llm_usage = None current_llm_usage = None
if self.stream_tool_call: if self.stream_tool_call:
is_first_chunk = True
for chunk in chunks: for chunk in chunks:
if is_first_chunk:
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
is_first_chunk = False
# check if there is any tool call # check if there is any tool call
if self.check_tool_calls(chunk): if self.check_tool_calls(chunk):
function_call_state = True function_call_state = True
...@@ -183,6 +186,8 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): ...@@ -183,6 +186,8 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
if not result.message.content: if not result.message.content:
result.message.content = '' result.message.content = ''
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
yield LLMResultChunk( yield LLMResultChunk(
model=model_instance.model, model=model_instance.model,
prompt_messages=result.prompt_messages, prompt_messages=result.prompt_messages,
......
...@@ -168,7 +168,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel): ...@@ -168,7 +168,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
return result return result
def _handle_generate_stream_response(self, model: str, credentials: dict, responses: list[Generator], def _handle_generate_stream_response(self, model: str, credentials: dict, responses: Generator,
prompt_messages: list[PromptMessage]) -> Generator: prompt_messages: list[PromptMessage]) -> Generator:
""" """
Handle llm stream response Handle llm stream response
...@@ -182,7 +182,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel): ...@@ -182,7 +182,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
for index, response in enumerate(responses): for index, response in enumerate(responses):
resp_finish_reason = response.output.finish_reason resp_finish_reason = response.output.finish_reason
resp_content = response.output.text resp_content = response.output.text
useage = response.usage usage = response.usage
if resp_finish_reason is None and (resp_content is None or resp_content == ''): if resp_finish_reason is None and (resp_content is None or resp_content == ''):
continue continue
...@@ -194,7 +194,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel): ...@@ -194,7 +194,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
if resp_finish_reason is not None: if resp_finish_reason is not None:
# transform usage # transform usage
usage = self._calc_response_usage(model, credentials, useage.input_tokens, useage.output_tokens) usage = self._calc_response_usage(model, credentials, usage.input_tokens, usage.output_tokens)
yield LLMResultChunk( yield LLMResultChunk(
model=model, model=model,
......
...@@ -70,6 +70,10 @@ class StableDiffusionTool(BuiltinTool): ...@@ -70,6 +70,10 @@ class StableDiffusionTool(BuiltinTool):
base_url = self.runtime.credentials.get('base_url', None) base_url = self.runtime.credentials.get('base_url', None)
if not base_url: if not base_url:
return self.create_text_message('Please input base_url') return self.create_text_message('Please input base_url')
if 'model' in tool_parameters:
self.runtime.credentials['model'] = tool_parameters['model']
model = self.runtime.credentials.get('model', None) model = self.runtime.credentials.get('model', None)
if not model: if not model:
return self.create_text_message('Please input model') return self.create_text_message('Please input model')
......
...@@ -25,6 +25,18 @@ parameters: ...@@ -25,6 +25,18 @@ parameters:
pt_BR: Image prompt, you can check the official documentation of Stable Diffusion pt_BR: Image prompt, you can check the official documentation of Stable Diffusion
llm_description: Image prompt of Stable Diffusion, you should describe the image you want to generate as a list of words as possible as detailed, the prompt must be written in English. llm_description: Image prompt of Stable Diffusion, you should describe the image you want to generate as a list of words as possible as detailed, the prompt must be written in English.
form: llm form: llm
- name: model
type: string
required: false
label:
en_US: Model Name
zh_Hans: 模型名称
pt_BR: Model Name
human_description:
en_US: Model Name
zh_Hans: 模型名称
pt_BR: Model Name
form: form
- name: lora - name: lora
type: string type: string
required: false required: false
......
...@@ -231,25 +231,26 @@ class BuiltinToolProviderController(ToolProviderController): ...@@ -231,25 +231,26 @@ class BuiltinToolProviderController(ToolProviderController):
if credential_schema == ToolProviderCredentials.CredentialsType.SECRET_INPUT or \ if credential_schema == ToolProviderCredentials.CredentialsType.SECRET_INPUT or \
credential_schema == ToolProviderCredentials.CredentialsType.TEXT_INPUT: credential_schema == ToolProviderCredentials.CredentialsType.TEXT_INPUT:
if not isinstance(credentials[credential_name], str): if not isinstance(credentials[credential_name], str):
raise ToolProviderCredentialValidationError(f'credential {credential_name} should be string') raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} should be string')
elif credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT: elif credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT:
if not isinstance(credentials[credential_name], str): if not isinstance(credentials[credential_name], str):
raise ToolProviderCredentialValidationError(f'credential {credential_name} should be string') raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} should be string')
options = credential_schema.options options = credential_schema.options
if not isinstance(options, list): if not isinstance(options, list):
raise ToolProviderCredentialValidationError(f'credential {credential_name} options should be list') raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} options should be list')
if credentials[credential_name] not in [x.value for x in options]: if credentials[credential_name] not in [x.value for x in options]:
raise ToolProviderCredentialValidationError(f'credential {credential_name} should be one of {options}') raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} should be one of {options}')
credentials_need_to_validate.pop(credential_name) if credentials[credential_name]:
credentials_need_to_validate.pop(credential_name)
for credential_name in credentials_need_to_validate: for credential_name in credentials_need_to_validate:
credential_schema = credentials_need_to_validate[credential_name] credential_schema = credentials_need_to_validate[credential_name]
if credential_schema.required: if credential_schema.required:
raise ToolProviderCredentialValidationError(f'credential {credential_name} is required') raise ToolProviderCredentialValidationError(f'credential {credential_schema.label.en_US} is required')
# the credential is not set currently, set the default value if needed # the credential is not set currently, set the default value if needed
if credential_schema.default is not None: if credential_schema.default is not None:
......
...@@ -66,6 +66,7 @@ class AccountService: ...@@ -66,6 +66,7 @@ class AccountService:
account.current_tenant_id = tenant_account_join.tenant_id account.current_tenant_id = tenant_account_join.tenant_id
else: else:
_create_tenant_for_account(account) _create_tenant_for_account(account)
session['workspace_id'] = account.current_tenant_id
else: else:
account.current_tenant_id = workspace_id account.current_tenant_id = workspace_id
else: else:
...@@ -75,6 +76,7 @@ class AccountService: ...@@ -75,6 +76,7 @@ class AccountService:
account.current_tenant_id = tenant_account_join.tenant_id account.current_tenant_id = tenant_account_join.tenant_id
else: else:
_create_tenant_for_account(account) _create_tenant_for_account(account)
session['workspace_id'] = account.current_tenant_id
current_time = datetime.utcnow() current_time = datetime.utcnow()
...@@ -286,6 +288,7 @@ class TenantService: ...@@ -286,6 +288,7 @@ class TenantService:
# Set the current tenant for the account # Set the current tenant for the account
account.current_tenant_id = tenant_account_join.tenant_id account.current_tenant_id = tenant_account_join.tenant_id
session['workspace_id'] = account.current_tenant.id
@staticmethod @staticmethod
def get_tenant_members(tenant: Tenant) -> List[Account]: def get_tenant_members(tenant: Tenant) -> List[Account]:
......
...@@ -15,6 +15,9 @@ services: ...@@ -15,6 +15,9 @@ services:
# different from api or web app domain. # different from api or web app domain.
# example: http://cloud.dify.ai # example: http://cloud.dify.ai
CONSOLE_WEB_URL: '' CONSOLE_WEB_URL: ''
# Password for admin user initialization.
# If left unset, admin user will not be prompted for a password when creating the initial admin account.
INIT_PASSWORD: ''
# The base URL of console application api server, refers to the Console base URL of WEB service if console domain is # The base URL of console application api server, refers to the Console base URL of WEB service if console domain is
# different from api or web app domain. # different from api or web app domain.
# example: http://cloud.dify.ai # example: http://cloud.dify.ai
......
'use client'
import { useEffect, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { useRouter } from 'next/navigation'
import Toast from '../components/base/toast'
import Loading from '../components/base/loading'
import Button from '@/app/components/base/button'
import { fetchInitValidateStatus, initValidate } from '@/service/common'
import type { InitValidateStatusResponse } from '@/models/common'
const InitPasswordPopup = () => {
const [password, setPassword] = useState('')
const [loading, setLoading] = useState(true)
const [validated, setValidated] = useState(false)
const router = useRouter()
const { t } = useTranslation()
const handleValidation = async () => {
setLoading(true)
try {
const response = await initValidate({ body: { password } })
if (response.result === 'success') {
setValidated(true)
router.push('/install') // or render setup form
}
else {
throw new Error('Validation failed')
}
}
catch (e: any) {
Toast.notify({
type: 'error',
message: e.message,
duration: 5000,
})
setLoading(false)
}
}
useEffect(() => {
fetchInitValidateStatus().then((res: InitValidateStatusResponse) => {
if (res.status === 'finished')
window.location.href = '/install'
else
setLoading(false)
})
}, [])
return (
loading
? <Loading />
: <div>
{!validated && (
<div className="block mx-12 min-w-28">
<div className="mb-4">
<label htmlFor="password" className="block text-sm font-medium text-gray-700">
{t('login.adminInitPassword')}
</label>
<div className="mt-1 relative rounded-md shadow-sm">
<input
id="password"
type="password"
value={password}
onChange={e => setPassword(e.target.value)}
className="appearance-none block w-full px-3 py-2 border border-gray-300 rounded-md shadow-sm placeholder-gray-400 focus:outline-none focus:ring-indigo-500 focus:border-indigo-500 sm:text-sm"
/>
</div>
</div>
<div className="flex flex-row flex-wrap justify-stretch p-0">
<Button type="primary" onClick={handleValidation} className="basis-full min-w-28">
{t('login.validate')}
</Button>
</div>
</div>
)}
</div>
)
}
export default InitPasswordPopup
import React from 'react'
import classNames from 'classnames'
import style from '../signin/page.module.css'
import InitPasswordPopup from './InitPasswordPopup'
const Install = () => {
return (
<div className={classNames(
style.background,
'flex w-full min-h-screen',
'p-4 lg:p-8',
'gap-x-20',
'justify-center lg:justify-start',
)}>
<div className="block m-auto w-96">
<InitPasswordPopup />
</div>
</div>
)
}
export default Install
...@@ -9,8 +9,8 @@ import Loading from '../components/base/loading' ...@@ -9,8 +9,8 @@ import Loading from '../components/base/loading'
import Button from '@/app/components/base/button' import Button from '@/app/components/base/button'
// import I18n from '@/context/i18n' // import I18n from '@/context/i18n'
import { fetchSetupStatus, setup } from '@/service/common' import { fetchInitValidateStatus, fetchSetupStatus, setup } from '@/service/common'
import type { SetupStatusResponse } from '@/models/common' import type { InitValidateStatusResponse, SetupStatusResponse } from '@/models/common'
const validEmailReg = /^[\w\.-]+@([\w-]+\.)+[\w-]{2,}$/ const validEmailReg = /^[\w\.-]+@([\w-]+\.)+[\w-]{2,}$/
const validPassword = /^(?=.*[a-zA-Z])(?=.*\d).{8,}$/ const validPassword = /^(?=.*[a-zA-Z])(?=.*\d).{8,}$/
...@@ -70,10 +70,16 @@ const InstallForm = () => { ...@@ -70,10 +70,16 @@ const InstallForm = () => {
useEffect(() => { useEffect(() => {
fetchSetupStatus().then((res: SetupStatusResponse) => { fetchSetupStatus().then((res: SetupStatusResponse) => {
if (res.step === 'finished') if (res.step === 'finished') {
window.location.href = '/signin' window.location.href = '/signin'
else }
setLoading(false) else {
fetchInitValidateStatus().then((res: InitValidateStatusResponse) => {
if (res.status === 'not_started')
window.location.href = '/init'
})
}
setLoading(false)
}) })
}, []) }, [])
......
...@@ -9,7 +9,7 @@ const translation = { ...@@ -9,7 +9,7 @@ const translation = {
namePlaceholder: 'Your username', namePlaceholder: 'Your username',
forget: 'Forgot your password?', forget: 'Forgot your password?',
signBtn: 'Sign in', signBtn: 'Sign in',
installBtn: 'Setting', installBtn: 'Set up',
setAdminAccount: 'Setting up an admin account', setAdminAccount: 'Setting up an admin account',
setAdminAccountDesc: 'Maximum privileges for admin account, which can be used to create applications and manage LLM providers, etc.', setAdminAccountDesc: 'Maximum privileges for admin account, which can be used to create applications and manage LLM providers, etc.',
createAndSignIn: 'Create and sign in', createAndSignIn: 'Create and sign in',
...@@ -32,7 +32,7 @@ const translation = { ...@@ -32,7 +32,7 @@ const translation = {
tosDesc: 'By signing up, you agree to our', tosDesc: 'By signing up, you agree to our',
donthave: 'Don\'t have?', donthave: 'Don\'t have?',
invalidInvitationCode: 'Invalid invitation code', invalidInvitationCode: 'Invalid invitation code',
accountAlreadyInited: 'Account already inited', accountAlreadyInited: 'Account already initialized',
error: { error: {
emailEmpty: 'Email address is required', emailEmpty: 'Email address is required',
emailInValid: 'Please enter a valid email address', emailInValid: 'Please enter a valid email address',
...@@ -51,7 +51,9 @@ const translation = { ...@@ -51,7 +51,9 @@ const translation = {
explore: 'Explore Dify', explore: 'Explore Dify',
activatedTipStart: 'You have joined the', activatedTipStart: 'You have joined the',
activatedTipEnd: 'team', activatedTipEnd: 'team',
activated: 'Sign In Now', activated: 'Sign in now',
adminInitPassword: 'Admin initialization password',
validate: 'Validate',
} }
export default translation export default translation
...@@ -52,6 +52,8 @@ const translation = { ...@@ -52,6 +52,8 @@ const translation = {
activatedTipStart: '您已加入', activatedTipStart: '您已加入',
activatedTipEnd: '团队', activatedTipEnd: '团队',
activated: '现在登录', activated: '现在登录',
adminInitPassword: '管理员初始化密码',
validate: '验证',
} }
export default translation export default translation
...@@ -13,6 +13,10 @@ export type SetupStatusResponse = { ...@@ -13,6 +13,10 @@ export type SetupStatusResponse = {
setup_at?: Date setup_at?: Date
} }
export type InitValidateStatusResponse = {
status: 'finished' | 'not_started'
}
export type UserProfileResponse = { export type UserProfileResponse = {
id: string id: string
name: string name: string
......
...@@ -256,7 +256,11 @@ const baseFetch = <T>( ...@@ -256,7 +256,11 @@ const baseFetch = <T>(
} }
const loginUrl = `${globalThis.location.origin}/signin` const loginUrl = `${globalThis.location.origin}/signin`
bodyJson.then((data: ResponseError) => { bodyJson.then((data: ResponseError) => {
if (data.code === 'not_setup' && IS_CE_EDITION) if (data.code === 'init_validate_failed' && IS_CE_EDITION)
Toast.notify({ type: 'error', message: data.message, duration: 4000 })
else if (data.code === 'not_init_validated' && IS_CE_EDITION)
globalThis.location.href = `${globalThis.location.origin}/init`
else if (data.code === 'not_setup' && IS_CE_EDITION)
globalThis.location.href = `${globalThis.location.origin}/install` globalThis.location.href = `${globalThis.location.origin}/install`
else if (location.pathname !== '/signin' || !IS_CE_EDITION) else if (location.pathname !== '/signin' || !IS_CE_EDITION)
globalThis.location.href = loginUrl globalThis.location.href = loginUrl
......
...@@ -9,6 +9,7 @@ import type { ...@@ -9,6 +9,7 @@ import type {
FileUploadConfigResponse, FileUploadConfigResponse,
ICurrentWorkspace, ICurrentWorkspace,
IWorkspace, IWorkspace,
InitValidateStatusResponse,
InvitationResponse, InvitationResponse,
LangGeniusVersionResponse, LangGeniusVersionResponse,
Member, Member,
...@@ -42,6 +43,14 @@ export const setup: Fetcher<CommonResponse, { body: Record<string, any> }> = ({ ...@@ -42,6 +43,14 @@ export const setup: Fetcher<CommonResponse, { body: Record<string, any> }> = ({
return post<CommonResponse>('/setup', { body }) return post<CommonResponse>('/setup', { body })
} }
export const initValidate: Fetcher<CommonResponse, { body: Record<string, any> }> = ({ body }) => {
return post<CommonResponse>('/init', { body })
}
export const fetchInitValidateStatus = () => {
return get<InitValidateStatusResponse>('/init')
}
export const fetchSetupStatus = () => { export const fetchSetupStatus = () => {
return get<SetupStatusResponse>('/setup') return get<SetupStatusResponse>('/setup')
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment