Unverified Commit 6768fd4d authored by zxhlyh's avatar zxhlyh Committed by GitHub

fix: some RAG retrieval bugs (#1577)

Co-authored-by: 's avatarJoel <iamjoel007@gmail.com>
parent d0456d0f
...@@ -100,7 +100,11 @@ const Popup: FC<PopupProps> = ({ ...@@ -100,7 +100,11 @@ const Popup: FC<PopupProps> = ({
data={source.index_node_hash.substring(0, 7)} data={source.index_node_hash.substring(0, 7)}
icon={<BezierCurve03 className='mr-1 w-3 h-3' />} icon={<BezierCurve03 className='mr-1 w-3 h-3' />}
/> />
<ProgressTooltip data={Number(source.score.toFixed(2))} /> {
source.score && (
<ProgressTooltip data={Number(source.score.toFixed(2))} />
)
}
</div> </div>
) )
} }
......
...@@ -59,6 +59,7 @@ const SettingsModal: FC<SettingsModalProps> = ({ ...@@ -59,6 +59,7 @@ const SettingsModal: FC<SettingsModalProps> = ({
const { const {
rerankDefaultModel, rerankDefaultModel,
isRerankDefaultModelVaild, isRerankDefaultModelVaild,
rerankModelList,
} = useProviderContext() } = useProviderContext()
const handleValueChange = (type: string, value: string) => { const handleValueChange = (type: string, value: string) => {
...@@ -78,6 +79,7 @@ const SettingsModal: FC<SettingsModalProps> = ({ ...@@ -78,6 +79,7 @@ const SettingsModal: FC<SettingsModalProps> = ({
!isReRankModelSelected({ !isReRankModelSelected({
rerankDefaultModel, rerankDefaultModel,
isRerankDefaultModelVaild, isRerankDefaultModelVaild,
rerankModelList,
retrievalConfig, retrievalConfig,
indexMethod, indexMethod,
}) })
...@@ -270,7 +272,7 @@ const SettingsModal: FC<SettingsModalProps> = ({ ...@@ -270,7 +272,7 @@ const SettingsModal: FC<SettingsModalProps> = ({
)} )}
<div <div
className='absolute z-10 bottom-0 w-full flex justify-end py-4 px-6 border-t bg-white ' className='absolute z-[5] bottom-0 w-full flex justify-end py-4 px-6 border-t bg-white '
style={{ style={{
borderColor: 'rgba(0, 0, 0, 0.05)', borderColor: 'rgba(0, 0, 0, 0.05)',
}} }}
......
...@@ -5,18 +5,29 @@ export const isReRankModelSelected = ({ ...@@ -5,18 +5,29 @@ export const isReRankModelSelected = ({
rerankDefaultModel, rerankDefaultModel,
isRerankDefaultModelVaild, isRerankDefaultModelVaild,
retrievalConfig, retrievalConfig,
rerankModelList,
indexMethod, indexMethod,
}: { }: {
rerankDefaultModel?: BackendModel rerankDefaultModel?: BackendModel
isRerankDefaultModelVaild: boolean isRerankDefaultModelVaild: boolean
retrievalConfig: RetrievalConfig retrievalConfig: RetrievalConfig
rerankModelList: BackendModel[]
indexMethod?: string indexMethod?: string
}) => { }) => {
const rerankModel = (retrievalConfig.reranking_model?.reranking_model_name ? retrievalConfig.reranking_model : undefined) || (isRerankDefaultModelVaild ? rerankDefaultModel : undefined) const rerankModelSelected = (() => {
if (retrievalConfig.reranking_model?.reranking_model_name)
return !!rerankModelList.find(({ model_name }) => model_name === retrievalConfig.reranking_model?.reranking_model_name)
if (isRerankDefaultModelVaild)
return !!rerankDefaultModel
return false
})()
if ( if (
indexMethod === 'high_quality' indexMethod === 'high_quality'
&& (retrievalConfig.reranking_enable || retrievalConfig.search_method === RETRIEVE_METHOD.fullText) && (retrievalConfig.reranking_enable || retrievalConfig.search_method === RETRIEVE_METHOD.hybrid)
&& !rerankModel && !rerankModelSelected
) )
return false return false
...@@ -35,7 +46,7 @@ export const ensureRerankModelSelected = ({ ...@@ -35,7 +46,7 @@ export const ensureRerankModelSelected = ({
const rerankModel = retrievalConfig.reranking_model?.reranking_model_name ? retrievalConfig.reranking_model : undefined const rerankModel = retrievalConfig.reranking_model?.reranking_model_name ? retrievalConfig.reranking_model : undefined
if ( if (
indexMethod === 'high_quality' indexMethod === 'high_quality'
&& (retrievalConfig.reranking_enable || retrievalConfig.search_method === RETRIEVE_METHOD.fullText) && (retrievalConfig.reranking_enable || retrievalConfig.search_method === RETRIEVE_METHOD.hybrid)
&& !rerankModel && !rerankModel
) { ) {
return { return {
......
...@@ -16,11 +16,23 @@ type Props = { ...@@ -16,11 +16,23 @@ type Props = {
} }
const RetrievalMethodConfig: FC<Props> = ({ const RetrievalMethodConfig: FC<Props> = ({
value, value: passValue,
onChange, onChange,
}) => { }) => {
const { t } = useTranslation() const { t } = useTranslation()
const { supportRetrievalMethods } = useProviderContext() const { supportRetrievalMethods, rerankDefaultModel } = useProviderContext()
const value = (() => {
if (!passValue.reranking_model.reranking_model_name) {
return {
...passValue,
reranking_model: {
reranking_provider_name: rerankDefaultModel?.model_provider.provider_name || '',
reranking_model_name: rerankDefaultModel?.model_name || '',
},
}
}
return passValue
})()
return ( return (
<div className='space-y-2'> <div className='space-y-2'>
{supportRetrievalMethods.includes(RETRIEVE_METHOD.semantic) && ( {supportRetrievalMethods.includes(RETRIEVE_METHOD.semantic) && (
......
...@@ -263,6 +263,7 @@ const StepTwo = ({ ...@@ -263,6 +263,7 @@ const StepTwo = ({
const { const {
rerankDefaultModel, rerankDefaultModel,
isRerankDefaultModelVaild, isRerankDefaultModelVaild,
rerankModelList,
} = useProviderContext() } = useProviderContext()
const getCreationParams = () => { const getCreationParams = () => {
let params let params
...@@ -282,6 +283,7 @@ const StepTwo = ({ ...@@ -282,6 +283,7 @@ const StepTwo = ({
!isReRankModelSelected({ !isReRankModelSelected({
rerankDefaultModel, rerankDefaultModel,
isRerankDefaultModelVaild, isRerankDefaultModelVaild,
rerankModelList,
// eslint-disable-next-line @typescript-eslint/no-use-before-define // eslint-disable-next-line @typescript-eslint/no-use-before-define
retrievalConfig, retrievalConfig,
indexMethod: indexMethod as string, indexMethod: indexMethod as string,
...@@ -359,6 +361,9 @@ const StepTwo = ({ ...@@ -359,6 +361,9 @@ const StepTwo = ({
try { try {
let res let res
const params = getCreationParams() const params = getCreationParams()
if (!params)
return false
setIsCreating(true) setIsCreating(true)
if (!datasetId) { if (!datasetId) {
res = await createFirstDocument({ res = await createFirstDocument({
......
...@@ -3,11 +3,14 @@ import type { FC } from 'react' ...@@ -3,11 +3,14 @@ import type { FC } from 'react'
import React, { useRef, useState } from 'react' import React, { useRef, useState } from 'react'
import { useClickAway } from 'ahooks' import { useClickAway } from 'ahooks'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import Toast from '../../base/toast'
import { XClose } from '@/app/components/base/icons/src/vender/line/general' import { XClose } from '@/app/components/base/icons/src/vender/line/general'
import type { RetrievalConfig } from '@/types/app' import type { RetrievalConfig } from '@/types/app'
import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-method-config' import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-method-config'
import EconomicalRetrievalMethodConfig from '@/app/components/datasets/common/economical-retrieval-method-config' import EconomicalRetrievalMethodConfig from '@/app/components/datasets/common/economical-retrieval-method-config'
import Button from '@/app/components/base/button' import Button from '@/app/components/base/button'
import { useProviderContext } from '@/context/provider-context'
import { ensureRerankModelSelected, isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model'
type Props = { type Props = {
indexMethod: string indexMethod: string
...@@ -33,6 +36,32 @@ const ModifyRetrievalModal: FC<Props> = ({ ...@@ -33,6 +36,32 @@ const ModifyRetrievalModal: FC<Props> = ({
onHide() onHide()
}, ref) }, ref)
const {
rerankDefaultModel,
isRerankDefaultModelVaild,
rerankModelList,
} = useProviderContext()
const handleSave = () => {
if (
!isReRankModelSelected({
rerankDefaultModel,
isRerankDefaultModelVaild,
rerankModelList,
retrievalConfig,
indexMethod,
})
) {
Toast.notify({ type: 'error', message: t('appDebug.datasetConfig.rerankModelRequired') })
return
}
onSave(ensureRerankModelSelected({
rerankDefaultModel: rerankDefaultModel!,
retrievalConfig,
indexMethod,
}))
}
if (!isShow) if (!isShow)
return null return null
...@@ -87,7 +116,7 @@ const ModifyRetrievalModal: FC<Props> = ({ ...@@ -87,7 +116,7 @@ const ModifyRetrievalModal: FC<Props> = ({
}} }}
> >
<Button className='mr-2 flex-shrink-0' onClick={onHide}>{t('common.operation.cancel')}</Button> <Button className='mr-2 flex-shrink-0' onClick={onHide}>{t('common.operation.cancel')}</Button>
<Button type='primary' className='flex-shrink-0' onClick={() => onSave(retrievalConfig)} >{t('common.operation.save')}</Button> <Button type='primary' className='flex-shrink-0' onClick={handleSave} >{t('common.operation.save')}</Button>
</div> </div>
</div> </div>
) )
......
...@@ -59,6 +59,7 @@ const Form = () => { ...@@ -59,6 +59,7 @@ const Form = () => {
const { const {
rerankDefaultModel, rerankDefaultModel,
isRerankDefaultModelVaild, isRerankDefaultModelVaild,
rerankModelList,
} = useProviderContext() } = useProviderContext()
const handleSave = async () => { const handleSave = async () => {
...@@ -72,6 +73,7 @@ const Form = () => { ...@@ -72,6 +73,7 @@ const Form = () => {
!isReRankModelSelected({ !isReRankModelSelected({
rerankDefaultModel, rerankDefaultModel,
isRerankDefaultModelVaild, isRerankDefaultModelVaild,
rerankModelList,
retrievalConfig, retrievalConfig,
indexMethod, indexMethod,
}) })
......
...@@ -16,12 +16,16 @@ const config: ProviderConfig = { ...@@ -16,12 +16,16 @@ const config: ProviderConfig = {
'en': <CohereText className='w-[120px] h-6' />, 'en': <CohereText className='w-[120px] h-6' />,
'zh-Hans': <CohereText className='w-[120px] h-6' />, 'zh-Hans': <CohereText className='w-[120px] h-6' />,
}, },
hit: {
'en': 'Rerank Model Supported',
'zh-Hans': '支持 Rerank 模型',
},
}, },
modal: { modal: {
key: ProviderEnum.cohere, key: ProviderEnum.cohere,
title: { title: {
'en': 'cohere', 'en': 'Rerank Model',
'zh-Hans': 'cohere', 'zh-Hans': 'Rerank 模型',
}, },
icon: <Cohere className='w-6 h-6' />, icon: <Cohere className='w-6 h-6' />,
link: { link: {
......
...@@ -26,6 +26,7 @@ import { ModelType } from '@/app/components/header/account-setting/model-page/de ...@@ -26,6 +26,7 @@ import { ModelType } from '@/app/components/header/account-setting/model-page/de
import { useEventEmitterContextContext } from '@/context/event-emitter' import { useEventEmitterContextContext } from '@/context/event-emitter'
import { useProviderContext } from '@/context/provider-context' import { useProviderContext } from '@/context/provider-context'
import I18n from '@/context/i18n' import I18n from '@/context/i18n'
import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback'
const MODEL_CARD_LIST = [ const MODEL_CARD_LIST = [
config.openai, config.openai,
...@@ -42,6 +43,10 @@ const ModelPage = () => { ...@@ -42,6 +43,10 @@ const ModelPage = () => {
const { locale } = useContext(I18n) const { locale } = useContext(I18n)
const { const {
updateModelList, updateModelList,
textGenerationDefaultModel,
embeddingsDefaultModel,
speech2textDefaultModel,
rerankDefaultModel,
} = useProviderContext() } = useProviderContext()
const { data: providers, mutate: mutateProviders } = useSWR('/workspaces/current/model-providers', fetchModelProviders) const { data: providers, mutate: mutateProviders } = useSWR('/workspaces/current/model-providers', fetchModelProviders)
const [showModal, setShowModal] = useState(false) const [showModal, setShowModal] = useState(false)
...@@ -196,11 +201,22 @@ const ModelPage = () => { ...@@ -196,11 +201,22 @@ const ModelPage = () => {
} }
} }
const defaultModelNotConfigured = !textGenerationDefaultModel && !embeddingsDefaultModel && !speech2textDefaultModel && !rerankDefaultModel
return ( return (
<div className='relative pt-1 -mt-2'> <div className='relative pt-1 -mt-2'>
<div className='flex items-center justify-between mb-2 h-8'> <div className={`flex items-center justify-between mb-2 h-8 ${defaultModelNotConfigured && 'px-3 bg-[#FFFAEB] rounded-lg border border-[#FEF0C7]'}`}>
<div className='text-sm font-medium text-gray-800'>{t('common.modelProvider.models')}</div> {
<SystemModel /> defaultModelNotConfigured
? (
<div className='flex items-center text-xs font-medium text-gray-700'>
<AlertTriangle className='mr-1 w-3 h-3 text-[#F79009]' />
{t('common.modelProvider.notConfigured')}
</div>
)
: <div className='text-sm font-medium text-gray-800'>{t('common.modelProvider.models')}</div>
}
<SystemModel onUpdate={() => mutateProviders()} />
</div> </div>
<div className='grid grid-cols-2 gap-4 mb-6'> <div className='grid grid-cols-2 gap-4 mb-6'>
{ {
......
...@@ -2,7 +2,6 @@ import { useCallback, useState } from 'react' ...@@ -2,7 +2,6 @@ import { useCallback, useState } from 'react'
import type { FC } from 'react' import type { FC } from 'react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import { useContext } from 'use-context-selector' import { useContext } from 'use-context-selector'
import { Portal } from '@headlessui/react'
import type { FormValue, ProviderConfigModal } from '../declarations' import type { FormValue, ProviderConfigModal } from '../declarations'
import { ConfigurableProviders } from '../utils' import { ConfigurableProviders } from '../utils'
import Form from './Form' import Form from './Form'
...@@ -12,6 +11,10 @@ import { Lock01 } from '@/app/components/base/icons/src/vender/solid/security' ...@@ -12,6 +11,10 @@ import { Lock01 } from '@/app/components/base/icons/src/vender/solid/security'
import { LinkExternal02 } from '@/app/components/base/icons/src/vender/line/general' import { LinkExternal02 } from '@/app/components/base/icons/src/vender/line/general'
import { AlertCircle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback' import { AlertCircle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback'
import { useEventEmitterContextContext } from '@/context/event-emitter' import { useEventEmitterContextContext } from '@/context/event-emitter'
import {
PortalToFollowElem,
PortalToFollowElemContent,
} from '@/app/components/base/portal-to-follow-elem'
type ModelModalProps = { type ModelModalProps = {
isShow: boolean isShow: boolean
...@@ -90,75 +93,77 @@ const ModelModal: FC<ModelModalProps> = ({ ...@@ -90,75 +93,77 @@ const ModelModal: FC<ModelModalProps> = ({
return null return null
return ( return (
<Portal> <PortalToFollowElem open>
<div className='fixed inset-0 flex items-center justify-center bg-black/[.25]'> <PortalToFollowElemContent className='w-full h-full z-[60]'>
<div className='w-[640px] max-h-[calc(100vh-120px)] bg-white shadow-xl rounded-2xl overflow-y-auto'> <div className='fixed inset-0 flex items-center justify-center bg-black/[.25]'>
<div className='px-8 pt-8'> <div className='w-[640px] max-h-[calc(100vh-120px)] bg-white shadow-xl rounded-2xl overflow-y-auto'>
<div className='flex justify-between items-center mb-2'> <div className='px-8 pt-8'>
<div className='text-xl font-semibold text-gray-900'>{renderTitlePrefix()}</div> <div className='flex justify-between items-center mb-2'>
{modelModal?.icon} <div className='text-xl font-semibold text-gray-900'>{renderTitlePrefix()}</div>
</div> {modelModal?.icon}
<Form </div>
modelModal={modelModal} <Form
fields={modelModal?.fields || []} modelModal={modelModal}
initValue={modelModal?.defaultValue} fields={modelModal?.fields || []}
onChange={newValue => setValue(newValue)} initValue={modelModal?.defaultValue}
onValidatedError={handleValidatedError} onChange={newValue => setValue(newValue)}
mode={mode} onValidatedError={handleValidatedError}
cleared={cleared} mode={mode}
onClearedChange={setCleared} cleared={cleared}
onValidating={handleValidating} onClearedChange={setCleared}
/> onValidating={handleValidating}
<div className='flex justify-between items-center py-6'> />
<a <div className='flex justify-between items-center py-6'>
href={modelModal?.link.href} <a
target='_blank' href={modelModal?.link.href}
className='inline-flex items-center text-xs text-primary-600' target='_blank'
> className='inline-flex items-center text-xs text-primary-600'
{modelModal?.link.label[locale]}
<LinkExternal02 className='ml-1 w-3 h-3' />
</a>
<div>
<Button className='mr-2 !h-9 !text-sm font-medium text-gray-700' onClick={onCancel}>{t('common.operation.cancel')}</Button>
<Button
className='!h-9 !text-sm font-medium'
type='primary'
onClick={handleSave}
disabled={loading || (mode === 'edit' && !cleared) || validating}
> >
{t('common.operation.save')} {modelModal?.link.label[locale]}
</Button> <LinkExternal02 className='ml-1 w-3 h-3' />
</a>
<div>
<Button className='mr-2 !h-9 !text-sm font-medium text-gray-700' onClick={onCancel}>{t('common.operation.cancel')}</Button>
<Button
className='!h-9 !text-sm font-medium'
type='primary'
onClick={handleSave}
disabled={loading || (mode === 'edit' && !cleared) || validating}
>
{t('common.operation.save')}
</Button>
</div>
</div> </div>
</div> </div>
</div> <div className='border-t-[0.5px] border-t-[rgba(0,0,0,0.05)]'>
<div className='border-t-[0.5px] border-t-[rgba(0,0,0,0.05)]'> {
{ errorMessage
errorMessage ? (
? ( <div className='flex px-[10px] py-3 bg-[#FEF3F2] text-xs text-[#D92D20]'>
<div className='flex px-[10px] py-3 bg-[#FEF3F2] text-xs text-[#D92D20]'> <AlertCircle className='mt-[1px] mr-2 w-[14px] h-[14px]' />
<AlertCircle className='mt-[1px] mr-2 w-[14px] h-[14px]' /> {errorMessage}
{errorMessage} </div>
</div> )
) : (
: ( <div className='flex justify-center items-center py-3 bg-gray-50 text-xs text-gray-500'>
<div className='flex justify-center items-center py-3 bg-gray-50 text-xs text-gray-500'> <Lock01 className='mr-1 w-3 h-3 text-gray-500' />
<Lock01 className='mr-1 w-3 h-3 text-gray-500' /> {t('common.modelProvider.encrypted.front')}
{t('common.modelProvider.encrypted.front')} <a
<a className='text-primary-600 mx-1'
className='text-primary-600 mx-1' target={'_blank'}
target={'_blank'} href='https://pycryptodome.readthedocs.io/en/latest/src/cipher/oaep.html'
href='https://pycryptodome.readthedocs.io/en/latest/src/cipher/oaep.html' >
> PKCS1_OAEP
PKCS1_OAEP </a>
</a> {t('common.modelProvider.encrypted.back')}
{t('common.modelProvider.encrypted.back')} </div>
</div> )
) }
} </div>
</div> </div>
</div> </div>
</div> </PortalToFollowElemContent>
</Portal> </PortalToFollowElem>
) )
} }
......
import type { FC } from 'react' import type { FC } from 'react'
import { Fragment, useState } from 'react' import React, { Fragment, useEffect, useState } from 'react'
import useSWR from 'swr'
import { Popover, Transition } from '@headlessui/react' import { Popover, Transition } from '@headlessui/react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import _ from 'lodash-es' import _ from 'lodash-es'
import cn from 'classnames' import cn from 'classnames'
import ModelModal from '../model-modal'
import cohereConfig from '../configs/cohere'
import s from './style.module.css' import s from './style.module.css'
import type { BackendModel, ProviderEnum } from '@/app/components/header/account-setting/model-page/declarations' import type { BackendModel, FormValue, ProviderEnum } from '@/app/components/header/account-setting/model-page/declarations'
import { ModelType } from '@/app/components/header/account-setting/model-page/declarations' import { ModelType } from '@/app/components/header/account-setting/model-page/declarations'
import { ChevronDown } from '@/app/components/base/icons/src/vender/line/arrows' import { ChevronDown } from '@/app/components/base/icons/src/vender/line/arrows'
import { Check, SearchLg } from '@/app/components/base/icons/src/vender/line/general' import { Check, LinkExternal01, SearchLg } from '@/app/components/base/icons/src/vender/line/general'
import { XCircle } from '@/app/components/base/icons/src/vender/solid/general' import { XCircle } from '@/app/components/base/icons/src/vender/solid/general'
import { AlertCircle } from '@/app/components/base/icons/src/vender/line/alertsAndFeedback' import { AlertCircle } from '@/app/components/base/icons/src/vender/line/alertsAndFeedback'
import Tooltip from '@/app/components/base/tooltip' import Tooltip from '@/app/components/base/tooltip'
...@@ -20,6 +23,9 @@ import ModelModeTypeLabel from '@/app/components/app/configuration/config-model/ ...@@ -20,6 +23,9 @@ import ModelModeTypeLabel from '@/app/components/app/configuration/config-model/
import type { ModelModeType } from '@/types/app' import type { ModelModeType } from '@/types/app'
import { CubeOutline } from '@/app/components/base/icons/src/vender/line/shapes' import { CubeOutline } from '@/app/components/base/icons/src/vender/line/shapes'
import { useModalContext } from '@/context/modal-context' import { useModalContext } from '@/context/modal-context'
import { useEventEmitterContextContext } from '@/context/event-emitter'
import { fetchDefaultModal, setModelProvider } from '@/service/common'
import { useToastContext } from '@/app/components/base/toast'
type Props = { type Props = {
value: { value: {
...@@ -35,6 +41,7 @@ type Props = { ...@@ -35,6 +41,7 @@ type Props = {
readonly?: boolean readonly?: boolean
triggerIconSmall?: boolean triggerIconSmall?: boolean
whenEmptyGoToSetting?: boolean whenEmptyGoToSetting?: boolean
onUpdate?: () => void
} }
type ModelOption = { type ModelOption = {
...@@ -59,6 +66,7 @@ const ModelSelector: FC<Props> = ({ ...@@ -59,6 +66,7 @@ const ModelSelector: FC<Props> = ({
readonly, readonly,
triggerIconSmall, triggerIconSmall,
whenEmptyGoToSetting, whenEmptyGoToSetting,
onUpdate,
}) => { }) => {
const { t } = useTranslation() const { t } = useTranslation()
const { setShowAccountSettingModal } = useModalContext() const { setShowAccountSettingModal } = useModalContext()
...@@ -68,6 +76,7 @@ const ModelSelector: FC<Props> = ({ ...@@ -68,6 +76,7 @@ const ModelSelector: FC<Props> = ({
speech2textModelList, speech2textModelList,
rerankModelList, rerankModelList,
agentThoughtModelList, agentThoughtModelList,
updateModelList,
} = useProviderContext() } = useProviderContext()
const [search, setSearch] = useState('') const [search, setSearch] = useState('')
const modelList = supportAgentThought const modelList = supportAgentThought
...@@ -98,7 +107,7 @@ const ModelSelector: FC<Props> = ({ ...@@ -98,7 +107,7 @@ const ModelSelector: FC<Props> = ({
}) })
: modelList : modelList
const hasRemoved = value && !modelList.find(({ model_name, model_provider }) => model_name === value.modelName && model_provider.provider_name === value.providerName) const hasRemoved = (value && value.modelName && value.providerName) && !modelList.find(({ model_name, model_provider }) => model_name === value.modelName && model_provider.provider_name === value.providerName)
const modelOptions: ModelOption[] = (() => { const modelOptions: ModelOption[] = (() => {
const providers = _.uniq(filteredModelList.map(item => item.model_provider.provider_name)) const providers = _.uniq(filteredModelList.map(item => item.model_provider.provider_name))
...@@ -121,6 +130,45 @@ const ModelSelector: FC<Props> = ({ ...@@ -121,6 +130,45 @@ const ModelSelector: FC<Props> = ({
}) })
return res return res
})() })()
const { eventEmitter } = useEventEmitterContextContext()
const [showRerankModal, setShowRerankModal] = useState(false)
const [shouldFetchRerankDefaultModel, setShouldFetchRerankDefaultModel] = useState(false)
const { notify } = useToastContext()
const { data: rerankDefaultModel } = useSWR(shouldFetchRerankDefaultModel ? '/workspaces/current/default-model?model_type=reranking' : null, fetchDefaultModal)
const handleOpenRerankModal = (e: React.MouseEvent<HTMLDivElement>) => {
e.stopPropagation()
setShowRerankModal(true)
}
const handleRerankModalSave = async (originValue?: FormValue) => {
if (originValue) {
try {
eventEmitter?.emit('provider-save')
const res = await setModelProvider({
url: `/workspaces/current/model-providers/${cohereConfig.modal.key}`,
body: {
config: originValue,
},
})
if (res.result === 'success') {
notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') })
updateModelList(ModelType.reranking)
setShowRerankModal(false)
setShouldFetchRerankDefaultModel(true)
if (onUpdate)
onUpdate()
}
eventEmitter?.emit('')
}
catch (e) {
eventEmitter?.emit('')
}
}
}
useEffect(() => {
if (rerankDefaultModel && whenEmptyGoToSetting)
onChange(rerankDefaultModel)
}, [rerankDefaultModel])
return ( return (
<div className=''> <div className=''>
...@@ -130,7 +178,7 @@ const ModelSelector: FC<Props> = ({ ...@@ -130,7 +178,7 @@ const ModelSelector: FC<Props> = ({
({ open }) => ( ({ open }) => (
<> <>
{ {
value (value && value.modelName && value.providerName)
? ( ? (
<> <>
<ModelIcon <ModelIcon
...@@ -146,9 +194,19 @@ const ModelSelector: FC<Props> = ({ ...@@ -146,9 +194,19 @@ const ModelSelector: FC<Props> = ({
</div> </div>
</> </>
) )
: ( : whenEmptyGoToSetting
<div className='grow text-left text-sm text-gray-800 opacity-60'>{t('common.modelProvider.selectModel')}</div> ? (
) <div className='grow flex items-center h-9 justify-between' onClick={handleOpenRerankModal}>
<div className='flex items-center text-[13px] font-medium text-primary-500'>
<CubeOutline className='mr-1.5 w-4 h-4' />
{t('common.modelProvider.selector.rerankTip')}
</div>
<LinkExternal01 className='w-3 h-3 text-gray-500' />
</div>
)
: (
<div className='grow text-left text-sm text-gray-800 opacity-60'>{t('common.modelProvider.selectModel')}</div>
)
} }
{ {
hasRemoved && ( hasRemoved && (
...@@ -162,7 +220,16 @@ const ModelSelector: FC<Props> = ({ ...@@ -162,7 +220,16 @@ const ModelSelector: FC<Props> = ({
</Tooltip> </Tooltip>
) )
} }
{!readonly && <ChevronDown className={`w-4 h-4 text-gray-700 ${open ? 'opacity-100' : 'opacity-60'}`} />} {
!readonly && !whenEmptyGoToSetting && (
<ChevronDown className={`w-4 h-4 text-gray-700 ${open ? 'opacity-100' : 'opacity-60'}`} />
)
}
{
whenEmptyGoToSetting && (value && value.modelName && value.providerName) && (
<ChevronDown className={`w-4 h-4 text-gray-700 ${open ? 'opacity-100' : 'opacity-60'}`} />
)
}
</> </>
) )
} }
...@@ -246,21 +313,6 @@ const ModelSelector: FC<Props> = ({ ...@@ -246,21 +313,6 @@ const ModelSelector: FC<Props> = ({
return null return null
}) })
} }
{
whenEmptyGoToSetting && modelList.length === 0 && (
<div className='pt-6'>
<div className='flex items-center justify-center mx-auto mb-2 w-12 h-12 rounded-[10px] border border-[#EAECF5]'>
<CubeOutline className='w-6 h-6 text-gray-500' />
</div>
<div className='mb-1 text-center text-[13px] font-medium text-gray-500'>
{t('common.modelProvider.selector.emptyTip')}
</div>
<div className='mb-6 text-center text-xs text-primary-500'>
<span onClick={() => setShowAccountSettingModal({ payload: 'provider' })}>{t('common.modelProvider.selector.emptySetting')}</span>
</div>
</div>
)
}
{modelList.length !== 0 && (search && filteredModelList.length === 0) && ( {modelList.length !== 0 && (search && filteredModelList.length === 0) && (
<div className='px-3 pt-1.5 h-[30px] text-center text-xs text-gray-500'>{t('common.modelProvider.noModelFound', { model: search })}</div> <div className='px-3 pt-1.5 h-[30px] text-center text-xs text-gray-500'>{t('common.modelProvider.noModelFound', { model: search })}</div>
)} )}
...@@ -281,6 +333,13 @@ const ModelSelector: FC<Props> = ({ ...@@ -281,6 +333,13 @@ const ModelSelector: FC<Props> = ({
</Transition> </Transition>
)} )}
</Popover> </Popover>
<ModelModal
isShow={showRerankModal}
modelModal={cohereConfig.modal}
onCancel={() => setShowRerankModal(false)}
onSave={handleRerankModalSave}
mode={'add'}
/>
</div> </div>
) )
} }
......
import type { FC } from 'react'
import { useState } from 'react' import { useState } from 'react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import ModelSelector from '../model-selector' import ModelSelector from '../model-selector'
...@@ -17,7 +18,12 @@ import { ModelType } from '@/app/components/header/account-setting/model-page/de ...@@ -17,7 +18,12 @@ import { ModelType } from '@/app/components/header/account-setting/model-page/de
import { useToastContext } from '@/app/components/base/toast' import { useToastContext } from '@/app/components/base/toast'
import Button from '@/app/components/base/button' import Button from '@/app/components/base/button'
const SystemModel = () => { type SystemModelProps = {
onUpdate: () => void
}
const SystemModel: FC<SystemModelProps> = ({
onUpdate,
}) => {
const { t } = useTranslation() const { t } = useTranslation()
const { const {
textGenerationDefaultModel, textGenerationDefaultModel,
...@@ -91,7 +97,7 @@ const SystemModel = () => { ...@@ -91,7 +97,7 @@ const SystemModel = () => {
> >
<PortalToFollowElemTrigger onClick={() => setOpen(v => !v)}> <PortalToFollowElemTrigger onClick={() => setOpen(v => !v)}>
<div className={` <div className={`
flex items-center px-2 h-6 text-xs text-gray-700 cursor-pointer rounded-md border-[0.5px] border-gray-200 shadow-xs flex items-center px-2 h-6 text-xs text-gray-700 cursor-pointer bg-white rounded-md border-[0.5px] border-gray-200 shadow-xs
hover:bg-gray-100 hover:shadow-none hover:bg-gray-100 hover:shadow-none
${open && 'bg-gray-100 shadow-none'} ${open && 'bg-gray-100 shadow-none'}
`}> `}>
...@@ -158,6 +164,8 @@ const SystemModel = () => { ...@@ -158,6 +164,8 @@ const SystemModel = () => {
value={selectedModel[ModelType.reranking]} value={selectedModel[ModelType.reranking]}
modelType={ModelType.reranking} modelType={ModelType.reranking}
onChange={v => handleChangeDefaultModel(ModelType.reranking, v)} onChange={v => handleChangeDefaultModel(ModelType.reranking, v)}
whenEmptyGoToSetting
onUpdate={onUpdate}
/> />
</div> </div>
</div> </div>
......
...@@ -305,7 +305,7 @@ const translation = { ...@@ -305,7 +305,7 @@ const translation = {
}, },
result: 'Output Text', result: 'Output Text',
datasetConfig: { datasetConfig: {
settingTitle: 'Retrieve Settings', settingTitle: 'Retrieval settings',
retrieveOneWay: { retrieveOneWay: {
title: 'N-to-1 retrieval', title: 'N-to-1 retrieval',
description: 'Based on user intent and dataset descriptions, the Agent autonomously selects the best dataset for querying. Best for applications with distinct, limited datasets.', description: 'Based on user intent and dataset descriptions, the Agent autonomously selects the best dataset for querying. Best for applications with distinct, limited datasets.',
......
...@@ -223,6 +223,7 @@ const translation = { ...@@ -223,6 +223,7 @@ const translation = {
}, },
}, },
modelProvider: { modelProvider: {
notConfigured: 'The system model has not yet been fully configured, and some functions may be unavailable.',
systemModelSettings: 'System Model Settings', systemModelSettings: 'System Model Settings',
systemModelSettingsLink: 'Why is it necessary to set up a system model?', systemModelSettingsLink: 'Why is it necessary to set up a system model?',
selectModel: 'Select your model', selectModel: 'Select your model',
...@@ -252,6 +253,7 @@ const translation = { ...@@ -252,6 +253,7 @@ const translation = {
tip: 'This model has been removed. Please add a model or select another model.', tip: 'This model has been removed. Please add a model or select another model.',
emptyTip: 'No available models', emptyTip: 'No available models',
emptySetting: 'Please go to settings to configure', emptySetting: 'Please go to settings to configure',
rerankTip: 'Please set up the Rerank model',
}, },
card: { card: {
quota: 'QUOTA', quota: 'QUOTA',
......
...@@ -223,6 +223,7 @@ const translation = { ...@@ -223,6 +223,7 @@ const translation = {
}, },
}, },
modelProvider: { modelProvider: {
notConfigured: '系统模型尚未完全配置,部分功能可能无法使用。',
systemModelSettings: '系统模型设置', systemModelSettings: '系统模型设置',
systemModelSettingsLink: '为什么需要设置系统模型?', systemModelSettingsLink: '为什么需要设置系统模型?',
selectModel: '选择您的模型', selectModel: '选择您的模型',
...@@ -252,6 +253,7 @@ const translation = { ...@@ -252,6 +253,7 @@ const translation = {
tip: '该模型已被删除。请添模型或选择其他模型。', tip: '该模型已被删除。请添模型或选择其他模型。',
emptyTip: '无可用模型', emptyTip: '无可用模型',
emptySetting: '请前往设置进行配置', emptySetting: '请前往设置进行配置',
rerankTip: '请设置 Rerank 模型',
}, },
card: { card: {
quota: '额度', quota: '额度',
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment