Unverified Commit c5d2181b authored by Yeuoly's avatar Yeuoly

Merge branch 'main' into feat/agent-image

parents 469a66b5 5929e840
...@@ -195,8 +195,8 @@ class ToolApiProviderUpdateApi(Resource): ...@@ -195,8 +195,8 @@ class ToolApiProviderUpdateApi(Resource):
parser.add_argument('schema', type=str, required=True, nullable=False, location='json') parser.add_argument('schema', type=str, required=True, nullable=False, location='json')
parser.add_argument('provider', type=str, required=True, nullable=False, location='json') parser.add_argument('provider', type=str, required=True, nullable=False, location='json')
parser.add_argument('original_provider', type=str, required=True, nullable=False, location='json') parser.add_argument('original_provider', type=str, required=True, nullable=False, location='json')
parser.add_argument('icon', type=str, required=True, nullable=False, location='json') parser.add_argument('icon', type=dict, required=True, nullable=False, location='json')
parser.add_argument('privacy_policy', type=str, required=True, nullable=False, location='json') parser.add_argument('privacy_policy', type=str, required=True, nullable=True, location='json')
args = parser.parse_args() args = parser.parse_args()
......
...@@ -35,6 +35,10 @@ class SparkLLMClient: ...@@ -35,6 +35,10 @@ class SparkLLMClient:
'spark-3': { 'spark-3': {
'version': 'v3.1', 'version': 'v3.1',
'chat_domain': 'generalv3' 'chat_domain': 'generalv3'
},
'spark-3.5': {
'version': 'v3.5',
'chat_domain': 'generalv3.5'
} }
} }
......
model: spark-3.5
label:
en_US: Spark V3.5
model_type: llm
model_properties:
mode: chat
parameter_rules:
- name: temperature
use_template: temperature
default: 0.5
help:
zh_Hans: 核采样阈值。用于决定结果随机性,取值越高随机性越强即相同的问题得到的不同答案的可能性越高。
en_US: Kernel sampling threshold. Used to determine the randomness of the results. The higher the value, the stronger the randomness, that is, the higher the possibility of getting different answers to the same question.
- name: max_tokens
use_template: max_tokens
default: 2048
min: 1
max: 8192
help:
zh_Hans: 模型回答的tokens的最大长度。
en_US: 模型回答的tokens的最大长度。
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
default: 4
min: 1
max: 6
help:
zh_Hans: 从 k 个候选中随机选择⼀个(⾮等概率)。
en_US: Randomly select one from k candidates (non-equal probability).
required: false
...@@ -38,7 +38,7 @@ from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT ...@@ -38,7 +38,7 @@ from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT
ZHIPUAI_DEFAULT_TIMEOUT = httpx.Timeout(timeout=300.0, connect=8.0) ZHIPUAI_DEFAULT_TIMEOUT = httpx.Timeout(timeout=300.0, connect=8.0)
ZHIPUAI_DEFAULT_MAX_RETRIES = 3 ZHIPUAI_DEFAULT_MAX_RETRIES = 3
ZHIPUAI_DEFAULT_LIMITS = httpx.Limits(max_connections=50, max_keepalive_connections=10) ZHIPUAI_DEFAULT_LIMITS = httpx.Limits(max_connections=5, max_keepalive_connections=5)
class HttpClient: class HttpClient:
......
...@@ -35,6 +35,10 @@ class SparkLLMClient: ...@@ -35,6 +35,10 @@ class SparkLLMClient:
'spark-v3': { 'spark-v3': {
'version': 'v3.1', 'version': 'v3.1',
'chat_domain': 'generalv3' 'chat_domain': 'generalv3'
},
'spark-v3.5': {
'version': 'v3.5',
'chat_domain': 'generalv3.5'
} }
} }
......
...@@ -5,6 +5,7 @@ from core.tools.provider.builtin.stablediffusion.tools.stable_diffusion import S ...@@ -5,6 +5,7 @@ from core.tools.provider.builtin.stablediffusion.tools.stable_diffusion import S
from typing import Any, Dict from typing import Any, Dict
class StableDiffusionProvider(BuiltinToolProviderController): class StableDiffusionProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: Dict[str, Any]) -> None: def _validate_credentials(self, credentials: Dict[str, Any]) -> None:
try: try:
...@@ -12,15 +13,6 @@ class StableDiffusionProvider(BuiltinToolProviderController): ...@@ -12,15 +13,6 @@ class StableDiffusionProvider(BuiltinToolProviderController):
meta={ meta={
"credentials": credentials, "credentials": credentials,
} }
).invoke( ).validate_models()
user_id='',
tool_parameters={
"prompt": "cat",
"lora": "",
"steps": 1,
"width": 512,
"height": 512,
},
)
except Exception as e: except Exception as e:
raise ToolProviderCredentialValidationError(str(e)) raise ToolProviderCredentialValidationError(str(e))
\ No newline at end of file
...@@ -4,7 +4,7 @@ from core.tools.entities.common_entities import I18nObject ...@@ -4,7 +4,7 @@ from core.tools.entities.common_entities import I18nObject
from core.tools.errors import ToolProviderCredentialValidationError from core.tools.errors import ToolProviderCredentialValidationError
from typing import Any, Dict, List, Union from typing import Any, Dict, List, Union
from httpx import post from httpx import post, get
from os.path import join from os.path import join
from base64 import b64decode, b64encode from base64 import b64decode, b64encode
from PIL import Image from PIL import Image
...@@ -59,6 +59,7 @@ DRAW_TEXT_OPTIONS = { ...@@ -59,6 +59,7 @@ DRAW_TEXT_OPTIONS = {
"alwayson_scripts": {} "alwayson_scripts": {}
} }
class StableDiffusionTool(BuiltinTool): class StableDiffusionTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: Dict[str, Any]) \ def _invoke(self, user_id: str, tool_parameters: Dict[str, Any]) \
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
...@@ -136,7 +137,31 @@ class StableDiffusionTool(BuiltinTool): ...@@ -136,7 +137,31 @@ class StableDiffusionTool(BuiltinTool):
width=width, width=width,
height=height, height=height,
steps=steps) steps=steps)
def validate_models(self) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""
validate models
"""
try:
base_url = self.runtime.credentials.get('base_url', None)
if not base_url:
raise ToolProviderCredentialValidationError('Please input base_url')
model = self.runtime.credentials.get('model', None)
if not model:
raise ToolProviderCredentialValidationError('Please input model')
response = get(url=f'{base_url}/sdapi/v1/sd-models', timeout=120)
if response.status_code != 200:
raise ToolProviderCredentialValidationError('Failed to get models')
else:
models = [d['model_name'] for d in response.json()]
if len([d for d in models if d == model]) > 0:
return self.create_text_message(json.dumps(models))
else:
raise ToolProviderCredentialValidationError(f'model {model} does not exist')
except Exception as e:
raise ToolProviderCredentialValidationError(f'Failed to get models, {e}')
def img2img(self, base_url: str, lora: str, image_binary: bytes, def img2img(self, base_url: str, lora: str, image_binary: bytes,
prompt: str, negative_prompt: str, prompt: str, negative_prompt: str,
width: int, height: int, steps: int) \ width: int, height: int, steps: int) \
...@@ -211,10 +236,9 @@ class StableDiffusionTool(BuiltinTool): ...@@ -211,10 +236,9 @@ class StableDiffusionTool(BuiltinTool):
except Exception as e: except Exception as e:
return self.create_text_message('Failed to generate image') return self.create_text_message('Failed to generate image')
def get_runtime_parameters(self) -> List[ToolParameter]: def get_runtime_parameters(self) -> List[ToolParameter]:
parameters = [ parameters = [
ToolParameter(name='prompt', ToolParameter(name='prompt',
label=I18nObject(en_US='Prompt', zh_Hans='Prompt'), label=I18nObject(en_US='Prompt', zh_Hans='Prompt'),
human_description=I18nObject( human_description=I18nObject(
en_US='Image prompt, you can check the official documentation of Stable Diffusion', en_US='Image prompt, you can check the official documentation of Stable Diffusion',
...@@ -227,7 +251,7 @@ class StableDiffusionTool(BuiltinTool): ...@@ -227,7 +251,7 @@ class StableDiffusionTool(BuiltinTool):
] ]
if len(self.list_default_image_variables()) != 0: if len(self.list_default_image_variables()) != 0:
parameters.append( parameters.append(
ToolParameter(name='image_id', ToolParameter(name='image_id',
label=I18nObject(en_US='image_id', zh_Hans='image_id'), label=I18nObject(en_US='image_id', zh_Hans='image_id'),
human_description=I18nObject( human_description=I18nObject(
en_US='Image id of the image you want to generate based on, if you want to generate image based on the default image, you can leave this field empty.', en_US='Image id of the image you want to generate based on, if you want to generate image based on the default image, you can leave this field empty.',
......
...@@ -66,7 +66,6 @@ class AccountService: ...@@ -66,7 +66,6 @@ class AccountService:
account.current_tenant_id = tenant_account_join.tenant_id account.current_tenant_id = tenant_account_join.tenant_id
else: else:
_create_tenant_for_account(account) _create_tenant_for_account(account)
session['workspace_id'] = account.current_tenant_id
else: else:
account.current_tenant_id = workspace_id account.current_tenant_id = workspace_id
else: else:
...@@ -76,7 +75,6 @@ class AccountService: ...@@ -76,7 +75,6 @@ class AccountService:
account.current_tenant_id = tenant_account_join.tenant_id account.current_tenant_id = tenant_account_join.tenant_id
else: else:
_create_tenant_for_account(account) _create_tenant_for_account(account)
session['workspace_id'] = account.current_tenant_id
current_time = datetime.utcnow() current_time = datetime.utcnow()
...@@ -288,7 +286,6 @@ class TenantService: ...@@ -288,7 +286,6 @@ class TenantService:
# Set the current tenant for the account # Set the current tenant for the account
account.current_tenant_id = tenant_account_join.tenant_id account.current_tenant_id = tenant_account_join.tenant_id
session['workspace_id'] = account.current_tenant.id
@staticmethod @staticmethod
def get_tenant_members(tenant: Tenant) -> List[Account]: def get_tenant_members(tenant: Tenant) -> List[Account]:
......
...@@ -390,7 +390,7 @@ class ToolManageService: ...@@ -390,7 +390,7 @@ class ToolManageService:
@staticmethod @staticmethod
def update_api_tool_provider( def update_api_tool_provider(
user_id: str, tenant_id: str, provider_name: str, original_provider: str, icon: str, credentials: dict, user_id: str, tenant_id: str, provider_name: str, original_provider: str, icon: dict, credentials: dict,
schema_type: str, schema: str, privacy_policy: str schema_type: str, schema: str, privacy_policy: str
): ):
""" """
...@@ -415,7 +415,7 @@ class ToolManageService: ...@@ -415,7 +415,7 @@ class ToolManageService:
# update db provider # update db provider
provider.name = provider_name provider.name = provider_name
provider.icon = icon provider.icon = json.dumps(icon)
provider.schema = schema provider.schema = schema
provider.description = extra_info.get('description', '') provider.description = extra_info.get('description', '')
provider.schema_type_str = ApiProviderSchemaType.OPENAPI.value provider.schema_type_str = ApiProviderSchemaType.OPENAPI.value
......
...@@ -46,7 +46,7 @@ import { fetchDatasets } from '@/service/datasets' ...@@ -46,7 +46,7 @@ import { fetchDatasets } from '@/service/datasets'
import { useProviderContext } from '@/context/provider-context' import { useProviderContext } from '@/context/provider-context'
import { AgentStrategy, AppType, ModelModeType, RETRIEVE_TYPE, Resolution, TransferMethod } from '@/types/app' import { AgentStrategy, AppType, ModelModeType, RETRIEVE_TYPE, Resolution, TransferMethod } from '@/types/app'
import { PromptMode } from '@/models/debug' import { PromptMode } from '@/models/debug'
import { ANNOTATION_DEFAULT, DEFAULT_AGENT_SETTING, DEFAULT_CHAT_PROMPT_CONFIG, DEFAULT_COMPLETION_PROMPT_CONFIG, supportFunctionCallModels } from '@/config' import { ANNOTATION_DEFAULT, DEFAULT_AGENT_SETTING, DEFAULT_CHAT_PROMPT_CONFIG, DEFAULT_COMPLETION_PROMPT_CONFIG } from '@/config'
import SelectDataSet from '@/app/components/app/configuration/dataset-config/select-dataset' import SelectDataSet from '@/app/components/app/configuration/dataset-config/select-dataset'
import { useModalContext } from '@/context/modal-context' import { useModalContext } from '@/context/modal-context'
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
...@@ -157,6 +157,7 @@ const Configuration: FC = () => { ...@@ -157,6 +157,7 @@ const Configuration: FC = () => {
dataSets: [], dataSets: [],
agentConfig: DEFAULT_AGENT_SETTING, agentConfig: DEFAULT_AGENT_SETTING,
}) })
const isChatApp = mode === AppType.chat const isChatApp = mode === AppType.chat
const isAgent = modelConfig.agentConfig?.enabled const isAgent = modelConfig.agentConfig?.enabled
const setIsAgent = (value: boolean) => { const setIsAgent = (value: boolean) => {
...@@ -166,7 +167,7 @@ const Configuration: FC = () => { ...@@ -166,7 +167,7 @@ const Configuration: FC = () => {
doSetModelConfig(newModelConfig) doSetModelConfig(newModelConfig)
} }
const isOpenAI = modelConfig.provider === 'openai' const isOpenAI = modelConfig.provider === 'openai'
const isFunctionCall = (isOpenAI && modelConfig.mode === ModelModeType.chat) || supportFunctionCallModels.includes(modelConfig.model_id)
const [collectionList, setCollectionList] = useState<Collection[]>([]) const [collectionList, setCollectionList] = useState<Collection[]>([])
useEffect(() => { useEffect(() => {
...@@ -262,6 +263,13 @@ const Configuration: FC = () => { ...@@ -262,6 +263,13 @@ const Configuration: FC = () => {
}, },
) )
const isFunctionCall = (() => {
const features = currModel?.features
if (!features)
return false
return features.includes(ModelFeatureEnum.toolCall) || features.includes(ModelFeatureEnum.multiToolCall)
})()
// Fill old app data missing model mode. // Fill old app data missing model mode.
useEffect(() => { useEffect(() => {
if (hasFetchedDetail && !modelModeType) { if (hasFetchedDetail && !modelModeType) {
......
...@@ -153,20 +153,6 @@ export const useTextGenerationCurrentProviderAndModelAndModelList = (defaultMode ...@@ -153,20 +153,6 @@ export const useTextGenerationCurrentProviderAndModelAndModelList = (defaultMode
} }
} }
export const useAgentThoughtCurrentProviderAndModelAndModelList = (defaultModel?: DefaultModel) => {
const { agentThoughtModelList } = useProviderContext()
const {
currentProvider,
currentModel,
} = useCurrentProviderAndModel(agentThoughtModelList, defaultModel)
return {
currentProvider,
currentModel,
agentThoughtModelList,
}
}
export const useModelListAndDefaultModel = (type: ModelTypeIndex) => { export const useModelListAndDefaultModel = (type: ModelTypeIndex) => {
const { data: modelList } = useModelList(type) const { data: modelList } = useModelList(type)
const { data: defaultModel } = useDefaultModel(type) const { data: defaultModel } = useDefaultModel(type)
......
...@@ -126,6 +126,19 @@ const EditCustomCollectionModal: FC<Props> = ({ ...@@ -126,6 +126,19 @@ const EditCustomCollectionModal: FC<Props> = ({
}) })
} }
const getPath = (url: string) => {
if (!url)
return ''
try {
const path = new URL(url).pathname
return path || ''
}
catch (e) {
return url
}
}
return ( return (
<> <>
<Drawer <Drawer
...@@ -202,7 +215,7 @@ const EditCustomCollectionModal: FC<Props> = ({ ...@@ -202,7 +215,7 @@ const EditCustomCollectionModal: FC<Props> = ({
<td className="p-2 pl-3">{item.operation_id}</td> <td className="p-2 pl-3">{item.operation_id}</td>
<td className="p-2 pl-3 text-gray-500 w-[236px]">{item.summary}</td> <td className="p-2 pl-3 text-gray-500 w-[236px]">{item.summary}</td>
<td className="p-2 pl-3">{item.method}</td> <td className="p-2 pl-3">{item.method}</td>
<td className="p-2 pl-3">{item.server_url ? new URL(item.server_url).pathname : ''}</td> <td className="p-2 pl-3">{getPath(item.server_url)}</td>
<td className="p-2 pl-3 w-[62px]"> <td className="p-2 pl-3 w-[62px]">
<Button <Button
className='!h-6 !px-2 text-xs font-medium text-gray-700 whitespace-nowrap' className='!h-6 !px-2 text-xs font-medium text-gray-700 whitespace-nowrap'
......
...@@ -139,8 +139,6 @@ export const DEFAULT_AGENT_SETTING = { ...@@ -139,8 +139,6 @@ export const DEFAULT_AGENT_SETTING = {
tools: [], tools: [],
} }
export const supportFunctionCallModels = ['glm-3-turbo', 'glm-4']
export const DEFAULT_AGENT_PROMPT = { export const DEFAULT_AGENT_PROMPT = {
chat: `Respond to the human as helpfully and accurately as possible. chat: `Respond to the human as helpfully and accurately as possible.
......
...@@ -2,14 +2,13 @@ ...@@ -2,14 +2,13 @@
import { createContext, useContext } from 'use-context-selector' import { createContext, useContext } from 'use-context-selector'
import useSWR from 'swr' import useSWR from 'swr'
import { useEffect, useMemo, useState } from 'react' import { useEffect, useState } from 'react'
import { import {
fetchModelList, fetchModelList,
fetchModelProviders, fetchModelProviders,
fetchSupportRetrievalMethods, fetchSupportRetrievalMethods,
} from '@/service/common' } from '@/service/common'
import { import {
ModelFeatureEnum,
ModelStatusEnum, ModelStatusEnum,
ModelTypeEnum, ModelTypeEnum,
} from '@/app/components/header/account-setting/model-provider-page/declarations' } from '@/app/components/header/account-setting/model-provider-page/declarations'
...@@ -23,7 +22,6 @@ import { defaultPlan } from '@/app/components/billing/config' ...@@ -23,7 +22,6 @@ import { defaultPlan } from '@/app/components/billing/config'
const ProviderContext = createContext<{ const ProviderContext = createContext<{
modelProviders: ModelProvider[] modelProviders: ModelProvider[]
textGenerationModelList: Model[] textGenerationModelList: Model[]
agentThoughtModelList: Model[]
supportRetrievalMethods: RETRIEVE_METHOD[] supportRetrievalMethods: RETRIEVE_METHOD[]
hasSettedApiKey: boolean hasSettedApiKey: boolean
plan: { plan: {
...@@ -38,7 +36,6 @@ const ProviderContext = createContext<{ ...@@ -38,7 +36,6 @@ const ProviderContext = createContext<{
}>({ }>({
modelProviders: [], modelProviders: [],
textGenerationModelList: [], textGenerationModelList: [],
agentThoughtModelList: [],
supportRetrievalMethods: [], supportRetrievalMethods: [],
hasSettedApiKey: true, hasSettedApiKey: true,
plan: { plan: {
...@@ -75,26 +72,6 @@ export const ProviderContextProvider = ({ ...@@ -75,26 +72,6 @@ export const ProviderContextProvider = ({
const { data: textGenerationModelList } = useSWR(`${fetchModelListUrlPrefix}${ModelTypeEnum.textGeneration}`, fetchModelList) const { data: textGenerationModelList } = useSWR(`${fetchModelListUrlPrefix}${ModelTypeEnum.textGeneration}`, fetchModelList)
const { data: supportRetrievalMethods } = useSWR('/datasets/retrieval-setting', fetchSupportRetrievalMethods) const { data: supportRetrievalMethods } = useSWR('/datasets/retrieval-setting', fetchSupportRetrievalMethods)
const agentThoughtModelList = useMemo(() => {
const result: Model[] = []
if (textGenerationModelList?.data) {
textGenerationModelList?.data.forEach((item) => {
const agentThoughtModels = item.models.filter(model => model.features?.includes(ModelFeatureEnum.agentThought))
if (agentThoughtModels.length) {
result.push({
...item,
models: agentThoughtModels,
})
}
})
return result
}
return []
}, [textGenerationModelList])
const [plan, setPlan] = useState(defaultPlan) const [plan, setPlan] = useState(defaultPlan)
const [isFetchedPlan, setIsFetchedPlan] = useState(false) const [isFetchedPlan, setIsFetchedPlan] = useState(false)
const [enableBilling, setEnableBilling] = useState(true) const [enableBilling, setEnableBilling] = useState(true)
...@@ -118,7 +95,6 @@ export const ProviderContextProvider = ({ ...@@ -118,7 +95,6 @@ export const ProviderContextProvider = ({
<ProviderContext.Provider value={{ <ProviderContext.Provider value={{
modelProviders: providersData?.data || [], modelProviders: providersData?.data || [],
textGenerationModelList: textGenerationModelList?.data || [], textGenerationModelList: textGenerationModelList?.data || [],
agentThoughtModelList,
hasSettedApiKey: !!textGenerationModelList?.data.some(model => model.status === ModelStatusEnum.active), hasSettedApiKey: !!textGenerationModelList?.data.some(model => model.status === ModelStatusEnum.active),
supportRetrievalMethods: supportRetrievalMethods?.retrieval_method || [], supportRetrievalMethods: supportRetrievalMethods?.retrieval_method || [],
plan, plan,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment