Commit 707b5b72 authored by John Wang's avatar John Wang

feat: complete universal chat completion api

parent aa13b8db
...@@ -16,7 +16,7 @@ def universal_chat_app_required(view=None): ...@@ -16,7 +16,7 @@ def universal_chat_app_required(view=None):
# get universal chat app # get universal chat app
universal_app = db.session.query(App).filter( universal_app = db.session.query(App).filter(
App.tenant_id == current_user.current_tenant_id, App.tenant_id == current_user.current_tenant_id,
App.is_universal is True App.is_universal == True
).first() ).first()
if universal_app is None: if universal_app is None:
...@@ -28,9 +28,10 @@ def universal_chat_app_required(view=None): ...@@ -28,9 +28,10 @@ def universal_chat_app_required(view=None):
is_universal=True, is_universal=True,
icon='', icon='',
icon_background='', icon_background='',
description='Universal Chat',
api_rpm=0, api_rpm=0,
api_rph=0, api_rph=0,
enable_site=False,
enable_api=False,
status='normal' status='normal'
) )
...@@ -60,13 +61,16 @@ def universal_chat_app_required(view=None): ...@@ -60,13 +61,16 @@ def universal_chat_app_required(view=None):
}), }),
user_input_form=json.dumps([]), user_input_form=json.dumps([]),
pre_prompt=None, pre_prompt=None,
agent_mode=json.dumps({"enabled": True, "strategy": None, "tools": []}), agent_mode=json.dumps({"enabled": True, "strategy": "function_call", "tools": []}),
) )
app_model_config.app_id = universal_app.id app_model_config.app_id = universal_app.id
db.session.add(app_model_config) db.session.add(app_model_config)
db.session.flush() db.session.flush()
universal_app.app_model_config_id = app_model_config.id
db.session.commit()
return view(universal_app, *args, **kwargs) return view(universal_app, *args, **kwargs)
return decorated return decorated
......
...@@ -5,6 +5,8 @@ from langchain.base_language import BaseLanguageModel ...@@ -5,6 +5,8 @@ from langchain.base_language import BaseLanguageModel
from langchain.chat_models.openai import ChatOpenAI from langchain.chat_models.openai import ChatOpenAI
from langchain.schema import BaseMessage from langchain.schema import BaseMessage
from core.constant import llm_constant
class CalcTokenMixin: class CalcTokenMixin:
...@@ -21,7 +23,7 @@ class CalcTokenMixin: ...@@ -21,7 +23,7 @@ class CalcTokenMixin:
:return: :return:
""" """
llm = cast(ChatOpenAI, llm) llm = cast(ChatOpenAI, llm)
llm_max_tokens = OpenAI.modelname_to_contextsize(llm.model_name) llm_max_tokens = llm_constant.max_context_token_length[llm.model_name]
completion_max_tokens = llm.max_tokens completion_max_tokens = llm.max_tokens
used_tokens = self.get_num_tokens_from_messages(llm, messages, **kwargs) used_tokens = self.get_num_tokens_from_messages(llm, messages, **kwargs)
rest_tokens = llm_max_tokens - completion_max_tokens - used_tokens rest_tokens = llm_max_tokens - completion_max_tokens - used_tokens
......
...@@ -73,7 +73,6 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): ...@@ -73,7 +73,6 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
), ),
**kwargs: Any, **kwargs: Any,
) -> BaseSingleActionAgent: ) -> BaseSingleActionAgent:
tools = [t for t in tools if isinstance(t, DatasetRetrieverTool)]
llm.model_name = 'gpt-3.5-turbo' llm.model_name = 'gpt-3.5-turbo'
return super().from_llm_and_tools( return super().from_llm_and_tools(
llm=llm, llm=llm,
......
import json
import re
from typing import Union
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser as LCStructuredChatOutputParser, \
logger
from langchain.schema import AgentAction, AgentFinish, OutputParserException
class StructuredChatOutputParser(LCStructuredChatOutputParser):
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
try:
action_match = re.search(r"```(json)?(.*?)```?", text, re.DOTALL)
if action_match is not None:
response = json.loads(action_match.group(2).strip(), strict=False)
if isinstance(response, list):
# gpt turbo frequently ignores the directive to emit a single action
logger.warning("Got multiple action responses: %s", response)
response = response[0]
if response["action"] == "Final Answer":
return AgentFinish({"output": response["action_input"]}, text)
else:
return AgentAction(
response["action"], response.get("action_input", {}), text
)
else:
return AgentFinish({"output": text}, text)
except Exception as e:
raise OutputParserException(f"Could not parse LLM output: {text}") from e
...@@ -11,9 +11,12 @@ from pydantic import BaseModel, Extra ...@@ -11,9 +11,12 @@ from pydantic import BaseModel, Extra
from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
from core.agent.agent.openai_multi_function_call import AutoSummarizingOpenMultiAIFunctionCallAgent from core.agent.agent.openai_multi_function_call import AutoSummarizingOpenMultiAIFunctionCallAgent
from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
from langchain.agents import AgentExecutor as LCAgentExecutor from langchain.agents import AgentExecutor as LCAgentExecutor
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
class PlanningStrategy(str, enum.Enum): class PlanningStrategy(str, enum.Enum):
ROUTER = 'router' ROUTER = 'router'
...@@ -44,6 +47,7 @@ class AgentConfiguration(BaseModel): ...@@ -44,6 +47,7 @@ class AgentConfiguration(BaseModel):
class AgentExecuteResult(BaseModel): class AgentExecuteResult(BaseModel):
strategy: PlanningStrategy strategy: PlanningStrategy
output: str output: str
configuration: AgentConfiguration
class AgentExecutor: class AgentExecutor:
...@@ -56,6 +60,7 @@ class AgentExecutor: ...@@ -56,6 +60,7 @@ class AgentExecutor:
agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools( agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
llm=self.configuration.llm, llm=self.configuration.llm,
tools=self.configuration.tools, tools=self.configuration.tools,
output_parser=StructuredChatOutputParser(),
summary_llm=self.configuration.summary_llm, summary_llm=self.configuration.summary_llm,
verbose=True verbose=True
) )
...@@ -76,6 +81,7 @@ class AgentExecutor: ...@@ -76,6 +81,7 @@ class AgentExecutor:
verbose=True verbose=True
) )
elif self.configuration.strategy == PlanningStrategy.ROUTER: elif self.configuration.strategy == PlanningStrategy.ROUTER:
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
agent = MultiDatasetRouterAgent.from_llm_and_tools( agent = MultiDatasetRouterAgent.from_llm_and_tools(
llm=self.configuration.llm, llm=self.configuration.llm,
tools=self.configuration.tools, tools=self.configuration.tools,
...@@ -105,5 +111,6 @@ class AgentExecutor: ...@@ -105,5 +111,6 @@ class AgentExecutor:
return AgentExecuteResult( return AgentExecuteResult(
output=output, output=output,
strategy=self.configuration.strategy strategy=self.configuration.strategy,
configuration=self.configuration
) )
...@@ -65,7 +65,8 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): ...@@ -65,7 +65,8 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
# kwargs={} # kwargs={}
if self._current_loop and self._current_loop.status == 'llm_started': if self._current_loop and self._current_loop.status == 'llm_started':
self._current_loop.status = 'llm_end' self._current_loop.status = 'llm_end'
self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens'] if response.llm_output:
self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
completion_generation = response.generations[0][0] completion_generation = response.generations[0][0]
if isinstance(completion_generation, ChatGeneration): if isinstance(completion_generation, ChatGeneration):
completion_message = completion_generation.message completion_message = completion_generation.message
...@@ -77,7 +78,8 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): ...@@ -77,7 +78,8 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
else: else:
self._current_loop.completion = completion_generation.text self._current_loop.completion = completion_generation.text
self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens'] if response.llm_output:
self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
def on_llm_error( def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
......
...@@ -10,9 +10,9 @@ class AgentLoop(BaseModel): ...@@ -10,9 +10,9 @@ class AgentLoop(BaseModel):
tool_output: str = None tool_output: str = None
prompt: str = None prompt: str = None
prompt_tokens: int = None prompt_tokens: int = 0
completion: str = None completion: str = None
completion_tokens: int = None completion_tokens: int = 0
latency: float = None latency: float = None
......
...@@ -67,9 +67,8 @@ class LLMCallbackHandler(BaseCallbackHandler): ...@@ -67,9 +67,8 @@ class LLMCallbackHandler(BaseCallbackHandler):
if not self.conversation_message_task.streaming: if not self.conversation_message_task.streaming:
self.conversation_message_task.append_message_text(response.generations[0][0].text) self.conversation_message_task.append_message_text(response.generations[0][0].text)
self.llm_message.completion = response.generations[0][0].text self.llm_message.completion = response.generations[0][0].text
self.llm_message.completion_tokens = response.llm_output['token_usage']['completion_tokens']
else: self.llm_message.completion_tokens = self.llm.get_num_tokens(self.llm_message.completion)
self.llm_message.completion_tokens = self.llm.get_num_tokens(self.llm_message.completion)
self.conversation_message_task.save_message(self.llm_message) self.conversation_message_task.save_message(self.llm_message)
......
...@@ -128,7 +128,9 @@ class Completion: ...@@ -128,7 +128,9 @@ class Completion:
# the output of the agent can be used directly as the main output content without calling LLM again # the output of the agent can be used directly as the main output content without calling LLM again
if not app_model_config.pre_prompt and agent_execute_result \ if not app_model_config.pre_prompt and agent_execute_result \
and agent_execute_result.strategy != PlanningStrategy.ROUTER: and agent_execute_result.strategy != PlanningStrategy.ROUTER:
final_llm = FakeLLM(response=agent_execute_result.output, streaming=streaming) final_llm = FakeLLM(response=agent_execute_result.output,
origin_llm=agent_execute_result.configuration.llm,
streaming=streaming)
final_llm.callbacks = cls.get_llm_callbacks(final_llm, streaming, conversation_message_task) final_llm.callbacks = cls.get_llm_callbacks(final_llm, streaming, conversation_message_task)
response = final_llm.generate([[HumanMessage(content=query)]]) response = final_llm.generate([[HumanMessage(content=query)]])
return response return response
......
...@@ -3,7 +3,7 @@ from typing import List, Optional, Any, Mapping ...@@ -3,7 +3,7 @@ from typing import List, Optional, Any, Mapping
from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.chat_models.base import SimpleChatModel from langchain.chat_models.base import SimpleChatModel
from langchain.schema import BaseMessage, ChatResult, AIMessage, ChatGeneration from langchain.schema import BaseMessage, ChatResult, AIMessage, ChatGeneration, BaseLanguageModel
class FakeLLM(SimpleChatModel): class FakeLLM(SimpleChatModel):
...@@ -12,6 +12,7 @@ class FakeLLM(SimpleChatModel): ...@@ -12,6 +12,7 @@ class FakeLLM(SimpleChatModel):
streaming: bool = False streaming: bool = False
"""Whether to stream the results or not.""" """Whether to stream the results or not."""
response: str response: str
origin_llm: Optional[BaseLanguageModel] = None
@property @property
def _llm_type(self) -> str: def _llm_type(self) -> str:
...@@ -32,10 +33,7 @@ class FakeLLM(SimpleChatModel): ...@@ -32,10 +33,7 @@ class FakeLLM(SimpleChatModel):
return {"response": self.response} return {"response": self.response}
def get_num_tokens(self, text: str) -> int: def get_num_tokens(self, text: str) -> int:
return 0 return self.origin_llm.get_num_tokens(text) if self.origin_llm else 0
def get_messages_tokens(self, messages: List[BaseMessage]) -> int:
return 0
def _generate( def _generate(
self, self,
......
...@@ -3,6 +3,7 @@ from typing import List, Optional, Any, Dict ...@@ -3,6 +3,7 @@ from typing import List, Optional, Any, Dict
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.chat_models import ChatAnthropic from langchain.chat_models import ChatAnthropic
from langchain.schema import BaseMessage, LLMResult from langchain.schema import BaseMessage, LLMResult
from pydantic import root_validator
from core.llm.wrappers.anthropic_wrapper import handle_anthropic_exceptions from core.llm.wrappers.anthropic_wrapper import handle_anthropic_exceptions
...@@ -12,6 +13,12 @@ class StreamableChatAnthropic(ChatAnthropic): ...@@ -12,6 +13,12 @@ class StreamableChatAnthropic(ChatAnthropic):
Wrapper around Anthropic's large language model. Wrapper around Anthropic's large language model.
""" """
@root_validator()
def prepare_params(cls, values: Dict) -> Dict:
values['model_name'] = values.get('model')
values['max_tokens'] = values.get('max_tokens_to_sample')
return values
@handle_anthropic_exceptions @handle_anthropic_exceptions
def generate( def generate(
self, self,
......
...@@ -71,7 +71,7 @@ class DatasetRetrieverTool(BaseTool): ...@@ -71,7 +71,7 @@ class DatasetRetrieverTool(BaseTool):
else: else:
model_credentials = LLMBuilder.get_model_credentials( model_credentials = LLMBuilder.get_model_credentials(
tenant_id=dataset.tenant_id, tenant_id=dataset.tenant_id,
model_provider=LLMBuilder.get_default_provider(dataset.tenant_id), model_provider=LLMBuilder.get_default_provider(dataset.tenant_id, 'text-embedding-ada-002'),
model_name='text-embedding-ada-002' model_name='text-embedding-ada-002'
) )
......
...@@ -119,7 +119,10 @@ def get_url(url: str) -> str: ...@@ -119,7 +119,10 @@ def get_url(url: str) -> str:
head_response = requests.head(url, headers=headers, allow_redirects=True) head_response = requests.head(url, headers=headers, allow_redirects=True)
# 检查响应的Content-Type头部是否在支持的类型范围内 if head_response.status_code != 200:
return "URL returned status code {}.".format(head_response.status_code)
# check content-type
main_content_type = head_response.headers.get('Content-Type').split(';')[0].strip() main_content_type = head_response.headers.get('Content-Type').split(';')[0].strip()
if main_content_type not in supported_content_types: if main_content_type not in supported_content_types:
return "Unsupported content-type [{}] of URL.".format(main_content_type) return "Unsupported content-type [{}] of URL.".format(main_content_type)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment