Commit 11e1b569 authored by takatost's avatar takatost

move workflow_id to app

parent 2bbf96d7
...@@ -7,8 +7,7 @@ default_app_templates = { ...@@ -7,8 +7,7 @@ default_app_templates = {
'mode': AppMode.WORKFLOW.value, 'mode': AppMode.WORKFLOW.value,
'enable_site': True, 'enable_site': True,
'enable_api': True 'enable_api': True
}, }
'model_config': {}
}, },
# chat default mode # chat default mode
...@@ -34,14 +33,6 @@ default_app_templates = { ...@@ -34,14 +33,6 @@ default_app_templates = {
'mode': AppMode.ADVANCED_CHAT.value, 'mode': AppMode.ADVANCED_CHAT.value,
'enable_site': True, 'enable_site': True,
'enable_api': True 'enable_api': True
},
'model_config': {
'model': {
"provider": "openai",
"name": "gpt-4",
"mode": "chat",
"completion_params": {}
}
} }
}, },
......
...@@ -41,10 +41,16 @@ class DraftWorkflowApi(Resource): ...@@ -41,10 +41,16 @@ class DraftWorkflowApi(Resource):
""" """
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('graph', type=dict, required=True, nullable=False, location='json') parser.add_argument('graph', type=dict, required=True, nullable=False, location='json')
parser.add_argument('features', type=dict, required=True, nullable=False, location='json')
args = parser.parse_args() args = parser.parse_args()
workflow_service = WorkflowService() workflow_service = WorkflowService()
workflow_service.sync_draft_workflow(app_model=app_model, graph=args.get('graph'), account=current_user) workflow_service.sync_draft_workflow(
app_model=app_model,
graph=args.get('graph'),
features=args.get('features'),
account=current_user
)
return { return {
"result": "success" "result": "success"
......
import logging import logging
from typing import Optional
from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.app.app_queue_manager import AppQueueManager, PublishFrom
from core.app.base_app_runner import AppRunner from core.app.base_app_runner import AppRunner
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.application_entities import ( from core.entities.application_entities import (
ApplicationGenerateEntity, ApplicationGenerateEntity,
DatasetEntity,
InvokeFrom,
ModelConfigEntity,
) )
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.moderation.base import ModerationException from core.moderation.base import ModerationException
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from extensions.ext_database import db from extensions.ext_database import db
from models.model import App, AppMode, Conversation, Message from models.model import App, Conversation, Message
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -145,18 +141,23 @@ class ChatAppRunner(AppRunner): ...@@ -145,18 +141,23 @@ class ChatAppRunner(AppRunner):
# get context from datasets # get context from datasets
context = None context = None
if app_orchestration_config.dataset and app_orchestration_config.dataset.dataset_ids: if app_orchestration_config.dataset and app_orchestration_config.dataset.dataset_ids:
context = self.retrieve_dataset_context( hit_callback = DatasetIndexToolCallbackHandler(
queue_manager,
app_record.id,
message.id,
application_generate_entity.user_id,
application_generate_entity.invoke_from
)
dataset_retrieval = DatasetRetrieval()
context = dataset_retrieval.retrieve(
tenant_id=app_record.tenant_id, tenant_id=app_record.tenant_id,
app_record=app_record,
queue_manager=queue_manager,
model_config=app_orchestration_config.model_config, model_config=app_orchestration_config.model_config,
show_retrieve_source=app_orchestration_config.show_retrieve_source, config=app_orchestration_config.dataset,
dataset_config=app_orchestration_config.dataset,
message=message,
inputs=inputs,
query=query, query=query,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from, invoke_from=application_generate_entity.invoke_from,
show_retrieve_source=app_orchestration_config.show_retrieve_source,
hit_callback=hit_callback,
memory=memory memory=memory
) )
...@@ -212,57 +213,3 @@ class ChatAppRunner(AppRunner): ...@@ -212,57 +213,3 @@ class ChatAppRunner(AppRunner):
queue_manager=queue_manager, queue_manager=queue_manager,
stream=application_generate_entity.stream stream=application_generate_entity.stream
) )
def retrieve_dataset_context(self, tenant_id: str,
app_record: App,
queue_manager: AppQueueManager,
model_config: ModelConfigEntity,
dataset_config: DatasetEntity,
show_retrieve_source: bool,
message: Message,
inputs: dict,
query: str,
user_id: str,
invoke_from: InvokeFrom,
memory: Optional[TokenBufferMemory] = None) -> Optional[str]:
"""
Retrieve dataset context
:param tenant_id: tenant id
:param app_record: app record
:param queue_manager: queue manager
:param model_config: model config
:param dataset_config: dataset config
:param show_retrieve_source: show retrieve source
:param message: message
:param inputs: inputs
:param query: query
:param user_id: user id
:param invoke_from: invoke from
:param memory: memory
:return:
"""
hit_callback = DatasetIndexToolCallbackHandler(
queue_manager,
app_record.id,
message.id,
user_id,
invoke_from
)
# TODO
if (app_record.mode == AppMode.COMPLETION.value and dataset_config
and dataset_config.retrieve_config.query_variable):
query = inputs.get(dataset_config.retrieve_config.query_variable, "")
dataset_retrieval = DatasetRetrieval()
return dataset_retrieval.retrieve(
tenant_id=tenant_id,
model_config=model_config,
config=dataset_config,
query=query,
invoke_from=invoke_from,
show_retrieve_source=show_retrieve_source,
hit_callback=hit_callback,
memory=memory
)
\ No newline at end of file
import logging import logging
from typing import Optional
from core.app.app_queue_manager import AppQueueManager, PublishFrom from core.app.app_queue_manager import AppQueueManager
from core.app.base_app_runner import AppRunner from core.app.base_app_runner import AppRunner
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.application_entities import ( from core.entities.application_entities import (
ApplicationGenerateEntity, ApplicationGenerateEntity,
DatasetEntity,
InvokeFrom,
ModelConfigEntity,
) )
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.moderation.base import ModerationException from core.moderation.base import ModerationException
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from extensions.ext_database import db from extensions.ext_database import db
from models.model import App, AppMode, Conversation, Message from models.model import App, Message
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -27,13 +22,11 @@ class CompletionAppRunner(AppRunner): ...@@ -27,13 +22,11 @@ class CompletionAppRunner(AppRunner):
def run(self, application_generate_entity: ApplicationGenerateEntity, def run(self, application_generate_entity: ApplicationGenerateEntity,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
conversation: Conversation,
message: Message) -> None: message: Message) -> None:
""" """
Run application Run application
:param application_generate_entity: application generate entity :param application_generate_entity: application generate entity
:param queue_manager: application queue manager :param queue_manager: application queue manager
:param conversation: conversation
:param message: message :param message: message
:return: :return:
""" """
...@@ -61,30 +54,15 @@ class CompletionAppRunner(AppRunner): ...@@ -61,30 +54,15 @@ class CompletionAppRunner(AppRunner):
query=query query=query
) )
memory = None
if application_generate_entity.conversation_id:
# get memory of conversation (read-only)
model_instance = ModelInstance(
provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
model=app_orchestration_config.model_config.model
)
memory = TokenBufferMemory(
conversation=conversation,
model_instance=model_instance
)
# organize all inputs and template to prompt messages # organize all inputs and template to prompt messages
# Include: prompt template, inputs, query(optional), files(optional) # Include: prompt template, inputs, query(optional), files(optional)
# memory(optional)
prompt_messages, stop = self.organize_prompt_messages( prompt_messages, stop = self.organize_prompt_messages(
app_record=app_record, app_record=app_record,
model_config=app_orchestration_config.model_config, model_config=app_orchestration_config.model_config,
prompt_template_entity=app_orchestration_config.prompt_template, prompt_template_entity=app_orchestration_config.prompt_template,
inputs=inputs, inputs=inputs,
files=files, files=files,
query=query, query=query
memory=memory
) )
# moderation # moderation
...@@ -107,30 +85,6 @@ class CompletionAppRunner(AppRunner): ...@@ -107,30 +85,6 @@ class CompletionAppRunner(AppRunner):
) )
return return
if query:
# annotation reply
annotation_reply = self.query_app_annotations_to_reply(
app_record=app_record,
message=message,
query=query,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from
)
if annotation_reply:
queue_manager.publish_annotation_reply(
message_annotation_id=annotation_reply.id,
pub_from=PublishFrom.APPLICATION_MANAGER
)
self.direct_output(
queue_manager=queue_manager,
app_orchestration_config=app_orchestration_config,
prompt_messages=prompt_messages,
text=annotation_reply.content,
stream=application_generate_entity.stream
)
return
# fill in variable inputs from external data tools if exists # fill in variable inputs from external data tools if exists
external_data_tools = app_orchestration_config.external_data_variables external_data_tools = app_orchestration_config.external_data_variables
if external_data_tools: if external_data_tools:
...@@ -145,19 +99,27 @@ class CompletionAppRunner(AppRunner): ...@@ -145,19 +99,27 @@ class CompletionAppRunner(AppRunner):
# get context from datasets # get context from datasets
context = None context = None
if app_orchestration_config.dataset and app_orchestration_config.dataset.dataset_ids: if app_orchestration_config.dataset and app_orchestration_config.dataset.dataset_ids:
context = self.retrieve_dataset_context( hit_callback = DatasetIndexToolCallbackHandler(
queue_manager,
app_record.id,
message.id,
application_generate_entity.user_id,
application_generate_entity.invoke_from
)
dataset_config = app_orchestration_config.dataset
if dataset_config and dataset_config.retrieve_config.query_variable:
query = inputs.get(dataset_config.retrieve_config.query_variable, "")
dataset_retrieval = DatasetRetrieval()
context = dataset_retrieval.retrieve(
tenant_id=app_record.tenant_id, tenant_id=app_record.tenant_id,
app_record=app_record,
queue_manager=queue_manager,
model_config=app_orchestration_config.model_config, model_config=app_orchestration_config.model_config,
show_retrieve_source=app_orchestration_config.show_retrieve_source, config=dataset_config,
dataset_config=app_orchestration_config.dataset,
message=message,
inputs=inputs,
query=query, query=query,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from, invoke_from=application_generate_entity.invoke_from,
memory=memory show_retrieve_source=app_orchestration_config.show_retrieve_source,
hit_callback=hit_callback
) )
# reorganize all inputs and template to prompt messages # reorganize all inputs and template to prompt messages
...@@ -170,8 +132,7 @@ class CompletionAppRunner(AppRunner): ...@@ -170,8 +132,7 @@ class CompletionAppRunner(AppRunner):
inputs=inputs, inputs=inputs,
files=files, files=files,
query=query, query=query,
context=context, context=context
memory=memory
) )
# check hosting moderation # check hosting moderation
...@@ -210,57 +171,4 @@ class CompletionAppRunner(AppRunner): ...@@ -210,57 +171,4 @@ class CompletionAppRunner(AppRunner):
queue_manager=queue_manager, queue_manager=queue_manager,
stream=application_generate_entity.stream stream=application_generate_entity.stream
) )
def retrieve_dataset_context(self, tenant_id: str,
app_record: App,
queue_manager: AppQueueManager,
model_config: ModelConfigEntity,
dataset_config: DatasetEntity,
show_retrieve_source: bool,
message: Message,
inputs: dict,
query: str,
user_id: str,
invoke_from: InvokeFrom,
memory: Optional[TokenBufferMemory] = None) -> Optional[str]:
"""
Retrieve dataset context
:param tenant_id: tenant id
:param app_record: app record
:param queue_manager: queue manager
:param model_config: model config
:param dataset_config: dataset config
:param show_retrieve_source: show retrieve source
:param message: message
:param inputs: inputs
:param query: query
:param user_id: user id
:param invoke_from: invoke from
:param memory: memory
:return:
"""
hit_callback = DatasetIndexToolCallbackHandler(
queue_manager,
app_record.id,
message.id,
user_id,
invoke_from
)
# TODO
if (app_record.mode == AppMode.COMPLETION.value and dataset_config
and dataset_config.retrieve_config.query_variable):
query = inputs.get(dataset_config.retrieve_config.query_variable, "")
dataset_retrieval = DatasetRetrieval()
return dataset_retrieval.retrieve(
tenant_id=tenant_id,
model_config=model_config,
config=dataset_config,
query=query,
invoke_from=invoke_from,
show_retrieve_source=show_retrieve_source,
hit_callback=hit_callback,
memory=memory
)
\ No newline at end of file
import json
from flask_restful import fields from flask_restful import fields
from fields.member_fields import simple_account_fields from fields.member_fields import simple_account_fields
...@@ -7,7 +5,8 @@ from libs.helper import TimestampField ...@@ -7,7 +5,8 @@ from libs.helper import TimestampField
workflow_fields = { workflow_fields = {
'id': fields.String, 'id': fields.String,
'graph': fields.Raw(attribute=lambda x: json.loads(x.graph) if hasattr(x, 'graph') else None), 'graph': fields.Nested(simple_account_fields, attribute='graph_dict'),
'features': fields.Nested(simple_account_fields, attribute='features_dict'),
'created_by': fields.Nested(simple_account_fields, attribute='created_by_account'), 'created_by': fields.Nested(simple_account_fields, attribute='created_by_account'),
'created_at': TimestampField, 'created_at': TimestampField,
'updated_by': fields.Nested(simple_account_fields, attribute='updated_by_account', allow_null=True), 'updated_by': fields.Nested(simple_account_fields, attribute='updated_by_account', allow_null=True),
......
...@@ -97,6 +97,7 @@ def upgrade(): ...@@ -97,6 +97,7 @@ def upgrade():
sa.Column('type', sa.String(length=255), nullable=False), sa.Column('type', sa.String(length=255), nullable=False),
sa.Column('version', sa.String(length=255), nullable=False), sa.Column('version', sa.String(length=255), nullable=False),
sa.Column('graph', sa.Text(), nullable=True), sa.Column('graph', sa.Text(), nullable=True),
sa.Column('features', sa.Text(), nullable=True),
sa.Column('created_by', postgresql.UUID(), nullable=False), sa.Column('created_by', postgresql.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.Column('updated_by', postgresql.UUID(), nullable=True), sa.Column('updated_by', postgresql.UUID(), nullable=True),
...@@ -106,7 +107,7 @@ def upgrade(): ...@@ -106,7 +107,7 @@ def upgrade():
with op.batch_alter_table('workflows', schema=None) as batch_op: with op.batch_alter_table('workflows', schema=None) as batch_op:
batch_op.create_index('workflow_version_idx', ['tenant_id', 'app_id', 'version'], unique=False) batch_op.create_index('workflow_version_idx', ['tenant_id', 'app_id', 'version'], unique=False)
with op.batch_alter_table('app_model_configs', schema=None) as batch_op: with op.batch_alter_table('apps', schema=None) as batch_op:
batch_op.add_column(sa.Column('workflow_id', postgresql.UUID(), nullable=True)) batch_op.add_column(sa.Column('workflow_id', postgresql.UUID(), nullable=True))
with op.batch_alter_table('messages', schema=None) as batch_op: with op.batch_alter_table('messages', schema=None) as batch_op:
...@@ -120,7 +121,7 @@ def downgrade(): ...@@ -120,7 +121,7 @@ def downgrade():
with op.batch_alter_table('messages', schema=None) as batch_op: with op.batch_alter_table('messages', schema=None) as batch_op:
batch_op.drop_column('workflow_run_id') batch_op.drop_column('workflow_run_id')
with op.batch_alter_table('app_model_configs', schema=None) as batch_op: with op.batch_alter_table('apps', schema=None) as batch_op:
batch_op.drop_column('workflow_id') batch_op.drop_column('workflow_id')
with op.batch_alter_table('workflows', schema=None) as batch_op: with op.batch_alter_table('workflows', schema=None) as batch_op:
......
...@@ -63,6 +63,7 @@ class App(db.Model): ...@@ -63,6 +63,7 @@ class App(db.Model):
icon = db.Column(db.String(255)) icon = db.Column(db.String(255))
icon_background = db.Column(db.String(255)) icon_background = db.Column(db.String(255))
app_model_config_id = db.Column(UUID, nullable=True) app_model_config_id = db.Column(UUID, nullable=True)
workflow_id = db.Column(UUID, nullable=True)
status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying"))
enable_site = db.Column(db.Boolean, nullable=False) enable_site = db.Column(db.Boolean, nullable=False)
enable_api = db.Column(db.Boolean, nullable=False) enable_api = db.Column(db.Boolean, nullable=False)
...@@ -85,6 +86,14 @@ class App(db.Model): ...@@ -85,6 +86,14 @@ class App(db.Model):
AppModelConfig.id == self.app_model_config_id).first() AppModelConfig.id == self.app_model_config_id).first()
return app_model_config return app_model_config
@property
def workflow(self):
if self.workflow_id:
from api.models.workflow import Workflow
return db.session.query(Workflow).filter(Workflow.id == self.workflow_id).first()
return None
@property @property
def api_base_url(self): def api_base_url(self):
return (current_app.config['SERVICE_API_URL'] if current_app.config['SERVICE_API_URL'] return (current_app.config['SERVICE_API_URL'] if current_app.config['SERVICE_API_URL']
...@@ -176,7 +185,6 @@ class AppModelConfig(db.Model): ...@@ -176,7 +185,6 @@ class AppModelConfig(db.Model):
dataset_configs = db.Column(db.Text) dataset_configs = db.Column(db.Text)
external_data_tools = db.Column(db.Text) external_data_tools = db.Column(db.Text)
file_upload = db.Column(db.Text) file_upload = db.Column(db.Text)
workflow_id = db.Column(UUID)
@property @property
def app(self): def app(self):
...@@ -276,14 +284,6 @@ class AppModelConfig(db.Model): ...@@ -276,14 +284,6 @@ class AppModelConfig(db.Model):
"image": {"enabled": False, "number_limits": 3, "detail": "high", "image": {"enabled": False, "number_limits": 3, "detail": "high",
"transfer_methods": ["remote_url", "local_file"]}} "transfer_methods": ["remote_url", "local_file"]}}
@property
def workflow(self):
if self.workflow_id:
from api.models.workflow import Workflow
return db.session.query(Workflow).filter(Workflow.id == self.workflow_id).first()
return None
def to_dict(self) -> dict: def to_dict(self) -> dict:
return { return {
"opening_statement": self.opening_statement, "opening_statement": self.opening_statement,
...@@ -343,7 +343,6 @@ class AppModelConfig(db.Model): ...@@ -343,7 +343,6 @@ class AppModelConfig(db.Model):
if model_config.get('dataset_configs') else None if model_config.get('dataset_configs') else None
self.file_upload = json.dumps(model_config.get('file_upload')) \ self.file_upload = json.dumps(model_config.get('file_upload')) \
if model_config.get('file_upload') else None if model_config.get('file_upload') else None
self.workflow_id = model_config.get('workflow_id')
return self return self
def copy(self): def copy(self):
...@@ -368,8 +367,7 @@ class AppModelConfig(db.Model): ...@@ -368,8 +367,7 @@ class AppModelConfig(db.Model):
chat_prompt_config=self.chat_prompt_config, chat_prompt_config=self.chat_prompt_config,
completion_prompt_config=self.completion_prompt_config, completion_prompt_config=self.completion_prompt_config,
dataset_configs=self.dataset_configs, dataset_configs=self.dataset_configs,
file_upload=self.file_upload, file_upload=self.file_upload
workflow_id=self.workflow_id
) )
return new_app_model_config return new_app_model_config
......
import json
from enum import Enum from enum import Enum
from typing import Union from typing import Union
...@@ -106,6 +107,7 @@ class Workflow(db.Model): ...@@ -106,6 +107,7 @@ class Workflow(db.Model):
type = db.Column(db.String(255), nullable=False) type = db.Column(db.String(255), nullable=False)
version = db.Column(db.String(255), nullable=False) version = db.Column(db.String(255), nullable=False)
graph = db.Column(db.Text) graph = db.Column(db.Text)
features = db.Column(db.Text)
created_by = db.Column(UUID, nullable=False) created_by = db.Column(UUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_by = db.Column(UUID) updated_by = db.Column(UUID)
...@@ -119,6 +121,14 @@ class Workflow(db.Model): ...@@ -119,6 +121,14 @@ class Workflow(db.Model):
def updated_by_account(self): def updated_by_account(self):
return Account.query.get(self.updated_by) return Account.query.get(self.updated_by)
@property
def graph_dict(self):
return self.graph if not self.graph else json.loads(self.graph)
@property
def features_dict(self):
return self.features if not self.features else json.loads(self.features)
class WorkflowRunTriggeredFrom(Enum): class WorkflowRunTriggeredFrom(Enum):
""" """
......
...@@ -64,8 +64,8 @@ class AppService: ...@@ -64,8 +64,8 @@ class AppService:
app_template = default_app_templates[app_mode] app_template = default_app_templates[app_mode]
# get model config # get model config
default_model_config = app_template['model_config'] default_model_config = app_template.get('model_config')
if 'model' in default_model_config: if default_model_config and 'model' in default_model_config:
# get model provider # get model provider
model_manager = ModelManager() model_manager = ModelManager()
...@@ -110,12 +110,15 @@ class AppService: ...@@ -110,12 +110,15 @@ class AppService:
db.session.add(app) db.session.add(app)
db.session.flush() db.session.flush()
app_model_config = AppModelConfig(**default_model_config) if default_model_config:
app_model_config.app_id = app.id app_model_config = AppModelConfig(**default_model_config)
db.session.add(app_model_config) app_model_config.app_id = app.id
db.session.flush() db.session.add(app_model_config)
db.session.flush()
app.app_model_config_id = app_model_config.id
app.app_model_config_id = app_model_config.id db.session.commit()
app_was_created.send(app, account=account) app_was_created.send(app, account=account)
...@@ -135,16 +138,22 @@ class AppService: ...@@ -135,16 +138,22 @@ class AppService:
app_data = import_data.get('app') app_data = import_data.get('app')
model_config_data = import_data.get('model_config') model_config_data = import_data.get('model_config')
workflow_graph = import_data.get('workflow_graph') workflow = import_data.get('workflow')
if not app_data or not model_config_data: if not app_data:
raise ValueError("Missing app or model_config in data argument") raise ValueError("Missing app in data argument")
app_mode = AppMode.value_of(app_data.get('mode')) app_mode = AppMode.value_of(app_data.get('mode'))
if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]:
if not workflow_graph: if not workflow:
raise ValueError("Missing workflow_graph in data argument " raise ValueError("Missing workflow in data argument "
"when mode is advanced-chat or workflow") "when app mode is advanced-chat or workflow")
elif app_mode in [AppMode.CHAT, AppMode.AGENT_CHAT]:
if not model_config_data:
raise ValueError("Missing model_config in data argument "
"when app mode is chat or agent-chat")
else:
raise ValueError("Invalid app mode")
app = App( app = App(
tenant_id=tenant_id, tenant_id=tenant_id,
...@@ -161,26 +170,32 @@ class AppService: ...@@ -161,26 +170,32 @@ class AppService:
db.session.add(app) db.session.add(app)
db.session.commit() db.session.commit()
if workflow_graph: app_was_created.send(app, account=account)
if workflow:
# init draft workflow # init draft workflow
workflow_service = WorkflowService() workflow_service = WorkflowService()
workflow_service.sync_draft_workflow(app, workflow_graph, account) workflow_service.sync_draft_workflow(
app_model=app,
app_model_config = AppModelConfig() graph=workflow.get('graph'),
app_model_config = app_model_config.from_model_config_dict(model_config_data) features=workflow.get('features'),
app_model_config.app_id = app.id account=account
)
db.session.add(app_model_config) if model_config_data:
db.session.commit() app_model_config = AppModelConfig()
app_model_config = app_model_config.from_model_config_dict(model_config_data)
app_model_config.app_id = app.id
app.app_model_config_id = app_model_config.id db.session.add(app_model_config)
db.session.commit()
app_was_created.send(app, account=account) app.app_model_config_id = app_model_config.id
app_model_config_was_updated.send( app_model_config_was_updated.send(
app, app,
app_model_config=app_model_config app_model_config=app_model_config
) )
return app return app
...@@ -190,7 +205,7 @@ class AppService: ...@@ -190,7 +205,7 @@ class AppService:
:param app: App instance :param app: App instance
:return: :return:
""" """
app_model_config = app.app_model_config app_mode = AppMode.value_of(app.mode)
export_data = { export_data = {
"app": { "app": {
...@@ -198,16 +213,27 @@ class AppService: ...@@ -198,16 +213,27 @@ class AppService:
"mode": app.mode, "mode": app.mode,
"icon": app.icon, "icon": app.icon,
"icon_background": app.icon_background "icon_background": app.icon_background
}, }
"model_config": app_model_config.to_dict(),
} }
if app_model_config.workflow_id: if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]:
export_data['workflow_graph'] = json.loads(app_model_config.workflow.graph) if app.workflow_id:
workflow = app.workflow
export_data['workflow'] = {
"graph": workflow.graph_dict,
"features": workflow.features_dict
}
else:
workflow_service = WorkflowService()
workflow = workflow_service.get_draft_workflow(app)
export_data['workflow'] = {
"graph": workflow.graph_dict,
"features": workflow.features_dict
}
else: else:
workflow_service = WorkflowService() app_model_config = app.app_model_config
workflow = workflow_service.get_draft_workflow(app)
export_data['workflow_graph'] = json.loads(workflow.graph) export_data['model_config'] = app_model_config.to_dict()
return yaml.dump(export_data) return yaml.dump(export_data)
......
...@@ -44,13 +44,10 @@ class WorkflowConverter: ...@@ -44,13 +44,10 @@ class WorkflowConverter:
:param account: Account :param account: Account
:return: new App instance :return: new App instance
""" """
# get original app config
app_model_config = app_model.app_model_config
# convert app model config # convert app model config
workflow = self.convert_app_model_config_to_workflow( workflow = self.convert_app_model_config_to_workflow(
app_model=app_model, app_model=app_model,
app_model_config=app_model_config, app_model_config=app_model.app_model_config,
account_id=account.id account_id=account.id
) )
...@@ -58,8 +55,9 @@ class WorkflowConverter: ...@@ -58,8 +55,9 @@ class WorkflowConverter:
new_app = App() new_app = App()
new_app.tenant_id = app_model.tenant_id new_app.tenant_id = app_model.tenant_id
new_app.name = app_model.name + '(workflow)' new_app.name = app_model.name + '(workflow)'
new_app.mode = AppMode.CHAT.value \ new_app.mode = AppMode.ADVANCED_CHAT.value \
if app_model.mode == AppMode.CHAT.value else AppMode.WORKFLOW.value if app_model.mode == AppMode.CHAT.value else AppMode.WORKFLOW.value
new_app.workflow_id = workflow.id
new_app.icon = app_model.icon new_app.icon = app_model.icon
new_app.icon_background = app_model.icon_background new_app.icon_background = app_model.icon_background
new_app.enable_site = app_model.enable_site new_app.enable_site = app_model.enable_site
...@@ -69,28 +67,6 @@ class WorkflowConverter: ...@@ -69,28 +67,6 @@ class WorkflowConverter:
new_app.is_demo = False new_app.is_demo = False
new_app.is_public = app_model.is_public new_app.is_public = app_model.is_public
db.session.add(new_app) db.session.add(new_app)
db.session.flush()
# create new app model config record
new_app_model_config = app_model_config.copy()
new_app_model_config.id = None
new_app_model_config.app_id = new_app.id
new_app_model_config.external_data_tools = ''
new_app_model_config.model = ''
new_app_model_config.user_input_form = ''
new_app_model_config.dataset_query_variable = None
new_app_model_config.pre_prompt = None
new_app_model_config.agent_mode = ''
new_app_model_config.prompt_type = 'simple'
new_app_model_config.chat_prompt_config = ''
new_app_model_config.completion_prompt_config = ''
new_app_model_config.dataset_configs = ''
new_app_model_config.workflow_id = workflow.id
db.session.add(new_app_model_config)
db.session.flush()
new_app.app_model_config_id = new_app_model_config.id
db.session.commit() db.session.commit()
app_was_created.send(new_app, account=account) app_was_created.send(new_app, account=account)
...@@ -110,11 +86,13 @@ class WorkflowConverter: ...@@ -110,11 +86,13 @@ class WorkflowConverter:
# get new app mode # get new app mode
new_app_mode = self._get_new_app_mode(app_model) new_app_mode = self._get_new_app_mode(app_model)
app_model_config_dict = app_model_config.to_dict()
# convert app model config # convert app model config
application_manager = AppManager() application_manager = AppManager()
app_orchestration_config_entity = application_manager.convert_from_app_model_config_dict( app_orchestration_config_entity = application_manager.convert_from_app_model_config_dict(
tenant_id=app_model.tenant_id, tenant_id=app_model.tenant_id,
app_model_config_dict=app_model_config.to_dict(), app_model_config_dict=app_model_config_dict,
skip_check=True skip_check=True
) )
...@@ -177,6 +155,25 @@ class WorkflowConverter: ...@@ -177,6 +155,25 @@ class WorkflowConverter:
graph = self._append_node(graph, end_node) graph = self._append_node(graph, end_node)
# features
if new_app_mode == AppMode.ADVANCED_CHAT:
features = {
"opening_statement": app_model_config_dict.get("opening_statement"),
"suggested_questions": app_model_config_dict.get("suggested_questions"),
"suggested_questions_after_answer": app_model_config_dict.get("suggested_questions_after_answer"),
"speech_to_text": app_model_config_dict.get("speech_to_text"),
"text_to_speech": app_model_config_dict.get("text_to_speech"),
"file_upload": app_model_config_dict.get("file_upload"),
"sensitive_word_avoidance": app_model_config_dict.get("sensitive_word_avoidance"),
"retriever_resource": app_model_config_dict.get("retriever_resource"),
}
else:
features = {
"text_to_speech": app_model_config_dict.get("text_to_speech"),
"file_upload": app_model_config_dict.get("file_upload"),
"sensitive_word_avoidance": app_model_config_dict.get("sensitive_word_avoidance"),
}
# create workflow record # create workflow record
workflow = Workflow( workflow = Workflow(
tenant_id=app_model.tenant_id, tenant_id=app_model.tenant_id,
...@@ -184,6 +181,7 @@ class WorkflowConverter: ...@@ -184,6 +181,7 @@ class WorkflowConverter:
type=WorkflowType.from_app_mode(new_app_mode).value, type=WorkflowType.from_app_mode(new_app_mode).value,
version='draft', version='draft',
graph=json.dumps(graph), graph=json.dumps(graph),
features=json.dumps(features),
created_by=account_id, created_by=account_id,
created_at=app_model_config.created_at created_at=app_model_config.created_at
) )
......
...@@ -33,29 +33,31 @@ class WorkflowService: ...@@ -33,29 +33,31 @@ class WorkflowService:
""" """
Get published workflow Get published workflow
""" """
app_model_config = app_model.app_model_config if not app_model.workflow_id:
if not app_model_config.workflow_id:
return None return None
# fetch published workflow by workflow_id # fetch published workflow by workflow_id
workflow = db.session.query(Workflow).filter( workflow = db.session.query(Workflow).filter(
Workflow.tenant_id == app_model.tenant_id, Workflow.tenant_id == app_model.tenant_id,
Workflow.app_id == app_model.id, Workflow.app_id == app_model.id,
Workflow.id == app_model_config.workflow_id Workflow.id == app_model.workflow_id
).first() ).first()
# return published workflow # return published workflow
return workflow return workflow
def sync_draft_workflow(self, app_model: App,
def sync_draft_workflow(self, app_model: App, graph: dict, account: Account) -> Workflow: graph: dict,
features: dict,
account: Account) -> Workflow:
""" """
Sync draft workflow Sync draft workflow
""" """
# fetch draft workflow by app_model # fetch draft workflow by app_model
workflow = self.get_draft_workflow(app_model=app_model) workflow = self.get_draft_workflow(app_model=app_model)
# TODO validate features
# create draft workflow if not found # create draft workflow if not found
if not workflow: if not workflow:
workflow = Workflow( workflow = Workflow(
...@@ -64,12 +66,14 @@ class WorkflowService: ...@@ -64,12 +66,14 @@ class WorkflowService:
type=WorkflowType.from_app_mode(app_model.mode).value, type=WorkflowType.from_app_mode(app_model.mode).value,
version='draft', version='draft',
graph=json.dumps(graph), graph=json.dumps(graph),
features=json.dumps(features),
created_by=account.id created_by=account.id
) )
db.session.add(workflow) db.session.add(workflow)
# update draft workflow if found # update draft workflow if found
else: else:
workflow.graph = json.dumps(graph) workflow.graph = json.dumps(graph)
workflow.features = json.dumps(features)
workflow.updated_by = account.id workflow.updated_by = account.id
workflow.updated_at = datetime.utcnow() workflow.updated_at = datetime.utcnow()
...@@ -112,28 +116,7 @@ class WorkflowService: ...@@ -112,28 +116,7 @@ class WorkflowService:
db.session.add(workflow) db.session.add(workflow)
db.session.commit() db.session.commit()
app_model_config = app_model.app_model_config app_model.workflow_id = workflow.id
# create new app model config record
new_app_model_config = app_model_config.copy()
new_app_model_config.id = None
new_app_model_config.app_id = app_model.id
new_app_model_config.external_data_tools = ''
new_app_model_config.model = ''
new_app_model_config.user_input_form = ''
new_app_model_config.dataset_query_variable = None
new_app_model_config.pre_prompt = None
new_app_model_config.agent_mode = ''
new_app_model_config.prompt_type = 'simple'
new_app_model_config.chat_prompt_config = ''
new_app_model_config.completion_prompt_config = ''
new_app_model_config.dataset_configs = ''
new_app_model_config.workflow_id = workflow.id
db.session.add(new_app_model_config)
db.session.flush()
app_model.app_model_config_id = new_app_model_config.id
db.session.commit() db.session.commit()
# TODO update app related datasets # TODO update app related datasets
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment