Commit 44eefb20 authored by John Wang's avatar John Wang

feat: add function call feature & embedding batch size of azure models

parent a03817be
...@@ -9,7 +9,7 @@ from core.llm.provider.errors import ValidateFailedError ...@@ -9,7 +9,7 @@ from core.llm.provider.errors import ValidateFailedError
from models.provider import ProviderName from models.provider import ProviderName
AZURE_OPENAI_API_VERSION = '2023-06-01-preview' AZURE_OPENAI_API_VERSION = '2023-07-01-preview'
class AzureProvider(BaseProvider): class AzureProvider(BaseProvider):
...@@ -45,9 +45,10 @@ class AzureProvider(BaseProvider): ...@@ -45,9 +45,10 @@ class AzureProvider(BaseProvider):
""" """
config = self.get_provider_api_key(model_id=model_id) config = self.get_provider_api_key(model_id=model_id)
config['openai_api_type'] = 'azure' config['openai_api_type'] = 'azure'
config['openai_api_version'] = AZURE_OPENAI_API_VERSION
if model_id == 'text-embedding-ada-002': if model_id == 'text-embedding-ada-002':
config['deployment'] = model_id.replace('.', '') if model_id else None config['deployment'] = model_id.replace('.', '') if model_id else None
config['chunk_size'] = 1 config['chunk_size'] = 16
else: else:
config['deployment_name'] = model_id.replace('.', '') if model_id else None config['deployment_name'] = model_id.replace('.', '') if model_id else None
return config return config
......
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks, CallbackManagerForLLMRun
from langchain.schema import BaseMessage, LLMResult from langchain.chat_models.openai import _convert_dict_to_message
from langchain.schema import BaseMessage, LLMResult, ChatResult, ChatGeneration
from langchain.chat_models import AzureChatOpenAI from langchain.chat_models import AzureChatOpenAI
from typing import Optional, List, Dict, Any from typing import Optional, List, Dict, Any
...@@ -71,3 +72,43 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI): ...@@ -71,3 +72,43 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI):
params['model_kwargs'] = model_kwargs params['model_kwargs'] = model_kwargs
return params return params
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
if self.streaming:
inner_completion = ""
role = "assistant"
params["stream"] = True
function_call: Optional[dict] = None
for stream_resp in self.completion_with_retry(
messages=message_dicts, **params
):
if len(stream_resp["choices"]) > 0:
role = stream_resp["choices"][0]["delta"].get("role", role)
token = stream_resp["choices"][0]["delta"].get("content") or ""
inner_completion += token
_function_call = stream_resp["choices"][0]["delta"].get("function_call")
if _function_call:
if function_call is None:
function_call = _function_call
else:
function_call["arguments"] += _function_call["arguments"]
if run_manager:
run_manager.on_llm_new_token(token)
message = _convert_dict_to_message(
{
"content": inner_completion,
"role": role,
"function_call": function_call,
}
)
return ChatResult(generations=[ChatGeneration(message=message)])
response = self.completion_with_retry(messages=message_dicts, **params)
return self._create_chat_result(response)
...@@ -3,6 +3,7 @@ from typing import Optional ...@@ -3,6 +3,7 @@ from typing import Optional
from langchain import WikipediaAPIWrapper from langchain import WikipediaAPIWrapper
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.chat_models import ChatOpenAI
from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.chat_memory import BaseChatMemory
from langchain.tools import BaseTool, Tool, WikipediaQueryRun from langchain.tools import BaseTool, Tool, WikipediaQueryRun
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
...@@ -15,7 +16,6 @@ from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHan ...@@ -15,7 +16,6 @@ from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHan
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain
from core.conversation_message_task import ConversationMessageTask from core.conversation_message_task import ConversationMessageTask
from core.llm.llm_builder import LLMBuilder from core.llm.llm_builder import LLMBuilder
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
from core.tool.dataset_retriever_tool import DatasetRetrieverTool from core.tool.dataset_retriever_tool import DatasetRetrieverTool
from core.tool.provider.serpapi_provider import SerpAPIToolProvider from core.tool.provider.serpapi_provider import SerpAPIToolProvider
from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper, OptimizedSerpAPIInput from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper, OptimizedSerpAPIInput
...@@ -64,8 +64,8 @@ class OrchestratorRuleParser: ...@@ -64,8 +64,8 @@ class OrchestratorRuleParser:
planning_strategy = PlanningStrategy(agent_mode_config.get('strategy', 'router')) planning_strategy = PlanningStrategy(agent_mode_config.get('strategy', 'router'))
# only OpenAI chat model support function call, use ReACT instead # only OpenAI chat model (include Azure) support function call, use ReACT instead
if not isinstance(agent_llm, StreamableChatOpenAI) \ if not isinstance(agent_llm, ChatOpenAI) \
and planning_strategy in [PlanningStrategy.FUNCTION_CALL, PlanningStrategy.MULTI_FUNCTION_CALL]: and planning_strategy in [PlanningStrategy.FUNCTION_CALL, PlanningStrategy.MULTI_FUNCTION_CALL]:
planning_strategy = PlanningStrategy.REACT planning_strategy = PlanningStrategy.REACT
......
...@@ -10,7 +10,7 @@ flask-session2==1.3.1 ...@@ -10,7 +10,7 @@ flask-session2==1.3.1
flask-cors==3.0.10 flask-cors==3.0.10
gunicorn~=20.1.0 gunicorn~=20.1.0
gevent~=22.10.2 gevent~=22.10.2
langchain==0.0.230 langchain==0.0.239
openai~=0.27.8 openai~=0.27.8
psycopg2-binary~=2.9.6 psycopg2-binary~=2.9.6
pycryptodome==3.17 pycryptodome==3.17
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment