Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
44eefb20
Commit
44eefb20
authored
Jul 23, 2023
by
John Wang
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat: add function call feature & embedding batch size of azure models
parent
a03817be
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
50 additions
and
8 deletions
+50
-8
azure_provider.py
api/core/llm/provider/azure_provider.py
+3
-2
streamable_azure_chat_open_ai.py
api/core/llm/streamable_azure_chat_open_ai.py
+43
-2
orchestrator_rule_parser.py
api/core/orchestrator_rule_parser.py
+3
-3
requirements.txt
api/requirements.txt
+1
-1
No files found.
api/core/llm/provider/azure_provider.py
View file @
44eefb20
...
@@ -9,7 +9,7 @@ from core.llm.provider.errors import ValidateFailedError
...
@@ -9,7 +9,7 @@ from core.llm.provider.errors import ValidateFailedError
from
models.provider
import
ProviderName
from
models.provider
import
ProviderName
AZURE_OPENAI_API_VERSION
=
'2023-0
6
-01-preview'
AZURE_OPENAI_API_VERSION
=
'2023-0
7
-01-preview'
class
AzureProvider
(
BaseProvider
):
class
AzureProvider
(
BaseProvider
):
...
@@ -45,9 +45,10 @@ class AzureProvider(BaseProvider):
...
@@ -45,9 +45,10 @@ class AzureProvider(BaseProvider):
"""
"""
config
=
self
.
get_provider_api_key
(
model_id
=
model_id
)
config
=
self
.
get_provider_api_key
(
model_id
=
model_id
)
config
[
'openai_api_type'
]
=
'azure'
config
[
'openai_api_type'
]
=
'azure'
config
[
'openai_api_version'
]
=
AZURE_OPENAI_API_VERSION
if
model_id
==
'text-embedding-ada-002'
:
if
model_id
==
'text-embedding-ada-002'
:
config
[
'deployment'
]
=
model_id
.
replace
(
'.'
,
''
)
if
model_id
else
None
config
[
'deployment'
]
=
model_id
.
replace
(
'.'
,
''
)
if
model_id
else
None
config
[
'chunk_size'
]
=
1
config
[
'chunk_size'
]
=
1
6
else
:
else
:
config
[
'deployment_name'
]
=
model_id
.
replace
(
'.'
,
''
)
if
model_id
else
None
config
[
'deployment_name'
]
=
model_id
.
replace
(
'.'
,
''
)
if
model_id
else
None
return
config
return
config
...
...
api/core/llm/streamable_azure_chat_open_ai.py
View file @
44eefb20
from
langchain.callbacks.manager
import
Callbacks
from
langchain.callbacks.manager
import
Callbacks
,
CallbackManagerForLLMRun
from
langchain.schema
import
BaseMessage
,
LLMResult
from
langchain.chat_models.openai
import
_convert_dict_to_message
from
langchain.schema
import
BaseMessage
,
LLMResult
,
ChatResult
,
ChatGeneration
from
langchain.chat_models
import
AzureChatOpenAI
from
langchain.chat_models
import
AzureChatOpenAI
from
typing
import
Optional
,
List
,
Dict
,
Any
from
typing
import
Optional
,
List
,
Dict
,
Any
...
@@ -71,3 +72,43 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI):
...
@@ -71,3 +72,43 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI):
params
[
'model_kwargs'
]
=
model_kwargs
params
[
'model_kwargs'
]
=
model_kwargs
return
params
return
params
def
_generate
(
self
,
messages
:
List
[
BaseMessage
],
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
**
kwargs
:
Any
,
)
->
ChatResult
:
message_dicts
,
params
=
self
.
_create_message_dicts
(
messages
,
stop
)
params
=
{
**
params
,
**
kwargs
}
if
self
.
streaming
:
inner_completion
=
""
role
=
"assistant"
params
[
"stream"
]
=
True
function_call
:
Optional
[
dict
]
=
None
for
stream_resp
in
self
.
completion_with_retry
(
messages
=
message_dicts
,
**
params
):
if
len
(
stream_resp
[
"choices"
])
>
0
:
role
=
stream_resp
[
"choices"
][
0
][
"delta"
]
.
get
(
"role"
,
role
)
token
=
stream_resp
[
"choices"
][
0
][
"delta"
]
.
get
(
"content"
)
or
""
inner_completion
+=
token
_function_call
=
stream_resp
[
"choices"
][
0
][
"delta"
]
.
get
(
"function_call"
)
if
_function_call
:
if
function_call
is
None
:
function_call
=
_function_call
else
:
function_call
[
"arguments"
]
+=
_function_call
[
"arguments"
]
if
run_manager
:
run_manager
.
on_llm_new_token
(
token
)
message
=
_convert_dict_to_message
(
{
"content"
:
inner_completion
,
"role"
:
role
,
"function_call"
:
function_call
,
}
)
return
ChatResult
(
generations
=
[
ChatGeneration
(
message
=
message
)])
response
=
self
.
completion_with_retry
(
messages
=
message_dicts
,
**
params
)
return
self
.
_create_chat_result
(
response
)
api/core/orchestrator_rule_parser.py
View file @
44eefb20
...
@@ -3,6 +3,7 @@ from typing import Optional
...
@@ -3,6 +3,7 @@ from typing import Optional
from
langchain
import
WikipediaAPIWrapper
from
langchain
import
WikipediaAPIWrapper
from
langchain.callbacks.manager
import
Callbacks
from
langchain.callbacks.manager
import
Callbacks
from
langchain.chat_models
import
ChatOpenAI
from
langchain.memory.chat_memory
import
BaseChatMemory
from
langchain.memory.chat_memory
import
BaseChatMemory
from
langchain.tools
import
BaseTool
,
Tool
,
WikipediaQueryRun
from
langchain.tools
import
BaseTool
,
Tool
,
WikipediaQueryRun
from
pydantic
import
BaseModel
,
Field
from
pydantic
import
BaseModel
,
Field
...
@@ -15,7 +16,6 @@ from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHan
...
@@ -15,7 +16,6 @@ from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHan
from
core.chain.sensitive_word_avoidance_chain
import
SensitiveWordAvoidanceChain
from
core.chain.sensitive_word_avoidance_chain
import
SensitiveWordAvoidanceChain
from
core.conversation_message_task
import
ConversationMessageTask
from
core.conversation_message_task
import
ConversationMessageTask
from
core.llm.llm_builder
import
LLMBuilder
from
core.llm.llm_builder
import
LLMBuilder
from
core.llm.streamable_chat_open_ai
import
StreamableChatOpenAI
from
core.tool.dataset_retriever_tool
import
DatasetRetrieverTool
from
core.tool.dataset_retriever_tool
import
DatasetRetrieverTool
from
core.tool.provider.serpapi_provider
import
SerpAPIToolProvider
from
core.tool.provider.serpapi_provider
import
SerpAPIToolProvider
from
core.tool.serpapi_wrapper
import
OptimizedSerpAPIWrapper
,
OptimizedSerpAPIInput
from
core.tool.serpapi_wrapper
import
OptimizedSerpAPIWrapper
,
OptimizedSerpAPIInput
...
@@ -64,8 +64,8 @@ class OrchestratorRuleParser:
...
@@ -64,8 +64,8 @@ class OrchestratorRuleParser:
planning_strategy
=
PlanningStrategy
(
agent_mode_config
.
get
(
'strategy'
,
'router'
))
planning_strategy
=
PlanningStrategy
(
agent_mode_config
.
get
(
'strategy'
,
'router'
))
# only OpenAI chat model support function call, use ReACT instead
# only OpenAI chat model
(include Azure)
support function call, use ReACT instead
if
not
isinstance
(
agent_llm
,
Streamable
ChatOpenAI
)
\
if
not
isinstance
(
agent_llm
,
ChatOpenAI
)
\
and
planning_strategy
in
[
PlanningStrategy
.
FUNCTION_CALL
,
PlanningStrategy
.
MULTI_FUNCTION_CALL
]:
and
planning_strategy
in
[
PlanningStrategy
.
FUNCTION_CALL
,
PlanningStrategy
.
MULTI_FUNCTION_CALL
]:
planning_strategy
=
PlanningStrategy
.
REACT
planning_strategy
=
PlanningStrategy
.
REACT
...
...
api/requirements.txt
View file @
44eefb20
...
@@ -10,7 +10,7 @@ flask-session2==1.3.1
...
@@ -10,7 +10,7 @@ flask-session2==1.3.1
flask-cors==3.0.10
flask-cors==3.0.10
gunicorn~=20.1.0
gunicorn~=20.1.0
gevent~=22.10.2
gevent~=22.10.2
langchain==0.0.23
0
langchain==0.0.23
9
openai~=0.27.8
openai~=0.27.8
psycopg2-binary~=2.9.6
psycopg2-binary~=2.9.6
pycryptodome==3.17
pycryptodome==3.17
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment