Commit ef861e07 authored by takatost's avatar takatost

Merge branch 'feat/workflow-backend' into deploy/dev

# Conflicts:
#	api/controllers/console/app/app.py
#	api/controllers/console/app/workflow.py
#	api/core/agent/base_agent_runner.py
#	api/core/app/app_config/easy_ui_based_app/dataset/manager.py
#	api/core/app/apps/advanced_chat/app_generator.py
#	api/core/app/apps/advanced_chat/app_runner.py
#	api/core/app/apps/advanced_chat/generate_task_pipeline.py
#	api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py
#	api/core/app/apps/agent_chat/app_config_manager.py
#	api/core/app/apps/agent_chat/app_generator.py
#	api/core/app/apps/base_app_queue_manager.py
#	api/core/app/apps/chat/app_generator.py
#	api/core/app/apps/completion/app_config_manager.py
#	api/core/app/apps/completion/app_generator.py
#	api/core/app/apps/completion/app_runner.py
#	api/core/app/apps/message_based_app_generator.py
#	api/core/app/apps/message_based_app_queue_manager.py
#	api/core/app/apps/workflow/app_generator.py
#	api/core/app/apps/workflow/app_queue_manager.py
#	api/core/app/apps/workflow/app_runner.py
#	api/core/app/apps/workflow/generate_task_pipeline.py
#	api/core/app/apps/workflow/workflow_event_trigger_callback.py
#	api/core/app/entities/queue_entities.py
#	api/core/workflow/callbacks/base_workflow_callback.py
#	api/core/workflow/entities/variable_pool.py
#	api/core/workflow/entities/workflow_entities.py
#	api/core/workflow/nodes/base_node.py
#	api/core/workflow/nodes/code/code_node.py
#	api/core/workflow/nodes/direct_answer/direct_answer_node.py
#	api/core/workflow/nodes/end/end_node.py
#	api/core/workflow/nodes/http_request/http_request_node.py
#	api/core/workflow/nodes/llm/llm_node.py
#	api/core/workflow/nodes/start/start_node.py
#	api/core/workflow/nodes/template_transform/template_transform_node.py
#	api/core/workflow/nodes/tool/tool_node.py
#	api/core/workflow/workflow_engine_manager.py
#	api/models/model.py
#	api/models/workflow.py
#	api/services/completion_service.py
#	api/services/workflow_service.py
parents f68b6c1f bbc76cb8
name: Run Pytest
on:
pull_request:
branches:
- main
- deploy/dev
jobs:
test:
runs-on: ubuntu-latest
env:
MOCK_SWITCH: true
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.10'
cache: 'pip'
cache-dependency-path: ./api/requirements.txt
- name: Install dependencies
run: pip install -r ./api/requirements.txt
- name: Run pytest
run: pytest api/tests/integration_tests/workflow
\ No newline at end of file
......@@ -132,3 +132,7 @@ SSRF_PROXY_HTTP_URL=
SSRF_PROXY_HTTPS_URL=
BATCH_UPLOAD_LIMIT=10
# CODE EXECUTION CONFIGURATION
CODE_EXECUTION_ENDPOINT=
CODE_EXECUTINO_API_KEY=
......@@ -27,6 +27,7 @@ DEFAULTS = {
'CHECK_UPDATE_URL': 'https://updates.dify.ai',
'DEPLOY_ENV': 'PRODUCTION',
'SQLALCHEMY_POOL_SIZE': 30,
'SQLALCHEMY_MAX_OVERFLOW': 10,
'SQLALCHEMY_POOL_RECYCLE': 3600,
'SQLALCHEMY_ECHO': 'False',
'SENTRY_TRACES_SAMPLE_RATE': 1.0,
......@@ -59,7 +60,9 @@ DEFAULTS = {
'CAN_REPLACE_LOGO': 'False',
'ETL_TYPE': 'dify',
'KEYWORD_STORE': 'jieba',
'BATCH_UPLOAD_LIMIT': 20
'BATCH_UPLOAD_LIMIT': 20,
'CODE_EXECUTION_ENDPOINT': '',
'CODE_EXECUTION_API_KEY': ''
}
......@@ -146,6 +149,7 @@ class Config:
self.SQLALCHEMY_DATABASE_URI = f"postgresql://{db_credentials['DB_USERNAME']}:{db_credentials['DB_PASSWORD']}@{db_credentials['DB_HOST']}:{db_credentials['DB_PORT']}/{db_credentials['DB_DATABASE']}{db_extras}"
self.SQLALCHEMY_ENGINE_OPTIONS = {
'pool_size': int(get_env('SQLALCHEMY_POOL_SIZE')),
'max_overflow': int(get_env('SQLALCHEMY_MAX_OVERFLOW')),
'pool_recycle': int(get_env('SQLALCHEMY_POOL_RECYCLE'))
}
......@@ -293,6 +297,9 @@ class Config:
self.BATCH_UPLOAD_LIMIT = get_env('BATCH_UPLOAD_LIMIT')
self.CODE_EXECUTION_ENDPOINT = get_env('CODE_EXECUTION_ENDPOINT')
self.CODE_EXECUTION_API_KEY = get_env('CODE_EXECUTION_API_KEY')
self.API_COMPRESSION_ENABLED = get_bool_env('API_COMPRESSION_ENABLED')
......
import json
from flask_login import current_user
from flask_restful import Resource, inputs, marshal_with, reqparse
from werkzeug.exceptions import Forbidden, BadRequest
......@@ -6,6 +8,8 @@ from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from core.agent.entities import AgentToolEntity
from extensions.ext_database import db
from fields.app_fields import (
app_detail_fields,
app_detail_fields_with_site,
......@@ -13,6 +17,9 @@ from fields.app_fields import (
)
from libs.login import login_required
from services.app_service import AppService
from models.model import App, AppModelConfig, AppMode
from core.tools.utils.configuration import ToolParameterConfigurationManager
from core.tools.tool_manager import ToolManager
ALLOW_CREATE_APP_MODES = ['chat', 'agent-chat', 'advanced-chat', 'workflow']
......@@ -102,6 +109,40 @@ class AppApi(Resource):
@marshal_with(app_detail_fields_with_site)
def get(self, app_model):
"""Get app detail"""
# get original app model config
if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent:
model_config: AppModelConfig = app_model.app_model_config
agent_mode = model_config.agent_mode_dict
# decrypt agent tool parameters if it's secret-input
for tool in agent_mode.get('tools') or []:
agent_tool_entity = AgentToolEntity(**tool)
# get tool
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
agent_tool=agent_tool_entity,
agent_callback=None
)
manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id,
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
)
# get decrypted parameters
if agent_tool_entity.tool_parameters:
parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
masked_parameter = manager.mask_tool_parameters(parameters or {})
else:
masked_parameter = {}
# override tool parameters
tool['tool_parameters'] = masked_parameter
# override agent mode
model_config.agent_mode = json.dumps(agent_mode)
db.session.commit()
return app_model
@setup_required
......
import json
from flask import request
from flask_login import current_user
......@@ -7,6 +8,9 @@ from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.agent.entities import AgentToolEntity
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolParameterConfigurationManager
from events.app_event import app_model_config_was_updated
from extensions.ext_database import db
from libs.login import login_required
......@@ -34,6 +38,83 @@ class ModelConfigResource(Resource):
)
new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration)
if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent:
# get original app model config
original_app_model_config: AppModelConfig = db.session.query(AppModelConfig).filter(
AppModelConfig.id == app_model.app_model_config_id
).first()
agent_mode = original_app_model_config.agent_mode_dict
# decrypt agent tool parameters if it's secret-input
parameter_map = {}
masked_parameter_map = {}
tool_map = {}
for tool in agent_mode.get('tools') or []:
agent_tool_entity = AgentToolEntity(**tool)
# get tool
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
agent_tool=agent_tool_entity,
agent_callback=None
)
manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id,
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
)
# get decrypted parameters
if agent_tool_entity.tool_parameters:
parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
masked_parameter = manager.mask_tool_parameters(parameters or {})
else:
parameters = {}
masked_parameter = {}
key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}'
masked_parameter_map[key] = masked_parameter
parameter_map[key] = parameters
tool_map[key] = tool_runtime
# encrypt agent tool parameters if it's secret-input
agent_mode = new_app_model_config.agent_mode_dict
for tool in agent_mode.get('tools') or []:
agent_tool_entity = AgentToolEntity(**tool)
# get tool
key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}'
if key in tool_map:
tool_runtime = tool_map[key]
else:
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
agent_tool=agent_tool_entity,
agent_callback=None
)
manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id,
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
)
manager.delete_tool_parameters_cache()
# override parameters if it equals to masked parameters
if agent_tool_entity.tool_parameters:
if key not in masked_parameter_map:
continue
if agent_tool_entity.tool_parameters == masked_parameter_map[key]:
agent_tool_entity.tool_parameters = parameter_map[key]
# encrypt parameters
if agent_tool_entity.tool_parameters:
tool['tool_parameters'] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
# update app model config
new_app_model_config.agent_mode = json.dumps(agent_mode)
db.session.add(new_app_model_config)
db.session.flush()
......
......@@ -147,9 +147,12 @@ class WorkflowTaskStopApi(Resource):
"""
Stop workflow task
"""
# TODO
workflow_service = WorkflowService()
workflow_service.stop_workflow_task(app_model=app_model, task_id=task_id, account=current_user)
workflow_service.stop_workflow_task(
task_id=task_id,
user=current_user,
invoke_from=InvokeFrom.DEBUGGER
)
return {
"result": "success"
......
......@@ -2,7 +2,6 @@ import json
import logging
import uuid
from datetime import datetime
from mimetypes import guess_extension
from typing import Optional, Union, cast
from core.agent.entities import AgentEntity, AgentToolEntity
......@@ -39,7 +38,6 @@ from core.tools.entities.tool_entities import (
)
from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool
from core.tools.tool.tool import Tool
from core.tools.tool_file_manager import ToolFileManager
from core.tools.tool_manager import ToolManager
from extensions.ext_database import db
from models.model import Message, MessageAgentThought, MessageFile
......@@ -113,6 +111,7 @@ class BaseAgentRunner(AppRunner):
self.agent_thought_count = db.session.query(MessageAgentThought).filter(
MessageAgentThought.message_id == self.message.id,
).count()
db.session.close()
# check if model supports stream tool call
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
......@@ -154,9 +153,9 @@ class BaseAgentRunner(AppRunner):
"""
convert tool to prompt message tool
"""
tool_entity = ToolManager.get_tool_runtime(
provider_type=tool.provider_type, provider_name=tool.provider_id, tool_name=tool.tool_name,
tenant_id=self.app_config.tenant_id,
tool_entity = ToolManager.get_agent_tool_runtime(
tenant_id=self.tenant_id,
agent_tool=tool,
agent_callback=self.agent_callback
)
tool_entity.load_variables(self.variables_pool)
......@@ -171,33 +170,11 @@ class BaseAgentRunner(AppRunner):
}
)
runtime_parameters = {}
parameters = tool_entity.parameters or []
user_parameters = tool_entity.get_runtime_parameters() or []
# override parameters
for parameter in user_parameters:
# check if parameter in tool parameters
found = False
for tool_parameter in parameters:
if tool_parameter.name == parameter.name:
found = True
break
if found:
# override parameter
tool_parameter.type = parameter.type
tool_parameter.form = parameter.form
tool_parameter.required = parameter.required
tool_parameter.default = parameter.default
tool_parameter.options = parameter.options
tool_parameter.llm_description = parameter.llm_description
else:
# add new parameter
parameters.append(parameter)
parameters = tool_entity.get_all_runtime_parameters()
for parameter in parameters:
if parameter.form != ToolParameter.ToolParameterForm.LLM:
continue
parameter_type = 'string'
enum = []
if parameter.type == ToolParameter.ToolParameterType.STRING:
......@@ -213,59 +190,16 @@ class BaseAgentRunner(AppRunner):
else:
raise ValueError(f"parameter type {parameter.type} is not supported")
if parameter.form == ToolParameter.ToolParameterForm.FORM:
# get tool parameter from form
tool_parameter_config = tool.tool_parameters.get(parameter.name)
if not tool_parameter_config:
# get default value
tool_parameter_config = parameter.default
if not tool_parameter_config and parameter.required:
raise ValueError(f"tool parameter {parameter.name} not found in tool config")
if parameter.type == ToolParameter.ToolParameterType.SELECT:
# check if tool_parameter_config in options
options = list(map(lambda x: x.value, parameter.options))
if tool_parameter_config not in options:
raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} not in options {options}")
# convert tool parameter config to correct type
try:
if parameter.type == ToolParameter.ToolParameterType.NUMBER:
# check if tool parameter is integer
if isinstance(tool_parameter_config, int):
tool_parameter_config = tool_parameter_config
elif isinstance(tool_parameter_config, float):
tool_parameter_config = tool_parameter_config
elif isinstance(tool_parameter_config, str):
if '.' in tool_parameter_config:
tool_parameter_config = float(tool_parameter_config)
else:
tool_parameter_config = int(tool_parameter_config)
elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
tool_parameter_config = bool(tool_parameter_config)
elif parameter.type not in [ToolParameter.ToolParameterType.SELECT, ToolParameter.ToolParameterType.STRING]:
tool_parameter_config = str(tool_parameter_config)
elif parameter.type == ToolParameter.ToolParameterType:
tool_parameter_config = str(tool_parameter_config)
except Exception as e:
raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} is not correct type")
# save tool parameter to tool entity memory
runtime_parameters[parameter.name] = tool_parameter_config
elif parameter.form == ToolParameter.ToolParameterForm.LLM:
message_tool.parameters['properties'][parameter.name] = {
"type": parameter_type,
"description": parameter.llm_description or '',
}
if len(enum) > 0:
message_tool.parameters['properties'][parameter.name]['enum'] = enum
message_tool.parameters['properties'][parameter.name] = {
"type": parameter_type,
"description": parameter.llm_description or '',
}
if parameter.required:
message_tool.parameters['required'].append(parameter.name)
if len(enum) > 0:
message_tool.parameters['properties'][parameter.name]['enum'] = enum
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
if parameter.required:
message_tool.parameters['required'].append(parameter.name)
return message_tool, tool_entity
......@@ -305,6 +239,9 @@ class BaseAgentRunner(AppRunner):
tool_runtime_parameters = tool.get_runtime_parameters() or []
for parameter in tool_runtime_parameters:
if parameter.form != ToolParameter.ToolParameterForm.LLM:
continue
parameter_type = 'string'
enum = []
if parameter.type == ToolParameter.ToolParameterType.STRING:
......@@ -320,18 +257,17 @@ class BaseAgentRunner(AppRunner):
else:
raise ValueError(f"parameter type {parameter.type} is not supported")
if parameter.form == ToolParameter.ToolParameterForm.LLM:
prompt_tool.parameters['properties'][parameter.name] = {
"type": parameter_type,
"description": parameter.llm_description or '',
}
prompt_tool.parameters['properties'][parameter.name] = {
"type": parameter_type,
"description": parameter.llm_description or '',
}
if len(enum) > 0:
prompt_tool.parameters['properties'][parameter.name]['enum'] = enum
if len(enum) > 0:
prompt_tool.parameters['properties'][parameter.name]['enum'] = enum
if parameter.required:
if parameter.name not in prompt_tool.parameters['required']:
prompt_tool.parameters['required'].append(parameter.name)
if parameter.required:
if parameter.name not in prompt_tool.parameters['required']:
prompt_tool.parameters['required'].append(parameter.name)
return prompt_tool
......@@ -404,12 +340,15 @@ class BaseAgentRunner(AppRunner):
created_by=self.user_id,
)
db.session.add(message_file)
db.session.commit()
db.session.refresh(message_file)
result.append((
message_file,
message.save_as
))
db.session.commit()
db.session.close()
return result
......@@ -447,6 +386,8 @@ class BaseAgentRunner(AppRunner):
db.session.add(thought)
db.session.commit()
db.session.refresh(thought)
db.session.close()
self.agent_thought_count += 1
......@@ -464,6 +405,10 @@ class BaseAgentRunner(AppRunner):
"""
Save agent thought
"""
agent_thought = db.session.query(MessageAgentThought).filter(
MessageAgentThought.id == agent_thought.id
).first()
if thought is not None:
agent_thought.thought = thought
......@@ -514,81 +459,20 @@ class BaseAgentRunner(AppRunner):
agent_thought.tool_labels_str = json.dumps(labels)
db.session.commit()
def transform_tool_invoke_messages(self, messages: list[ToolInvokeMessage]) -> list[ToolInvokeMessage]:
"""
Transform tool message into agent thought
"""
result = []
for message in messages:
if message.type == ToolInvokeMessage.MessageType.TEXT:
result.append(message)
elif message.type == ToolInvokeMessage.MessageType.LINK:
result.append(message)
elif message.type == ToolInvokeMessage.MessageType.IMAGE:
# try to download image
try:
file = ToolFileManager.create_file_by_url(user_id=self.user_id, tenant_id=self.tenant_id,
conversation_id=self.message.conversation_id,
file_url=message.message)
url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}'
result.append(ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
message=url,
save_as=message.save_as,
meta=message.meta.copy() if message.meta is not None else {},
))
except Exception as e:
logger.exception(e)
result.append(ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.TEXT,
message=f"Failed to download image: {message.message}, you can try to download it yourself.",
meta=message.meta.copy() if message.meta is not None else {},
save_as=message.save_as,
))
elif message.type == ToolInvokeMessage.MessageType.BLOB:
# get mime type and save blob to storage
mimetype = message.meta.get('mime_type', 'octet/stream')
# if message is str, encode it to bytes
if isinstance(message.message, str):
message.message = message.message.encode('utf-8')
file = ToolFileManager.create_file_by_raw(user_id=self.user_id, tenant_id=self.tenant_id,
conversation_id=self.message.conversation_id,
file_binary=message.message,
mimetype=mimetype)
url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".bin"}'
# check if file is image
if 'image' in mimetype:
result.append(ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
message=url,
save_as=message.save_as,
meta=message.meta.copy() if message.meta is not None else {},
))
else:
result.append(ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.LINK,
message=url,
save_as=message.save_as,
meta=message.meta.copy() if message.meta is not None else {},
))
else:
result.append(message)
return result
db.session.close()
def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables):
"""
convert tool variables to db variables
"""
db_variables = db.session.query(ToolConversationVariables).filter(
ToolConversationVariables.conversation_id == self.message.conversation_id,
).first()
db_variables.updated_at = datetime.utcnow()
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
db.session.commit()
db.session.close()
def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
......@@ -644,4 +528,6 @@ class BaseAgentRunner(AppRunner):
if message.answer:
result.append(AssistantPromptMessage(content=message.answer))
db.session.close()
return result
\ No newline at end of file
......@@ -25,6 +25,7 @@ from core.tools.errors import (
ToolProviderCredentialValidationError,
ToolProviderNotFoundError,
)
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from models.model import Conversation, Message
......@@ -280,7 +281,12 @@ class CotAgentRunner(BaseAgentRunner):
tool_parameters=tool_call_args
)
# transform tool response to llm friendly response
tool_response = self.transform_tool_invoke_messages(tool_response)
tool_response = ToolFileMessageTransformer.transform_tool_invoke_messages(
messages=tool_response,
user_id=self.user_id,
tenant_id=self.tenant_id,
conversation_id=self.message.conversation_id
)
# extract binary data from tool invoke message
binary_files = self.extract_tool_response_binary(tool_response)
# create message file
......
......@@ -23,6 +23,7 @@ from core.tools.errors import (
ToolProviderCredentialValidationError,
ToolProviderNotFoundError,
)
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from models.model import Conversation, Message, MessageAgentThought
logger = logging.getLogger(__name__)
......@@ -270,7 +271,12 @@ class FunctionCallAgentRunner(BaseAgentRunner):
tool_parameters=tool_call_args,
)
# transform tool invoke message to get LLM friendly message
tool_invoke_message = self.transform_tool_invoke_messages(tool_invoke_message)
tool_invoke_message = ToolFileMessageTransformer.transform_tool_invoke_messages(
messages=tool_invoke_message,
user_id=self.user_id,
tenant_id=self.tenant_id,
conversation_id=self.message.conversation_id
)
# extract binary data from tool invoke message
binary_files = self.extract_tool_response_binary(tool_invoke_message)
# create message file
......
......@@ -123,7 +123,8 @@ class DatasetConfigManager:
if not isinstance(config["dataset_configs"], dict):
raise ValueError("dataset_configs must be of object type")
need_manual_query_datasets = config.get("dataset_configs") and config["dataset_configs"].get("datasets")
need_manual_query_datasets = (config.get("dataset_configs")
and config["dataset_configs"].get("datasets", {}).get("datasets"))
if need_manual_query_datasets and app_mode == AppMode.COMPLETION:
# Only check when mode is completion
......
## Guidelines for Database Connection Management in App Runner and Task Pipeline
Due to the presence of tasks in App Runner that require long execution times, such as LLM generation and external requests, Flask-Sqlalchemy's strategy for database connection pooling is to allocate one connection (transaction) per request. This approach keeps a connection occupied even during non-DB tasks, leading to the inability to acquire new connections during high concurrency requests due to multiple long-running tasks.
Therefore, the database operations in App Runner and Task Pipeline must ensure connections are closed immediately after use, and it's better to pass IDs rather than Model objects to avoid deattach errors.
Examples:
1. Creating a new record:
```python
app = App(id=1)
db.session.add(app)
db.session.commit()
db.session.refresh(app) # Retrieve table default values, like created_at, cached in the app object, won't affect after close
# Handle non-long-running tasks or store the content of the App instance in memory (via variable assignment).
db.session.close()
return app.id
```
2. Fetching a record from the table:
```python
app = db.session.query(App).filter(App.id == app_id).first()
created_at = app.created_at
db.session.close()
# Handle tasks (include long-running).
```
3. Updating a table field:
```python
app = db.session.query(App).filter(App.id == app_id).first()
app.updated_at = time.utcnow()
db.session.commit()
db.session.close()
return app_id
```
......@@ -11,7 +11,7 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigMan
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline
from core.app.apps.base_app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
......@@ -123,11 +123,13 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
worker_thread.start()
# return response or stream generator
return self._handle_response(
return self._handle_advanced_chat_response(
application_generate_entity=application_generate_entity,
workflow=workflow,
queue_manager=queue_manager,
conversation=conversation,
message=message,
user=user,
stream=stream
)
......@@ -159,7 +161,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation=conversation,
message=message
)
except ConversationTaskStoppedException:
except GenerateTaskStoppedException:
pass
except InvokeAuthorizationError:
queue_manager.publish_error(
......@@ -175,37 +177,42 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
logger.exception("Unknown Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
finally:
db.session.remove()
def _handle_response(self, application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
stream: bool = False) -> Union[dict, Generator]:
db.session.close()
def _handle_advanced_chat_response(self, application_generate_entity: AdvancedChatAppGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
stream: bool = False) -> Union[dict, Generator]:
"""
Handle response.
:param application_generate_entity: application generate entity
:param workflow: workflow
:param queue_manager: queue manager
:param conversation: conversation
:param message: message
:param user: account or end user
:param stream: is stream
:return:
"""
# init generate task pipeline
generate_task_pipeline = AdvancedChatAppGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
workflow=workflow,
queue_manager=queue_manager,
conversation=conversation,
message=message
message=message,
user=user,
stream=stream
)
try:
return generate_task_pipeline.process(stream=stream)
return generate_task_pipeline.process()
except ValueError as e:
if e.args[0] == "I/O operation on closed file.": # ignore this error
raise ConversationTaskStoppedException()
raise GenerateTaskStoppedException()
else:
logger.exception(e)
raise e
finally:
db.session.remove()
import logging
import time
from typing import cast
from typing import Optional, cast
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback
......@@ -13,11 +13,11 @@ from core.app.entities.app_invoke_entities import (
from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent
from core.moderation.base import ModerationException
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.nodes.base_node import UserFrom
from core.workflow.workflow_engine_manager import WorkflowEngineManager
from extensions.ext_database import db
from models.account import Account
from models.model import App, Conversation, EndUser, Message
from models.workflow import WorkflowRunTriggeredFrom
from models.model import App, Conversation, Message
from models.workflow import Workflow
logger = logging.getLogger(__name__)
......@@ -46,7 +46,7 @@ class AdvancedChatAppRunner(AppRunner):
if not app_record:
raise ValueError("App not found")
workflow = WorkflowEngineManager().get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
if not workflow:
raise ValueError("Workflow not initialized")
......@@ -74,19 +74,16 @@ class AdvancedChatAppRunner(AppRunner):
):
return
# fetch user
if application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE]:
user = db.session.query(Account).filter(Account.id == application_generate_entity.user_id).first()
else:
user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
db.session.close()
# RUN WORKFLOW
workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.run_workflow(
workflow=workflow,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING
if application_generate_entity.invoke_from == InvokeFrom.DEBUGGER else WorkflowRunTriggeredFrom.APP_RUN,
user=user,
user_id=application_generate_entity.user_id,
user_from=UserFrom.ACCOUNT
if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
else UserFrom.END_USER,
user_inputs=inputs,
system_inputs={
SystemVariable.QUERY: query,
......@@ -99,6 +96,20 @@ class AdvancedChatAppRunner(AppRunner):
)]
)
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
"""
Get workflow
"""
# fetch workflow by workflow_id
workflow = db.session.query(Workflow).filter(
Workflow.tenant_id == app_model.tenant_id,
Workflow.app_id == app_model.id,
Workflow.id == workflow_id
).first()
# return workflow
return workflow
def handle_input_moderation(self, queue_manager: AppQueueManager,
app_record: App,
app_generate_entity: AdvancedChatAppGenerateEntity,
......
......@@ -4,9 +4,10 @@ import time
from collections.abc import Generator
from typing import Optional, Union
from pydantic import BaseModel
from pydantic import BaseModel, Extra
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.workflow_based_generate_task_pipeline import WorkflowBasedGenerateTaskPipeline
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
InvokeFrom,
......@@ -16,25 +17,35 @@ from core.app.entities.queue_entities import (
QueueErrorEvent,
QueueMessageFileEvent,
QueueMessageReplaceEvent,
QueueNodeFinishedEvent,
QueueNodeFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueuePingEvent,
QueueRetrieverResourcesEvent,
QueueStopEvent,
QueueTextChunkEvent,
QueueWorkflowFinishedEvent,
QueueWorkflowFailedEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.moderation.output_moderation import ModerationRule, OutputModeration
from core.tools.tool_file_manager import ToolFileManager
from core.workflow.entities.node_entities import NodeType
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType, SystemVariable
from events.message_event import message_was_created
from extensions.ext_database import db
from models.model import Conversation, Message, MessageFile
from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionStatus, WorkflowRun, WorkflowRunStatus
from models.account import Account
from models.model import Conversation, EndUser, Message, MessageFile
from models.workflow import (
Workflow,
WorkflowNodeExecution,
WorkflowNodeExecutionStatus,
WorkflowRun,
WorkflowRunStatus,
WorkflowRunTriggeredFrom,
)
from services.annotation_service import AppAnnotationService
logger = logging.getLogger(__name__)
......@@ -44,44 +55,83 @@ class TaskState(BaseModel):
"""
TaskState entity
"""
class NodeExecutionInfo(BaseModel):
"""
NodeExecutionInfo entity
"""
workflow_node_execution_id: str
start_at: float
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
answer: str = ""
metadata: dict = {}
usage: LLMUsage
workflow_run_id: Optional[str] = None
start_at: Optional[float] = None
total_tokens: int = 0
total_steps: int = 0
running_node_execution_infos: dict[str, NodeExecutionInfo] = {}
latest_node_execution_info: Optional[NodeExecutionInfo] = None
class Config:
"""Configuration for this pydantic object."""
class AdvancedChatAppGenerateTaskPipeline:
extra = Extra.forbid
arbitrary_types_allowed = True
class AdvancedChatAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
"""
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""
def __init__(self, application_generate_entity: AdvancedChatAppGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message) -> None:
message: Message,
user: Union[Account, EndUser],
stream: bool) -> None:
"""
Initialize GenerateTaskPipeline.
:param application_generate_entity: application generate entity
:param workflow: workflow
:param queue_manager: queue manager
:param conversation: conversation
:param message: message
:param user: user
:param stream: stream
"""
self._application_generate_entity = application_generate_entity
self._workflow = workflow
self._queue_manager = queue_manager
self._conversation = conversation
self._message = message
self._user = user
self._task_state = TaskState(
usage=LLMUsage.empty_usage()
)
self._start_at = time.perf_counter()
self._output_moderation_handler = self._init_output_moderation()
self._stream = stream
def process(self, stream: bool) -> Union[dict, Generator]:
def process(self) -> Union[dict, Generator]:
"""
Process generate task pipeline.
:return:
"""
if stream:
db.session.refresh(self._workflow)
db.session.refresh(self._user)
db.session.close()
if self._stream:
return self._process_stream_response()
else:
return self._process_blocking_response()
......@@ -112,22 +162,16 @@ class AdvancedChatAppGenerateTaskPipeline:
self._task_state.answer = annotation.content
elif isinstance(event, QueueWorkflowStartedEvent):
self._task_state.workflow_run_id = event.workflow_run_id
elif isinstance(event, QueueNodeFinishedEvent):
workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id)
if workflow_node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED.value:
if workflow_node_execution.node_type == NodeType.LLM.value:
outputs = workflow_node_execution.outputs_dict
usage_dict = outputs.get('usage', {})
self._task_state.metadata['usage'] = usage_dict
elif isinstance(event, QueueStopEvent | QueueWorkflowFinishedEvent):
if isinstance(event, QueueWorkflowFinishedEvent):
workflow_run = self._get_workflow_run(event.workflow_run_id)
if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value:
outputs = workflow_run.outputs
self._task_state.answer = outputs.get('text', '')
else:
raise self._handle_error(QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')))
self._on_workflow_start()
elif isinstance(event, QueueNodeStartedEvent):
self._on_node_start(event)
elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent):
self._on_node_finished(event)
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
workflow_run = self._on_workflow_finished(event)
if workflow_run.status != WorkflowRunStatus.SUCCEEDED.value:
raise self._handle_error(QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')))
# response moderation
if self._output_moderation_handler:
......@@ -173,8 +217,8 @@ class AdvancedChatAppGenerateTaskPipeline:
yield self._yield_response(data)
break
elif isinstance(event, QueueWorkflowStartedEvent):
workflow_run = self._get_workflow_run(event.workflow_run_id)
self._task_state.workflow_run_id = workflow_run.id
workflow_run = self._on_workflow_start()
response = {
'event': 'workflow_started',
'task_id': self._application_generate_entity.task_id,
......@@ -188,7 +232,8 @@ class AdvancedChatAppGenerateTaskPipeline:
yield self._yield_response(response)
elif isinstance(event, QueueNodeStartedEvent):
workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id)
workflow_node_execution = self._on_node_start(event)
response = {
'event': 'node_started',
'task_id': self._application_generate_entity.task_id,
......@@ -204,8 +249,9 @@ class AdvancedChatAppGenerateTaskPipeline:
}
yield self._yield_response(response)
elif isinstance(event, QueueNodeFinishedEvent):
workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id)
elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent):
workflow_node_execution = self._on_node_finished(event)
if workflow_node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED.value:
if workflow_node_execution.node_type == NodeType.LLM.value:
outputs = workflow_node_execution.outputs_dict
......@@ -234,16 +280,10 @@ class AdvancedChatAppGenerateTaskPipeline:
}
yield self._yield_response(response)
elif isinstance(event, QueueStopEvent | QueueWorkflowFinishedEvent):
if isinstance(event, QueueStopEvent):
workflow_run = self._get_workflow_run(self._task_state.workflow_run_id)
else:
workflow_run = self._get_workflow_run(event.workflow_run_id)
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
workflow_run = self._on_workflow_finished(event)
if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value:
outputs = workflow_run.outputs_dict
self._task_state.answer = outputs.get('text', '')
else:
if workflow_run.status != WorkflowRunStatus.SUCCEEDED.value:
err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))
data = self._error_to_stream_response_data(self._handle_error(err_event))
yield self._yield_response(data)
......@@ -252,7 +292,7 @@ class AdvancedChatAppGenerateTaskPipeline:
workflow_run_response = {
'event': 'workflow_finished',
'task_id': self._application_generate_entity.task_id,
'workflow_run_id': event.workflow_run_id,
'workflow_run_id': workflow_run.id,
'data': {
'id': workflow_run.id,
'workflow_id': workflow_run.workflow_id,
......@@ -390,35 +430,137 @@ class AdvancedChatAppGenerateTaskPipeline:
else:
continue
def _get_workflow_run(self, workflow_run_id: str) -> WorkflowRun:
"""
Get workflow run.
:param workflow_run_id: workflow run id
:return:
"""
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first()
if workflow_run:
# Because the workflow_run will be modified in the sub-thread,
# and the first query in the main thread will cache the entity,
# you need to expire the entity after the query
db.session.expire(workflow_run)
def _on_workflow_start(self) -> WorkflowRun:
self._task_state.start_at = time.perf_counter()
workflow_run = self._init_workflow_run(
workflow=self._workflow,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING
if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER
else WorkflowRunTriggeredFrom.APP_RUN,
user=self._user,
user_inputs=self._application_generate_entity.inputs,
system_inputs={
SystemVariable.QUERY: self._message.query,
SystemVariable.FILES: self._application_generate_entity.files,
SystemVariable.CONVERSATION: self._conversation.id,
}
)
self._task_state.workflow_run_id = workflow_run.id
db.session.close()
return workflow_run
def _get_workflow_node_execution(self, workflow_node_execution_id: str) -> WorkflowNodeExecution:
"""
Get workflow node execution.
:param workflow_node_execution_id: workflow node execution id
:return:
"""
workflow_node_execution = (db.session.query(WorkflowNodeExecution)
.filter(WorkflowNodeExecution.id == workflow_node_execution_id).first())
if workflow_node_execution:
# Because the workflow_node_execution will be modified in the sub-thread,
# and the first query in the main thread will cache the entity,
# you need to expire the entity after the query
db.session.expire(workflow_node_execution)
def _on_node_start(self, event: QueueNodeStartedEvent) -> WorkflowNodeExecution:
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first()
workflow_node_execution = self._init_node_execution_from_workflow_run(
workflow_run=workflow_run,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_data.title,
node_run_index=event.node_run_index,
predecessor_node_id=event.predecessor_node_id
)
latest_node_execution_info = TaskState.NodeExecutionInfo(
workflow_node_execution_id=workflow_node_execution.id,
start_at=time.perf_counter()
)
self._task_state.running_node_execution_infos[event.node_id] = latest_node_execution_info
self._task_state.latest_node_execution_info = latest_node_execution_info
self._task_state.total_steps += 1
db.session.close()
return workflow_node_execution
def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> WorkflowNodeExecution:
current_node_execution = self._task_state.running_node_execution_infos[event.node_id]
workflow_node_execution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == current_node_execution.workflow_node_execution_id).first()
if isinstance(event, QueueNodeSucceededEvent):
workflow_node_execution = self._workflow_node_execution_success(
workflow_node_execution=workflow_node_execution,
start_at=current_node_execution.start_at,
inputs=event.inputs,
process_data=event.process_data,
outputs=event.outputs,
execution_metadata=event.execution_metadata
)
if event.execution_metadata and event.execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
self._task_state.total_tokens += (
int(event.execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)))
if workflow_node_execution.node_type == NodeType.LLM.value:
outputs = workflow_node_execution.outputs_dict
usage_dict = outputs.get('usage', {})
self._task_state.metadata['usage'] = usage_dict
else:
workflow_node_execution = self._workflow_node_execution_failed(
workflow_node_execution=workflow_node_execution,
start_at=current_node_execution.start_at,
error=event.error
)
# remove running node execution info
del self._task_state.running_node_execution_infos[event.node_id]
db.session.close()
return workflow_node_execution
def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) \
-> WorkflowRun:
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first()
if isinstance(event, QueueStopEvent):
workflow_run = self._workflow_run_failed(
workflow_run=workflow_run,
start_at=self._task_state.start_at,
total_tokens=self._task_state.total_tokens,
total_steps=self._task_state.total_steps,
status=WorkflowRunStatus.STOPPED,
error='Workflow stopped.'
)
elif isinstance(event, QueueWorkflowFailedEvent):
workflow_run = self._workflow_run_failed(
workflow_run=workflow_run,
start_at=self._task_state.start_at,
total_tokens=self._task_state.total_tokens,
total_steps=self._task_state.total_steps,
status=WorkflowRunStatus.FAILED,
error=event.error
)
else:
if self._task_state.latest_node_execution_info:
workflow_node_execution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == self._task_state.latest_node_execution_info.workflow_node_execution_id).first()
outputs = workflow_node_execution.outputs
else:
outputs = None
workflow_run = self._workflow_run_success(
workflow_run=workflow_run,
start_at=self._task_state.start_at,
total_tokens=self._task_state.total_tokens,
total_steps=self._task_state.total_steps,
outputs=outputs
)
self._task_state.workflow_run_id = workflow_run.id
if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value:
outputs = workflow_run.outputs_dict
self._task_state.answer = outputs.get('text', '')
db.session.close()
return workflow_run
def _save_message(self) -> None:
"""
Save message.
......
from typing import Optional
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.queue_entities import (
QueueNodeFinishedEvent,
QueueNodeFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueTextChunkEvent,
QueueWorkflowFinishedEvent,
QueueWorkflowFailedEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
from models.workflow import Workflow, WorkflowNodeExecution, WorkflowRun
from models.workflow import Workflow
class WorkflowEventTriggerCallback(BaseWorkflowCallback):
......@@ -17,39 +22,91 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback):
self._queue_manager = queue_manager
self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph_dict)
def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None:
def on_workflow_run_started(self) -> None:
"""
Workflow run started
"""
self._queue_manager.publish(
QueueWorkflowStartedEvent(workflow_run_id=workflow_run.id),
QueueWorkflowStartedEvent(),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_run_succeeded(self) -> None:
"""
Workflow run succeeded
"""
self._queue_manager.publish(
QueueWorkflowSucceededEvent(),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_run_finished(self, workflow_run: WorkflowRun) -> None:
def on_workflow_run_failed(self, error: str) -> None:
"""
Workflow run finished
Workflow run failed
"""
self._queue_manager.publish(
QueueWorkflowFinishedEvent(workflow_run_id=workflow_run.id),
QueueWorkflowFailedEvent(
error=error
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_node_execute_started(self, workflow_node_execution: WorkflowNodeExecution) -> None:
def on_workflow_node_execute_started(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
node_run_index: int = 1,
predecessor_node_id: Optional[str] = None) -> None:
"""
Workflow node execute started
"""
self._queue_manager.publish(
QueueNodeStartedEvent(workflow_node_execution_id=workflow_node_execution.id),
QueueNodeStartedEvent(
node_id=node_id,
node_type=node_type,
node_data=node_data,
node_run_index=node_run_index,
predecessor_node_id=predecessor_node_id
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_node_execute_succeeded(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
inputs: Optional[dict] = None,
process_data: Optional[dict] = None,
outputs: Optional[dict] = None,
execution_metadata: Optional[dict] = None) -> None:
"""
Workflow node execute succeeded
"""
self._queue_manager.publish(
QueueNodeSucceededEvent(
node_id=node_id,
node_type=node_type,
node_data=node_data,
inputs=inputs,
process_data=process_data,
outputs=outputs,
execution_metadata=execution_metadata
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_node_execute_finished(self, workflow_node_execution: WorkflowNodeExecution) -> None:
def on_workflow_node_execute_failed(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
error: str) -> None:
"""
Workflow node execute finished
Workflow node execute failed
"""
self._queue_manager.publish(
QueueNodeFinishedEvent(workflow_node_execution_id=workflow_node_execution.id),
QueueNodeFailedEvent(
node_id=node_id,
node_type=node_type,
node_data=node_data,
error=error
),
PublishFrom.APPLICATION_MANAGER
)
......
......@@ -52,7 +52,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
else:
config_from = EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG
if override_config_dict != EasyUIBasedAppModelConfigFrom.ARGS:
if config_from != EasyUIBasedAppModelConfigFrom.ARGS:
app_model_config_dict = app_model_config.to_dict()
config_dict = app_model_config_dict.copy()
else:
......
......@@ -11,7 +11,7 @@ from core.app.app_config.easy_ui_based_app.model_config.converter import ModelCo
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
from core.app.apps.agent_chat.app_runner import AgentChatAppRunner
from core.app.apps.base_app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom
......@@ -177,7 +177,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
conversation=conversation,
message=message
)
except ConversationTaskStoppedException:
except GenerateTaskStoppedException:
pass
except InvokeAuthorizationError:
queue_manager.publish_error(
......@@ -193,4 +193,4 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
logger.exception("Unknown Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
finally:
db.session.remove()
db.session.close()
......@@ -201,6 +201,10 @@ class AgentChatAppRunner(AppRunner):
if set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL]).intersection(model_schema.features or []):
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
conversation = db.session.query(Conversation).filter(Conversation.id == conversation.id).first()
message = db.session.query(Message).filter(Message.id == message.id).first()
db.session.close()
# start agent runner
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
assistant_cot_runner = CotAgentRunner(
......
......@@ -11,11 +11,8 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueErrorEvent,
QueueMessage,
QueueMessageEndEvent,
QueuePingEvent,
QueueStopEvent,
QueueWorkflowFinishedEvent,
)
from extensions.ext_redis import redis_client
......@@ -103,22 +100,16 @@ class AppQueueManager:
:return:
"""
self._check_for_sqlalchemy_models(event.dict())
message = self.construct_queue_message(event)
self._q.put(message)
if isinstance(event, QueueStopEvent
| QueueErrorEvent
| QueueMessageEndEvent
| QueueWorkflowFinishedEvent):
self.stop_listen()
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
raise ConversationTaskStoppedException()
self._publish(event, pub_from)
@abstractmethod
def construct_queue_message(self, event: AppQueueEvent) -> QueueMessage:
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
"""
Publish event to queue
:param event:
:param pub_from:
:return:
"""
raise NotImplementedError
@classmethod
......@@ -182,5 +173,5 @@ class AppQueueManager:
"that cause thread safety issues is not allowed.")
class ConversationTaskStoppedException(Exception):
class GenerateTaskStoppedException(Exception):
pass
......@@ -9,7 +9,7 @@ from pydantic import ValidationError
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.base_app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
from core.app.apps.chat.app_runner import ChatAppRunner
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
......@@ -177,7 +177,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
conversation=conversation,
message=message
)
except ConversationTaskStoppedException:
except GenerateTaskStoppedException:
pass
except InvokeAuthorizationError:
queue_manager.publish_error(
......@@ -193,4 +193,4 @@ class ChatAppGenerator(MessageBasedAppGenerator):
logger.exception("Unknown Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
finally:
db.session.remove()
db.session.close()
......@@ -202,6 +202,8 @@ class ChatAppRunner(AppRunner):
model=application_generate_entity.model_config.model
)
db.session.close()
invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=application_generate_entity.model_config.parameters,
......
......@@ -37,7 +37,7 @@ class CompletionAppConfigManager(BaseAppConfigManager):
else:
config_from = EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG
if override_config_dict != EasyUIBasedAppModelConfigFrom.ARGS:
if config_from != EasyUIBasedAppModelConfigFrom.ARGS:
app_model_config_dict = app_model_config.to_dict()
config_dict = app_model_config_dict.copy()
else:
......
......@@ -9,7 +9,7 @@ from pydantic import ValidationError
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.base_app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
from core.app.apps.completion.app_runner import CompletionAppRunner
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
......@@ -166,7 +166,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
queue_manager=queue_manager,
message=message
)
except ConversationTaskStoppedException:
except GenerateTaskStoppedException:
pass
except InvokeAuthorizationError:
queue_manager.publish_error(
......@@ -182,7 +182,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
logger.exception("Unknown Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
finally:
db.session.remove()
db.session.close()
def generate_more_like_this(self, app_model: App,
message_id: str,
......
......@@ -160,6 +160,8 @@ class CompletionAppRunner(AppRunner):
model=application_generate_entity.model_config.model
)
db.session.close()
invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=application_generate_entity.model_config.parameters,
......
......@@ -100,6 +100,10 @@ class EasyUIBasedGenerateTaskPipeline:
Process generate task pipeline.
:return:
"""
db.session.refresh(self._conversation)
db.session.refresh(self._message)
db.session.close()
if stream:
return self._process_stream_response()
else:
......@@ -314,6 +318,7 @@ class EasyUIBasedGenerateTaskPipeline:
.first()
)
db.session.refresh(agent_thought)
db.session.close()
if agent_thought:
response = {
......@@ -341,6 +346,8 @@ class EasyUIBasedGenerateTaskPipeline:
.filter(MessageFile.id == event.message_file_id)
.first()
)
db.session.close()
# get extension
if '.' in message_file.url:
extension = f'.{message_file.url.split(".")[-1]}'
......@@ -429,6 +436,7 @@ class EasyUIBasedGenerateTaskPipeline:
usage = llm_result.usage
self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first()
self._message.message = self._prompt_messages_to_prompt_for_saving(self._task_state.llm_result.prompt_messages)
self._message.message_tokens = usage.prompt_tokens
......
......@@ -7,7 +7,7 @@ from sqlalchemy import and_
from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.base_app_queue_manager import AppQueueManager, ConversationTaskStoppedException
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException
from core.app.apps.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
......@@ -60,12 +60,10 @@ class MessageBasedAppGenerator(BaseAppGenerator):
return generate_task_pipeline.process(stream=stream)
except ValueError as e:
if e.args[0] == "I/O operation on closed file.": # ignore this error
raise ConversationTaskStoppedException()
raise GenerateTaskStoppedException()
else:
logger.exception(e)
raise e
finally:
db.session.remove()
def _get_conversation_by_user(self, app_model: App, conversation_id: str,
user: Union[Account, EndUser]) -> Conversation:
......@@ -176,6 +174,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
db.session.add(conversation)
db.session.commit()
db.session.refresh(conversation)
message = Message(
app_id=app_config.app_id,
......@@ -203,6 +202,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
db.session.add(message)
db.session.commit()
db.session.refresh(message)
for file in application_generate_entity.files:
message_file = MessageFile(
......
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
MessageQueueMessage,
QueueErrorEvent,
QueueMessage,
QueueMessageEndEvent,
QueueStopEvent,
QueueWorkflowFailedEvent,
QueueWorkflowSucceededEvent,
)
......@@ -28,3 +33,31 @@ class MessageBasedAppQueueManager(AppQueueManager):
app_mode=self._app_mode,
event=event
)
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
"""
Publish event to queue
:param event:
:param pub_from:
:return:
"""
message = MessageQueueMessage(
task_id=self._task_id,
message_id=self._message_id,
conversation_id=self._conversation_id,
app_mode=self._app_mode,
event=event
)
self._q.put(message)
if isinstance(event, QueueStopEvent
| QueueErrorEvent
| QueueMessageEndEvent
| QueueWorkflowSucceededEvent
| QueueWorkflowFailedEvent):
self.stop_listen()
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
raise GenerateTaskStoppedException()
......@@ -9,7 +9,7 @@ from pydantic import ValidationError
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.base_app_queue_manager import AppQueueManager, ConversationTaskStoppedException, PublishFrom
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager
from core.app.apps.workflow.app_runner import WorkflowAppRunner
......@@ -95,7 +95,9 @@ class WorkflowAppGenerator(BaseAppGenerator):
# return response or stream generator
return self._handle_response(
application_generate_entity=application_generate_entity,
workflow=workflow,
queue_manager=queue_manager,
user=user,
stream=stream
)
......@@ -117,7 +119,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
application_generate_entity=application_generate_entity,
queue_manager=queue_manager
)
except ConversationTaskStoppedException:
except GenerateTaskStoppedException:
pass
except InvokeAuthorizationError:
queue_manager.publish_error(
......@@ -136,19 +138,25 @@ class WorkflowAppGenerator(BaseAppGenerator):
db.session.remove()
def _handle_response(self, application_generate_entity: WorkflowAppGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
user: Union[Account, EndUser],
stream: bool = False) -> Union[dict, Generator]:
"""
Handle response.
:param application_generate_entity: application generate entity
:param workflow: workflow
:param queue_manager: queue manager
:param user: account or end user
:param stream: is stream
:return:
"""
# init generate task pipeline
generate_task_pipeline = WorkflowAppGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
workflow=workflow,
queue_manager=queue_manager,
user=user,
stream=stream
)
......@@ -156,9 +164,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
return generate_task_pipeline.process()
except ValueError as e:
if e.args[0] == "I/O operation on closed file.": # ignore this error
raise ConversationTaskStoppedException()
raise GenerateTaskStoppedException()
else:
logger.exception(e)
raise e
finally:
db.session.remove()
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueMessage,
QueueErrorEvent,
QueueMessageEndEvent,
QueueStopEvent,
QueueWorkflowFailedEvent,
QueueWorkflowSucceededEvent,
WorkflowQueueMessage,
)
......@@ -16,9 +20,27 @@ class WorkflowAppQueueManager(AppQueueManager):
self._app_mode = app_mode
def construct_queue_message(self, event: AppQueueEvent) -> QueueMessage:
return WorkflowQueueMessage(
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
"""
Publish event to queue
:param event:
:param pub_from:
:return:
"""
message = WorkflowQueueMessage(
task_id=self._task_id,
app_mode=self._app_mode,
event=event
)
self._q.put(message)
if isinstance(event, QueueStopEvent
| QueueErrorEvent
| QueueMessageEndEvent
| QueueWorkflowSucceededEvent
| QueueWorkflowFailedEvent):
self.stop_listen()
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
raise GenerateTaskStoppedException()
import logging
import time
from typing import cast
from typing import Optional, cast
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
......@@ -14,11 +14,11 @@ from core.app.entities.queue_entities import QueueStopEvent, QueueTextChunkEvent
from core.moderation.base import ModerationException
from core.moderation.input_moderation import InputModeration
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.nodes.base_node import UserFrom
from core.workflow.workflow_engine_manager import WorkflowEngineManager
from extensions.ext_database import db
from models.account import Account
from models.model import App, EndUser
from models.workflow import WorkflowRunTriggeredFrom
from models.model import App
from models.workflow import Workflow
logger = logging.getLogger(__name__)
......@@ -43,7 +43,7 @@ class WorkflowAppRunner:
if not app_record:
raise ValueError("App not found")
workflow = WorkflowEngineManager().get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
if not workflow:
raise ValueError("Workflow not initialized")
......@@ -59,19 +59,16 @@ class WorkflowAppRunner:
):
return
# fetch user
if application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE]:
user = db.session.query(Account).filter(Account.id == application_generate_entity.user_id).first()
else:
user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
db.session.close()
# RUN WORKFLOW
workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.run_workflow(
workflow=workflow,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING
if application_generate_entity.invoke_from == InvokeFrom.DEBUGGER else WorkflowRunTriggeredFrom.APP_RUN,
user=user,
user_id=application_generate_entity.user_id,
user_from=UserFrom.ACCOUNT
if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
else UserFrom.END_USER,
user_inputs=inputs,
system_inputs={
SystemVariable.FILES: files
......@@ -82,6 +79,20 @@ class WorkflowAppRunner:
)]
)
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
"""
Get workflow
"""
# fetch workflow by workflow_id
workflow = db.session.query(Workflow).filter(
Workflow.tenant_id == app_model.tenant_id,
Workflow.app_id == app_model.id,
Workflow.id == workflow_id
).first()
# return workflow
return workflow
def handle_input_moderation(self, queue_manager: AppQueueManager,
app_record: App,
app_generate_entity: WorkflowAppGenerateEntity,
......
......@@ -4,28 +4,43 @@ import time
from collections.abc import Generator
from typing import Optional, Union
from pydantic import BaseModel
from pydantic import BaseModel, Extra
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.workflow_based_generate_task_pipeline import WorkflowBasedGenerateTaskPipeline
from core.app.entities.app_invoke_entities import (
InvokeFrom,
WorkflowAppGenerateEntity,
)
from core.app.entities.queue_entities import (
QueueErrorEvent,
QueueMessageReplaceEvent,
QueueNodeFinishedEvent,
QueueNodeFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueuePingEvent,
QueueStopEvent,
QueueTextChunkEvent,
QueueWorkflowFinishedEvent,
QueueWorkflowFailedEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.moderation.output_moderation import ModerationRule, OutputModeration
from core.workflow.entities.node_entities import NodeRunMetadataKey, SystemVariable
from extensions.ext_database import db
from models.workflow import WorkflowNodeExecution, WorkflowRun, WorkflowRunStatus
from models.account import Account
from models.model import EndUser
from models.workflow import (
Workflow,
WorkflowAppLog,
WorkflowAppLogCreatedFrom,
WorkflowNodeExecution,
WorkflowRun,
WorkflowRunStatus,
WorkflowRunTriggeredFrom,
)
logger = logging.getLogger(__name__)
......@@ -34,26 +49,59 @@ class TaskState(BaseModel):
"""
TaskState entity
"""
class NodeExecutionInfo(BaseModel):
"""
NodeExecutionInfo entity
"""
workflow_node_execution_id: str
start_at: float
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
answer: str = ""
metadata: dict = {}
workflow_run_id: Optional[str] = None
start_at: Optional[float] = None
total_tokens: int = 0
total_steps: int = 0
running_node_execution_infos: dict[str, NodeExecutionInfo] = {}
latest_node_execution_info: Optional[NodeExecutionInfo] = None
class Config:
"""Configuration for this pydantic object."""
class WorkflowAppGenerateTaskPipeline:
extra = Extra.forbid
arbitrary_types_allowed = True
class WorkflowAppGenerateTaskPipeline(WorkflowBasedGenerateTaskPipeline):
"""
WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""
def __init__(self, application_generate_entity: WorkflowAppGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
user: Union[Account, EndUser],
stream: bool) -> None:
"""
Initialize GenerateTaskPipeline.
:param application_generate_entity: application generate entity
:param workflow: workflow
:param queue_manager: queue manager
:param user: user
:param stream: is stream
"""
self._application_generate_entity = application_generate_entity
self._workflow = workflow
self._queue_manager = queue_manager
self._user = user
self._task_state = TaskState()
self._start_at = time.perf_counter()
self._output_moderation_handler = self._init_output_moderation()
......@@ -64,6 +112,10 @@ class WorkflowAppGenerateTaskPipeline:
Process generate task pipeline.
:return:
"""
db.session.refresh(self._workflow)
db.session.refresh(self._user)
db.session.close()
if self._stream:
return self._process_stream_response()
else:
......@@ -79,17 +131,14 @@ class WorkflowAppGenerateTaskPipeline:
if isinstance(event, QueueErrorEvent):
raise self._handle_error(event)
elif isinstance(event, QueueStopEvent | QueueWorkflowFinishedEvent):
if isinstance(event, QueueStopEvent):
workflow_run = self._get_workflow_run(self._task_state.workflow_run_id)
else:
workflow_run = self._get_workflow_run(event.workflow_run_id)
if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value:
outputs = workflow_run.outputs_dict
self._task_state.answer = outputs.get('text', '')
else:
raise self._handle_error(QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')))
elif isinstance(event, QueueWorkflowStartedEvent):
self._on_workflow_start()
elif isinstance(event, QueueNodeStartedEvent):
self._on_node_start(event)
elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent):
self._on_node_finished(event)
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
workflow_run = self._on_workflow_finished(event)
# response moderation
if self._output_moderation_handler:
......@@ -100,10 +149,12 @@ class WorkflowAppGenerateTaskPipeline:
public_event=False
)
# save workflow app log
self._save_workflow_app_log(workflow_run)
response = {
'event': 'workflow_finished',
'task_id': self._application_generate_entity.task_id,
'workflow_run_id': event.workflow_run_id,
'workflow_run_id': workflow_run.id,
'data': {
'id': workflow_run.id,
'workflow_id': workflow_run.workflow_id,
......@@ -135,8 +186,8 @@ class WorkflowAppGenerateTaskPipeline:
yield self._yield_response(data)
break
elif isinstance(event, QueueWorkflowStartedEvent):
self._task_state.workflow_run_id = event.workflow_run_id
workflow_run = self._get_workflow_run(event.workflow_run_id)
workflow_run = self._on_workflow_start()
response = {
'event': 'workflow_started',
'task_id': self._application_generate_entity.task_id,
......@@ -150,7 +201,8 @@ class WorkflowAppGenerateTaskPipeline:
yield self._yield_response(response)
elif isinstance(event, QueueNodeStartedEvent):
workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id)
workflow_node_execution = self._on_node_start(event)
response = {
'event': 'node_started',
'task_id': self._application_generate_entity.task_id,
......@@ -166,8 +218,9 @@ class WorkflowAppGenerateTaskPipeline:
}
yield self._yield_response(response)
elif isinstance(event, QueueNodeFinishedEvent):
workflow_node_execution = self._get_workflow_node_execution(event.workflow_node_execution_id)
elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent):
workflow_node_execution = self._on_node_finished(event)
response = {
'event': 'node_finished',
'task_id': self._application_generate_entity.task_id,
......@@ -190,20 +243,8 @@ class WorkflowAppGenerateTaskPipeline:
}
yield self._yield_response(response)
elif isinstance(event, QueueStopEvent | QueueWorkflowFinishedEvent):
if isinstance(event, QueueStopEvent):
workflow_run = self._get_workflow_run(self._task_state.workflow_run_id)
else:
workflow_run = self._get_workflow_run(event.workflow_run_id)
if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value:
outputs = workflow_run.outputs_dict
self._task_state.answer = outputs.get('text', '')
else:
err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))
data = self._error_to_stream_response_data(self._handle_error(err_event))
yield self._yield_response(data)
break
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
workflow_run = self._on_workflow_finished(event)
# response moderation
if self._output_moderation_handler:
......@@ -228,12 +269,12 @@ class WorkflowAppGenerateTaskPipeline:
yield self._yield_response(replace_response)
# save workflow app log
self._save_workflow_app_log()
self._save_workflow_app_log(workflow_run)
workflow_run_response = {
'event': 'workflow_finished',
'task_id': self._application_generate_entity.task_id,
'workflow_run_id': event.workflow_run_id,
'workflow_run_id': workflow_run.id,
'data': {
'id': workflow_run.id,
'workflow_id': workflow_run.workflow_id,
......@@ -244,7 +285,7 @@ class WorkflowAppGenerateTaskPipeline:
'total_tokens': workflow_run.total_tokens,
'total_steps': workflow_run.total_steps,
'created_at': int(workflow_run.created_at.timestamp()),
'finished_at': int(workflow_run.finished_at.timestamp())
'finished_at': int(workflow_run.finished_at.timestamp()) if workflow_run.finished_at else None
}
}
......@@ -291,41 +332,158 @@ class WorkflowAppGenerateTaskPipeline:
else:
continue
def _get_workflow_run(self, workflow_run_id: str) -> WorkflowRun:
"""
Get workflow run.
:param workflow_run_id: workflow run id
:return:
"""
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first()
if workflow_run:
# Because the workflow_run will be modified in the sub-thread,
# and the first query in the main thread will cache the entity,
# you need to expire the entity after the query
db.session.expire(workflow_run)
def _on_workflow_start(self) -> WorkflowRun:
self._task_state.start_at = time.perf_counter()
workflow_run = self._init_workflow_run(
workflow=self._workflow,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING
if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER
else WorkflowRunTriggeredFrom.APP_RUN,
user=self._user,
user_inputs=self._application_generate_entity.inputs,
system_inputs={
SystemVariable.FILES: self._application_generate_entity.files
}
)
self._task_state.workflow_run_id = workflow_run.id
db.session.close()
return workflow_run
def _get_workflow_node_execution(self, workflow_node_execution_id: str) -> WorkflowNodeExecution:
"""
Get workflow node execution.
:param workflow_node_execution_id: workflow node execution id
:return:
"""
workflow_node_execution = (db.session.query(WorkflowNodeExecution)
.filter(WorkflowNodeExecution.id == workflow_node_execution_id).first())
if workflow_node_execution:
# Because the workflow_node_execution will be modified in the sub-thread,
# and the first query in the main thread will cache the entity,
# you need to expire the entity after the query
db.session.expire(workflow_node_execution)
def _on_node_start(self, event: QueueNodeStartedEvent) -> WorkflowNodeExecution:
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first()
workflow_node_execution = self._init_node_execution_from_workflow_run(
workflow_run=workflow_run,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_data.title,
node_run_index=event.node_run_index,
predecessor_node_id=event.predecessor_node_id
)
latest_node_execution_info = TaskState.NodeExecutionInfo(
workflow_node_execution_id=workflow_node_execution.id,
start_at=time.perf_counter()
)
self._task_state.running_node_execution_infos[event.node_id] = latest_node_execution_info
self._task_state.latest_node_execution_info = latest_node_execution_info
self._task_state.total_steps += 1
db.session.close()
return workflow_node_execution
def _save_workflow_app_log(self) -> None:
def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> WorkflowNodeExecution:
current_node_execution = self._task_state.running_node_execution_infos[event.node_id]
workflow_node_execution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == current_node_execution.workflow_node_execution_id).first()
if isinstance(event, QueueNodeSucceededEvent):
workflow_node_execution = self._workflow_node_execution_success(
workflow_node_execution=workflow_node_execution,
start_at=current_node_execution.start_at,
inputs=event.inputs,
process_data=event.process_data,
outputs=event.outputs,
execution_metadata=event.execution_metadata
)
if event.execution_metadata and event.execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
self._task_state.total_tokens += (
int(event.execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)))
else:
workflow_node_execution = self._workflow_node_execution_failed(
workflow_node_execution=workflow_node_execution,
start_at=current_node_execution.start_at,
error=event.error
)
# remove running node execution info
del self._task_state.running_node_execution_infos[event.node_id]
db.session.close()
return workflow_node_execution
def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) \
-> WorkflowRun:
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first()
if isinstance(event, QueueStopEvent):
workflow_run = self._workflow_run_failed(
workflow_run=workflow_run,
start_at=self._task_state.start_at,
total_tokens=self._task_state.total_tokens,
total_steps=self._task_state.total_steps,
status=WorkflowRunStatus.STOPPED,
error='Workflow stopped.'
)
elif isinstance(event, QueueWorkflowFailedEvent):
workflow_run = self._workflow_run_failed(
workflow_run=workflow_run,
start_at=self._task_state.start_at,
total_tokens=self._task_state.total_tokens,
total_steps=self._task_state.total_steps,
status=WorkflowRunStatus.FAILED,
error=event.error
)
else:
if self._task_state.latest_node_execution_info:
workflow_node_execution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == self._task_state.latest_node_execution_info.workflow_node_execution_id).first()
outputs = workflow_node_execution.outputs
else:
outputs = None
workflow_run = self._workflow_run_success(
workflow_run=workflow_run,
start_at=self._task_state.start_at,
total_tokens=self._task_state.total_tokens,
total_steps=self._task_state.total_steps,
outputs=outputs
)
self._task_state.workflow_run_id = workflow_run.id
if workflow_run.status == WorkflowRunStatus.SUCCEEDED.value:
outputs = workflow_run.outputs_dict
self._task_state.answer = outputs.get('text', '')
db.session.close()
return workflow_run
def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None:
"""
Save workflow app log.
:return:
"""
pass # todo
invoke_from = self._application_generate_entity.invoke_from
if invoke_from == InvokeFrom.SERVICE_API:
created_from = WorkflowAppLogCreatedFrom.SERVICE_API
elif invoke_from == InvokeFrom.EXPLORE:
created_from = WorkflowAppLogCreatedFrom.INSTALLED_APP
elif invoke_from == InvokeFrom.WEB_APP:
created_from = WorkflowAppLogCreatedFrom.WEB_APP
else:
# not save log for debugging
return
workflow_app_log = WorkflowAppLog(
tenant_id=workflow_run.tenant_id,
app_id=workflow_run.app_id,
workflow_id=workflow_run.workflow_id,
workflow_run_id=workflow_run.id,
created_from=created_from.value,
created_by_role=('account' if isinstance(self._user, Account) else 'end_user'),
created_by=self._user.id,
)
db.session.add(workflow_app_log)
db.session.commit()
db.session.close()
def _handle_chunk(self, text: str) -> dict:
"""
......@@ -398,7 +556,6 @@ class WorkflowAppGenerateTaskPipeline:
return {
'event': 'error',
'task_id': self._application_generate_entity.task_id,
'workflow_run_id': self._task_state.workflow_run_id,
**data
}
......
from typing import Optional
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.queue_entities import (
QueueNodeFinishedEvent,
QueueNodeFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueTextChunkEvent,
QueueWorkflowFinishedEvent,
QueueWorkflowFailedEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
from models.workflow import Workflow, WorkflowNodeExecution, WorkflowRun
from models.workflow import Workflow
class WorkflowEventTriggerCallback(BaseWorkflowCallback):
......@@ -17,39 +22,91 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback):
self._queue_manager = queue_manager
self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph_dict)
def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None:
def on_workflow_run_started(self) -> None:
"""
Workflow run started
"""
self._queue_manager.publish(
QueueWorkflowStartedEvent(workflow_run_id=workflow_run.id),
QueueWorkflowStartedEvent(),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_run_succeeded(self) -> None:
"""
Workflow run succeeded
"""
self._queue_manager.publish(
QueueWorkflowSucceededEvent(),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_run_finished(self, workflow_run: WorkflowRun) -> None:
def on_workflow_run_failed(self, error: str) -> None:
"""
Workflow run finished
Workflow run failed
"""
self._queue_manager.publish(
QueueWorkflowFinishedEvent(workflow_run_id=workflow_run.id),
QueueWorkflowFailedEvent(
error=error
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_node_execute_started(self, workflow_node_execution: WorkflowNodeExecution) -> None:
def on_workflow_node_execute_started(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
node_run_index: int = 1,
predecessor_node_id: Optional[str] = None) -> None:
"""
Workflow node execute started
"""
self._queue_manager.publish(
QueueNodeStartedEvent(workflow_node_execution_id=workflow_node_execution.id),
QueueNodeStartedEvent(
node_id=node_id,
node_type=node_type,
node_data=node_data,
node_run_index=node_run_index,
predecessor_node_id=predecessor_node_id
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_node_execute_succeeded(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
inputs: Optional[dict] = None,
process_data: Optional[dict] = None,
outputs: Optional[dict] = None,
execution_metadata: Optional[dict] = None) -> None:
"""
Workflow node execute succeeded
"""
self._queue_manager.publish(
QueueNodeSucceededEvent(
node_id=node_id,
node_type=node_type,
node_data=node_data,
inputs=inputs,
process_data=process_data,
outputs=outputs,
execution_metadata=execution_metadata
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_node_execute_finished(self, workflow_node_execution: WorkflowNodeExecution) -> None:
def on_workflow_node_execute_failed(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
error: str) -> None:
"""
Workflow node execute finished
Workflow node execute failed
"""
self._queue_manager.publish(
QueueNodeFinishedEvent(workflow_node_execution_id=workflow_node_execution.id),
QueueNodeFailedEvent(
node_id=node_id,
node_type=node_type,
node_data=node_data,
error=error
),
PublishFrom.APPLICATION_MANAGER
)
......
import json
import time
from datetime import datetime
from typing import Optional, Union
from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.entities.node_entities import NodeType
from extensions.ext_database import db
from models.account import Account
from models.model import EndUser
from models.workflow import (
CreatedByRole,
Workflow,
WorkflowNodeExecution,
WorkflowNodeExecutionStatus,
WorkflowNodeExecutionTriggeredFrom,
WorkflowRun,
WorkflowRunStatus,
WorkflowRunTriggeredFrom,
)
class WorkflowBasedGenerateTaskPipeline:
def _init_workflow_run(self, workflow: Workflow,
triggered_from: WorkflowRunTriggeredFrom,
user: Union[Account, EndUser],
user_inputs: dict,
system_inputs: Optional[dict] = None) -> WorkflowRun:
"""
Init workflow run
:param workflow: Workflow instance
:param triggered_from: triggered from
:param user: account or end user
:param user_inputs: user variables inputs
:param system_inputs: system inputs, like: query, files
:return:
"""
max_sequence = db.session.query(db.func.max(WorkflowRun.sequence_number)) \
.filter(WorkflowRun.tenant_id == workflow.tenant_id) \
.filter(WorkflowRun.app_id == workflow.app_id) \
.scalar() or 0
new_sequence_number = max_sequence + 1
# init workflow run
workflow_run = WorkflowRun(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
sequence_number=new_sequence_number,
workflow_id=workflow.id,
type=workflow.type,
triggered_from=triggered_from.value,
version=workflow.version,
graph=workflow.graph,
inputs=json.dumps({**user_inputs, **jsonable_encoder(system_inputs)}),
status=WorkflowRunStatus.RUNNING.value,
created_by_role=(CreatedByRole.ACCOUNT.value
if isinstance(user, Account) else CreatedByRole.END_USER.value),
created_by=user.id
)
db.session.add(workflow_run)
db.session.commit()
db.session.refresh(workflow_run)
db.session.close()
return workflow_run
def _workflow_run_success(self, workflow_run: WorkflowRun,
start_at: float,
total_tokens: int,
total_steps: int,
outputs: Optional[dict] = None) -> WorkflowRun:
"""
Workflow run success
:param workflow_run: workflow run
:param start_at: start time
:param total_tokens: total tokens
:param total_steps: total steps
:param outputs: outputs
:return:
"""
workflow_run.status = WorkflowRunStatus.SUCCEEDED.value
workflow_run.outputs = outputs
workflow_run.elapsed_time = time.perf_counter() - start_at
workflow_run.total_tokens = total_tokens
workflow_run.total_steps = total_steps
workflow_run.finished_at = datetime.utcnow()
db.session.commit()
db.session.refresh(workflow_run)
db.session.close()
return workflow_run
def _workflow_run_failed(self, workflow_run: WorkflowRun,
start_at: float,
total_tokens: int,
total_steps: int,
status: WorkflowRunStatus,
error: str) -> WorkflowRun:
"""
Workflow run failed
:param workflow_run: workflow run
:param start_at: start time
:param total_tokens: total tokens
:param total_steps: total steps
:param status: status
:param error: error message
:return:
"""
workflow_run.status = status.value
workflow_run.error = error
workflow_run.elapsed_time = time.perf_counter() - start_at
workflow_run.total_tokens = total_tokens
workflow_run.total_steps = total_steps
workflow_run.finished_at = datetime.utcnow()
db.session.commit()
db.session.refresh(workflow_run)
db.session.close()
return workflow_run
def _init_node_execution_from_workflow_run(self, workflow_run: WorkflowRun,
node_id: str,
node_type: NodeType,
node_title: str,
node_run_index: int = 1,
predecessor_node_id: Optional[str] = None) -> WorkflowNodeExecution:
"""
Init workflow node execution from workflow run
:param workflow_run: workflow run
:param node_id: node id
:param node_type: node type
:param node_title: node title
:param node_run_index: run index
:param predecessor_node_id: predecessor node id if exists
:return:
"""
# init workflow node execution
workflow_node_execution = WorkflowNodeExecution(
tenant_id=workflow_run.tenant_id,
app_id=workflow_run.app_id,
workflow_id=workflow_run.workflow_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
workflow_run_id=workflow_run.id,
predecessor_node_id=predecessor_node_id,
index=node_run_index,
node_id=node_id,
node_type=node_type.value,
title=node_title,
status=WorkflowNodeExecutionStatus.RUNNING.value,
created_by_role=workflow_run.created_by_role,
created_by=workflow_run.created_by
)
db.session.add(workflow_node_execution)
db.session.commit()
db.session.refresh(workflow_node_execution)
db.session.close()
return workflow_node_execution
def _workflow_node_execution_success(self, workflow_node_execution: WorkflowNodeExecution,
start_at: float,
inputs: Optional[dict] = None,
process_data: Optional[dict] = None,
outputs: Optional[dict] = None,
execution_metadata: Optional[dict] = None) -> WorkflowNodeExecution:
"""
Workflow node execution success
:param workflow_node_execution: workflow node execution
:param start_at: start time
:param inputs: inputs
:param process_data: process data
:param outputs: outputs
:param execution_metadata: execution metadata
:return:
"""
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
workflow_node_execution.elapsed_time = time.perf_counter() - start_at
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(execution_metadata)) \
if execution_metadata else None
workflow_node_execution.finished_at = datetime.utcnow()
db.session.commit()
db.session.refresh(workflow_node_execution)
db.session.close()
return workflow_node_execution
def _workflow_node_execution_failed(self, workflow_node_execution: WorkflowNodeExecution,
start_at: float,
error: str) -> WorkflowNodeExecution:
"""
Workflow node execution failed
:param workflow_node_execution: workflow node execution
:param start_at: start time
:param error: error message
:return:
"""
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.error = error
workflow_node_execution.elapsed_time = time.perf_counter() - start_at
workflow_node_execution.finished_at = datetime.utcnow()
db.session.commit()
db.session.refresh(workflow_node_execution)
db.session.close()
return workflow_node_execution
from enum import Enum
from typing import Any
from typing import Any, Optional
from pydantic import BaseModel
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
class QueueEvent(Enum):
......@@ -16,9 +18,11 @@ class QueueEvent(Enum):
MESSAGE_REPLACE = "message_replace"
MESSAGE_END = "message_end"
WORKFLOW_STARTED = "workflow_started"
WORKFLOW_FINISHED = "workflow_finished"
WORKFLOW_SUCCEEDED = "workflow_succeeded"
WORKFLOW_FAILED = "workflow_failed"
NODE_STARTED = "node_started"
NODE_FINISHED = "node_finished"
NODE_SUCCEEDED = "node_succeeded"
NODE_FAILED = "node_failed"
RETRIEVER_RESOURCES = "retriever_resources"
ANNOTATION_REPLY = "annotation_reply"
AGENT_THOUGHT = "agent_thought"
......@@ -96,15 +100,21 @@ class QueueWorkflowStartedEvent(AppQueueEvent):
QueueWorkflowStartedEvent entity
"""
event = QueueEvent.WORKFLOW_STARTED
workflow_run_id: str
class QueueWorkflowFinishedEvent(AppQueueEvent):
class QueueWorkflowSucceededEvent(AppQueueEvent):
"""
QueueWorkflowFinishedEvent entity
QueueWorkflowSucceededEvent entity
"""
event = QueueEvent.WORKFLOW_FINISHED
workflow_run_id: str
event = QueueEvent.WORKFLOW_SUCCEEDED
class QueueWorkflowFailedEvent(AppQueueEvent):
"""
QueueWorkflowFailedEvent entity
"""
event = QueueEvent.WORKFLOW_FAILED
error: str
class QueueNodeStartedEvent(AppQueueEvent):
......@@ -112,17 +122,45 @@ class QueueNodeStartedEvent(AppQueueEvent):
QueueNodeStartedEvent entity
"""
event = QueueEvent.NODE_STARTED
workflow_node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
node_run_index: int = 1
predecessor_node_id: Optional[str] = None
class QueueNodeFinishedEvent(AppQueueEvent):
class QueueNodeSucceededEvent(AppQueueEvent):
"""
QueueNodeFinishedEvent entity
QueueNodeSucceededEvent entity
"""
event = QueueEvent.NODE_FINISHED
workflow_node_execution_id: str
event = QueueEvent.NODE_SUCCEEDED
node_id: str
node_type: NodeType
node_data: BaseNodeData
inputs: Optional[dict] = None
process_data: Optional[dict] = None
outputs: Optional[dict] = None
execution_metadata: Optional[dict] = None
error: Optional[str] = None
class QueueNodeFailedEvent(AppQueueEvent):
"""
QueueNodeFailedEvent entity
"""
event = QueueEvent.NODE_FAILED
node_id: str
node_type: NodeType
node_data: BaseNodeData
error: str
class QueueAgentThoughtEvent(AppQueueEvent):
"""
QueueAgentThoughtEvent entity
......
from os import environ
from typing import Literal, Optional
from httpx import post
from pydantic import BaseModel
from yarl import URL
from core.helper.code_executor.jina2_transformer import Jinja2TemplateTransformer
from core.helper.code_executor.python_transformer import PythonTemplateTransformer
# Code Executor
CODE_EXECUTION_ENDPOINT = environ.get('CODE_EXECUTION_ENDPOINT', '')
CODE_EXECUTION_API_KEY = environ.get('CODE_EXECUTION_API_KEY', '')
class CodeExecutionException(Exception):
pass
class CodeExecutionResponse(BaseModel):
class Data(BaseModel):
stdout: Optional[str]
error: Optional[str]
code: int
message: str
data: Data
class CodeExecutor:
@classmethod
def execute_code(cls, language: Literal['python3', 'javascript', 'jinja2'], code: str, inputs: dict) -> dict:
"""
Execute code
:param language: code language
:param code: code
:param inputs: inputs
:return:
"""
template_transformer = None
if language == 'python3':
template_transformer = PythonTemplateTransformer
elif language == 'jinja2':
template_transformer = Jinja2TemplateTransformer
else:
raise CodeExecutionException('Unsupported language')
runner = template_transformer.transform_caller(code, inputs)
url = URL(CODE_EXECUTION_ENDPOINT) / 'v1' / 'sandbox' / 'run'
headers = {
'X-Api-Key': CODE_EXECUTION_API_KEY
}
data = {
'language': language if language != 'jinja2' else 'python3',
'code': runner,
}
try:
response = post(str(url), json=data, headers=headers)
if response.status_code == 503:
raise CodeExecutionException('Code execution service is unavailable')
elif response.status_code != 200:
raise Exception('Failed to execute code')
except CodeExecutionException as e:
raise e
except Exception as e:
raise CodeExecutionException('Failed to execute code')
try:
response = response.json()
except:
raise CodeExecutionException('Failed to parse response')
response = CodeExecutionResponse(**response)
if response.code != 0:
raise CodeExecutionException(response.message)
if response.data.error:
raise CodeExecutionException(response.data.error)
return template_transformer.transform_response(response.data.stdout)
\ No newline at end of file
import json
import re
from core.helper.code_executor.template_transformer import TemplateTransformer
PYTHON_RUNNER = """
import jinja2
template = jinja2.Template('''{{code}}''')
def main(**inputs):
return template.render(**inputs)
# execute main function, and return the result
output = main(**{{inputs}})
result = f'''<<RESULT>>{output}<<RESULT>>'''
print(result)
"""
class Jinja2TemplateTransformer(TemplateTransformer):
@classmethod
def transform_caller(cls, code: str, inputs: dict) -> str:
"""
Transform code to python runner
:param code: code
:param inputs: inputs
:return:
"""
# transform jinja2 template to python code
runner = PYTHON_RUNNER.replace('{{code}}', code)
runner = runner.replace('{{inputs}}', json.dumps(inputs, indent=4))
return runner
@classmethod
def transform_response(cls, response: str) -> dict:
"""
Transform response to dict
:param response: response
:return:
"""
# extract result
result = re.search(r'<<RESULT>>(.*)<<RESULT>>', response, re.DOTALL)
if not result:
raise ValueError('Failed to parse result')
result = result.group(1)
return {
'result': result
}
\ No newline at end of file
import json
import re
from core.helper.code_executor.template_transformer import TemplateTransformer
PYTHON_RUNNER = """# declare main function here
{{code}}
# execute main function, and return the result
# inputs is a dict, and it
output = main(**{{inputs}})
# convert output to json and print
output = json.dumps(output, indent=4)
result = f'''<<RESULT>>
{output}
<<RESULT>>'''
print(result)
"""
class PythonTemplateTransformer(TemplateTransformer):
@classmethod
def transform_caller(cls, code: str, inputs: dict) -> str:
"""
Transform code to python runner
:param code: code
:param inputs: inputs
:return:
"""
# transform inputs to json string
inputs_str = json.dumps(inputs, indent=4)
# replace code and inputs
runner = PYTHON_RUNNER.replace('{{code}}', code)
runner = runner.replace('{{inputs}}', inputs_str)
return runner
@classmethod
def transform_response(cls, response: str) -> dict:
"""
Transform response to dict
:param response: response
:return:
"""
# extract result
result = re.search(r'<<RESULT>>(.*)<<RESULT>>', response, re.DOTALL)
if not result:
raise ValueError('Failed to parse result')
result = result.group(1)
return json.loads(result)
from abc import ABC, abstractmethod
class TemplateTransformer(ABC):
@classmethod
@abstractmethod
def transform_caller(cls, code: str, inputs: dict) -> str:
"""
Transform code to python runner
:param code: code
:param inputs: inputs
:return:
"""
pass
@classmethod
@abstractmethod
def transform_response(cls, response: str) -> dict:
"""
Transform response to dict
:param response: response
:return:
"""
pass
\ No newline at end of file
......@@ -38,6 +38,10 @@ def patch(url, *args, **kwargs):
return _patch(url=url, *args, proxies=httpx_proxies, **kwargs)
def delete(url, *args, **kwargs):
if 'follow_redirects' in kwargs:
if kwargs['follow_redirects']:
kwargs['allow_redirects'] = kwargs['follow_redirects']
kwargs.pop('follow_redirects')
return _delete(url=url, *args, proxies=requests_proxies, **kwargs)
def head(url, *args, **kwargs):
......
import json
from enum import Enum
from json import JSONDecodeError
from typing import Optional
from extensions.ext_redis import redis_client
class ToolParameterCacheType(Enum):
PARAMETER = "tool_parameter"
class ToolParameterCache:
def __init__(self,
tenant_id: str,
provider: str,
tool_name: str,
cache_type: ToolParameterCacheType
):
self.cache_key = f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}"
def get(self) -> Optional[dict]:
"""
Get cached model provider credentials.
:return:
"""
cached_tool_parameter = redis_client.get(self.cache_key)
if cached_tool_parameter:
try:
cached_tool_parameter = cached_tool_parameter.decode('utf-8')
cached_tool_parameter = json.loads(cached_tool_parameter)
except JSONDecodeError:
return None
return cached_tool_parameter
else:
return None
def set(self, parameters: dict) -> None:
"""
Cache model provider credentials.
:param credentials: provider credentials
:return:
"""
redis_client.setex(self.cache_key, 86400, json.dumps(parameters))
def delete(self) -> None:
"""
Delete cached model provider credentials.
:return:
"""
redis_client.delete(self.cache_key)
\ No newline at end of file
......@@ -82,6 +82,8 @@ class HostingConfiguration:
RestrictModel(model="gpt-35-turbo-16k", base_model_name="gpt-35-turbo-16k", model_type=ModelType.LLM),
RestrictModel(model="text-davinci-003", base_model_name="text-davinci-003", model_type=ModelType.LLM),
RestrictModel(model="text-embedding-ada-002", base_model_name="text-embedding-ada-002", model_type=ModelType.TEXT_EMBEDDING),
RestrictModel(model="text-embedding-3-small", base_model_name="text-embedding-3-small", model_type=ModelType.TEXT_EMBEDDING),
RestrictModel(model="text-embedding-3-large", base_model_name="text-embedding-3-large", model_type=ModelType.TEXT_EMBEDDING),
]
)
quotas.append(trial_quota)
......
......@@ -524,6 +524,46 @@ EMBEDDING_BASE_MODELS = [
currency='USD',
)
)
),
AzureBaseModel(
base_model_name='text-embedding-3-small',
entity=AIModelEntity(
model='fake-deployment-name',
label=I18nObject(
en_US='fake-deployment-name-label'
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.TEXT_EMBEDDING,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: 8191,
ModelPropertyKey.MAX_CHUNKS: 32,
},
pricing=PriceConfig(
input=0.00002,
unit=0.001,
currency='USD',
)
)
),
AzureBaseModel(
base_model_name='text-embedding-3-large',
entity=AIModelEntity(
model='fake-deployment-name',
label=I18nObject(
en_US='fake-deployment-name-label'
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.TEXT_EMBEDDING,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: 8191,
ModelPropertyKey.MAX_CHUNKS: 32,
},
pricing=PriceConfig(
input=0.00013,
unit=0.001,
currency='USD',
)
)
)
]
SPEECH2TEXT_BASE_MODELS = [
......
......@@ -100,6 +100,18 @@ model_credential_schema:
show_on:
- variable: __model_type
value: text-embedding
- label:
en_US: text-embedding-3-small
value: text-embedding-3-small
show_on:
- variable: __model_type
value: text-embedding
- label:
en_US: text-embedding-3-large
value: text-embedding-3-large
show_on:
- variable: __model_type
value: text-embedding
- label:
en_US: whisper-1
value: whisper-1
......
......@@ -119,7 +119,7 @@ parameters: # Parameter list
- The `identity` field is mandatory, it contains the basic information of the tool, including name, author, label, description, etc.
- `parameters` Parameter list
- `name` Parameter name, unique, no duplication with other parameters
- `type` Parameter type, currently supports `string`, `number`, `boolean`, `select` four types, corresponding to string, number, boolean, drop-down box
- `type` Parameter type, currently supports `string`, `number`, `boolean`, `select`, `secret-input` four types, corresponding to string, number, boolean, drop-down box, and encrypted input box, respectively. For sensitive information, we recommend using `secret-input` type
- `required` Required or not
- In `llm` mode, if the parameter is required, the Agent is required to infer this parameter
- In `form` mode, if the parameter is required, the user is required to fill in this parameter on the frontend before the conversation starts
......
......@@ -119,7 +119,7 @@ parameters: # 参数列表
- `identity` 字段是必须的,它包含了工具的基本信息,包括名称、作者、标签、描述等
- `parameters` 参数列表
- `name` 参数名称,唯一,不允许和其他参数重名
- `type` 参数类型,目前支持`string``number``boolean``select` 四种类型,分别对应字符串、数字、布尔值、下拉框
- `type` 参数类型,目前支持`string``number``boolean``select``secret-input` 五种类型,分别对应字符串、数字、布尔值、下拉框、加密输入框,对于敏感信息,我们建议使用`secret-input`类型
- `required` 是否必填
-`llm`模式下,如果参数为必填,则会要求Agent必须要推理出这个参数
-`form`模式下,如果参数为必填,则会要求用户在对话开始前在前端填写这个参数
......
......@@ -100,6 +100,7 @@ class ToolParameter(BaseModel):
NUMBER = "number"
BOOLEAN = "boolean"
SELECT = "select"
SECRET_INPUT = "secret-input"
class ToolParameterForm(Enum):
SCHEMA = "schema" # should be set while adding tool
......
......@@ -23,6 +23,8 @@ class AIPPTGenerateTool(BuiltinTool):
_api_base_url = URL('https://co.aippt.cn/api')
_api_token_cache = {}
_api_token_cache_lock = Lock()
_style_cache = {}
_style_cache_lock = Lock()
_task = {}
_task_type_map = {
......@@ -390,20 +392,31 @@ class AIPPTGenerateTool(BuiltinTool):
).digest()
).decode('utf-8')
def get_styles(self, user_id: str) -> tuple[list[dict], list[dict]]:
@classmethod
def _get_styles(cls, credentials: dict[str, str], user_id: str) -> tuple[list[dict], list[dict]]:
"""
Get styles
:param credentials: the credentials
:return: Tuple[list[dict[id, color]], list[dict[id, style]]
"""
# check cache
with cls._style_cache_lock:
# clear expired styles
now = time()
for key in list(cls._style_cache.keys()):
if cls._style_cache[key]['expire'] < now:
del cls._style_cache[key]
key = f'{credentials["aippt_access_key"]}#@#{user_id}'
if key in cls._style_cache:
return cls._style_cache[key]['colors'], cls._style_cache[key]['styles']
headers = {
'x-channel': '',
'x-api-key': self.runtime.credentials['aippt_access_key'],
'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id)
'x-api-key': credentials['aippt_access_key'],
'x-token': cls._get_api_token(credentials=credentials, user_id=user_id)
}
response = get(
str(self._api_base_url / 'template_component' / 'suit' / 'select'),
str(cls._api_base_url / 'template_component' / 'suit' / 'select'),
headers=headers
)
......@@ -425,7 +438,26 @@ class AIPPTGenerateTool(BuiltinTool):
'name': item.get('title'),
} for item in response.get('data', {}).get('suit_style') or []]
with cls._style_cache_lock:
cls._style_cache[key] = {
'colors': colors,
'styles': styles,
'expire': now + 60 * 60
}
return colors, styles
def get_styles(self, user_id: str) -> tuple[list[dict], list[dict]]:
"""
Get styles
:param credentials: the credentials
:return: Tuple[list[dict[id, color]], list[dict[id, style]]
"""
if not self.runtime.credentials.get('aippt_access_key') or not self.runtime.credentials.get('aippt_secret_key'):
return [], []
return self._get_styles(credentials=self.runtime.credentials, user_id=user_id)
def _get_suit(self, style_id: int, colour_id: int) -> int:
"""
......
......@@ -14,7 +14,7 @@ description:
llm: A tool for sending messages to a chat group on Wecom(企业微信) .
parameters:
- name: hook_key
type: string
type: secret-input
required: true
label:
en_US: Wecom Group bot webhook key
......
......@@ -266,6 +266,40 @@ class Tool(BaseModel, ABC):
"""
return self.parameters
def get_all_runtime_parameters(self) -> list[ToolParameter]:
"""
get all runtime parameters
:return: all runtime parameters
"""
parameters = self.parameters or []
parameters = parameters.copy()
user_parameters = self.get_runtime_parameters() or []
user_parameters = user_parameters.copy()
# override parameters
for parameter in user_parameters:
# check if parameter in tool parameters
found = False
for tool_parameter in parameters:
if tool_parameter.name == parameter.name:
found = True
break
if found:
# override parameter
tool_parameter.type = parameter.type
tool_parameter.form = parameter.form
tool_parameter.required = parameter.required
tool_parameter.default = parameter.default
tool_parameter.options = parameter.options
tool_parameter.llm_description = parameter.llm_description
else:
# add new parameter
parameters.append(parameter)
return parameters
def is_tool_available(self) -> bool:
"""
check if the tool is available
......
......@@ -5,12 +5,18 @@ import mimetypes
from os import listdir, path
from typing import Any, Union
from core.agent.entities import AgentToolEntity
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
from core.model_runtime.entities.message_entities import PromptMessage
from core.provider_manager import ProviderManager
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.constant import DEFAULT_PROVIDERS
from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeMessage, ToolProviderCredentials
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
ToolInvokeMessage,
ToolParameter,
ToolProviderCredentials,
)
from core.tools.entities.user_entities import UserToolProvider
from core.tools.errors import ToolProviderNotFoundError
from core.tools.provider.api_tool_provider import ApiBasedToolProviderController
......@@ -21,8 +27,14 @@ from core.tools.provider.model_tool_provider import ModelToolProviderController
from core.tools.provider.tool_provider import ToolProviderController
from core.tools.tool.api_tool import ApiTool
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.utils.configuration import ModelToolConfigurationManager, ToolConfiguration
from core.tools.tool.tool import Tool
from core.tools.utils.configuration import (
ModelToolConfigurationManager,
ToolConfigurationManager,
ToolParameterConfigurationManager,
)
from core.tools.utils.encoder import serialize_base_model_dict
from core.workflow.nodes.tool.entities import ToolEntity
from extensions.ext_database import db
from models.tools import ApiToolProvider, BuiltinToolProvider
......@@ -172,7 +184,7 @@ class ToolManager:
# decrypt the credentials
credentials = builtin_provider.credentials
controller = ToolManager.get_builtin_provider(provider_name)
tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller)
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
......@@ -189,7 +201,7 @@ class ToolManager:
api_provider, credentials = ToolManager.get_api_provider_controller(tenant_id, provider_name)
# decrypt the credentials
tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=api_provider)
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=api_provider)
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
return api_provider.get_tool(tool_name).fork_tool_runtime(meta={
......@@ -214,6 +226,110 @@ class ToolManager:
else:
raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
@staticmethod
def _init_runtime_parameter(parameter_rule: ToolParameter, parameters: dict) -> Union[str, int, float, bool]:
"""
init runtime parameter
"""
parameter_value = parameters.get(parameter_rule.name)
if not parameter_value:
# get default value
parameter_value = parameter_rule.default
if not parameter_value and parameter_rule.required:
raise ValueError(f"tool parameter {parameter_rule.name} not found in tool config")
if parameter_rule.type == ToolParameter.ToolParameterType.SELECT:
# check if tool_parameter_config in options
options = list(map(lambda x: x.value, parameter_rule.options))
if parameter_value not in options:
raise ValueError(f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}")
# convert tool parameter config to correct type
try:
if parameter_rule.type == ToolParameter.ToolParameterType.NUMBER:
# check if tool parameter is integer
if isinstance(parameter_value, int):
parameter_value = parameter_value
elif isinstance(parameter_value, float):
parameter_value = parameter_value
elif isinstance(parameter_value, str):
if '.' in parameter_value:
parameter_value = float(parameter_value)
else:
parameter_value = int(parameter_value)
elif parameter_rule.type == ToolParameter.ToolParameterType.BOOLEAN:
parameter_value = bool(parameter_value)
elif parameter_rule.type not in [ToolParameter.ToolParameterType.SELECT, ToolParameter.ToolParameterType.STRING]:
parameter_value = str(parameter_value)
elif parameter_rule.type == ToolParameter.ToolParameterType:
parameter_value = str(parameter_value)
except Exception as e:
raise ValueError(f"tool parameter {parameter_rule.name} value {parameter_value} is not correct type")
return parameter_value
@staticmethod
def get_agent_tool_runtime(tenant_id: str, agent_tool: AgentToolEntity, agent_callback: DifyAgentCallbackHandler) -> Tool:
"""
get the agent tool runtime
"""
tool_entity = ToolManager.get_tool_runtime(
provider_type=agent_tool.provider_type, provider_name=agent_tool.provider_id, tool_name=agent_tool.tool_name,
tenant_id=tenant_id,
agent_callback=agent_callback
)
runtime_parameters = {}
parameters = tool_entity.get_all_runtime_parameters()
for parameter in parameters:
if parameter.form == ToolParameter.ToolParameterForm.FORM:
# save tool parameter to tool entity memory
value = ToolManager._init_runtime_parameter(parameter, agent_tool.tool_parameters)
runtime_parameters[parameter.name] = value
# decrypt runtime parameters
encryption_manager = ToolParameterConfigurationManager(
tenant_id=tenant_id,
tool_runtime=tool_entity,
provider_name=agent_tool.provider_id,
provider_type=agent_tool.provider_type,
)
runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
return tool_entity
@staticmethod
def get_workflow_tool_runtime(tenant_id: str, workflow_tool: ToolEntity, agent_callback: DifyAgentCallbackHandler):
"""
get the workflow tool runtime
"""
tool_entity = ToolManager.get_tool_runtime(
provider_type=workflow_tool.provider_type,
provider_name=workflow_tool.provider_id,
tool_name=workflow_tool.tool_name,
tenant_id=tenant_id,
agent_callback=agent_callback
)
runtime_parameters = {}
parameters = tool_entity.get_all_runtime_parameters()
for parameter in parameters:
# save tool parameter to tool entity memory
value = ToolManager._init_runtime_parameter(parameter, workflow_tool.tool_configurations)
runtime_parameters[parameter.name] = value
# decrypt runtime parameters
encryption_manager = ToolParameterConfigurationManager(
tenant_id=tenant_id,
tool_runtime=tool_entity,
provider_name=workflow_tool.provider_id,
provider_type=workflow_tool.provider_type,
)
runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
return tool_entity
@staticmethod
def get_builtin_provider_icon(provider: str) -> tuple[str, str]:
"""
......@@ -396,7 +512,7 @@ class ToolManager:
controller = ToolManager.get_builtin_provider(provider_name)
# init tool configuration
tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller)
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
# decrypt the credentials and mask the credentials
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials)
masked_credentials = tool_configuration.mask_tool_credentials(credentials=decrypted_credentials)
......@@ -463,7 +579,7 @@ class ToolManager:
)
# init tool configuration
tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller)
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
# decrypt the credentials and mask the credentials
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials)
......@@ -523,7 +639,7 @@ class ToolManager:
provider, ApiProviderAuthType.API_KEY if credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE
)
# init tool configuration
tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller)
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
......
......@@ -5,16 +5,19 @@ from pydantic import BaseModel
from yaml import FullLoader, load
from core.helper import encrypter
from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
from core.tools.entities.tool_entities import (
ModelToolConfiguration,
ModelToolProviderConfiguration,
ToolParameter,
ToolProviderCredentials,
)
from core.tools.provider.tool_provider import ToolProviderController
from core.tools.tool.tool import Tool
class ToolConfiguration(BaseModel):
class ToolConfigurationManager(BaseModel):
tenant_id: str
provider_controller: ToolProviderController
......@@ -101,6 +104,128 @@ class ToolConfiguration(BaseModel):
)
cache.delete()
class ToolParameterConfigurationManager(BaseModel):
"""
Tool parameter configuration manager
"""
tenant_id: str
tool_runtime: Tool
provider_name: str
provider_type: str
def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
deep copy parameters
"""
return {key: value for key, value in parameters.items()}
def _merge_parameters(self) -> list[ToolParameter]:
"""
merge parameters
"""
# get tool parameters
tool_parameters = self.tool_runtime.parameters or []
# get tool runtime parameters
runtime_parameters = self.tool_runtime.get_runtime_parameters() or []
# override parameters
current_parameters = tool_parameters.copy()
for runtime_parameter in runtime_parameters:
found = False
for index, parameter in enumerate(current_parameters):
if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form:
current_parameters[index] = runtime_parameter
found = True
break
if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
current_parameters.append(runtime_parameter)
return current_parameters
def mask_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
mask tool parameters
return a deep copy of parameters with masked values
"""
parameters = self._deep_copy(parameters)
# override parameters
current_parameters = self._merge_parameters()
for parameter in current_parameters:
if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT:
if parameter.name in parameters:
if len(parameters[parameter.name]) > 6:
parameters[parameter.name] = \
parameters[parameter.name][:2] + \
'*' * (len(parameters[parameter.name]) - 4) +\
parameters[parameter.name][-2:]
else:
parameters[parameter.name] = '*' * len(parameters[parameter.name])
return parameters
def encrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
encrypt tool parameters with tenant id
return a deep copy of parameters with encrypted values
"""
# override parameters
current_parameters = self._merge_parameters()
for parameter in current_parameters:
if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT:
if parameter.name in parameters:
encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name])
parameters[parameter.name] = encrypted
return parameters
def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
decrypt tool parameters with tenant id
return a deep copy of parameters with decrypted values
"""
cache = ToolParameterCache(
tenant_id=self.tenant_id,
provider=f'{self.provider_type}.{self.provider_name}',
tool_name=self.tool_runtime.identity.name,
cache_type=ToolParameterCacheType.PARAMETER
)
cached_parameters = cache.get()
if cached_parameters:
return cached_parameters
# override parameters
current_parameters = self._merge_parameters()
has_secret_input = False
for parameter in current_parameters:
if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT:
if parameter.name in parameters:
try:
has_secret_input = True
parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name])
except:
pass
if has_secret_input:
cache.set(parameters)
return parameters
def delete_tool_parameters_cache(self):
cache = ToolParameterCache(
tenant_id=self.tenant_id,
provider=f'{self.provider_type}.{self.provider_name}',
tool_name=self.tool_runtime.identity.name,
cache_type=ToolParameterCacheType.PARAMETER
)
cache.delete()
class ModelToolConfigurationManager:
"""
Model as tool configuration
......
import logging
from mimetypes import guess_extension
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool_file_manager import ToolFileManager
logger = logging.getLogger(__name__)
class ToolFileMessageTransformer:
@staticmethod
def transform_tool_invoke_messages(messages: list[ToolInvokeMessage],
user_id: str,
tenant_id: str,
conversation_id: str) -> list[ToolInvokeMessage]:
"""
Transform tool message and handle file download
"""
result = []
for message in messages:
if message.type == ToolInvokeMessage.MessageType.TEXT:
result.append(message)
elif message.type == ToolInvokeMessage.MessageType.LINK:
result.append(message)
elif message.type == ToolInvokeMessage.MessageType.IMAGE:
# try to download image
try:
file = ToolFileManager.create_file_by_url(
user_id=user_id,
tenant_id=tenant_id,
conversation_id=conversation_id,
file_url=message.message
)
url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}'
result.append(ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
message=url,
save_as=message.save_as,
meta=message.meta.copy() if message.meta is not None else {},
))
except Exception as e:
logger.exception(e)
result.append(ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.TEXT,
message=f"Failed to download image: {message.message}, you can try to download it yourself.",
meta=message.meta.copy() if message.meta is not None else {},
save_as=message.save_as,
))
elif message.type == ToolInvokeMessage.MessageType.BLOB:
# get mime type and save blob to storage
mimetype = message.meta.get('mime_type', 'octet/stream')
# if message is str, encode it to bytes
if isinstance(message.message, str):
message.message = message.message.encode('utf-8')
file = ToolFileManager.create_file_by_raw(
user_id=user_id, tenant_id=tenant_id,
conversation_id=conversation_id,
file_binary=message.message,
mimetype=mimetype
)
url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".bin"}'
# check if file is image
if 'image' in mimetype:
result.append(ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
message=url,
save_as=message.save_as,
meta=message.meta.copy() if message.meta is not None else {},
))
else:
result.append(ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.LINK,
message=url,
save_as=message.save_as,
meta=message.meta.copy() if message.meta is not None else {},
))
else:
result.append(message)
return result
\ No newline at end of file
from abc import ABC, abstractmethod
from typing import Optional
from models.workflow import WorkflowNodeExecution, WorkflowRun
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
class BaseWorkflowCallback(ABC):
@abstractmethod
def on_workflow_run_started(self, workflow_run: WorkflowRun) -> None:
def on_workflow_run_started(self) -> None:
"""
Workflow run started
"""
raise NotImplementedError
@abstractmethod
def on_workflow_run_finished(self, workflow_run: WorkflowRun) -> None:
def on_workflow_run_succeeded(self) -> None:
"""
Workflow run finished
Workflow run succeeded
"""
raise NotImplementedError
@abstractmethod
def on_workflow_node_execute_started(self, workflow_node_execution: WorkflowNodeExecution) -> None:
def on_workflow_run_failed(self, error: str) -> None:
"""
Workflow run failed
"""
raise NotImplementedError
@abstractmethod
def on_workflow_node_execute_started(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
node_run_index: int = 1,
predecessor_node_id: Optional[str] = None) -> None:
"""
Workflow node execute started
"""
raise NotImplementedError
@abstractmethod
def on_workflow_node_execute_finished(self, workflow_node_execution: WorkflowNodeExecution) -> None:
def on_workflow_node_execute_succeeded(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
inputs: Optional[dict] = None,
process_data: Optional[dict] = None,
outputs: Optional[dict] = None,
execution_metadata: Optional[dict] = None) -> None:
"""
Workflow node execute finished
Workflow node execute succeeded
"""
raise NotImplementedError
@abstractmethod
def on_workflow_node_execute_failed(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
error: str) -> None:
"""
Workflow node execute failed
"""
raise NotImplementedError
......@@ -38,4 +67,3 @@ class BaseWorkflowCallback(ABC):
Publish text chunk
"""
raise NotImplementedError
......@@ -19,14 +19,17 @@ class ValueType(Enum):
class VariablePool:
variables_mapping = {}
user_inputs: dict
def __init__(self, system_variables: dict[SystemVariable, Any]) -> None:
def __init__(self, system_variables: dict[SystemVariable, Any],
user_inputs: dict) -> None:
# system variables
# for example:
# {
# 'query': 'abc',
# 'files': []
# }
self.user_inputs = user_inputs
for system_variable, value in system_variables.items():
self.append_variable('sys', [system_variable.value], value)
......
from typing import Optional
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from models.workflow import WorkflowNodeExecution, WorkflowRun
from core.workflow.nodes.base_node import BaseNode, UserFrom
from models.workflow import Workflow, WorkflowType
class WorkflowNodeAndResult:
node: BaseNode
result: Optional[NodeRunResult] = None
def __init__(self, node: BaseNode, result: Optional[NodeRunResult] = None):
self.node = node
self.result = result
class WorkflowRunState:
workflow_run: WorkflowRun
tenant_id: str
app_id: str
workflow_id: str
workflow_type: WorkflowType
user_id: str
user_from: UserFrom
start_at: float
user_inputs: dict
variable_pool: VariablePool
total_tokens: int = 0
workflow_node_executions: list[WorkflowNodeExecution] = []
workflow_nodes_and_results: list[WorkflowNodeAndResult] = []
def __init__(self, workflow_run: WorkflowRun,
def __init__(self, workflow: Workflow,
start_at: float,
user_inputs: dict,
variable_pool: VariablePool) -> None:
self.workflow_run = workflow_run
variable_pool: VariablePool,
user_id: str,
user_from: UserFrom):
self.workflow_id = workflow.id
self.tenant_id = workflow.tenant_id
self.app_id = workflow.app_id
self.workflow_type = WorkflowType.value_of(workflow.type)
self.user_id = user_id
self.user_from = user_from
self.start_at = start_at
self.user_inputs = user_inputs
self.variable_pool = variable_pool
from abc import ABC, abstractmethod
from enum import Enum
from typing import Optional
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
......@@ -8,18 +9,55 @@ from core.workflow.entities.variable_pool import VariablePool
from models.workflow import WorkflowNodeExecutionStatus
class UserFrom(Enum):
"""
User from
"""
ACCOUNT = "account"
END_USER = "end-user"
@classmethod
def value_of(cls, value: str) -> "UserFrom":
"""
Value of
:param value: value
:return:
"""
for item in cls:
if item.value == value:
return item
raise ValueError(f"Invalid value: {value}")
class BaseNode(ABC):
_node_data_cls: type[BaseNodeData]
_node_type: NodeType
tenant_id: str
app_id: str
workflow_id: str
user_id: str
user_from: UserFrom
node_id: str
node_data: BaseNodeData
node_run_result: Optional[NodeRunResult] = None
callbacks: list[BaseWorkflowCallback]
def __init__(self, config: dict,
def __init__(self, tenant_id: str,
app_id: str,
workflow_id: str,
user_id: str,
user_from: UserFrom,
config: dict,
callbacks: list[BaseWorkflowCallback] = None) -> None:
self.tenant_id = tenant_id
self.app_id = app_id
self.workflow_id = workflow_id
self.user_id = user_id
self.user_from = user_from
self.node_id = config.get("id")
if not self.node_id:
raise ValueError("Node ID is required.")
......@@ -28,31 +66,23 @@ class BaseNode(ABC):
self.callbacks = callbacks or []
@abstractmethod
def _run(self, variable_pool: Optional[VariablePool] = None,
run_args: Optional[dict] = None) -> NodeRunResult:
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
"""
Run node
:param variable_pool: variable pool
:param run_args: run args
:return:
"""
raise NotImplementedError
def run(self, variable_pool: Optional[VariablePool] = None,
run_args: Optional[dict] = None) -> NodeRunResult:
def run(self, variable_pool: VariablePool) -> NodeRunResult:
"""
Run node entry
:param variable_pool: variable pool
:param run_args: run args
:return:
"""
if variable_pool is None and run_args is None:
raise ValueError("At least one of `variable_pool` or `run_args` must be provided.")
try:
result = self._run(
variable_pool=variable_pool,
run_args=run_args
variable_pool=variable_pool
)
except Exception as e:
# process unhandled exception
......@@ -77,6 +107,26 @@ class BaseNode(ABC):
text=text
)
@classmethod
def extract_variable_selector_to_variable_mapping(cls, config: dict) -> dict:
"""
Extract variable selector to variable mapping
:param config: node config
:return:
"""
node_data = cls._node_data_cls(**config.get("data", {}))
return cls._extract_variable_selector_to_variable_mapping(node_data)
@classmethod
@abstractmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]:
"""
Extract variable selector to variable mapping
:param node_data: node data
:return:
"""
raise NotImplementedError
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
"""
......
from typing import Optional
from typing import Optional, Union, cast
from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.code.entities import CodeNodeData
from models.workflow import WorkflowNodeExecutionStatus
MAX_NUMBER = 2 ** 63 - 1
MIN_NUMBER = -2 ** 63
MAX_PRECISION = 20
MAX_DEPTH = 5
MAX_STRING_LENGTH = 1000
MAX_STRING_ARRAY_LENGTH = 30
MAX_NUMBER_ARRAY_LENGTH = 1000
class CodeNode(BaseNode):
_node_data_cls = CodeNodeData
node_type = NodeType.CODE
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
"""
......@@ -62,3 +78,175 @@ class CodeNode(BaseNode):
]
}
}
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
"""
Run code
:param variable_pool: variable pool
:return:
"""
node_data = self.node_data
node_data: CodeNodeData = cast(self._node_data_cls, node_data)
# Get code language
code_language = node_data.code_language
code = node_data.code
# Get variables
variables = {}
for variable_selector in node_data.variables:
variable = variable_selector.variable
value = variable_pool.get_variable_value(
variable_selector=variable_selector.value_selector
)
variables[variable] = value
# Run code
try:
result = CodeExecutor.execute_code(
language=code_language,
code=code,
inputs=variables
)
# Transform result
result = self._transform_result(result, node_data.outputs)
except (CodeExecutionException, ValueError) as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error=str(e)
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables,
outputs=result
)
def _check_string(self, value: str, variable: str) -> str:
"""
Check string
:param value: value
:param variable: variable
:return:
"""
if not isinstance(value, str):
raise ValueError(f"{variable} in output form must be a string")
if len(value) > MAX_STRING_LENGTH:
raise ValueError(f'{variable} in output form must be less than {MAX_STRING_LENGTH} characters')
return value.replace('\x00', '')
def _check_number(self, value: Union[int, float], variable: str) -> Union[int, float]:
"""
Check number
:param value: value
:param variable: variable
:return:
"""
if not isinstance(value, int | float):
raise ValueError(f"{variable} in output form must be a number")
if value > MAX_NUMBER or value < MIN_NUMBER:
raise ValueError(f'{variable} in input form is out of range.')
if isinstance(value, float):
value = round(value, MAX_PRECISION)
return value
def _transform_result(self, result: dict, output_schema: dict[str, CodeNodeData.Output],
prefix: str = '',
depth: int = 1) -> dict:
"""
Transform result
:param result: result
:param output_schema: output schema
:return:
"""
if depth > MAX_DEPTH:
raise ValueError("Depth limit reached, object too deep.")
transformed_result = {}
for output_name, output_config in output_schema.items():
if output_config.type == 'object':
# check if output is object
if not isinstance(result.get(output_name), dict):
raise ValueError(
f'Output {prefix}.{output_name} is not an object, got {type(result.get(output_name))} instead.'
)
transformed_result[output_name] = self._transform_result(
result=result[output_name],
output_schema=output_config.children,
prefix=f'{prefix}.{output_name}' if prefix else output_name,
depth=depth + 1
)
elif output_config.type == 'number':
# check if number available
transformed_result[output_name] = self._check_number(
value=result[output_name],
variable=f'{prefix}.{output_name}' if prefix else output_name
)
elif output_config.type == 'string':
# check if string available
transformed_result[output_name] = self._check_string(
value=result[output_name],
variable=f'{prefix}.{output_name}' if prefix else output_name,
)
elif output_config.type == 'array[number]':
# check if array of number available
if not isinstance(result[output_name], list):
raise ValueError(
f'Output {prefix}.{output_name} is not an array, got {type(result.get(output_name))} instead.'
)
if len(result[output_name]) > MAX_NUMBER_ARRAY_LENGTH:
raise ValueError(
f'{prefix}.{output_name} in output form must be less than {MAX_NUMBER_ARRAY_LENGTH} characters'
)
transformed_result[output_name] = [
self._check_number(
value=value,
variable=f'{prefix}.{output_name}' if prefix else output_name
)
for value in result[output_name]
]
elif output_config.type == 'array[string]':
# check if array of string available
if not isinstance(result[output_name], list):
raise ValueError(
f'Output {prefix}.{output_name} is not an array, got {type(result.get(output_name))} instead.'
)
if len(result[output_name]) > MAX_STRING_ARRAY_LENGTH:
raise ValueError(
f'{prefix}.{output_name} in output form must be less than {MAX_STRING_ARRAY_LENGTH} characters'
)
transformed_result[output_name] = [
self._check_string(
value=value,
variable=f'{prefix}.{output_name}' if prefix else output_name
)
for value in result[output_name]
]
else:
raise ValueError(f'Output type {output_config.type} is not supported.')
return transformed_result
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: CodeNodeData) -> dict[list[str], str]:
"""
Extract variable selector to variable mapping
:param node_data: node data
:return:
"""
return {
variable_selector.value_selector: variable_selector.variable for variable_selector in node_data.variables
}
\ No newline at end of file
from typing import Literal, Union
from pydantic import BaseModel
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.variable_entities import VariableSelector
class CodeNodeData(BaseNodeData):
"""
Code Node Data.
"""
class Output(BaseModel):
type: Literal['string', 'number', 'object', 'array[string]', 'array[number]']
children: Union[None, dict[str, 'Output']]
variables: list[VariableSelector]
answer: str
code_language: Literal['python3', 'javascript']
code: str
outputs: dict[str, Output]
import time
from typing import Optional, cast
from typing import cast
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import ValueType, VariablePool
from core.workflow.nodes.base_node import BaseNode
......@@ -13,20 +14,15 @@ class DirectAnswerNode(BaseNode):
_node_data_cls = DirectAnswerNodeData
node_type = NodeType.DIRECT_ANSWER
def _run(self, variable_pool: Optional[VariablePool] = None,
run_args: Optional[dict] = None) -> NodeRunResult:
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
"""
Run node
:param variable_pool: variable pool
:param run_args: run args
:return:
"""
node_data = self.node_data
node_data = cast(self._node_data_cls, node_data)
if variable_pool is None and run_args:
raise ValueError("Not support single step debug.")
variable_values = {}
for variable_selector in node_data.variables:
value = variable_pool.get_variable_value(
......@@ -52,3 +48,12 @@ class DirectAnswerNode(BaseNode):
"answer": answer
}
)
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]:
"""
Extract variable selector to variable mapping
:param node_data: node data
:return:
"""
return {}
from typing import Optional, cast
from typing import cast
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import ValueType, VariablePool
from core.workflow.nodes.base_node import BaseNode
......@@ -11,50 +12,54 @@ class EndNode(BaseNode):
_node_data_cls = EndNodeData
node_type = NodeType.END
def _run(self, variable_pool: Optional[VariablePool] = None,
run_args: Optional[dict] = None) -> NodeRunResult:
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
"""
Run node
:param variable_pool: variable pool
:param run_args: run args
:return:
"""
node_data = self.node_data
node_data = cast(self._node_data_cls, node_data)
outputs_config = node_data.outputs
if variable_pool is not None:
outputs = None
if outputs_config:
if outputs_config.type == EndNodeDataOutputs.OutputType.PLAIN_TEXT:
plain_text_selector = outputs_config.plain_text_selector
if plain_text_selector:
outputs = {
'text': variable_pool.get_variable_value(
variable_selector=plain_text_selector,
target_value_type=ValueType.STRING
)
}
else:
outputs = {
'text': ''
}
elif outputs_config.type == EndNodeDataOutputs.OutputType.STRUCTURED:
structured_variables = outputs_config.structured_variables
if structured_variables:
outputs = {}
for variable_selector in structured_variables:
variable_value = variable_pool.get_variable_value(
variable_selector=variable_selector.value_selector
)
outputs[variable_selector.variable] = variable_value
else:
outputs = {}
else:
raise ValueError("Not support single step debug.")
outputs = None
if outputs_config:
if outputs_config.type == EndNodeDataOutputs.OutputType.PLAIN_TEXT:
plain_text_selector = outputs_config.plain_text_selector
if plain_text_selector:
outputs = {
'text': variable_pool.get_variable_value(
variable_selector=plain_text_selector,
target_value_type=ValueType.STRING
)
}
else:
outputs = {
'text': ''
}
elif outputs_config.type == EndNodeDataOutputs.OutputType.STRUCTURED:
structured_variables = outputs_config.structured_variables
if structured_variables:
outputs = {}
for variable_selector in structured_variables:
variable_value = variable_pool.get_variable_value(
variable_selector=variable_selector.value_selector
)
outputs[variable_selector.variable] = variable_value
else:
outputs = {}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=outputs,
outputs=outputs
)
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]:
"""
Extract variable selector to variable mapping
:param node_data: node data
:return:
"""
return {}
from typing import Literal, Union
from pydantic import BaseModel
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.variable_entities import VariableSelector
class HttpRequestNodeData(BaseNodeData):
"""
Code Node Data.
"""
class Authorization(BaseModel):
class Config(BaseModel):
type: Literal[None, 'basic', 'bearer', 'custom']
api_key: Union[None, str]
header: Union[None, str]
type: Literal['no-auth', 'api-key']
config: Config
class Body(BaseModel):
type: Literal[None, 'form-data', 'x-www-form-urlencoded', 'raw', 'json']
data: Union[None, str]
variables: list[VariableSelector]
method: Literal['get', 'post', 'put', 'patch', 'delete']
url: str
authorization: Authorization
headers: str
params: str
body: Body
\ No newline at end of file
import re
from copy import deepcopy
from typing import Any, Union
from urllib.parse import urlencode
import httpx
import requests
import core.helper.ssrf_proxy as ssrf_proxy
from core.workflow.nodes.http_request.entities import HttpRequestNodeData
HTTP_REQUEST_DEFAULT_TIMEOUT = (10, 60)
class HttpExecutorResponse:
status_code: int
headers: dict[str, str]
body: str
def __init__(self, status_code: int, headers: dict[str, str], body: str):
"""
init
"""
self.status_code = status_code
self.headers = headers
self.body = body
class HttpExecutor:
server_url: str
method: str
authorization: HttpRequestNodeData.Authorization
params: dict[str, Any]
headers: dict[str, Any]
body: Union[None, str]
files: Union[None, dict[str, Any]]
def __init__(self, node_data: HttpRequestNodeData, variables: dict[str, Any]):
"""
init
"""
self.server_url = node_data.url
self.method = node_data.method
self.authorization = node_data.authorization
self.params = {}
self.headers = {}
self.body = None
# init template
self._init_template(node_data, variables)
def _init_template(self, node_data: HttpRequestNodeData, variables: dict[str, Any]):
"""
init template
"""
# extract all template in url
url_template = re.findall(r'{{(.*?)}}', node_data.url) or []
url_template = list(set(url_template))
original_url = node_data.url
for url in url_template:
if not url:
continue
original_url = original_url.replace(f'{{{{{url}}}}}', str(variables.get(url, '')))
self.server_url = original_url
# extract all template in params
param_template = re.findall(r'{{(.*?)}}', node_data.params) or []
param_template = list(set(param_template))
original_params = node_data.params
for param in param_template:
if not param:
continue
original_params = original_params.replace(f'{{{{{param}}}}}', str(variables.get(param, '')))
# fill in params
kv_paris = original_params.split('\n')
for kv in kv_paris:
kv = kv.split(':')
if len(kv) != 2:
raise ValueError(f'Invalid params {kv}')
k, v = kv
self.params[k] = v
# extract all template in headers
header_template = re.findall(r'{{(.*?)}}', node_data.headers) or []
header_template = list(set(header_template))
original_headers = node_data.headers
for header in header_template:
if not header:
continue
original_headers = original_headers.replace(f'{{{{{header}}}}}', str(variables.get(header, '')))
# fill in headers
kv_paris = original_headers.split('\n')
for kv in kv_paris:
kv = kv.split(':')
if len(kv) != 2:
raise ValueError(f'Invalid headers {kv}')
k, v = kv
self.headers[k] = v
# extract all template in body
body_template = re.findall(r'{{(.*?)}}', node_data.body.data or '') or []
body_template = list(set(body_template))
original_body = node_data.body.data or ''
for body in body_template:
if not body:
continue
original_body = original_body.replace(f'{{{{{body}}}}}', str(variables.get(body, '')))
if node_data.body.type == 'json':
self.headers['Content-Type'] = 'application/json'
elif node_data.body.type == 'x-www-form-urlencoded':
self.headers['Content-Type'] = 'application/x-www-form-urlencoded'
# elif node_data.body.type == 'form-data':
# self.headers['Content-Type'] = 'multipart/form-data'
if node_data.body.type in ['form-data', 'x-www-form-urlencoded']:
body = {}
kv_paris = original_body.split('\n')
for kv in kv_paris:
kv = kv.split(':')
if len(kv) != 2:
raise ValueError(f'Invalid body {kv}')
body[kv[0]] = kv[1]
if node_data.body.type == 'form-data':
self.files = {
k: ('', v) for k, v in body.items()
}
else:
self.body = urlencode(body)
else:
self.body = original_body
def _assembling_headers(self) -> dict[str, Any]:
authorization = deepcopy(self.authorization)
headers = deepcopy(self.headers) or []
if self.authorization.type == 'api-key':
if self.authorization.config.api_key is None:
raise ValueError('api_key is required')
if not self.authorization.config.header:
authorization.config.header = 'Authorization'
if self.authorization.config.type == 'bearer':
headers[authorization.config.header] = f'Bearer {authorization.config.api_key}'
elif self.authorization.config.type == 'basic':
headers[authorization.config.header] = f'Basic {authorization.config.api_key}'
elif self.authorization.config.type == 'custom':
headers[authorization.config.header] = authorization.config.api_key
return headers
def _validate_and_parse_response(self, response: Union[httpx.Response, requests.Response]) -> HttpExecutorResponse:
"""
validate the response
"""
if isinstance(response, httpx.Response):
# get key-value pairs headers
headers = {}
for k, v in response.headers.items():
headers[k] = v
return HttpExecutorResponse(response.status_code, headers, response.text)
elif isinstance(response, requests.Response):
# get key-value pairs headers
headers = {}
for k, v in response.headers.items():
headers[k] = v
return HttpExecutorResponse(response.status_code, headers, response.text)
else:
raise ValueError(f'Invalid response type {type(response)}')
def _do_http_request(self, headers: dict[str, Any]) -> httpx.Response:
"""
do http request depending on api bundle
"""
# do http request
kwargs = {
'url': self.server_url,
'headers': headers,
'params': self.params,
'timeout': HTTP_REQUEST_DEFAULT_TIMEOUT,
'follow_redirects': True
}
if self.method == 'get':
response = ssrf_proxy.get(**kwargs)
elif self.method == 'post':
response = ssrf_proxy.post(data=self.body, files=self.files, **kwargs)
elif self.method == 'put':
response = ssrf_proxy.put(data=self.body, files=self.files, **kwargs)
elif self.method == 'delete':
response = ssrf_proxy.delete(data=self.body, files=self.files, **kwargs)
elif self.method == 'patch':
response = ssrf_proxy.patch(data=self.body, files=self.files, **kwargs)
elif self.method == 'head':
response = ssrf_proxy.head(**kwargs)
elif self.method == 'options':
response = ssrf_proxy.options(**kwargs)
else:
raise ValueError(f'Invalid http method {self.method}')
return response
def invoke(self) -> HttpExecutorResponse:
"""
invoke http request
"""
# assemble headers
headers = self._assembling_headers()
# do http request
response = self._do_http_request(headers)
# validate response
return self._validate_and_parse_response(response)
def to_raw_request(self) -> str:
"""
convert to raw request
"""
server_url = self.server_url
if self.params:
server_url += f'?{urlencode(self.params)}'
raw_request = f'{self.method.upper()} {server_url} HTTP/1.1\n'
for k, v in self.headers.items():
raw_request += f'{k}: {v}\n'
raw_request += '\n'
raw_request += self.body or ''
return raw_request
\ No newline at end of file
from typing import cast
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.http_request.entities import HttpRequestNodeData
from core.workflow.nodes.http_request.http_executor import HttpExecutor
from models.workflow import WorkflowNodeExecutionStatus
class HttpRequestNode(BaseNode):
pass
_node_data_cls = HttpRequestNodeData
node_type = NodeType.HTTP_REQUEST
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
node_data: HttpRequestNodeData = cast(self._node_data_cls, self.node_data)
# extract variables
variables = {
variable_selector.variable: variable_pool.get_variable_value(variable_selector=variable_selector.value_selector)
for variable_selector in node_data.variables
}
# init http executor
try:
http_executor = HttpExecutor(node_data=node_data, variables=variables)
# invoke http executor
response = http_executor.invoke()
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error=str(e),
process_data=http_executor.to_raw_request()
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables,
outputs={
'status_code': response.status_code,
'body': response,
'headers': response.headers
},
process_data=http_executor.to_raw_request()
)
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: HttpRequestNodeData) -> dict[list[str], str]:
"""
Extract variable selector to variable mapping
:param node_data: node data
:return:
"""
return {
variable_selector.value_selector: variable_selector.variable for variable_selector in node_data.variables
}
\ No newline at end of file
from typing import Optional, cast
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
......@@ -10,12 +11,10 @@ class LLMNode(BaseNode):
_node_data_cls = LLMNodeData
node_type = NodeType.LLM
def _run(self, variable_pool: Optional[VariablePool] = None,
run_args: Optional[dict] = None) -> NodeRunResult:
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
"""
Run node
:param variable_pool: variable pool
:param run_args: run args
:return:
"""
node_data = self.node_data
......@@ -23,6 +22,17 @@ class LLMNode(BaseNode):
pass
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]:
"""
Extract variable selector to variable mapping
:param node_data: node data
:return:
"""
# TODO extract variable selector to variable mapping for single step debugging
return {}
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
"""
......
from typing import Optional, cast
from typing import cast
from core.app.app_config.entities import VariableEntity
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
......@@ -12,12 +13,10 @@ class StartNode(BaseNode):
_node_data_cls = StartNodeData
node_type = NodeType.START
def _run(self, variable_pool: Optional[VariablePool] = None,
run_args: Optional[dict] = None) -> NodeRunResult:
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
"""
Run node
:param variable_pool: variable pool
:param run_args: run args
:return:
"""
node_data = self.node_data
......@@ -25,7 +24,7 @@ class StartNode(BaseNode):
variables = node_data.variables
# Get cleaned inputs
cleaned_inputs = self._get_cleaned_inputs(variables, run_args)
cleaned_inputs = self._get_cleaned_inputs(variables, variable_pool.user_inputs)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
......@@ -68,3 +67,12 @@ class StartNode(BaseNode):
filtered_inputs[variable] = value.replace('\x00', '') if value else None
return filtered_inputs
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]:
"""
Extract variable selector to variable mapping
:param node_data: node data
:return:
"""
return {}
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.variable_entities import VariableSelector
class TemplateTransformNodeData(BaseNodeData):
"""
Code Node Data.
"""
variables: list[VariableSelector]
template: str
\ No newline at end of file
from typing import Optional
from typing import Optional, cast
from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
from models.workflow import WorkflowNodeExecutionStatus
class TemplateTransformNode(BaseNode):
_node_data_cls = TemplateTransformNodeData
_node_type = NodeType.TEMPLATE_TRANSFORM
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
"""
......@@ -23,3 +31,53 @@ class TemplateTransformNode(BaseNode):
"template": "{{ arg1 }}"
}
}
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
"""
Run node
"""
node_data = self.node_data
node_data: TemplateTransformNodeData = cast(self._node_data_cls, node_data)
# Get variables
variables = {}
for variable_selector in node_data.variables:
variable = variable_selector.variable
value = variable_pool.get_variable_value(
variable_selector=variable_selector.value_selector
)
variables[variable] = value
# Run code
try:
result = CodeExecutor.execute_code(
language='jinja2',
code=node_data.template,
inputs=variables
)
except CodeExecutionException as e:
return NodeRunResult(
inputs=variables,
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e)
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables,
outputs={
'output': result['result']
}
)
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: TemplateTransformNodeData) -> dict[list[str], str]:
"""
Extract variable selector to variable mapping
:param node_data: node data
:return:
"""
return {
variable_selector.value_selector: variable_selector.variable for variable_selector in node_data.variables
}
\ No newline at end of file
from typing import Literal, Optional, Union
from pydantic import BaseModel, validator
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.variable_entities import VariableSelector
ToolParameterValue = Union[str, int, float, bool]
class ToolEntity(BaseModel):
provider_id: str
provider_type: Literal['builtin', 'api']
provider_name: str # redundancy
tool_name: str
tool_label: str # redundancy
tool_configurations: dict[str, ToolParameterValue]
class ToolNodeData(BaseNodeData, ToolEntity):
class ToolInput(VariableSelector):
variable_type: Literal['selector', 'static']
value: Optional[str]
@validator('value')
def check_value(cls, value, values, **kwargs):
if values['variable_type'] == 'static' and value is None:
raise ValueError('value is required for static variable')
return value
"""
Tool Node Schema
"""
tool_parameters: list[ToolInput]
from os import path
from typing import cast
from core.file.file_obj import FileTransferMethod
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool_manager import ToolManager
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.tool.entities import ToolNodeData
from models.workflow import WorkflowNodeExecutionStatus
class ToolNode(BaseNode):
pass
"""
Tool Node
"""
_node_data_cls = ToolNodeData
_node_type = NodeType.TOOL
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
"""
Run the tool node
"""
node_data = cast(ToolNodeData, self.node_data)
# get parameters
parameters = self._generate_parameters(variable_pool, node_data)
# get tool runtime
try:
tool_runtime = ToolManager.get_workflow_tool_runtime(self.tenant_id, node_data, None)
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters,
error=f'Failed to get tool runtime: {str(e)}'
)
try:
# TODO: user_id
messages = tool_runtime.invoke(None, parameters)
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters,
error=f'Failed to invoke tool: {str(e)}'
)
# convert tool messages
plain_text, files = self._convert_tool_messages(messages)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={
'text': plain_text,
'files': files
},
)
def _generate_parameters(self, variable_pool: VariablePool, node_data: ToolNodeData) -> dict:
"""
Generate parameters
"""
return {
k.variable:
k.value if k.variable_type == 'static' else
variable_pool.get_variable_value(k.value) if k.variable_type == 'selector' else ''
for k in node_data.tool_parameters
}
def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str, list[dict]]:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
"""
# transform message and handle file storage
messages = ToolFileMessageTransformer.transform_tool_invoke_messages(messages)
# extract plain text and files
files = self._extract_tool_response_binary(messages)
plain_text = self._extract_tool_response_text(messages)
return plain_text, files
def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[dict]:
"""
Extract tool response binary
"""
result = []
for response in tool_response:
if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
response.type == ToolInvokeMessage.MessageType.IMAGE:
url = response.message
ext = path.splitext(url)[1]
mimetype = response.meta.get('mime_type', 'image/jpeg')
filename = response.save_as or url.split('/')[-1]
result.append({
'type': 'image',
'transfer_method': FileTransferMethod.TOOL_FILE,
'url': url,
'upload_file_id': None,
'filename': filename,
'file-ext': ext,
'mime-type': mimetype,
})
elif response.type == ToolInvokeMessage.MessageType.BLOB:
result.append({
'type': 'image', # TODO: only support image for now
'transfer_method': FileTransferMethod.TOOL_FILE,
'url': response.message,
'upload_file_id': None,
'filename': response.save_as,
'file-ext': path.splitext(response.save_as)[1],
'mime-type': response.meta.get('mime_type', 'application/octet-stream'),
})
elif response.type == ToolInvokeMessage.MessageType.LINK:
pass # TODO:
return result
def _extract_tool_response_text(self, tool_response: list[ToolInvokeMessage]) -> str:
"""
Extract tool response text
"""
return ''.join([
f'{message.message}\n' if message.type == ToolInvokeMessage.MessageType.TEXT else
f'Link: {message.message}\n' if message.type == ToolInvokeMessage.MessageType.LINK else ''
for message in tool_response
])
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]:
"""
Extract variable selector to variable mapping
"""
pass
\ No newline at end of file
import json
import time
from datetime import datetime
from typing import Optional, Union
from typing import Optional
from core.model_runtime.utils.encoders import jsonable_encoder
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool, VariableValue
from core.workflow.entities.workflow_entities import WorkflowRunState
from core.workflow.nodes.base_node import BaseNode
from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, WorkflowRunState
from core.workflow.nodes.base_node import BaseNode, UserFrom
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.direct_answer.direct_answer_node import DirectAnswerNode
from core.workflow.nodes.end.end_node import EndNode
......@@ -22,17 +20,9 @@ from core.workflow.nodes.template_transform.template_transform_node import Templ
from core.workflow.nodes.tool.tool_node import ToolNode
from core.workflow.nodes.variable_assigner.variable_assigner_node import VariableAssignerNode
from extensions.ext_database import db
from models.account import Account
from models.model import App, EndUser
from models.workflow import (
CreatedByRole,
Workflow,
WorkflowNodeExecution,
WorkflowNodeExecutionStatus,
WorkflowNodeExecutionTriggeredFrom,
WorkflowRun,
WorkflowRunStatus,
WorkflowRunTriggeredFrom,
WorkflowType,
)
......@@ -53,20 +43,6 @@ node_classes = {
class WorkflowEngineManager:
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
"""
Get workflow
"""
# fetch workflow by workflow_id
workflow = db.session.query(Workflow).filter(
Workflow.tenant_id == app_model.tenant_id,
Workflow.app_id == app_model.id,
Workflow.id == workflow_id
).first()
# return workflow
return workflow
def get_default_configs(self) -> list[dict]:
"""
Get default block configs
......@@ -100,16 +76,16 @@ class WorkflowEngineManager:
return default_config
def run_workflow(self, workflow: Workflow,
triggered_from: WorkflowRunTriggeredFrom,
user: Union[Account, EndUser],
user_id: str,
user_from: UserFrom,
user_inputs: dict,
system_inputs: Optional[dict] = None,
callbacks: list[BaseWorkflowCallback] = None) -> None:
"""
Run workflow
:param workflow: Workflow instance
:param triggered_from: triggered from
:param user: account or end user
:param user_id: user id
:param user_from: user from
:param user_inputs: user variables inputs
:param system_inputs: system inputs, like: query, files
:param callbacks: workflow callbacks
......@@ -130,23 +106,20 @@ class WorkflowEngineManager:
raise ValueError('edges in workflow graph must be a list')
# init workflow run
workflow_run = self._init_workflow_run(
workflow=workflow,
triggered_from=triggered_from,
user=user,
user_inputs=user_inputs,
system_inputs=system_inputs,
callbacks=callbacks
)
if callbacks:
for callback in callbacks:
callback.on_workflow_run_started()
# init workflow run state
workflow_run_state = WorkflowRunState(
workflow_run=workflow_run,
workflow=workflow,
start_at=time.perf_counter(),
user_inputs=user_inputs,
variable_pool=VariablePool(
system_variables=system_inputs,
)
user_inputs=user_inputs
),
user_id=user_id,
user_from=user_from
)
try:
......@@ -155,6 +128,7 @@ class WorkflowEngineManager:
while True:
# get next node, multiple target nodes in the future
next_node = self._get_next_node(
workflow_run_state=workflow_run_state,
graph=graph,
predecessor_node=predecessor_node,
callbacks=callbacks
......@@ -166,7 +140,7 @@ class WorkflowEngineManager:
has_entry_node = True
# max steps 30 reached
if len(workflow_run_state.workflow_node_executions) > 30:
if len(workflow_run_state.workflow_nodes_and_results) > 30:
raise ValueError('Max steps 30 reached.')
# or max execution time 10min reached
......@@ -188,14 +162,14 @@ class WorkflowEngineManager:
if not has_entry_node:
self._workflow_run_failed(
workflow_run_state=workflow_run_state,
error='Start node not found in workflow graph.',
callbacks=callbacks
)
return
except GenerateTaskStoppedException as e:
return
except Exception as e:
self._workflow_run_failed(
workflow_run_state=workflow_run_state,
error=str(e),
callbacks=callbacks
)
......@@ -203,114 +177,36 @@ class WorkflowEngineManager:
# workflow run success
self._workflow_run_success(
workflow_run_state=workflow_run_state,
callbacks=callbacks
)
def _init_workflow_run(self, workflow: Workflow,
triggered_from: WorkflowRunTriggeredFrom,
user: Union[Account, EndUser],
user_inputs: dict,
system_inputs: Optional[dict] = None,
callbacks: list[BaseWorkflowCallback] = None) -> WorkflowRun:
"""
Init workflow run
:param workflow: Workflow instance
:param triggered_from: triggered from
:param user: account or end user
:param user_inputs: user variables inputs
:param system_inputs: system inputs, like: query, files
:param callbacks: workflow callbacks
:return:
"""
max_sequence = db.session.query(db.func.max(WorkflowRun.sequence_number)) \
.filter(WorkflowRun.tenant_id == workflow.tenant_id) \
.filter(WorkflowRun.app_id == workflow.app_id) \
.scalar() or 0
new_sequence_number = max_sequence + 1
# init workflow run
workflow_run = WorkflowRun(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
sequence_number=new_sequence_number,
workflow_id=workflow.id,
type=workflow.type,
triggered_from=triggered_from.value,
version=workflow.version,
graph=workflow.graph,
inputs=json.dumps({**user_inputs, **jsonable_encoder(system_inputs)}),
status=WorkflowRunStatus.RUNNING.value,
created_by_role=(CreatedByRole.ACCOUNT.value
if isinstance(user, Account) else CreatedByRole.END_USER.value),
created_by=user.id
)
db.session.add(workflow_run)
db.session.commit()
if callbacks:
for callback in callbacks:
callback.on_workflow_run_started(workflow_run)
return workflow_run
def _workflow_run_success(self, workflow_run_state: WorkflowRunState,
callbacks: list[BaseWorkflowCallback] = None) -> WorkflowRun:
def _workflow_run_success(self, callbacks: list[BaseWorkflowCallback] = None) -> None:
"""
Workflow run success
:param workflow_run_state: workflow run state
:param callbacks: workflow callbacks
:return:
"""
workflow_run = workflow_run_state.workflow_run
workflow_run.status = WorkflowRunStatus.SUCCEEDED.value
# fetch last workflow_node_executions
last_workflow_node_execution = workflow_run_state.workflow_node_executions[-1]
if last_workflow_node_execution:
workflow_run.outputs = last_workflow_node_execution.outputs
workflow_run.elapsed_time = time.perf_counter() - workflow_run_state.start_at
workflow_run.total_tokens = workflow_run_state.total_tokens
workflow_run.total_steps = len(workflow_run_state.workflow_node_executions)
workflow_run.finished_at = datetime.utcnow()
db.session.commit()
if callbacks:
for callback in callbacks:
callback.on_workflow_run_finished(workflow_run)
return workflow_run
callback.on_workflow_run_succeeded()
def _workflow_run_failed(self, workflow_run_state: WorkflowRunState,
error: str,
callbacks: list[BaseWorkflowCallback] = None) -> WorkflowRun:
def _workflow_run_failed(self, error: str,
callbacks: list[BaseWorkflowCallback] = None) -> None:
"""
Workflow run failed
:param workflow_run_state: workflow run state
:param error: error message
:param callbacks: workflow callbacks
:return:
"""
workflow_run = workflow_run_state.workflow_run
workflow_run.status = WorkflowRunStatus.FAILED.value
workflow_run.error = error
workflow_run.elapsed_time = time.perf_counter() - workflow_run_state.start_at
workflow_run.total_tokens = workflow_run_state.total_tokens
workflow_run.total_steps = len(workflow_run_state.workflow_node_executions)
workflow_run.finished_at = datetime.utcnow()
db.session.commit()
if callbacks:
for callback in callbacks:
callback.on_workflow_run_finished(workflow_run)
return workflow_run
callback.on_workflow_run_failed(
error=error
)
def _get_next_node(self, graph: dict,
def _get_next_node(self, workflow_run_state: WorkflowRunState,
graph: dict,
predecessor_node: Optional[BaseNode] = None,
callbacks: list[BaseWorkflowCallback] = None) -> Optional[BaseNode]:
"""
......@@ -328,7 +224,15 @@ class WorkflowEngineManager:
if not predecessor_node:
for node_config in nodes:
if node_config.get('data', {}).get('type', '') == NodeType.START.value:
return StartNode(config=node_config)
return StartNode(
tenant_id=workflow_run_state.tenant_id,
app_id=workflow_run_state.app_id,
workflow_id=workflow_run_state.workflow_id,
user_id=workflow_run_state.user_id,
user_from=workflow_run_state.user_from,
config=node_config,
callbacks=callbacks
)
else:
edges = graph.get('edges')
source_node_id = predecessor_node.node_id
......@@ -368,6 +272,11 @@ class WorkflowEngineManager:
target_node = node_classes.get(NodeType.value_of(target_node_config.get('data', {}).get('type')))
return target_node(
tenant_id=workflow_run_state.tenant_id,
app_id=workflow_run_state.app_id,
workflow_id=workflow_run_state.workflow_id,
user_id=workflow_run_state.user_id,
user_from=workflow_run_state.user_from,
config=target_node_config,
callbacks=callbacks
)
......@@ -384,46 +293,62 @@ class WorkflowEngineManager:
def _run_workflow_node(self, workflow_run_state: WorkflowRunState,
node: BaseNode,
predecessor_node: Optional[BaseNode] = None,
callbacks: list[BaseWorkflowCallback] = None) -> WorkflowNodeExecution:
# init workflow node execution
start_at = time.perf_counter()
workflow_node_execution = self._init_node_execution_from_workflow_run(
workflow_run_state=workflow_run_state,
callbacks: list[BaseWorkflowCallback] = None) -> None:
if callbacks:
for callback in callbacks:
callback.on_workflow_node_execute_started(
node_id=node.node_id,
node_type=node.node_type,
node_data=node.node_data,
node_run_index=len(workflow_run_state.workflow_nodes_and_results) + 1,
predecessor_node_id=predecessor_node.node_id if predecessor_node else None
)
db.session.close()
workflow_nodes_and_result = WorkflowNodeAndResult(
node=node,
predecessor_node=predecessor_node,
callbacks=callbacks
result=None
)
# add to workflow node executions
workflow_run_state.workflow_node_executions.append(workflow_node_execution)
# add to workflow_nodes_and_results
workflow_run_state.workflow_nodes_and_results.append(workflow_nodes_and_result)
# run node, result must have inputs, process_data, outputs, execution_metadata
node_run_result = node.run(
variable_pool=workflow_run_state.variable_pool,
run_args=workflow_run_state.user_inputs
if (not predecessor_node and node.node_type == NodeType.START) else None # only on start node
variable_pool=workflow_run_state.variable_pool
)
if node_run_result.status == WorkflowNodeExecutionStatus.FAILED:
# node run failed
self._workflow_node_execution_failed(
workflow_node_execution=workflow_node_execution,
start_at=start_at,
error=node_run_result.error,
callbacks=callbacks
)
if callbacks:
for callback in callbacks:
callback.on_workflow_node_execute_failed(
node_id=node.node_id,
node_type=node.node_type,
node_data=node.node_data,
error=node_run_result.error
)
raise ValueError(f"Node {node.node_data.title} run failed: {node_run_result.error}")
# set end node output if in chat
self._set_end_node_output_if_in_chat(workflow_run_state, node, node_run_result)
workflow_nodes_and_result.result = node_run_result
# node run success
self._workflow_node_execution_success(
workflow_node_execution=workflow_node_execution,
start_at=start_at,
result=node_run_result,
callbacks=callbacks
)
if callbacks:
for callback in callbacks:
callback.on_workflow_node_execute_succeeded(
node_id=node.node_id,
node_type=node.node_type,
node_data=node.node_data,
inputs=node_run_result.inputs,
process_data=node_run_result.process_data,
outputs=node_run_result.outputs,
execution_metadata=node_run_result.metadata
)
if node_run_result.outputs:
for variable_key, variable_value in node_run_result.outputs.items():
......@@ -438,105 +363,11 @@ class WorkflowEngineManager:
if node_run_result.metadata and node_run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
workflow_run_state.total_tokens += int(node_run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS))
return workflow_node_execution
def _init_node_execution_from_workflow_run(self, workflow_run_state: WorkflowRunState,
node: BaseNode,
predecessor_node: Optional[BaseNode] = None,
callbacks: list[BaseWorkflowCallback] = None) -> WorkflowNodeExecution:
"""
Init workflow node execution from workflow run
:param workflow_run_state: workflow run state
:param node: current node
:param predecessor_node: predecessor node if exists
:param callbacks: workflow callbacks
:return:
"""
workflow_run = workflow_run_state.workflow_run
# init workflow node execution
workflow_node_execution = WorkflowNodeExecution(
tenant_id=workflow_run.tenant_id,
app_id=workflow_run.app_id,
workflow_id=workflow_run.workflow_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
workflow_run_id=workflow_run.id,
predecessor_node_id=predecessor_node.node_id if predecessor_node else None,
index=len(workflow_run_state.workflow_node_executions) + 1,
node_id=node.node_id,
node_type=node.node_type.value,
title=node.node_data.title,
status=WorkflowNodeExecutionStatus.RUNNING.value,
created_by_role=workflow_run.created_by_role,
created_by=workflow_run.created_by
)
db.session.add(workflow_node_execution)
db.session.commit()
if callbacks:
for callback in callbacks:
callback.on_workflow_node_execute_started(workflow_node_execution)
return workflow_node_execution
def _workflow_node_execution_success(self, workflow_node_execution: WorkflowNodeExecution,
start_at: float,
result: NodeRunResult,
callbacks: list[BaseWorkflowCallback] = None) -> WorkflowNodeExecution:
"""
Workflow node execution success
:param workflow_node_execution: workflow node execution
:param start_at: start time
:param result: node run result
:param callbacks: workflow callbacks
:return:
"""
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
workflow_node_execution.elapsed_time = time.perf_counter() - start_at
workflow_node_execution.inputs = json.dumps(result.inputs) if result.inputs else None
workflow_node_execution.process_data = json.dumps(result.process_data) if result.process_data else None
workflow_node_execution.outputs = json.dumps(result.outputs) if result.outputs else None
workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(result.metadata)) \
if result.metadata else None
workflow_node_execution.finished_at = datetime.utcnow()
db.session.commit()
if callbacks:
for callback in callbacks:
callback.on_workflow_node_execute_finished(workflow_node_execution)
return workflow_node_execution
def _workflow_node_execution_failed(self, workflow_node_execution: WorkflowNodeExecution,
start_at: float,
error: str,
callbacks: list[BaseWorkflowCallback] = None) -> WorkflowNodeExecution:
"""
Workflow node execution failed
:param workflow_node_execution: workflow node execution
:param start_at: start time
:param error: error message
:param callbacks: workflow callbacks
:return:
"""
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.error = error
workflow_node_execution.elapsed_time = time.perf_counter() - start_at
workflow_node_execution.finished_at = datetime.utcnow()
db.session.commit()
if callbacks:
for callback in callbacks:
callback.on_workflow_node_execute_finished(workflow_node_execution)
return workflow_node_execution
db.session.close()
def _set_end_node_output_if_in_chat(self, workflow_run_state: WorkflowRunState,
node: BaseNode,
node_run_result: NodeRunResult):
node_run_result: NodeRunResult) -> None:
"""
Set end node output if in chat
:param workflow_run_state: workflow run state
......@@ -544,21 +375,19 @@ class WorkflowEngineManager:
:param node_run_result: node run result
:return:
"""
if workflow_run_state.workflow_run.type == WorkflowType.CHAT.value and node.node_type == NodeType.END:
workflow_node_execution_before_end = workflow_run_state.workflow_node_executions[-2]
if workflow_node_execution_before_end:
if workflow_node_execution_before_end.node_type == NodeType.LLM.value:
if workflow_run_state.workflow_type == WorkflowType.CHAT and node.node_type == NodeType.END:
workflow_nodes_and_result_before_end = workflow_run_state.workflow_nodes_and_results[-2]
if workflow_nodes_and_result_before_end:
if workflow_nodes_and_result_before_end.node.node_type == NodeType.LLM:
if not node_run_result.outputs:
node_run_result.outputs = {}
node_run_result.outputs['text'] = workflow_node_execution_before_end.outputs_dict.get('text')
elif workflow_node_execution_before_end.node_type == NodeType.DIRECT_ANSWER.value:
node_run_result.outputs['text'] = workflow_nodes_and_result_before_end.result.outputs.get('text')
elif workflow_nodes_and_result_before_end.node.node_type == NodeType.DIRECT_ANSWER:
if not node_run_result.outputs:
node_run_result.outputs = {}
node_run_result.outputs['text'] = workflow_node_execution_before_end.outputs_dict.get('answer')
return node_run_result
node_run_result.outputs['text'] = workflow_nodes_and_result_before_end.result.outputs.get('answer')
def _append_variables_recursively(self, variable_pool: VariablePool,
node_id: str,
......
......@@ -32,8 +32,6 @@ class Mail:
from libs.smtp import SMTPClient
if not app.config.get('SMTP_SERVER') or not app.config.get('SMTP_PORT'):
raise ValueError('SMTP_SERVER and SMTP_PORT are required for smtp mail type')
if not app.config.get('SMTP_USERNAME') or not app.config.get('SMTP_PASSWORD'):
raise ValueError('SMTP_USERNAME and SMTP_PASSWORD are required for smtp mail type')
self._client = SMTPClient(
server=app.config.get('SMTP_SERVER'),
port=app.config.get('SMTP_PORT'),
......
......@@ -16,7 +16,8 @@ class SMTPClient:
smtp = smtplib.SMTP(self.server, self.port)
if self._use_tls:
smtp.starttls()
smtp.login(self.username, self.password)
if (self.username):
smtp.login(self.username, self.password)
msg = MIMEMultipart()
msg['Subject'] = mail['subject']
msg['From'] = self._from
......
......@@ -322,7 +322,7 @@ class AppModelConfig(db.Model):
}
def from_model_config_dict(self, model_config: dict):
self.opening_statement = model_config['opening_statement']
self.opening_statement = model_config.get('opening_statement')
self.suggested_questions = json.dumps(model_config['suggested_questions']) \
if model_config.get('suggested_questions') else None
self.suggested_questions_after_answer = json.dumps(model_config['suggested_questions_after_answer']) \
......
......@@ -433,6 +433,29 @@ class WorkflowNodeExecution(db.Model):
def execution_metadata_dict(self):
return self.execution_metadata if not self.execution_metadata else json.loads(self.execution_metadata)
class WorkflowAppLogCreatedFrom(Enum):
"""
Workflow App Log Created From Enum
"""
SERVICE_API = 'service-api'
WEB_APP = 'web-app'
INSTALLED_APP = 'installed-app'
@classmethod
def value_of(cls, value: str) -> 'WorkflowAppLogCreatedFrom':
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f'invalid workflow app log created from value {value}')
class WorkflowAppLog(db.Model):
"""
Workflow App execution log, excluding workflow debugging records.
......
......@@ -15,7 +15,7 @@ from events.tenant_event import tenant_was_created
from extensions.ext_redis import redis_client
from libs.helper import get_remote_ip
from libs.passport import PassportService
from libs.password import compare_password, hash_password
from libs.password import compare_password, hash_password, valid_password
from libs.rsa import generate_key_pair
from models.account import *
from services.errors.account import (
......@@ -58,7 +58,7 @@ class AccountService:
account.current_tenant_id = available_ta.tenant_id
available_ta.current = True
db.session.commit()
if datetime.utcnow() - account.last_active_at > timedelta(minutes=10):
account.last_active_at = datetime.utcnow()
db.session.commit()
......@@ -104,6 +104,9 @@ class AccountService:
if account.password and not compare_password(password, account.password, account.password_salt):
raise CurrentPasswordIncorrectError("Current password is incorrect.")
# may be raised
valid_password(new_password)
# generate password salt
salt = secrets.token_bytes(16)
base64_salt = base64.b64encode(salt).decode()
......@@ -140,9 +143,9 @@ class AccountService:
account.interface_language = interface_language
account.interface_theme = interface_theme
# Set timezone based on language
account.timezone = language_timezone_mapping.get(interface_language, 'UTC')
account.timezone = language_timezone_mapping.get(interface_language, 'UTC')
db.session.add(account)
db.session.commit()
......@@ -279,7 +282,7 @@ class TenantService:
tenant_account_join = TenantAccountJoin.query.filter_by(account_id=account.id, tenant_id=tenant_id).first()
if not tenant_account_join:
raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.")
else:
else:
TenantAccountJoin.query.filter(TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id).update({'current': False})
tenant_account_join.current = True
# Set the current tenant for the account
......@@ -449,7 +452,7 @@ class RegisterService:
return account
@classmethod
def invite_new_member(cls, tenant: Tenant, email: str, language: str, role: str = 'normal', inviter: Account = None) -> str:
def invite_new_member(cls, tenant: Tenant, email: str, language: str, role: str = 'normal', inviter: Account = None) -> str:
"""Invite new member"""
account = Account.query.filter_by(email=email).first()
......
......@@ -30,16 +30,16 @@ class CompletionService:
invoke_from=invoke_from,
stream=streaming
)
elif app_model.mode == AppMode.CHAT.value:
return ChatAppGenerator().generate(
elif app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent:
return AgentChatAppGenerator().generate(
app_model=app_model,
user=user,
args=args,
invoke_from=invoke_from,
stream=streaming
)
elif app_model.mode == AppMode.AGENT_CHAT.value:
return AgentChatAppGenerator().generate(
elif app_model.mode == AppMode.CHAT.value:
return ChatAppGenerator().generate(
app_model=app_model,
user=user,
args=args,
......
......@@ -17,7 +17,7 @@ from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidatio
from core.tools.provider.api_tool_provider import ApiBasedToolProviderController
from core.tools.provider.tool_provider import ToolProviderController
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolConfiguration
from core.tools.utils.configuration import ToolConfigurationManager
from core.tools.utils.encoder import serialize_base_model_array, serialize_base_model_dict
from core.tools.utils.parser import ApiBasedToolSchemaParser
from extensions.ext_database import db
......@@ -77,7 +77,7 @@ class ToolManageService:
provider_controller: ToolProviderController = ToolManager.get_builtin_provider(provider)
tools = provider_controller.get_tools()
tool_provider_configurations = ToolConfiguration(tenant_id=tenant_id, provider_controller=provider_controller)
tool_provider_configurations = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
# check if user has added the provider
builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
BuiltinToolProvider.tenant_id == tenant_id,
......@@ -279,7 +279,7 @@ class ToolManageService:
provider_controller.load_bundled_tools(tool_bundles)
# encrypt credentials
tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=provider_controller)
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
encrypted_credentials = tool_configuration.encrypt_tool_credentials(credentials)
db_provider.credentials_str = json.dumps(encrypted_credentials)
......@@ -366,7 +366,7 @@ class ToolManageService:
provider_controller = ToolManager.get_builtin_provider(provider_name)
if not provider_controller.need_credentials:
raise ValueError(f'provider {provider_name} does not need credentials')
tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=provider_controller)
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
# get original credentials if exists
if provider is not None:
original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
......@@ -450,7 +450,7 @@ class ToolManageService:
provider_controller.load_bundled_tools(tool_bundles)
# get original credentials if exists
tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=provider_controller)
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
......@@ -490,7 +490,7 @@ class ToolManageService:
# delete cache
provider_controller = ToolManager.get_builtin_provider(provider_name)
tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=provider_controller)
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
tool_configuration.delete_tool_credentials_cache()
return { 'result': 'success' }
......@@ -632,7 +632,7 @@ class ToolManageService:
# decrypt credentials
if db_provider.id:
tool_configuration = ToolConfiguration(
tool_configuration = ToolConfigurationManager(
tenant_id=tenant_id,
provider_controller=provider_controller
)
......
......@@ -5,6 +5,7 @@ from typing import Optional, Union
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
......@@ -44,10 +45,14 @@ class WorkflowService:
if not app_model.workflow_id:
return None
workflow_engine_manager = WorkflowEngineManager()
# fetch published workflow by workflow_id
return workflow_engine_manager.get_workflow(app_model, app_model.workflow_id)
workflow = db.session.query(Workflow).filter(
Workflow.tenant_id == app_model.tenant_id,
Workflow.app_id == app_model.id,
Workflow.id == app_model.workflow_id
).first()
return workflow
def sync_draft_workflow(self, app_model: App,
graph: dict,
......@@ -201,6 +206,14 @@ class WorkflowService:
return response
def stop_workflow_task(self, task_id: str,
user: Union[Account, EndUser],
invoke_from: InvokeFrom) -> None:
"""
Stop workflow task
"""
AppQueueManager.set_stop_flag(task_id, invoke_from, user.id)
def convert_to_workflow(self, app_model: App, account: Account) -> App:
"""
Basic mode of chatbot app(expert mode) to workflow
......
......@@ -66,4 +66,8 @@ JINA_API_KEY=
OLLAMA_BASE_URL=
# Mock Switch
MOCK_SWITCH=false
\ No newline at end of file
MOCK_SWITCH=false
# CODE EXECUTION CONFIGURATION
CODE_EXECUTION_ENDPOINT=
CODE_EXECUTINO_API_KEY=
\ No newline at end of file
import os
import pytest
from typing import Literal
from _pytest.monkeypatch import MonkeyPatch
from core.helper.code_executor.code_executor import CodeExecutor
MOCK = os.getenv('MOCK_SWITCH', 'false') == 'true'
class MockedCodeExecutor:
@classmethod
def invoke(cls, language: Literal['python3', 'javascript', 'jinja2'], code: str, inputs: dict) -> dict:
# invoke directly
if language == 'python3':
return {
"result": 3
}
@pytest.fixture
def setup_code_executor_mock(request, monkeypatch: MonkeyPatch):
if not MOCK:
yield
return
monkeypatch.setattr(CodeExecutor, "execute_code", MockedCodeExecutor.invoke)
yield
monkeypatch.undo()
import pytest
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.code.code_node import CodeNode
from models.workflow import WorkflowNodeExecutionStatus, WorkflowRunStatus
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
@pytest.mark.parametrize('setup_code_executor_mock', [['none']], indirect=True)
def test_execute_code(setup_code_executor_mock):
code = '''
def main(args1: int, args2: int) -> dict:
return {
"result": args1 + args2,
}
'''
# trim first 4 spaces at the beginning of each line
code = '\n'.join([line[4:] for line in code.split('\n')])
node = CodeNode(config={
'id': '1',
'data': {
'outputs': {
'result': {
'type': 'number',
},
},
'title': '123',
'variables': [
{
'variable': 'args1',
'value_selector': ['1', '123', 'args1'],
},
{
'variable': 'args2',
'value_selector': ['1', '123', 'args2']
}
],
'answer': '123',
'code_language': 'python3',
'code': code
}
})
# construct variable pool
pool = VariablePool(system_variables={}, user_inputs={})
pool.append_variable(node_id='1', variable_key_list=['123', 'args1'], value=1)
pool.append_variable(node_id='1', variable_key_list=['123', 'args2'], value=2)
# execute node
result = node.run(pool)
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs['result'] == 3
assert result.error is None
@pytest.mark.parametrize('setup_code_executor_mock', [['none']], indirect=True)
def test_execute_code_output_validator(setup_code_executor_mock):
code = '''
def main(args1: int, args2: int) -> dict:
return {
"result": args1 + args2,
}
'''
# trim first 4 spaces at the beginning of each line
code = '\n'.join([line[4:] for line in code.split('\n')])
node = CodeNode(config={
'id': '1',
'data': {
"outputs": {
"result": {
"type": "string",
},
},
'title': '123',
'variables': [
{
'variable': 'args1',
'value_selector': ['1', '123', 'args1'],
},
{
'variable': 'args2',
'value_selector': ['1', '123', 'args2']
}
],
'answer': '123',
'code_language': 'python3',
'code': code
}
})
# construct variable pool
pool = VariablePool(system_variables={}, user_inputs={})
pool.append_variable(node_id='1', variable_key_list=['123', 'args1'], value=1)
pool.append_variable(node_id='1', variable_key_list=['123', 'args2'], value=2)
# execute node
result = node.run(pool)
assert result.status == WorkflowNodeExecutionStatus.FAILED
assert result.error == 'result in output form must be a string'
def test_execute_code_output_validator_depth():
code = '''
def main(args1: int, args2: int) -> dict:
return {
"result": {
"result": args1 + args2,
}
}
'''
# trim first 4 spaces at the beginning of each line
code = '\n'.join([line[4:] for line in code.split('\n')])
node = CodeNode(config={
'id': '1',
'data': {
"outputs": {
"string_validator": {
"type": "string",
},
"number_validator": {
"type": "number",
},
"number_array_validator": {
"type": "array[number]",
},
"string_array_validator": {
"type": "array[string]",
},
"object_validator": {
"type": "object",
"children": {
"result": {
"type": "number",
},
"depth": {
"type": "object",
"children": {
"depth": {
"type": "object",
"children": {
"depth": {
"type": "number",
}
}
}
}
}
}
},
},
'title': '123',
'variables': [
{
'variable': 'args1',
'value_selector': ['1', '123', 'args1'],
},
{
'variable': 'args2',
'value_selector': ['1', '123', 'args2']
}
],
'answer': '123',
'code_language': 'python3',
'code': code
}
})
# construct result
result = {
"number_validator": 1,
"string_validator": "1",
"number_array_validator": [1, 2, 3, 3.333],
"string_array_validator": ["1", "2", "3"],
"object_validator": {
"result": 1,
"depth": {
"depth": {
"depth": 1
}
}
}
}
# validate
node._transform_result(result, node.node_data.outputs)
# construct result
result = {
"number_validator": "1",
"string_validator": 1,
"number_array_validator": ["1", "2", "3", "3.333"],
"string_array_validator": [1, 2, 3],
"object_validator": {
"result": "1",
"depth": {
"depth": {
"depth": "1"
}
}
}
}
# validate
with pytest.raises(ValueError):
node._transform_result(result, node.node_data.outputs)
# construct result
result = {
"number_validator": 1,
"string_validator": "1" * 2000,
"number_array_validator": [1, 2, 3, 3.333],
"string_array_validator": ["1", "2", "3"],
"object_validator": {
"result": 1,
"depth": {
"depth": {
"depth": 1
}
}
}
}
# validate
with pytest.raises(ValueError):
node._transform_result(result, node.node_data.outputs)
# construct result
result = {
"number_validator": 1,
"string_validator": "1",
"number_array_validator": [1, 2, 3, 3.333] * 2000,
"string_array_validator": ["1", "2", "3"],
"object_validator": {
"result": 1,
"depth": {
"depth": {
"depth": 1
}
}
}
}
# validate
with pytest.raises(ValueError):
node._transform_result(result, node.node_data.outputs)
\ No newline at end of file
......@@ -62,8 +62,10 @@ const ActivateForm = () => {
showErrorMessage(t('login.error.passwordEmpty'))
return false
}
if (!validPassword.test(password))
if (!validPassword.test(password)) {
showErrorMessage(t('login.error.passwordInvalid'))
return false
}
return true
}, [name, password, showErrorMessage, t])
......
......@@ -24,7 +24,7 @@ const WarningMask: FC<IWarningMaskProps> = ({
return (
<div className={`${s.mask} absolute z-10 inset-0 pt-16`}
>
<div className='mx-auto w-[535px]'>
<div className='mx-auto px-10'>
<div className={`${s.icon} flex items-center justify-center w-11 h-11 rounded-xl bg-white`}>{warningIcon}</div>
<div className='mt-4 text-[24px] leading-normal font-semibold text-gray-800'>
{title}
......
......@@ -26,6 +26,7 @@ import { ArrowNarrowRight } from '@/app/components/base/icons/src/vender/line/ar
import { useEventEmitterContextContext } from '@/context/event-emitter'
import { ADD_EXTERNAL_DATA_TOOL } from '@/app/components/app/configuration/config-var'
import { INSERT_VARIABLE_VALUE_BLOCK_COMMAND } from '@/app/components/base/prompt-editor/plugins/variable-block'
import { PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER } from '@/app/components/base/prompt-editor/plugins/update-block'
export type ISimplePromptInput = {
mode: AppType
......@@ -125,6 +126,10 @@ const Prompt: FC<ISimplePromptInput> = ({
if (mode === AppType.chat)
setIntroduction(res.opening_statement)
showAutomaticFalse()
eventEmitter?.emit({
type: PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER,
payload: res.prompt,
} as any)
}
const minHeight = 228
const [editorHeight, setEditorHeight] = useState(minHeight)
......
......@@ -12,6 +12,7 @@ import { SimpleSelect } from '@/app/components/base/select'
import type { AppDetailResponse } from '@/models/app'
import type { Language } from '@/types/app'
import EmojiPicker from '@/app/components/base/emoji-picker'
import { useToastContext } from '@/app/components/base/toast'
import { languages } from '@/i18n/language'
......@@ -42,6 +43,7 @@ const SettingsModal: FC<ISettingsModalProps> = ({
onClose,
onSave,
}) => {
const { notify } = useToastContext()
const [isShowMore, setIsShowMore] = useState(false)
const { icon, icon_background } = appInfo
const { title, description, copyright, privacy_policy, default_language } = appInfo.site
......@@ -67,6 +69,10 @@ const SettingsModal: FC<ISettingsModalProps> = ({
}
const onClickSave = async () => {
if (!inputInfo.title) {
notify({ type: 'error', message: t('app.newApp.nameNotEmpty') })
return
}
setSaveLoading(true)
const params = {
title: inputInfo.title,
......
import type { FC } from 'react'
import { useState } from 'react'
import { useTranslation } from 'react-i18next'
import cn from 'classnames'
import Uploader from './uploader'
import ImageLinkInput from './image-link-input'
import { ImagePlus } from '@/app/components/base/icons/src/vender/line/images'
......@@ -25,16 +26,16 @@ const UploadOnlyFromLocal: FC<UploadOnlyFromLocalProps> = ({
}) => {
return (
<Uploader onUpload={onUpload} disabled={disabled} limit={limit}>
{
hovering => (
<div className={`
{hovering => (
<div
className={`
relative flex items-center justify-center w-8 h-8 rounded-lg cursor-pointer
${hovering && 'bg-gray-100'}
`}>
<ImagePlus className='w-4 h-4 text-gray-500' />
</div>
)
}
`}
>
<ImagePlus className="w-4 h-4 text-gray-500" />
</div>
)}
</Uploader>
)
}
......@@ -54,13 +55,16 @@ const UploaderButton: FC<UploaderButtonProps> = ({
const { t } = useTranslation()
const [open, setOpen] = useState(false)
const hasUploadFromLocal = methods.find(method => method === TransferMethod.local_file)
const hasUploadFromLocal = methods.find(
method => method === TransferMethod.local_file,
)
const handleUpload = (imageFile: ImageFile) => {
setOpen(false)
onUpload(imageFile)
}
const closePopover = () => setOpen(false)
const handleToggle = () => {
if (disabled)
return
......@@ -72,43 +76,46 @@ const UploaderButton: FC<UploaderButtonProps> = ({
<PortalToFollowElem
open={open}
onOpenChange={setOpen}
placement='top-start'
placement="top-start"
>
<PortalToFollowElemTrigger onClick={handleToggle}>
<div className={`
relative flex items-center justify-center w-8 h-8 hover:bg-gray-100 rounded-lg
${disabled ? 'cursor-not-allowed' : 'cursor-pointer'}
`}>
<ImagePlus className='w-4 h-4 text-gray-500' />
</div>
<button
type="button"
disabled={disabled}
className="relative flex items-center justify-center w-8 h-8 enabled:hover:bg-gray-100 rounded-lg disabled:cursor-not-allowed"
>
<ImagePlus className="w-4 h-4 text-gray-500" />
</button>
</PortalToFollowElemTrigger>
<PortalToFollowElemContent className='z-50'>
<div className='p-2 w-[260px] bg-white rounded-lg border-[0.5px] border-gray-200 shadow-lg'>
<PortalToFollowElemContent className="z-50">
<div className="p-2 w-[260px] bg-white rounded-lg border-[0.5px] border-gray-200 shadow-lg">
<ImageLinkInput onUpload={handleUpload} />
{
hasUploadFromLocal && (
<>
<div className='flex items-center mt-2 px-2 text-xs font-medium text-gray-400'>
<div className='mr-3 w-[93px] h-[1px] bg-gradient-to-l from-[#F3F4F6]' />
OR
<div className='ml-3 w-[93px] h-[1px] bg-gradient-to-r from-[#F3F4F6]' />
</div>
<Uploader onUpload={handleUpload} limit={limit}>
{
hovering => (
<div className={`
flex items-center justify-center h-8 text-[13px] font-medium text-[#155EEF] rounded-lg cursor-pointer
${hovering && 'bg-primary-50'}
`}>
<Upload03 className='mr-1 w-4 h-4' />
{t('common.imageUploader.uploadFromComputer')}
</div>
)
}
</Uploader>
</>
)
}
{hasUploadFromLocal && (
<>
<div className="flex items-center mt-2 px-2 text-xs font-medium text-gray-400">
<div className="mr-3 w-[93px] h-[1px] bg-gradient-to-l from-[#F3F4F6]" />
OR
<div className="ml-3 w-[93px] h-[1px] bg-gradient-to-r from-[#F3F4F6]" />
</div>
<Uploader
onUpload={handleUpload}
limit={limit}
closePopover={closePopover}
>
{hovering => (
<div
className={cn(
'flex items-center justify-center h-8 text-[13px] font-medium text-[#155EEF] rounded-lg cursor-pointer',
hovering && 'bg-primary-50',
)}
>
<Upload03 className="mr-1 w-4 h-4" />
{t('common.imageUploader.uploadFromComputer')}
</div>
)}
</Uploader>
</>
)}
</div>
</PortalToFollowElemContent>
</PortalToFollowElem>
......@@ -125,7 +132,9 @@ const ChatImageUploader: FC<ChatImageUploaderProps> = ({
onUpload,
disabled,
}) => {
const onlyUploadLocal = settings.transfer_methods.length === 1 && settings.transfer_methods[0] === TransferMethod.local_file
const onlyUploadLocal
= settings.transfer_methods.length === 1
&& settings.transfer_methods[0] === TransferMethod.local_file
if (onlyUploadLocal) {
return (
......
......@@ -30,6 +30,7 @@ const ImageLinkInput: FC<ImageLinkInputProps> = ({
return (
<div className='flex items-center pl-1.5 pr-1 h-8 border border-gray-200 bg-white shadow-xs rounded-lg'>
<input
type="text"
className='grow mr-0.5 px-1 h-[18px] text-[13px] outline-none appearance-none'
value={imageLink}
onChange={e => setImageLink(e.target.value)}
......
import type { FC } from 'react'
import { useState } from 'react'
import { useTranslation } from 'react-i18next'
import { Loading02, XClose } from '@/app/components/base/icons/src/vender/line/general'
import cn from 'classnames'
import {
Loading02,
XClose,
} from '@/app/components/base/icons/src/vender/line/general'
import { RefreshCcw01 } from '@/app/components/base/icons/src/vender/line/arrows'
import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback'
import TooltipPlus from '@/app/components/base/tooltip-plus'
......@@ -30,7 +34,11 @@ const ImageList: FC<ImageListProps> = ({
const [imagePreviewUrl, setImagePreviewUrl] = useState('')
const handleImageLinkLoadSuccess = (item: ImageFile) => {
if (item.type === TransferMethod.remote_url && onImageLinkLoadSuccess && item.progress !== -1)
if (
item.type === TransferMethod.remote_url
&& onImageLinkLoadSuccess
&& item.progress !== -1
)
onImageLinkLoadSuccess(item._id)
}
const handleImageLinkLoadError = (item: ImageFile) => {
......@@ -39,89 +47,95 @@ const ImageList: FC<ImageListProps> = ({
}
return (
<div className='flex flex-wrap'>
{
list.map(item => (
<div
key={item._id}
className='group relative mr-1 border-[0.5px] border-black/5 rounded-lg'
>
{
item.type === TransferMethod.local_file && item.progress !== 100 && (
<>
<div
className='absolute inset-0 flex items-center justify-center z-[1] bg-black/30'
style={{ left: item.progress > -1 ? `${item.progress}%` : 0 }}
>
{
item.progress === -1 && (
<RefreshCcw01 className='w-5 h-5 text-white' onClick={() => onReUpload && onReUpload(item._id)} />
)
}
</div>
{
item.progress > -1 && (
<span className='absolute top-[50%] left-[50%] translate-x-[-50%] translate-y-[-50%] text-sm text-white mix-blend-lighten z-[1]'>{item.progress}%</span>
)
}
</>
)
}
{
item.type === TransferMethod.remote_url && item.progress !== 100 && (
<div className={`
<div className="flex flex-wrap">
{list.map(item => (
<div
key={item._id}
className="group relative mr-1 border-[0.5px] border-black/5 rounded-lg"
>
{item.type === TransferMethod.local_file && item.progress !== 100 && (
<>
<div
className="absolute inset-0 flex items-center justify-center z-[1] bg-black/30"
style={{ left: item.progress > -1 ? `${item.progress}%` : 0 }}
>
{item.progress === -1 && (
<RefreshCcw01
className="w-5 h-5 text-white"
onClick={() => onReUpload && onReUpload(item._id)}
/>
)}
</div>
{item.progress > -1 && (
<span className="absolute top-[50%] left-[50%] translate-x-[-50%] translate-y-[-50%] text-sm text-white mix-blend-lighten z-[1]">
{item.progress}%
</span>
)}
</>
)}
{item.type === TransferMethod.remote_url && item.progress !== 100 && (
<div
className={`
absolute inset-0 flex items-center justify-center rounded-lg z-[1] border
${item.progress === -1 ? 'bg-[#FEF0C7] border-[#DC6803]' : 'bg-black/[0.16] border-transparent'}
`}>
{
item.progress > -1 && (
<Loading02 className='animate-spin w-5 h-5 text-white' />
)
}
{
item.progress === -1 && (
<TooltipPlus popupContent={t('common.imageUploader.pasteImageLinkInvalid')}>
<AlertTriangle className='w-4 h-4 text-[#DC6803]' />
</TooltipPlus>
)
}
</div>
)
${
item.progress === -1
? 'bg-[#FEF0C7] border-[#DC6803]'
: 'bg-black/[0.16] border-transparent'
}
<img
className='w-16 h-16 rounded-lg object-cover cursor-pointer border-[0.5px] border-black/5'
alt=''
onLoad={() => handleImageLinkLoadSuccess(item)}
onError={() => handleImageLinkLoadError(item)}
src={item.type === TransferMethod.remote_url ? item.url : item.base64Url}
onClick={() => item.progress === 100 && setImagePreviewUrl((item.type === TransferMethod.remote_url ? item.url : item.base64Url) as string)}
/>
{
!readonly && (
<div
className={`
absolute z-10 -top-[9px] -right-[9px] items-center justify-center w-[18px] h-[18px]
bg-white hover:bg-gray-50 border-[0.5px] border-black/[0.02] rounded-2xl shadow-lg
cursor-pointer
${item.progress === -1 ? 'flex' : 'hidden group-hover:flex'}
`}
onClick={() => onRemove && onRemove(item._id)}
`}
>
{item.progress > -1 && (
<Loading02 className="animate-spin w-5 h-5 text-white" />
)}
{item.progress === -1 && (
<TooltipPlus
popupContent={t('common.imageUploader.pasteImageLinkInvalid')}
>
<XClose className='w-3 h-3 text-gray-500' />
</div>
<AlertTriangle className="w-4 h-4 text-[#DC6803]" />
</TooltipPlus>
)}
</div>
)}
<img
className="w-16 h-16 rounded-lg object-cover cursor-pointer border-[0.5px] border-black/5"
alt={item.file?.name}
onLoad={() => handleImageLinkLoadSuccess(item)}
onError={() => handleImageLinkLoadError(item)}
src={
item.type === TransferMethod.remote_url
? item.url
: item.base64Url
}
onClick={() =>
item.progress === 100
&& setImagePreviewUrl(
(item.type === TransferMethod.remote_url
? item.url
: item.base64Url) as string,
)
}
</div>
))
}
{
imagePreviewUrl && (
<ImagePreview
url={imagePreviewUrl}
onCancel={() => setImagePreviewUrl('')}
/>
)
}
{!readonly && (
<button
type="button"
className={cn(
'absolute z-10 -top-[9px] -right-[9px] items-center justify-center w-[18px] h-[18px]',
'bg-white hover:bg-gray-50 border-[0.5px] border-black/[0.02] rounded-2xl shadow-lg',
item.progress === -1 ? 'flex' : 'hidden group-hover:flex',
)}
onClick={() => onRemove && onRemove(item._id)}
>
<XClose className="w-3 h-3 text-gray-500" />
</button>
)}
</div>
))}
{imagePreviewUrl && (
<ImagePreview
url={imagePreviewUrl}
onCancel={() => setImagePreviewUrl('')}
/>
)}
</div>
)
}
......
......@@ -7,6 +7,7 @@ import { ALLOW_FILE_EXTENSIONS } from '@/types/app'
type UploaderProps = {
children: (hovering: boolean) => JSX.Element
onUpload: (imageFile: ImageFile) => void
closePopover?: () => void
limit?: number
disabled?: boolean
}
......@@ -14,11 +15,16 @@ type UploaderProps = {
const Uploader: FC<UploaderProps> = ({
children,
onUpload,
closePopover,
limit,
disabled,
}) => {
const [hovering, setHovering] = useState(false)
const { handleLocalFileUpload } = useLocalFileUploader({ limit, onUpload, disabled })
const { handleLocalFileUpload } = useLocalFileUploader({
limit,
onUpload,
disabled,
})
const handleChange = (e: ChangeEvent<HTMLInputElement>) => {
const file = e.target.files?.[0]
......@@ -27,6 +33,7 @@ const Uploader: FC<UploaderProps> = ({
return
handleLocalFileUpload(file)
closePopover?.()
}
return (
......@@ -37,11 +44,8 @@ const Uploader: FC<UploaderProps> = ({
>
{children(hovering)}
<input
className={`
absolute block inset-0 opacity-0 text-[0] w-full
${disabled ? 'cursor-not-allowed' : 'cursor-pointer'}
`}
onClick={e => (e.target as HTMLInputElement).value = ''}
className='absolute block inset-0 opacity-0 text-[0] w-full disabled:cursor-not-allowed cursor-pointer'
onClick={e => ((e.target as HTMLInputElement).value = '')}
type='file'
accept={ALLOW_FILE_EXTENSIONS.map(ext => `.${ext}`).join(',')}
onChange={handleChange}
......
......@@ -32,6 +32,7 @@ import VariableValueBlock from './plugins/variable-value-block'
import { VariableValueBlockNode } from './plugins/variable-value-block/node'
import { CustomTextNode } from './plugins/custom-text/node'
import OnBlurBlock from './plugins/on-blur-block'
import UpdateBlock from './plugins/update-block'
import { textToEditorState } from './utils'
import type { Dataset } from './plugins/context-block'
import type { RoleName } from './plugins/history-block'
......@@ -222,6 +223,7 @@ const PromptEditor: FC<PromptEditorProps> = ({
<VariableValueBlock />
<OnChangePlugin onChange={handleEditorChange} />
<OnBlurBlock onBlur={onBlur} />
<UpdateBlock />
{/* <TreeView /> */}
</div>
</LexicalComposer>
......
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'
import { textToEditorState } from '../utils'
import { useEventEmitterContextContext } from '@/context/event-emitter'
export const PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER = 'PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER'
const UpdateBlock = () => {
const { eventEmitter } = useEventEmitterContextContext()
const [editor] = useLexicalComposerContext()
eventEmitter?.useSubscription((v: any) => {
if (v.type === PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER) {
const editorState = editor.parseEditorState(textToEditorState(v.payload))
editor.setEditorState(editorState)
}
})
return null
}
export default UpdateBlock
......@@ -838,7 +838,7 @@ const StepTwo = ({
{!isSetting
? (
<div className='flex items-center mt-8 py-2'>
<Button onClick={() => onStepChange && onStepChange(-1)}>{t('datasetCreation.stepTwo.lastStep')}</Button>
<Button onClick={() => onStepChange && onStepChange(-1)}>{t('datasetCreation.stepTwo.previousStep')}</Button>
<div className={s.divider} />
<Button loading={isCreating} type='primary' onClick={createHandle}>{t('datasetCreation.stepTwo.nextStep')}</Button>
</div>
......
......@@ -427,6 +427,53 @@ Chat applications support session persistence, allowing previous chat history to
---
<Heading
url='/messages/{message_id}/suggested'
method='GET'
title='next suggested questions'
name='#suggested'
/>
<Row>
<Col>
Get next questions suggestions for the current message
### Path Params
<Properties>
<Property name='message_id' type='string' key='message_id'>
Message ID
</Property>
</Properties>
</Col>
<Col sticky>
<CodeGroup title="Request" tag="GET" label="/messages/{message_id}/suggested" targetCode={`curl --location --request GET '${props.appDetail.api_base_url}/messages/{message_id}/suggested \\\n--header 'Authorization: Bearer ENTER-YOUR-SECRET-KEY' \\\n--header 'Content-Type: application/json'`}>
```bash {{ title: 'cURL' }}
curl --location --request GET '${props.appDetail.api_base_url}/messages/{message_id}/suggested' \
--header 'Authorization: Bearer ENTER-YOUR-SECRET-KEY' \
--header 'Content-Type: application/json' \
```
</CodeGroup>
<CodeGroup title="Response">
```json {{ title: 'Response' }}
{
"result": "success",
"data": [
"a",
"b",
"c"
]
}
```
</CodeGroup>
</Col>
</Row>
---
<Heading
url='/messages'
method='GET'
......
......@@ -442,6 +442,55 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
---
<Heading
url='/messages/{message_id}/suggested'
method='GET'
title='获取下一轮建议问题列表'
name='#suggested'
/>
<Row>
<Col>
获取下一轮建议问题列表。
### Path Params
<Properties>
<Property name='message_id' type='string' key='message_id'>
Message ID
</Property>
</Properties>
</Col>
<Col sticky>
<CodeGroup title="Request" tag="GET" label="/messages/{message_id}/suggested" targetCode={`curl --location --request GET '${props.appDetail.api_base_url}/messages/{message_id}/suggested \\\n--header 'Authorization: Bearer ENTER-YOUR-SECRET-KEY' \\\n--header 'Content-Type: application/json'`}>
```bash {{ title: 'cURL' }}
curl --location --request GET '${props.appDetail.api_base_url}/messages/{message_id}/suggested' \
--header 'Authorization: Bearer ENTER-YOUR-SECRET-KEY' \
--header 'Content-Type: application/json' \
```
</CodeGroup>
<CodeGroup title="Response">
```json {{ title: 'Response' }}
{
"result": "success",
"data": [
"a",
"b",
"c"
]
}
```
</CodeGroup>
</Col>
</Row>
---
---
<Heading
url='/messages'
method='GET'
......
......@@ -71,10 +71,14 @@ export default function AccountPage() {
showErrorMessage(t('login.error.passwordEmpty'))
return false
}
if (!validPassword.test(password))
if (!validPassword.test(password)) {
showErrorMessage(t('login.error.passwordInvalid'))
if (password !== confirmPassword)
return false
}
if (password !== confirmPassword) {
showErrorMessage(t('common.account.notEqual'))
return false
}
return true
}
......
......@@ -89,7 +89,7 @@ const translation = {
other: 'and other ',
fileUnit: ' files',
notionUnit: ' pages',
lastStep: 'Last step',
previousStep: 'Previous step',
nextStep: 'Save & Process',
save: 'Save & Process',
cancel: 'Cancel',
......
......@@ -89,7 +89,7 @@ const translation = {
other: 'その他',
fileUnit: 'ファイル',
notionUnit: 'ページ',
lastStep: '最後のステップ',
previousStep: '前のステップ',
nextStep: '保存して処理',
save: '保存して処理',
cancel: 'キャンセル',
......
......@@ -89,7 +89,7 @@ const translation = {
other: 'e outros ',
fileUnit: ' arquivos',
notionUnit: ' páginas',
lastStep: 'Última etapa',
previousStep: 'Passo anterior',
nextStep: 'Salvar e Processar',
save: 'Salvar e Processar',
cancel: 'Cancelar',
......
......@@ -89,7 +89,7 @@ const translation = {
other: ' та інші ',
fileUnit: ' файли',
notionUnit: ' сторінки',
lastStep: 'Попередній крок',
previousStep: 'Попередній крок',
nextStep: 'Зберегти та обробити',
save: 'Зберегти та обробити',
cancel: 'Скасувати',
......
......@@ -89,7 +89,7 @@ const translation = {
other: '和其他 ',
fileUnit: ' 个文件',
notionUnit: ' 个页面',
lastStep: '上一步',
previousStep: '上一步',
nextStep: '保存并处理',
save: '保存并处理',
cancel: '取消',
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment