Unverified Commit ad84b996 authored by Yeuoly's avatar Yeuoly

feat: model invoke api

parent 282922f3
from flask_restful import Resource
import json
from flask_restful import Resource, reqparse
from flask import Response
from flask.helpers import stream_with_context
from controllers.console.setup import setup_required
from controllers.inner_api import api
from controllers.inner_api.wraps import inner_api_only
from services.completion_service import CompletionService
from typing import Generator, Union
class EnterpriseModelInvokeLLMApi(Resource):
"""Model invoke API for enterprise edition"""
......@@ -11,6 +18,41 @@ class EnterpriseModelInvokeLLMApi(Resource):
@setup_required
@inner_api_only
def post(self):
pass
api.add_resource(EnterpriseModelInvokeLLMApi, '/model/invoke/llm')
\ No newline at end of file
request_parser = reqparse.RequestParser()
request_parser.add_argument('tenant_id', type=str, required=True, nullable=False, location='json')
request_parser.add_argument('provider', type=str, required=True, nullable=False, location='json')
request_parser.add_argument('model', type=str, required=True, nullable=False, location='json')
request_parser.add_argument('completion_params', type=dict, required=True, nullable=False, location='json')
request_parser.add_argument('prompt_messages', type=list, required=True, nullable=False, location='json')
request_parser.add_argument('tools', type=list, required=False, nullable=True, location='json')
request_parser.add_argument('stop', type=list, required=False, nullable=True, location='json')
request_parser.add_argument('stream', type=bool, required=False, nullable=True, location='json')
request_parser.add_argument('user', type=str, required=False, nullable=True, location='json')
args = request_parser.parse_args()
response = CompletionService.invoke_model(
tenant_id=args['tenant_id'],
provider=args['provider'],
model=args['model'],
completion_params=args['completion_params'],
prompt_messages=args['prompt_messages'],
tools=args['tools'],
stop=args['stop'],
stream=args['stream'],
user=args['user'],
)
return compact_response(response)
def compact_response(response: Union[dict, Generator]) -> Response:
if isinstance(response, dict):
return Response(response=json.dumps(response), status=200, mimetype='application/json')
else:
def generate() -> Generator:
yield from response
return Response(stream_with_context(generate()), status=200,
mimetype='text/event-stream')
api.add_resource(EnterpriseModelInvokeLLMApi, '/model/invoke/llm')
from core.provider_manager import ProviderManager, ProviderModelBundle
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMUsage
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.utils.encoders import jsonable_encoder
from events.inner_event import model_was_invoked
from typing import Generator, Union, cast, Optional
class ModelRunner:
"""
Model runner
"""
@staticmethod
def run(
provider_model_bundle: ProviderModelBundle,
model: str,
prompt_messages: list[PromptMessage],
model_parameters: Optional[dict] = None,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[Generator, dict]:
"""
Run model
"""
llm_model = cast(LargeLanguageModel, provider_model_bundle.model_type_instance)
credentials = provider_model_bundle.configuration.get_current_credentials(
model_type=ModelType.LLM,
model=model,
)
if not credentials:
raise ValueError('No credentials found for model')
response = llm_model.invoke(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
)
if stream:
return ModelRunner.handle_streaming_response(
tenant_id=provider_model_bundle.configuration.tenant_id,
provider=provider_model_bundle.configuration.provider,
model=model,
model_type=ModelType.LLM.value,
response=response,
)
return ModelRunner.handle_blocking_response(
tenant_id=provider_model_bundle.configuration.tenant_id,
provider=provider_model_bundle.configuration.provider,
model=model,
model_type=ModelType.LLM.value,
response=response,
)
def handle_streaming_response(
tenant_id: str,
provider: str,
model: str,
model_type: str,
response: Generator[LLMResultChunk, None, None],
) -> Generator[dict]:
"""
Handle streaming response
"""
usage = LLMUsage.empty_usage()
for chunk in response:
if chunk.delta.usage:
usage.completion_price += chunk.delta.usage.completion_price
usage.prompt_price += chunk.delta.usage.prompt_price
usage.prompt_price_unit = chunk.delta.usage.prompt_price_unit
usage.prompt_unit_price = chunk.delta.usage.prompt_unit_price
usage.completion_price_unit = chunk.delta.usage.completion_price_unit
usage.completion_unit_price = chunk.delta.usage.completion_unit_price
usage.prompt_tokens += chunk.delta.usage.prompt_tokens
usage.completion_tokens += chunk.delta.usage.completion_tokens
usage.currency = chunk.delta.usage.currency
yield jsonable_encoder(chunk)
model_was_invoked(
None,
tenant_id=tenant_id,
model_config={
'provider': provider,
'model_type': model_type,
'model': model,
},
message_tokens=usage.prompt_tokens,
answer_tokens=usage.completion_tokens,
)
def handle_blocking_response(
tenant_id: str,
provider: str,
model: str,
model_type: str,
response: LLMResult,
) -> dict:
"""
Handle blocking response
"""
usage = response.usage or LLMUsage.empty_usage()
model_was_invoked(
None,
tenant_id=tenant_id,
model_config={
'provider': provider,
'model_type': model_type,
'model': model,
},
message_tokens=usage.prompt_tokens,
answer_tokens=usage.completion_tokens,
)
return jsonable_encoder(response)
\ No newline at end of file
......@@ -5,8 +5,22 @@ from typing import Any, Union
from sqlalchemy import and_
from core.application_manager import ApplicationManager
from core.provider_manager import ProviderManager
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.entities.application_entities import InvokeFrom
from core.entities.model_entities import ModelStatus
from core.model_runtime.entities.model_entities import ModelType
from core.file.message_file_parser import MessageFileParser
from core.model_runtime.entities.message_entities import (
PromptMessage,
UserPromptMessage,
SystemPromptMessage,
AssistantPromptMessage,
ToolPromptMessage,
PromptMessageRole,
PromptMessageTool
)
from core.app_runner.model_runner import ModelRunner
from extensions.ext_database import db
from models.model import Account, App, AppModelConfig, Conversation, EndUser, Message
from services.app_model_config_service import AppModelConfigService
......@@ -15,7 +29,6 @@ from services.errors.app_model_config import AppModelConfigBrokenError
from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError
from services.errors.message import MessageNotExistsError
class CompletionService:
@classmethod
......@@ -256,3 +269,90 @@ class CompletionService:
filtered_inputs[variable] = value.replace('\x00', '') if value else None
return filtered_inputs
@staticmethod
def invoke_model(
tenant_id: str, provider: str, model: str,
completion_params: dict,
prompt_messages: list[dict],
tools: list[dict],
stop: list[str],
stream: bool,
user: str
) -> Union[Generator, dict]:
"""
invoke model
:param tenant_id: the tenant id
:param provider: the provider
:param model: the model
:param mode: the mode
:param completion_params: the completion params
:param prompt_messages: the prompt messages
:param stream: the stream
:return: the model result
"""
converted_prompt_messages: list[PromptMessage] = []
for message in prompt_messages:
role = message.get('role')
if not role:
raise ValueError('role is required')
if role == PromptMessageRole.USER.value:
converted_prompt_messages.append(UserPromptMessage(content=message['content']))
elif role == PromptMessageRole.ASSISTANT.value:
converted_prompt_messages.append(AssistantPromptMessage(
content=message['content'],
tool_calls=message.get('tool_calls', [])
))
elif role == PromptMessageRole.SYSTEM.value:
converted_prompt_messages.append(SystemPromptMessage(content=message['content']))
elif role == PromptMessageRole.TOOL.value:
converted_prompt_messages.append(ToolPromptMessage(
content=message['content'],
tool_call_id=message['tool_call_id']
))
else:
raise ValueError(f'Unknown role: {role}')
# check if the model is available
bundle = ProviderManager().get_provider_model_bundle(
tenant_id=tenant_id,
provider=provider,
model_type=ModelType.LLM,
)
provider_model = bundle.configuration.get_provider_model(
model_type=ModelType.LLM,
model=model,
)
if not provider_model:
raise ModelCurrentlyNotSupportError(f"Could not find model {model} in provider {provider}.")
if provider_model.status == ModelStatus.NO_CONFIGURE:
raise ProviderTokenNotInitError(f"Model {model} credentials is not initialized.")
elif provider_model.status == ModelStatus.NO_PERMISSION:
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model} currently not support.")
if provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise QuotaExceededError(f"Model provider {provider} quota exceeded.")
converted_tools = []
for tool in tools:
converted_tools.append(PromptMessageTool(
name=tool['name'],
description=tool['description'],
parameters=tool['parameters']
))
# invoke model
return ModelRunner.run(
provider_model_bundle=bundle,
model=model,
prompt_messages=converted_prompt_messages,
model_parameters=completion_params,
tools=converted_tools,
stop=stop,
stream=stream,
user=user
)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment