Commit 4e4b07ce authored by takatost's avatar takatost

Merge branch 'feat/workflow-backend' into deploy/dev

parents 8d4d0a29 5fe0d50c
...@@ -135,4 +135,4 @@ BATCH_UPLOAD_LIMIT=10 ...@@ -135,4 +135,4 @@ BATCH_UPLOAD_LIMIT=10
# CODE EXECUTION CONFIGURATION # CODE EXECUTION CONFIGURATION
CODE_EXECUTION_ENDPOINT= CODE_EXECUTION_ENDPOINT=
CODE_EXECUTINO_API_KEY= CODE_EXECUTION_API_KEY=
...@@ -23,7 +23,8 @@ from core.model_runtime.errors.invoke import InvokeBadRequestError ...@@ -23,7 +23,8 @@ from core.model_runtime.errors.invoke import InvokeBadRequestError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.moderation.input_moderation import InputModeration from core.moderation.input_moderation import InputModeration
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.simple_prompt_transform import SimplePromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform
from models.model import App, AppMode, Message, MessageAnnotation from models.model import App, AppMode, Message, MessageAnnotation
...@@ -155,13 +156,39 @@ class AppRunner: ...@@ -155,13 +156,39 @@ class AppRunner:
model_config=model_config model_config=model_config
) )
else: else:
memory_config = MemoryConfig(
window=MemoryConfig.WindowConfig(
enabled=False
)
)
model_mode = ModelMode.value_of(model_config.mode)
if model_mode == ModelMode.COMPLETION:
advanced_completion_prompt_template = prompt_template_entity.advanced_completion_prompt_template
prompt_template = CompletionModelPromptTemplate(
text=advanced_completion_prompt_template.prompt
)
memory_config.role_prefix = MemoryConfig.RolePrefix(
user=advanced_completion_prompt_template.role_prefix.user,
assistant=advanced_completion_prompt_template.role_prefix.assistant
)
else:
prompt_template = []
for message in prompt_template_entity.advanced_chat_prompt_template.messages:
prompt_template.append(ChatModelMessage(
text=message.text,
role=message.role
))
prompt_transform = AdvancedPromptTransform() prompt_transform = AdvancedPromptTransform()
prompt_messages = prompt_transform.get_prompt( prompt_messages = prompt_transform.get_prompt(
prompt_template_entity=prompt_template_entity, prompt_template=prompt_template,
inputs=inputs, inputs=inputs,
query=query if query else '', query=query if query else '',
files=files, files=files,
context=context, context=context,
memory_config=memory_config,
memory=memory, memory=memory,
model_config=model_config model_config=model_config
) )
......
...@@ -30,17 +30,12 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni ...@@ -30,17 +30,12 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContentType,
PromptMessageRole,
TextPromptMessageContent,
) )
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.moderation.output_moderation import ModerationRule, OutputModeration from core.moderation.output_moderation import ModerationRule, OutputModeration
from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.tools.tool_file_manager import ToolFileManager from core.tools.tool_file_manager import ToolFileManager
from events.message_event import message_was_created from events.message_event import message_was_created
...@@ -438,7 +433,10 @@ class EasyUIBasedGenerateTaskPipeline: ...@@ -438,7 +433,10 @@ class EasyUIBasedGenerateTaskPipeline:
self._message = db.session.query(Message).filter(Message.id == self._message.id).first() self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first() self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first()
self._message.message = self._prompt_messages_to_prompt_for_saving(self._task_state.llm_result.prompt_messages) self._message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving(
self._model_config.mode,
self._task_state.llm_result.prompt_messages
)
self._message.message_tokens = usage.prompt_tokens self._message.message_tokens = usage.prompt_tokens
self._message.message_unit_price = usage.prompt_unit_price self._message.message_unit_price = usage.prompt_unit_price
self._message.message_price_unit = usage.prompt_price_unit self._message.message_price_unit = usage.prompt_price_unit
...@@ -582,77 +580,6 @@ class EasyUIBasedGenerateTaskPipeline: ...@@ -582,77 +580,6 @@ class EasyUIBasedGenerateTaskPipeline:
""" """
return "data: " + json.dumps(response) + "\n\n" return "data: " + json.dumps(response) + "\n\n"
def _prompt_messages_to_prompt_for_saving(self, prompt_messages: list[PromptMessage]) -> list[dict]:
"""
Prompt messages to prompt for saving.
:param prompt_messages: prompt messages
:return:
"""
prompts = []
if self._model_config.mode == ModelMode.CHAT.value:
for prompt_message in prompt_messages:
if prompt_message.role == PromptMessageRole.USER:
role = 'user'
elif prompt_message.role == PromptMessageRole.ASSISTANT:
role = 'assistant'
elif prompt_message.role == PromptMessageRole.SYSTEM:
role = 'system'
else:
continue
text = ''
files = []
if isinstance(prompt_message.content, list):
for content in prompt_message.content:
if content.type == PromptMessageContentType.TEXT:
content = cast(TextPromptMessageContent, content)
text += content.data
else:
content = cast(ImagePromptMessageContent, content)
files.append({
"type": 'image',
"data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:],
"detail": content.detail.value
})
else:
text = prompt_message.content
prompts.append({
"role": role,
"text": text,
"files": files
})
else:
prompt_message = prompt_messages[0]
text = ''
files = []
if isinstance(prompt_message.content, list):
for content in prompt_message.content:
if content.type == PromptMessageContentType.TEXT:
content = cast(TextPromptMessageContent, content)
text += content.data
else:
content = cast(ImagePromptMessageContent, content)
files.append({
"type": 'image',
"data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:],
"detail": content.detail.value
})
else:
text = prompt_message.content
params = {
"role": 'user',
"text": text,
}
if files:
params['files'] = files
prompts.append(params)
return prompts
def _init_output_moderation(self) -> Optional[OutputModeration]: def _init_output_moderation(self) -> Optional[OutputModeration]:
""" """
Init output moderation. Init output moderation.
......
...@@ -5,6 +5,7 @@ from httpx import post ...@@ -5,6 +5,7 @@ from httpx import post
from pydantic import BaseModel from pydantic import BaseModel
from yarl import URL from yarl import URL
from core.helper.code_executor.javascript_transformer import NodeJsTemplateTransformer
from core.helper.code_executor.jina2_transformer import Jinja2TemplateTransformer from core.helper.code_executor.jina2_transformer import Jinja2TemplateTransformer
from core.helper.code_executor.python_transformer import PythonTemplateTransformer from core.helper.code_executor.python_transformer import PythonTemplateTransformer
...@@ -39,17 +40,20 @@ class CodeExecutor: ...@@ -39,17 +40,20 @@ class CodeExecutor:
template_transformer = PythonTemplateTransformer template_transformer = PythonTemplateTransformer
elif language == 'jinja2': elif language == 'jinja2':
template_transformer = Jinja2TemplateTransformer template_transformer = Jinja2TemplateTransformer
elif language == 'javascript':
template_transformer = NodeJsTemplateTransformer
else: else:
raise CodeExecutionException('Unsupported language') raise CodeExecutionException('Unsupported language')
runner = template_transformer.transform_caller(code, inputs) runner = template_transformer.transform_caller(code, inputs)
url = URL(CODE_EXECUTION_ENDPOINT) / 'v1' / 'sandbox' / 'run' url = URL(CODE_EXECUTION_ENDPOINT) / 'v1' / 'sandbox' / 'run'
headers = { headers = {
'X-Api-Key': CODE_EXECUTION_API_KEY 'X-Api-Key': CODE_EXECUTION_API_KEY
} }
data = { data = {
'language': language if language != 'jinja2' else 'python3', 'language': 'python3' if language == 'jinja2' else
'nodejs' if language == 'javascript' else
'python3' if language == 'python3' else None,
'code': runner, 'code': runner,
} }
......
# TODO import json
\ No newline at end of file import re
from core.helper.code_executor.template_transformer import TemplateTransformer
NODEJS_RUNNER = """// declare main function here
{{code}}
// execute main function, and return the result
// inputs is a dict, unstructured inputs
output = main({{inputs}})
// convert output to json and print
output = JSON.stringify(output)
result = `<<RESULT>>${output}<<RESULT>>`
console.log(result)
"""
class NodeJsTemplateTransformer(TemplateTransformer):
@classmethod
def transform_caller(cls, code: str, inputs: dict) -> str:
"""
Transform code to python runner
:param code: code
:param inputs: inputs
:return:
"""
# transform inputs to json string
inputs_str = json.dumps(inputs, indent=4)
# replace code and inputs
runner = NODEJS_RUNNER.replace('{{code}}', code)
runner = runner.replace('{{inputs}}', inputs_str)
return runner
@classmethod
def transform_response(cls, response: str) -> dict:
"""
Transform response to dict
:param response: response
:return:
"""
# extract result
result = re.search(r'<<RESULT>>(.*)<<RESULT>>', response, re.DOTALL)
if not result:
raise ValueError('Failed to parse result')
result = result.group(1)
return json.loads(result)
...@@ -24,11 +24,11 @@ class ModelInstance: ...@@ -24,11 +24,11 @@ class ModelInstance:
""" """
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str) -> None: def __init__(self, provider_model_bundle: ProviderModelBundle, model: str) -> None:
self._provider_model_bundle = provider_model_bundle self.provider_model_bundle = provider_model_bundle
self.model = model self.model = model
self.provider = provider_model_bundle.configuration.provider.provider self.provider = provider_model_bundle.configuration.provider.provider
self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model) self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
self.model_type_instance = self._provider_model_bundle.model_type_instance self.model_type_instance = self.provider_model_bundle.model_type_instance
def _fetch_credentials_from_bundle(self, provider_model_bundle: ProviderModelBundle, model: str) -> dict: def _fetch_credentials_from_bundle(self, provider_model_bundle: ProviderModelBundle, model: str) -> dict:
""" """
......
from typing import Optional from typing import Optional, Union
from core.app.app_config.entities import AdvancedCompletionPromptTemplateEntity, PromptTemplateEntity
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.file.file_obj import FileObj from core.file.file_obj import FileObj
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
...@@ -12,6 +11,7 @@ from core.model_runtime.entities.message_entities import ( ...@@ -12,6 +11,7 @@ from core.model_runtime.entities.message_entities import (
TextPromptMessageContent, TextPromptMessageContent,
UserPromptMessage, UserPromptMessage,
) )
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.prompt.prompt_transform import PromptTransform from core.prompt.prompt_transform import PromptTransform
from core.prompt.simple_prompt_transform import ModelMode from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.prompt.utils.prompt_template_parser import PromptTemplateParser
...@@ -22,11 +22,12 @@ class AdvancedPromptTransform(PromptTransform): ...@@ -22,11 +22,12 @@ class AdvancedPromptTransform(PromptTransform):
Advanced Prompt Transform for Workflow LLM Node. Advanced Prompt Transform for Workflow LLM Node.
""" """
def get_prompt(self, prompt_template_entity: PromptTemplateEntity, def get_prompt(self, prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate],
inputs: dict, inputs: dict,
query: str, query: str,
files: list[FileObj], files: list[FileObj],
context: Optional[str], context: Optional[str],
memory_config: Optional[MemoryConfig],
memory: Optional[TokenBufferMemory], memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]: model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]:
prompt_messages = [] prompt_messages = []
...@@ -34,21 +35,23 @@ class AdvancedPromptTransform(PromptTransform): ...@@ -34,21 +35,23 @@ class AdvancedPromptTransform(PromptTransform):
model_mode = ModelMode.value_of(model_config.mode) model_mode = ModelMode.value_of(model_config.mode)
if model_mode == ModelMode.COMPLETION: if model_mode == ModelMode.COMPLETION:
prompt_messages = self._get_completion_model_prompt_messages( prompt_messages = self._get_completion_model_prompt_messages(
prompt_template_entity=prompt_template_entity, prompt_template=prompt_template,
inputs=inputs, inputs=inputs,
query=query, query=query,
files=files, files=files,
context=context, context=context,
memory_config=memory_config,
memory=memory, memory=memory,
model_config=model_config model_config=model_config
) )
elif model_mode == ModelMode.CHAT: elif model_mode == ModelMode.CHAT:
prompt_messages = self._get_chat_model_prompt_messages( prompt_messages = self._get_chat_model_prompt_messages(
prompt_template_entity=prompt_template_entity, prompt_template=prompt_template,
inputs=inputs, inputs=inputs,
query=query, query=query,
files=files, files=files,
context=context, context=context,
memory_config=memory_config,
memory=memory, memory=memory,
model_config=model_config model_config=model_config
) )
...@@ -56,17 +59,18 @@ class AdvancedPromptTransform(PromptTransform): ...@@ -56,17 +59,18 @@ class AdvancedPromptTransform(PromptTransform):
return prompt_messages return prompt_messages
def _get_completion_model_prompt_messages(self, def _get_completion_model_prompt_messages(self,
prompt_template_entity: PromptTemplateEntity, prompt_template: CompletionModelPromptTemplate,
inputs: dict, inputs: dict,
query: Optional[str], query: Optional[str],
files: list[FileObj], files: list[FileObj],
context: Optional[str], context: Optional[str],
memory_config: Optional[MemoryConfig],
memory: Optional[TokenBufferMemory], memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]: model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]:
""" """
Get completion model prompt messages. Get completion model prompt messages.
""" """
raw_prompt = prompt_template_entity.advanced_completion_prompt_template.prompt raw_prompt = prompt_template.text
prompt_messages = [] prompt_messages = []
...@@ -75,15 +79,17 @@ class AdvancedPromptTransform(PromptTransform): ...@@ -75,15 +79,17 @@ class AdvancedPromptTransform(PromptTransform):
prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs) prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
role_prefix = prompt_template_entity.advanced_completion_prompt_template.role_prefix if memory and memory_config:
prompt_inputs = self._set_histories_variable( role_prefix = memory_config.role_prefix
memory=memory, prompt_inputs = self._set_histories_variable(
raw_prompt=raw_prompt, memory=memory,
role_prefix=role_prefix, memory_config=memory_config,
prompt_template=prompt_template, raw_prompt=raw_prompt,
prompt_inputs=prompt_inputs, role_prefix=role_prefix,
model_config=model_config prompt_template=prompt_template,
) prompt_inputs=prompt_inputs,
model_config=model_config
)
if query: if query:
prompt_inputs = self._set_query_variable(query, prompt_template, prompt_inputs) prompt_inputs = self._set_query_variable(query, prompt_template, prompt_inputs)
...@@ -104,17 +110,18 @@ class AdvancedPromptTransform(PromptTransform): ...@@ -104,17 +110,18 @@ class AdvancedPromptTransform(PromptTransform):
return prompt_messages return prompt_messages
def _get_chat_model_prompt_messages(self, def _get_chat_model_prompt_messages(self,
prompt_template_entity: PromptTemplateEntity, prompt_template: list[ChatModelMessage],
inputs: dict, inputs: dict,
query: Optional[str], query: Optional[str],
files: list[FileObj], files: list[FileObj],
context: Optional[str], context: Optional[str],
memory_config: Optional[MemoryConfig],
memory: Optional[TokenBufferMemory], memory: Optional[TokenBufferMemory],
model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]: model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]:
""" """
Get chat model prompt messages. Get chat model prompt messages.
""" """
raw_prompt_list = prompt_template_entity.advanced_chat_prompt_template.messages raw_prompt_list = prompt_template
prompt_messages = [] prompt_messages = []
...@@ -137,8 +144,8 @@ class AdvancedPromptTransform(PromptTransform): ...@@ -137,8 +144,8 @@ class AdvancedPromptTransform(PromptTransform):
elif prompt_item.role == PromptMessageRole.ASSISTANT: elif prompt_item.role == PromptMessageRole.ASSISTANT:
prompt_messages.append(AssistantPromptMessage(content=prompt)) prompt_messages.append(AssistantPromptMessage(content=prompt))
if memory: if memory and memory_config:
prompt_messages = self._append_chat_histories(memory, prompt_messages, model_config) prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config)
if files: if files:
prompt_message_contents = [TextPromptMessageContent(data=query)] prompt_message_contents = [TextPromptMessageContent(data=query)]
...@@ -195,8 +202,9 @@ class AdvancedPromptTransform(PromptTransform): ...@@ -195,8 +202,9 @@ class AdvancedPromptTransform(PromptTransform):
return prompt_inputs return prompt_inputs
def _set_histories_variable(self, memory: TokenBufferMemory, def _set_histories_variable(self, memory: TokenBufferMemory,
memory_config: MemoryConfig,
raw_prompt: str, raw_prompt: str,
role_prefix: AdvancedCompletionPromptTemplateEntity.RolePrefixEntity, role_prefix: MemoryConfig.RolePrefix,
prompt_template: PromptTemplateParser, prompt_template: PromptTemplateParser,
prompt_inputs: dict, prompt_inputs: dict,
model_config: ModelConfigWithCredentialsEntity) -> dict: model_config: ModelConfigWithCredentialsEntity) -> dict:
...@@ -213,6 +221,7 @@ class AdvancedPromptTransform(PromptTransform): ...@@ -213,6 +221,7 @@ class AdvancedPromptTransform(PromptTransform):
histories = self._get_history_messages_from_memory( histories = self._get_history_messages_from_memory(
memory=memory, memory=memory,
memory_config=memory_config,
max_token_limit=rest_tokens, max_token_limit=rest_tokens,
human_prefix=role_prefix.user, human_prefix=role_prefix.user,
ai_prefix=role_prefix.assistant ai_prefix=role_prefix.assistant
......
from typing import Optional
from pydantic import BaseModel
from core.model_runtime.entities.message_entities import PromptMessageRole
class ChatModelMessage(BaseModel):
"""
Chat Message.
"""
text: str
role: PromptMessageRole
class CompletionModelPromptTemplate(BaseModel):
"""
Completion Model Prompt Template.
"""
text: str
class MemoryConfig(BaseModel):
"""
Memory Config.
"""
class RolePrefix(BaseModel):
"""
Role Prefix.
"""
user: str
assistant: str
class WindowConfig(BaseModel):
"""
Window Config.
"""
enabled: bool
size: Optional[int] = None
role_prefix: Optional[RolePrefix] = None
window: WindowConfig
...@@ -5,19 +5,22 @@ from core.memory.token_buffer_memory import TokenBufferMemory ...@@ -5,19 +5,22 @@ from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import PromptMessage from core.model_runtime.entities.message_entities import PromptMessage
from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
class PromptTransform: class PromptTransform:
def _append_chat_histories(self, memory: TokenBufferMemory, def _append_chat_histories(self, memory: TokenBufferMemory,
memory_config: MemoryConfig,
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]: model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]:
rest_tokens = self._calculate_rest_token(prompt_messages, model_config) rest_tokens = self._calculate_rest_token(prompt_messages, model_config)
histories = self._get_history_messages_list_from_memory(memory, rest_tokens) histories = self._get_history_messages_list_from_memory(memory, memory_config, rest_tokens)
prompt_messages.extend(histories) prompt_messages.extend(histories)
return prompt_messages return prompt_messages
def _calculate_rest_token(self, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity) -> int: def _calculate_rest_token(self, prompt_messages: list[PromptMessage],
model_config: ModelConfigWithCredentialsEntity) -> int:
rest_tokens = 2000 rest_tokens = 2000
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
...@@ -44,6 +47,7 @@ class PromptTransform: ...@@ -44,6 +47,7 @@ class PromptTransform:
return rest_tokens return rest_tokens
def _get_history_messages_from_memory(self, memory: TokenBufferMemory, def _get_history_messages_from_memory(self, memory: TokenBufferMemory,
memory_config: MemoryConfig,
max_token_limit: int, max_token_limit: int,
human_prefix: Optional[str] = None, human_prefix: Optional[str] = None,
ai_prefix: Optional[str] = None) -> str: ai_prefix: Optional[str] = None) -> str:
...@@ -58,13 +62,22 @@ class PromptTransform: ...@@ -58,13 +62,22 @@ class PromptTransform:
if ai_prefix: if ai_prefix:
kwargs['ai_prefix'] = ai_prefix kwargs['ai_prefix'] = ai_prefix
if memory_config.window.enabled and memory_config.window.size is not None and memory_config.window.size > 0:
kwargs['message_limit'] = memory_config.window.size
return memory.get_history_prompt_text( return memory.get_history_prompt_text(
**kwargs **kwargs
) )
def _get_history_messages_list_from_memory(self, memory: TokenBufferMemory, def _get_history_messages_list_from_memory(self, memory: TokenBufferMemory,
memory_config: MemoryConfig,
max_token_limit: int) -> list[PromptMessage]: max_token_limit: int) -> list[PromptMessage]:
"""Get memory messages.""" """Get memory messages."""
return memory.get_history_prompt_messages( return memory.get_history_prompt_messages(
max_token_limit=max_token_limit max_token_limit=max_token_limit,
message_limit=memory_config.window.size
if (memory_config.window.enabled
and memory_config.window.size is not None
and memory_config.window.size > 0)
else 10
) )
...@@ -13,6 +13,7 @@ from core.model_runtime.entities.message_entities import ( ...@@ -13,6 +13,7 @@ from core.model_runtime.entities.message_entities import (
TextPromptMessageContent, TextPromptMessageContent,
UserPromptMessage, UserPromptMessage,
) )
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.prompt.prompt_transform import PromptTransform from core.prompt.prompt_transform import PromptTransform
from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from models.model import AppMode from models.model import AppMode
...@@ -182,6 +183,11 @@ class SimplePromptTransform(PromptTransform): ...@@ -182,6 +183,11 @@ class SimplePromptTransform(PromptTransform):
if memory: if memory:
prompt_messages = self._append_chat_histories( prompt_messages = self._append_chat_histories(
memory=memory, memory=memory,
memory_config=MemoryConfig(
window=MemoryConfig.WindowConfig(
enabled=False,
)
),
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
model_config=model_config model_config=model_config
) )
...@@ -220,6 +226,11 @@ class SimplePromptTransform(PromptTransform): ...@@ -220,6 +226,11 @@ class SimplePromptTransform(PromptTransform):
rest_tokens = self._calculate_rest_token([tmp_human_message], model_config) rest_tokens = self._calculate_rest_token([tmp_human_message], model_config)
histories = self._get_history_messages_from_memory( histories = self._get_history_messages_from_memory(
memory=memory, memory=memory,
memory_config=MemoryConfig(
window=MemoryConfig.WindowConfig(
enabled=False,
)
),
max_token_limit=rest_tokens, max_token_limit=rest_tokens,
ai_prefix=prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human', ai_prefix=prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human',
human_prefix=prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant' human_prefix=prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
......
from typing import cast
from core.model_runtime.entities.message_entities import (
ImagePromptMessageContent,
PromptMessage,
PromptMessageContentType,
PromptMessageRole,
TextPromptMessageContent,
)
from core.prompt.simple_prompt_transform import ModelMode
class PromptMessageUtil:
@staticmethod
def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: list[PromptMessage]) -> list[dict]:
"""
Prompt messages to prompt for saving.
:param model_mode: model mode
:param prompt_messages: prompt messages
:return:
"""
prompts = []
if model_mode == ModelMode.CHAT.value:
for prompt_message in prompt_messages:
if prompt_message.role == PromptMessageRole.USER:
role = 'user'
elif prompt_message.role == PromptMessageRole.ASSISTANT:
role = 'assistant'
elif prompt_message.role == PromptMessageRole.SYSTEM:
role = 'system'
else:
continue
text = ''
files = []
if isinstance(prompt_message.content, list):
for content in prompt_message.content:
if content.type == PromptMessageContentType.TEXT:
content = cast(TextPromptMessageContent, content)
text += content.data
else:
content = cast(ImagePromptMessageContent, content)
files.append({
"type": 'image',
"data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:],
"detail": content.detail.value
})
else:
text = prompt_message.content
prompts.append({
"role": role,
"text": text,
"files": files
})
else:
prompt_message = prompt_messages[0]
text = ''
files = []
if isinstance(prompt_message.content, list):
for content in prompt_message.content:
if content.type == PromptMessageContentType.TEXT:
content = cast(TextPromptMessageContent, content)
text += content.data
else:
content = cast(ImagePromptMessageContent, content)
files.append({
"type": 'image',
"data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:],
"detail": content.detail.value
})
else:
text = prompt_message.content
params = {
"role": 'user',
"text": text,
}
if files:
params['files'] = files
prompts.append(params)
return prompts
...@@ -12,7 +12,7 @@ class NodeType(Enum): ...@@ -12,7 +12,7 @@ class NodeType(Enum):
""" """
START = 'start' START = 'start'
END = 'end' END = 'end'
DIRECT_ANSWER = 'direct-answer' ANSWER = 'answer'
LLM = 'llm' LLM = 'llm'
KNOWLEDGE_RETRIEVAL = 'knowledge-retrieval' KNOWLEDGE_RETRIEVAL = 'knowledge-retrieval'
IF_ELSE = 'if-else' IF_ELSE = 'if-else'
......
...@@ -5,14 +5,14 @@ from core.prompt.utils.prompt_template_parser import PromptTemplateParser ...@@ -5,14 +5,14 @@ from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import ValueType, VariablePool from core.workflow.entities.variable_pool import ValueType, VariablePool
from core.workflow.nodes.answer.entities import AnswerNodeData
from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.direct_answer.entities import DirectAnswerNodeData
from models.workflow import WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionStatus
class DirectAnswerNode(BaseNode): class AnswerNode(BaseNode):
_node_data_cls = DirectAnswerNodeData _node_data_cls = AnswerNodeData
node_type = NodeType.DIRECT_ANSWER node_type = NodeType.ANSWER
def _run(self, variable_pool: VariablePool) -> NodeRunResult: def _run(self, variable_pool: VariablePool) -> NodeRunResult:
""" """
......
...@@ -2,9 +2,9 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData ...@@ -2,9 +2,9 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.variable_entities import VariableSelector from core.workflow.entities.variable_entities import VariableSelector
class DirectAnswerNodeData(BaseNodeData): class AnswerNodeData(BaseNodeData):
""" """
DirectAnswer Node Data. Answer Node Data.
""" """
variables: list[VariableSelector] = [] variables: list[VariableSelector] = []
answer: str answer: str
...@@ -15,6 +15,16 @@ MAX_STRING_LENGTH = 1000 ...@@ -15,6 +15,16 @@ MAX_STRING_LENGTH = 1000
MAX_STRING_ARRAY_LENGTH = 30 MAX_STRING_ARRAY_LENGTH = 30
MAX_NUMBER_ARRAY_LENGTH = 1000 MAX_NUMBER_ARRAY_LENGTH = 1000
JAVASCRIPT_DEFAULT_CODE = """function main({args1, args2}) {
return {
result: args1 + args2
}
}"""
PYTHON_DEFAULT_CODE = """def main(args1: int, args2: int) -> dict:
return {
"result": args1 + args2,
}"""
class CodeNode(BaseNode): class CodeNode(BaseNode):
_node_data_cls = CodeNodeData _node_data_cls = CodeNodeData
...@@ -42,9 +52,7 @@ class CodeNode(BaseNode): ...@@ -42,9 +52,7 @@ class CodeNode(BaseNode):
} }
], ],
"code_language": "javascript", "code_language": "javascript",
"code": "async function main(arg1, arg2) {\n return new Promise((resolve, reject) => {" "code": JAVASCRIPT_DEFAULT_CODE,
"\n if (true) {\n resolve({\n \"result\": arg1 + arg2"
"\n });\n } else {\n reject(\"e\");\n }\n });\n}",
"outputs": [ "outputs": [
{ {
"variable": "result", "variable": "result",
...@@ -68,8 +76,7 @@ class CodeNode(BaseNode): ...@@ -68,8 +76,7 @@ class CodeNode(BaseNode):
} }
], ],
"code_language": "python3", "code_language": "python3",
"code": "def main(\n arg1: int,\n arg2: int,\n) -> int:\n return {\n \"result\": arg1 " "code": PYTHON_DEFAULT_CODE,
"+ arg2\n }",
"outputs": [ "outputs": [
{ {
"variable": "result", "variable": "result",
......
...@@ -17,4 +17,4 @@ class CodeNodeData(BaseNodeData): ...@@ -17,4 +17,4 @@ class CodeNodeData(BaseNodeData):
variables: list[VariableSelector] variables: list[VariableSelector]
code_language: Literal['python3', 'javascript'] code_language: Literal['python3', 'javascript']
code: str code: str
outputs: dict[str, Output] outputs: dict[str, Output]
\ No newline at end of file
from typing import Any, Literal, Optional, Union
from pydantic import BaseModel
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.variable_entities import VariableSelector
class ModelConfig(BaseModel):
"""
Model Config.
"""
provider: str
name: str
mode: str
completion_params: dict[str, Any] = {}
class ContextConfig(BaseModel):
"""
Context Config.
"""
enabled: bool
variable_selector: Optional[list[str]] = None
class VisionConfig(BaseModel):
"""
Vision Config.
"""
class Configs(BaseModel):
"""
Configs.
"""
detail: Literal['low', 'high']
enabled: bool
configs: Optional[Configs] = None
class LLMNodeData(BaseNodeData): class LLMNodeData(BaseNodeData):
""" """
LLM Node Data. LLM Node Data.
""" """
pass model: ModelConfig
variables: list[VariableSelector] = []
prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate]
memory: Optional[MemoryConfig] = None
context: ContextConfig
vision: VisionConfig
This diff is collapsed.
...@@ -7,9 +7,9 @@ from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResu ...@@ -7,9 +7,9 @@ from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResu
from core.workflow.entities.variable_pool import VariablePool, VariableValue from core.workflow.entities.variable_pool import VariablePool, VariableValue
from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, WorkflowRunState from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, WorkflowRunState
from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.base_node import BaseNode, UserFrom from core.workflow.nodes.base_node import BaseNode, UserFrom
from core.workflow.nodes.code.code_node import CodeNode from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.direct_answer.direct_answer_node import DirectAnswerNode
from core.workflow.nodes.end.end_node import EndNode from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.http_request.http_request_node import HttpRequestNode from core.workflow.nodes.http_request.http_request_node import HttpRequestNode
from core.workflow.nodes.if_else.if_else_node import IfElseNode from core.workflow.nodes.if_else.if_else_node import IfElseNode
...@@ -24,13 +24,12 @@ from extensions.ext_database import db ...@@ -24,13 +24,12 @@ from extensions.ext_database import db
from models.workflow import ( from models.workflow import (
Workflow, Workflow,
WorkflowNodeExecutionStatus, WorkflowNodeExecutionStatus,
WorkflowType,
) )
node_classes = { node_classes = {
NodeType.START: StartNode, NodeType.START: StartNode,
NodeType.END: EndNode, NodeType.END: EndNode,
NodeType.DIRECT_ANSWER: DirectAnswerNode, NodeType.ANSWER: AnswerNode,
NodeType.LLM: LLMNode, NodeType.LLM: LLMNode,
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode, NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
NodeType.IF_ELSE: IfElseNode, NodeType.IF_ELSE: IfElseNode,
...@@ -156,7 +155,7 @@ class WorkflowEngineManager: ...@@ -156,7 +155,7 @@ class WorkflowEngineManager:
callbacks=callbacks callbacks=callbacks
) )
if next_node.node_type == NodeType.END: if next_node.node_type in [NodeType.END, NodeType.ANSWER]:
break break
predecessor_node = next_node predecessor_node = next_node
...@@ -402,10 +401,16 @@ class WorkflowEngineManager: ...@@ -402,10 +401,16 @@ class WorkflowEngineManager:
# add to workflow_nodes_and_results # add to workflow_nodes_and_results
workflow_run_state.workflow_nodes_and_results.append(workflow_nodes_and_result) workflow_run_state.workflow_nodes_and_results.append(workflow_nodes_and_result)
# run node, result must have inputs, process_data, outputs, execution_metadata try:
node_run_result = node.run( # run node, result must have inputs, process_data, outputs, execution_metadata
variable_pool=workflow_run_state.variable_pool node_run_result = node.run(
) variable_pool=workflow_run_state.variable_pool
)
except Exception as e:
node_run_result = NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e)
)
if node_run_result.status == WorkflowNodeExecutionStatus.FAILED: if node_run_result.status == WorkflowNodeExecutionStatus.FAILED:
# node run failed # node run failed
...@@ -420,9 +425,6 @@ class WorkflowEngineManager: ...@@ -420,9 +425,6 @@ class WorkflowEngineManager:
raise ValueError(f"Node {node.node_data.title} run failed: {node_run_result.error}") raise ValueError(f"Node {node.node_data.title} run failed: {node_run_result.error}")
# set end node output if in chat
self._set_end_node_output_if_in_chat(workflow_run_state, node, node_run_result)
workflow_nodes_and_result.result = node_run_result workflow_nodes_and_result.result = node_run_result
# node run success # node run success
...@@ -453,29 +455,6 @@ class WorkflowEngineManager: ...@@ -453,29 +455,6 @@ class WorkflowEngineManager:
db.session.close() db.session.close()
def _set_end_node_output_if_in_chat(self, workflow_run_state: WorkflowRunState,
node: BaseNode,
node_run_result: NodeRunResult) -> None:
"""
Set end node output if in chat
:param workflow_run_state: workflow run state
:param node: current node
:param node_run_result: node run result
:return:
"""
if workflow_run_state.workflow_type == WorkflowType.CHAT and node.node_type == NodeType.END:
workflow_nodes_and_result_before_end = workflow_run_state.workflow_nodes_and_results[-2]
if workflow_nodes_and_result_before_end:
if workflow_nodes_and_result_before_end.node.node_type == NodeType.LLM:
if not node_run_result.outputs:
node_run_result.outputs = {}
node_run_result.outputs['text'] = workflow_nodes_and_result_before_end.result.outputs.get('text')
elif workflow_nodes_and_result_before_end.node.node_type == NodeType.DIRECT_ANSWER:
if not node_run_result.outputs:
node_run_result.outputs = {}
node_run_result.outputs['text'] = workflow_nodes_and_result_before_end.result.outputs.get('answer')
def _append_variables_recursively(self, variable_pool: VariablePool, def _append_variables_recursively(self, variable_pool: VariablePool,
node_id: str, node_id: str,
......
...@@ -270,28 +270,48 @@ class WorkflowService: ...@@ -270,28 +270,48 @@ class WorkflowService:
return workflow_node_execution return workflow_node_execution
# create workflow node execution if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
workflow_node_execution = WorkflowNodeExecution( # create workflow node execution
tenant_id=app_model.tenant_id, workflow_node_execution = WorkflowNodeExecution(
app_id=app_model.id, tenant_id=app_model.tenant_id,
workflow_id=draft_workflow.id, app_id=app_model.id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value, workflow_id=draft_workflow.id,
index=1, triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value,
node_id=node_id, index=1,
node_type=node_instance.node_type.value, node_id=node_id,
title=node_instance.node_data.title, node_type=node_instance.node_type.value,
inputs=json.dumps(node_run_result.inputs) if node_run_result.inputs else None, title=node_instance.node_data.title,
process_data=json.dumps(node_run_result.process_data) if node_run_result.process_data else None, inputs=json.dumps(node_run_result.inputs) if node_run_result.inputs else None,
outputs=json.dumps(node_run_result.outputs) if node_run_result.outputs else None, process_data=json.dumps(node_run_result.process_data) if node_run_result.process_data else None,
execution_metadata=(json.dumps(jsonable_encoder(node_run_result.metadata)) outputs=json.dumps(node_run_result.outputs) if node_run_result.outputs else None,
if node_run_result.metadata else None), execution_metadata=(json.dumps(jsonable_encoder(node_run_result.metadata))
status=WorkflowNodeExecutionStatus.SUCCEEDED.value, if node_run_result.metadata else None),
elapsed_time=time.perf_counter() - start_at, status=WorkflowNodeExecutionStatus.SUCCEEDED.value,
created_by_role=CreatedByRole.ACCOUNT.value, elapsed_time=time.perf_counter() - start_at,
created_by=account.id, created_by_role=CreatedByRole.ACCOUNT.value,
created_at=datetime.utcnow(), created_by=account.id,
finished_at=datetime.utcnow() created_at=datetime.utcnow(),
) finished_at=datetime.utcnow()
)
else:
# create workflow node execution
workflow_node_execution = WorkflowNodeExecution(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
workflow_id=draft_workflow.id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value,
index=1,
node_id=node_id,
node_type=node_instance.node_type.value,
title=node_instance.node_data.title,
status=node_run_result.status.value,
error=node_run_result.error,
elapsed_time=time.perf_counter() - start_at,
created_by_role=CreatedByRole.ACCOUNT.value,
created_by=account.id,
created_at=datetime.utcnow(),
finished_at=datetime.utcnow()
)
db.session.add(workflow_node_execution) db.session.add(workflow_node_execution)
db.session.commit() db.session.commit()
......
import os
from unittest.mock import MagicMock
import pytest
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.provider_configuration import ProviderModelBundle, ProviderConfiguration
from core.entities.provider_entities import SystemConfiguration, CustomConfiguration, CustomProviderConfiguration
from core.model_manager import ModelInstance
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers import ModelProviderFactory
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import UserFrom
from core.workflow.nodes.llm.llm_node import LLMNode
from extensions.ext_database import db
from models.provider import ProviderType
from models.workflow import WorkflowNodeExecutionStatus
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
def test_execute_llm(setup_openai_mock):
node = LLMNode(
tenant_id='1',
app_id='1',
workflow_id='1',
user_id='1',
user_from=UserFrom.ACCOUNT,
config={
'id': 'llm',
'data': {
'title': '123',
'type': 'llm',
'model': {
'provider': 'openai',
'name': 'gpt-3.5.turbo',
'mode': 'chat',
'completion_params': {}
},
'variables': [
{
'variable': 'weather',
'value_selector': ['abc', 'output'],
},
{
'variable': 'query',
'value_selector': ['sys', 'query']
}
],
'prompt_template': [
{
'role': 'system',
'text': 'you are a helpful assistant.\ntoday\'s weather is {{weather}}.'
},
{
'role': 'user',
'text': '{{query}}'
}
],
'memory': {
'window': {
'enabled': True,
'size': 2
}
},
'context': {
'enabled': False
},
'vision': {
'enabled': False
}
}
}
)
# construct variable pool
pool = VariablePool(system_variables={
SystemVariable.QUERY: 'what\'s the weather today?',
SystemVariable.FILES: [],
SystemVariable.CONVERSATION: 'abababa'
}, user_inputs={})
pool.append_variable(node_id='abc', variable_key_list=['output'], value='sunny')
credentials = {
'openai_api_key': os.environ.get('OPENAI_API_KEY')
}
provider_instance = ModelProviderFactory().get_provider_instance('openai')
model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
provider_model_bundle = ProviderModelBundle(
configuration=ProviderConfiguration(
tenant_id='1',
provider=provider_instance.get_provider_schema(),
preferred_provider_type=ProviderType.CUSTOM,
using_provider_type=ProviderType.CUSTOM,
system_configuration=SystemConfiguration(
enabled=False
),
custom_configuration=CustomConfiguration(
provider=CustomProviderConfiguration(
credentials=credentials
)
)
),
provider_instance=provider_instance,
model_type_instance=model_type_instance
)
model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model='gpt-3.5-turbo')
model_config = ModelConfigWithCredentialsEntity(
model='gpt-3.5-turbo',
provider='openai',
mode='chat',
credentials=credentials,
parameters={},
model_schema=model_type_instance.get_model_schema('gpt-3.5-turbo'),
provider_model_bundle=provider_model_bundle
)
# Mock db.session.close()
db.session.close = MagicMock()
node._fetch_model_config = MagicMock(return_value=tuple([model_instance, model_config]))
# execute node
result = node.run(pool)
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs['text'] is not None
assert result.outputs['usage']['total_tokens'] > 0
import pytest import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import UserFrom
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from models.workflow import WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionStatus
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
...@@ -14,7 +14,7 @@ def test_execute_code(setup_code_executor_mock): ...@@ -14,7 +14,7 @@ def test_execute_code(setup_code_executor_mock):
app_id='1', app_id='1',
workflow_id='1', workflow_id='1',
user_id='1', user_id='1',
user_from=InvokeFrom.WEB_APP, user_from=UserFrom.END_USER,
config={ config={
'id': '1', 'id': '1',
'data': { 'data': {
......
...@@ -2,12 +2,12 @@ from unittest.mock import MagicMock ...@@ -2,12 +2,12 @@ from unittest.mock import MagicMock
import pytest import pytest
from core.app.app_config.entities import PromptTemplateEntity, AdvancedCompletionPromptTemplateEntity, \ from core.app.app_config.entities import ModelConfigEntity, FileUploadEntity
ModelConfigEntity, AdvancedChatPromptTemplateEntity, AdvancedChatMessageEntity, FileUploadEntity
from core.file.file_obj import FileObj, FileType, FileTransferMethod from core.file.file_obj import FileObj, FileType, FileTransferMethod
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage, PromptMessageRole from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage, PromptMessageRole
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig, ChatModelMessage
from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from models.model import Conversation from models.model import Conversation
...@@ -18,16 +18,20 @@ def test__get_completion_model_prompt_messages(): ...@@ -18,16 +18,20 @@ def test__get_completion_model_prompt_messages():
model_config_mock.model = 'gpt-3.5-turbo-instruct' model_config_mock.model = 'gpt-3.5-turbo-instruct'
prompt_template = "Context:\n{{#context#}}\n\nHistories:\n{{#histories#}}\n\nyou are {{name}}." prompt_template = "Context:\n{{#context#}}\n\nHistories:\n{{#histories#}}\n\nyou are {{name}}."
prompt_template_entity = PromptTemplateEntity( prompt_template_config = CompletionModelPromptTemplate(
prompt_type=PromptTemplateEntity.PromptType.ADVANCED, text=prompt_template
advanced_completion_prompt_template=AdvancedCompletionPromptTemplateEntity( )
prompt=prompt_template,
role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity( memory_config = MemoryConfig(
user="Human", role_prefix=MemoryConfig.RolePrefix(
assistant="Assistant" user="Human",
) assistant="Assistant"
),
window=MemoryConfig.WindowConfig(
enabled=False
) )
) )
inputs = { inputs = {
"name": "John" "name": "John"
} }
...@@ -48,11 +52,12 @@ def test__get_completion_model_prompt_messages(): ...@@ -48,11 +52,12 @@ def test__get_completion_model_prompt_messages():
prompt_transform = AdvancedPromptTransform() prompt_transform = AdvancedPromptTransform()
prompt_transform._calculate_rest_token = MagicMock(return_value=2000) prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
prompt_messages = prompt_transform._get_completion_model_prompt_messages( prompt_messages = prompt_transform._get_completion_model_prompt_messages(
prompt_template_entity=prompt_template_entity, prompt_template=prompt_template_config,
inputs=inputs, inputs=inputs,
query=None, query=None,
files=files, files=files,
context=context, context=context,
memory_config=memory_config,
memory=memory, memory=memory,
model_config=model_config_mock model_config=model_config_mock
) )
...@@ -67,7 +72,7 @@ def test__get_completion_model_prompt_messages(): ...@@ -67,7 +72,7 @@ def test__get_completion_model_prompt_messages():
def test__get_chat_model_prompt_messages(get_chat_model_args): def test__get_chat_model_prompt_messages(get_chat_model_args):
model_config_mock, prompt_template_entity, inputs, context = get_chat_model_args model_config_mock, memory_config, messages, inputs, context = get_chat_model_args
files = [] files = []
query = "Hi2." query = "Hi2."
...@@ -86,11 +91,12 @@ def test__get_chat_model_prompt_messages(get_chat_model_args): ...@@ -86,11 +91,12 @@ def test__get_chat_model_prompt_messages(get_chat_model_args):
prompt_transform = AdvancedPromptTransform() prompt_transform = AdvancedPromptTransform()
prompt_transform._calculate_rest_token = MagicMock(return_value=2000) prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
prompt_messages = prompt_transform._get_chat_model_prompt_messages( prompt_messages = prompt_transform._get_chat_model_prompt_messages(
prompt_template_entity=prompt_template_entity, prompt_template=messages,
inputs=inputs, inputs=inputs,
query=query, query=query,
files=files, files=files,
context=context, context=context,
memory_config=memory_config,
memory=memory, memory=memory,
model_config=model_config_mock model_config=model_config_mock
) )
...@@ -98,24 +104,25 @@ def test__get_chat_model_prompt_messages(get_chat_model_args): ...@@ -98,24 +104,25 @@ def test__get_chat_model_prompt_messages(get_chat_model_args):
assert len(prompt_messages) == 6 assert len(prompt_messages) == 6
assert prompt_messages[0].role == PromptMessageRole.SYSTEM assert prompt_messages[0].role == PromptMessageRole.SYSTEM
assert prompt_messages[0].content == PromptTemplateParser( assert prompt_messages[0].content == PromptTemplateParser(
template=prompt_template_entity.advanced_chat_prompt_template.messages[0].text template=messages[0].text
).format({**inputs, "#context#": context}) ).format({**inputs, "#context#": context})
assert prompt_messages[5].content == query assert prompt_messages[5].content == query
def test__get_chat_model_prompt_messages_no_memory(get_chat_model_args): def test__get_chat_model_prompt_messages_no_memory(get_chat_model_args):
model_config_mock, prompt_template_entity, inputs, context = get_chat_model_args model_config_mock, _, messages, inputs, context = get_chat_model_args
files = [] files = []
prompt_transform = AdvancedPromptTransform() prompt_transform = AdvancedPromptTransform()
prompt_transform._calculate_rest_token = MagicMock(return_value=2000) prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
prompt_messages = prompt_transform._get_chat_model_prompt_messages( prompt_messages = prompt_transform._get_chat_model_prompt_messages(
prompt_template_entity=prompt_template_entity, prompt_template=messages,
inputs=inputs, inputs=inputs,
query=None, query=None,
files=files, files=files,
context=context, context=context,
memory_config=None,
memory=None, memory=None,
model_config=model_config_mock model_config=model_config_mock
) )
...@@ -123,12 +130,12 @@ def test__get_chat_model_prompt_messages_no_memory(get_chat_model_args): ...@@ -123,12 +130,12 @@ def test__get_chat_model_prompt_messages_no_memory(get_chat_model_args):
assert len(prompt_messages) == 3 assert len(prompt_messages) == 3
assert prompt_messages[0].role == PromptMessageRole.SYSTEM assert prompt_messages[0].role == PromptMessageRole.SYSTEM
assert prompt_messages[0].content == PromptTemplateParser( assert prompt_messages[0].content == PromptTemplateParser(
template=prompt_template_entity.advanced_chat_prompt_template.messages[0].text template=messages[0].text
).format({**inputs, "#context#": context}) ).format({**inputs, "#context#": context})
def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_args): def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_args):
model_config_mock, prompt_template_entity, inputs, context = get_chat_model_args model_config_mock, _, messages, inputs, context = get_chat_model_args
files = [ files = [
FileObj( FileObj(
...@@ -148,11 +155,12 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg ...@@ -148,11 +155,12 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
prompt_transform = AdvancedPromptTransform() prompt_transform = AdvancedPromptTransform()
prompt_transform._calculate_rest_token = MagicMock(return_value=2000) prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
prompt_messages = prompt_transform._get_chat_model_prompt_messages( prompt_messages = prompt_transform._get_chat_model_prompt_messages(
prompt_template_entity=prompt_template_entity, prompt_template=messages,
inputs=inputs, inputs=inputs,
query=None, query=None,
files=files, files=files,
context=context, context=context,
memory_config=None,
memory=None, memory=None,
model_config=model_config_mock model_config=model_config_mock
) )
...@@ -160,7 +168,7 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg ...@@ -160,7 +168,7 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
assert len(prompt_messages) == 4 assert len(prompt_messages) == 4
assert prompt_messages[0].role == PromptMessageRole.SYSTEM assert prompt_messages[0].role == PromptMessageRole.SYSTEM
assert prompt_messages[0].content == PromptTemplateParser( assert prompt_messages[0].content == PromptTemplateParser(
template=prompt_template_entity.advanced_chat_prompt_template.messages[0].text template=messages[0].text
).format({**inputs, "#context#": context}) ).format({**inputs, "#context#": context})
assert isinstance(prompt_messages[3].content, list) assert isinstance(prompt_messages[3].content, list)
assert len(prompt_messages[3].content) == 2 assert len(prompt_messages[3].content) == 2
...@@ -173,22 +181,31 @@ def get_chat_model_args(): ...@@ -173,22 +181,31 @@ def get_chat_model_args():
model_config_mock.provider = 'openai' model_config_mock.provider = 'openai'
model_config_mock.model = 'gpt-4' model_config_mock.model = 'gpt-4'
prompt_template_entity = PromptTemplateEntity( memory_config = MemoryConfig(
prompt_type=PromptTemplateEntity.PromptType.ADVANCED, window=MemoryConfig.WindowConfig(
advanced_chat_prompt_template=AdvancedChatPromptTemplateEntity( enabled=False
messages=[
AdvancedChatMessageEntity(text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}",
role=PromptMessageRole.SYSTEM),
AdvancedChatMessageEntity(text="Hi.", role=PromptMessageRole.USER),
AdvancedChatMessageEntity(text="Hello!", role=PromptMessageRole.ASSISTANT),
]
) )
) )
prompt_messages = [
ChatModelMessage(
text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}",
role=PromptMessageRole.SYSTEM
),
ChatModelMessage(
text="Hi.",
role=PromptMessageRole.USER
),
ChatModelMessage(
text="Hello!",
role=PromptMessageRole.ASSISTANT
)
]
inputs = { inputs = {
"name": "John" "name": "John"
} }
context = "I am superman." context = "I am superman."
return model_config_mock, prompt_template_entity, inputs, context return model_config_mock, memory_config, prompt_messages, inputs, context
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment