Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
D
dify
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ai-tech
dify
Commits
f9082104
Unverified
Commit
f9082104
authored
Sep 12, 2023
by
takatost
Committed by
GitHub
Sep 12, 2023
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat: add hosted moderation (#1158)
parent
983834cd
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
241 additions
and
70 deletions
+241
-70
config.py
api/config.py
+5
-0
agent_executor.py
api/core/agent/agent_executor.py
+17
-1
agent_loop_gather_callback_handler.py
...re/callback_handler/agent_loop_gather_callback_handler.py
+22
-7
llm_message.py
api/core/callback_handler/entity/llm_message.py
+0
-1
llm_callback_handler.py
api/core/callback_handler/llm_callback_handler.py
+0
-9
sensitive_word_avoidance_chain.py
api/core/chain/sensitive_word_avoidance_chain.py
+52
-8
completion.py
api/core/completion.py
+3
-6
conversation_message_task.py
api/core/conversation_message_task.py
+13
-10
moderation.py
api/core/helper/moderation.py
+32
-0
model_factory.py
api/core/model_providers/model_factory.py
+2
-1
base.py
api/core/model_providers/models/llm/base.py
+10
-0
base.py
api/core/model_providers/models/moderation/base.py
+29
-0
openai_moderation.py
...re/model_providers/models/moderation/openai_moderation.py
+18
-12
orchestrator_rule_parser.py
api/core/orchestrator_rule_parser.py
+35
-11
test_openai_moderation.py
...gration_tests/models/moderation/test_openai_moderation.py
+3
-4
No files found.
api/config.py
View file @
f9082104
...
@@ -61,6 +61,8 @@ DEFAULTS = {
...
@@ -61,6 +61,8 @@ DEFAULTS = {
'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA'
:
1000000
,
'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA'
:
1000000
,
'HOSTED_ANTHROPIC_PAID_MIN_QUANTITY'
:
20
,
'HOSTED_ANTHROPIC_PAID_MIN_QUANTITY'
:
20
,
'HOSTED_ANTHROPIC_PAID_MAX_QUANTITY'
:
100
,
'HOSTED_ANTHROPIC_PAID_MAX_QUANTITY'
:
100
,
'HOSTED_MODERATION_ENABLED'
:
'False'
,
'HOSTED_MODERATION_PROVIDERS'
:
''
,
'TENANT_DOCUMENT_COUNT'
:
100
,
'TENANT_DOCUMENT_COUNT'
:
100
,
'CLEAN_DAY_SETTING'
:
30
,
'CLEAN_DAY_SETTING'
:
30
,
'UPLOAD_FILE_SIZE_LIMIT'
:
15
,
'UPLOAD_FILE_SIZE_LIMIT'
:
15
,
...
@@ -230,6 +232,9 @@ class Config:
...
@@ -230,6 +232,9 @@ class Config:
self
.
HOSTED_ANTHROPIC_PAID_MIN_QUANTITY
=
int
(
get_env
(
'HOSTED_ANTHROPIC_PAID_MIN_QUANTITY'
))
self
.
HOSTED_ANTHROPIC_PAID_MIN_QUANTITY
=
int
(
get_env
(
'HOSTED_ANTHROPIC_PAID_MIN_QUANTITY'
))
self
.
HOSTED_ANTHROPIC_PAID_MAX_QUANTITY
=
int
(
get_env
(
'HOSTED_ANTHROPIC_PAID_MAX_QUANTITY'
))
self
.
HOSTED_ANTHROPIC_PAID_MAX_QUANTITY
=
int
(
get_env
(
'HOSTED_ANTHROPIC_PAID_MAX_QUANTITY'
))
self
.
HOSTED_MODERATION_ENABLED
=
get_bool_env
(
'HOSTED_MODERATION_ENABLED'
)
self
.
HOSTED_MODERATION_PROVIDERS
=
get_env
(
'HOSTED_MODERATION_PROVIDERS'
)
self
.
STRIPE_API_KEY
=
get_env
(
'STRIPE_API_KEY'
)
self
.
STRIPE_API_KEY
=
get_env
(
'STRIPE_API_KEY'
)
self
.
STRIPE_WEBHOOK_SECRET
=
get_env
(
'STRIPE_WEBHOOK_SECRET'
)
self
.
STRIPE_WEBHOOK_SECRET
=
get_env
(
'STRIPE_WEBHOOK_SECRET'
)
...
...
api/core/agent/agent_executor.py
View file @
f9082104
...
@@ -16,6 +16,8 @@ from core.agent.agent.structed_multi_dataset_router_agent import StructuredMulti
...
@@ -16,6 +16,8 @@ from core.agent.agent.structed_multi_dataset_router_agent import StructuredMulti
from
core.agent.agent.structured_chat
import
AutoSummarizingStructuredChatAgent
from
core.agent.agent.structured_chat
import
AutoSummarizingStructuredChatAgent
from
langchain.agents
import
AgentExecutor
as
LCAgentExecutor
from
langchain.agents
import
AgentExecutor
as
LCAgentExecutor
from
core.helper
import
moderation
from
core.model_providers.error
import
LLMError
from
core.model_providers.models.llm.base
import
BaseLLM
from
core.model_providers.models.llm.base
import
BaseLLM
from
core.tool.dataset_retriever_tool
import
DatasetRetrieverTool
from
core.tool.dataset_retriever_tool
import
DatasetRetrieverTool
...
@@ -116,6 +118,18 @@ class AgentExecutor:
...
@@ -116,6 +118,18 @@ class AgentExecutor:
return
self
.
agent
.
should_use_agent
(
query
)
return
self
.
agent
.
should_use_agent
(
query
)
def
run
(
self
,
query
:
str
)
->
AgentExecuteResult
:
def
run
(
self
,
query
:
str
)
->
AgentExecuteResult
:
moderation_result
=
moderation
.
check_moderation
(
self
.
configuration
.
model_instance
.
model_provider
,
query
)
if
not
moderation_result
:
return
AgentExecuteResult
(
output
=
"I apologize for any confusion, but I'm an AI assistant to be helpful, harmless, and honest."
,
strategy
=
self
.
configuration
.
strategy
,
configuration
=
self
.
configuration
)
agent_executor
=
LCAgentExecutor
.
from_agent_and_tools
(
agent_executor
=
LCAgentExecutor
.
from_agent_and_tools
(
agent
=
self
.
agent
,
agent
=
self
.
agent
,
tools
=
self
.
configuration
.
tools
,
tools
=
self
.
configuration
.
tools
,
...
@@ -128,7 +142,9 @@ class AgentExecutor:
...
@@ -128,7 +142,9 @@ class AgentExecutor:
try
:
try
:
output
=
agent_executor
.
run
(
query
)
output
=
agent_executor
.
run
(
query
)
except
Exception
:
except
LLMError
as
ex
:
raise
ex
except
Exception
as
ex
:
logging
.
exception
(
"agent_executor run failed"
)
logging
.
exception
(
"agent_executor run failed"
)
output
=
None
output
=
None
...
...
api/core/callback_handler/agent_loop_gather_callback_handler.py
View file @
f9082104
...
@@ -6,7 +6,7 @@ from typing import Any, Dict, List, Union, Optional
...
@@ -6,7 +6,7 @@ from typing import Any, Dict, List, Union, Optional
from
langchain.agents
import
openai_functions_agent
,
openai_functions_multi_agent
from
langchain.agents
import
openai_functions_agent
,
openai_functions_multi_agent
from
langchain.callbacks.base
import
BaseCallbackHandler
from
langchain.callbacks.base
import
BaseCallbackHandler
from
langchain.schema
import
AgentAction
,
AgentFinish
,
LLMResult
,
ChatGeneration
from
langchain.schema
import
AgentAction
,
AgentFinish
,
LLMResult
,
ChatGeneration
,
BaseMessage
from
core.callback_handler.entity.agent_loop
import
AgentLoop
from
core.callback_handler.entity.agent_loop
import
AgentLoop
from
core.conversation_message_task
import
ConversationMessageTask
from
core.conversation_message_task
import
ConversationMessageTask
...
@@ -18,9 +18,9 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
...
@@ -18,9 +18,9 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
"""Callback Handler that prints to std out."""
"""Callback Handler that prints to std out."""
raise_error
:
bool
=
True
raise_error
:
bool
=
True
def
__init__
(
self
,
model_instan
t
:
BaseLLM
,
conversation_message_task
:
ConversationMessageTask
)
->
None
:
def
__init__
(
self
,
model_instan
ce
:
BaseLLM
,
conversation_message_task
:
ConversationMessageTask
)
->
None
:
"""Initialize callback handler."""
"""Initialize callback handler."""
self
.
model_instan
t
=
model_instant
self
.
model_instan
ce
=
model_instance
self
.
conversation_message_task
=
conversation_message_task
self
.
conversation_message_task
=
conversation_message_task
self
.
_agent_loops
=
[]
self
.
_agent_loops
=
[]
self
.
_current_loop
=
None
self
.
_current_loop
=
None
...
@@ -46,6 +46,21 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
...
@@ -46,6 +46,21 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
"""Whether to ignore chain callbacks."""
"""Whether to ignore chain callbacks."""
return
True
return
True
def
on_chat_model_start
(
self
,
serialized
:
Dict
[
str
,
Any
],
messages
:
List
[
List
[
BaseMessage
]],
**
kwargs
:
Any
)
->
Any
:
if
not
self
.
_current_loop
:
# Agent start with a LLM query
self
.
_current_loop
=
AgentLoop
(
position
=
len
(
self
.
_agent_loops
)
+
1
,
prompt
=
"
\n
"
.
join
([
message
.
content
for
message
in
messages
[
0
]]),
status
=
'llm_started'
,
started_at
=
time
.
perf_counter
()
)
def
on_llm_start
(
def
on_llm_start
(
self
,
serialized
:
Dict
[
str
,
Any
],
prompts
:
List
[
str
],
**
kwargs
:
Any
self
,
serialized
:
Dict
[
str
,
Any
],
prompts
:
List
[
str
],
**
kwargs
:
Any
)
->
None
:
)
->
None
:
...
@@ -70,7 +85,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
...
@@ -70,7 +85,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
if
response
.
llm_output
:
if
response
.
llm_output
:
self
.
_current_loop
.
prompt_tokens
=
response
.
llm_output
[
'token_usage'
][
'prompt_tokens'
]
self
.
_current_loop
.
prompt_tokens
=
response
.
llm_output
[
'token_usage'
][
'prompt_tokens'
]
else
:
else
:
self
.
_current_loop
.
prompt_tokens
=
self
.
model_instan
t
.
get_num_tokens
(
self
.
_current_loop
.
prompt_tokens
=
self
.
model_instan
ce
.
get_num_tokens
(
[
PromptMessage
(
content
=
self
.
_current_loop
.
prompt
)]
[
PromptMessage
(
content
=
self
.
_current_loop
.
prompt
)]
)
)
completion_generation
=
response
.
generations
[
0
][
0
]
completion_generation
=
response
.
generations
[
0
][
0
]
...
@@ -87,7 +102,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
...
@@ -87,7 +102,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
if
response
.
llm_output
:
if
response
.
llm_output
:
self
.
_current_loop
.
completion_tokens
=
response
.
llm_output
[
'token_usage'
][
'completion_tokens'
]
self
.
_current_loop
.
completion_tokens
=
response
.
llm_output
[
'token_usage'
][
'completion_tokens'
]
else
:
else
:
self
.
_current_loop
.
completion_tokens
=
self
.
model_instan
t
.
get_num_tokens
(
self
.
_current_loop
.
completion_tokens
=
self
.
model_instan
ce
.
get_num_tokens
(
[
PromptMessage
(
content
=
self
.
_current_loop
.
completion
)]
[
PromptMessage
(
content
=
self
.
_current_loop
.
completion
)]
)
)
...
@@ -162,7 +177,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
...
@@ -162,7 +177,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self
.
_current_loop
.
latency
=
self
.
_current_loop
.
completed_at
-
self
.
_current_loop
.
started_at
self
.
_current_loop
.
latency
=
self
.
_current_loop
.
completed_at
-
self
.
_current_loop
.
started_at
self
.
conversation_message_task
.
on_agent_end
(
self
.
conversation_message_task
.
on_agent_end
(
self
.
_message_agent_thought
,
self
.
model_instan
t
,
self
.
_current_loop
self
.
_message_agent_thought
,
self
.
model_instan
ce
,
self
.
_current_loop
)
)
self
.
_agent_loops
.
append
(
self
.
_current_loop
)
self
.
_agent_loops
.
append
(
self
.
_current_loop
)
...
@@ -193,7 +208,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
...
@@ -193,7 +208,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
)
)
self
.
conversation_message_task
.
on_agent_end
(
self
.
conversation_message_task
.
on_agent_end
(
self
.
_message_agent_thought
,
self
.
model_instan
t
,
self
.
_current_loop
self
.
_message_agent_thought
,
self
.
model_instan
ce
,
self
.
_current_loop
)
)
self
.
_agent_loops
.
append
(
self
.
_current_loop
)
self
.
_agent_loops
.
append
(
self
.
_current_loop
)
...
...
api/core/callback_handler/entity/llm_message.py
View file @
f9082104
...
@@ -6,4 +6,3 @@ class LLMMessage(BaseModel):
...
@@ -6,4 +6,3 @@ class LLMMessage(BaseModel):
prompt_tokens
:
int
=
0
prompt_tokens
:
int
=
0
completion
:
str
=
''
completion
:
str
=
''
completion_tokens
:
int
=
0
completion_tokens
:
int
=
0
latency
:
float
=
0.0
api/core/callback_handler/llm_callback_handler.py
View file @
f9082104
import
logging
import
logging
import
time
from
typing
import
Any
,
Dict
,
List
,
Union
from
typing
import
Any
,
Dict
,
List
,
Union
from
langchain.callbacks.base
import
BaseCallbackHandler
from
langchain.callbacks.base
import
BaseCallbackHandler
...
@@ -32,7 +31,6 @@ class LLMCallbackHandler(BaseCallbackHandler):
...
@@ -32,7 +31,6 @@ class LLMCallbackHandler(BaseCallbackHandler):
messages
:
List
[
List
[
BaseMessage
]],
messages
:
List
[
List
[
BaseMessage
]],
**
kwargs
:
Any
**
kwargs
:
Any
)
->
Any
:
)
->
Any
:
self
.
start_at
=
time
.
perf_counter
()
real_prompts
=
[]
real_prompts
=
[]
for
message
in
messages
[
0
]:
for
message
in
messages
[
0
]:
if
message
.
type
==
'human'
:
if
message
.
type
==
'human'
:
...
@@ -53,8 +51,6 @@ class LLMCallbackHandler(BaseCallbackHandler):
...
@@ -53,8 +51,6 @@ class LLMCallbackHandler(BaseCallbackHandler):
def
on_llm_start
(
def
on_llm_start
(
self
,
serialized
:
Dict
[
str
,
Any
],
prompts
:
List
[
str
],
**
kwargs
:
Any
self
,
serialized
:
Dict
[
str
,
Any
],
prompts
:
List
[
str
],
**
kwargs
:
Any
)
->
None
:
)
->
None
:
self
.
start_at
=
time
.
perf_counter
()
self
.
llm_message
.
prompt
=
[{
self
.
llm_message
.
prompt
=
[{
"role"
:
'user'
,
"role"
:
'user'
,
"text"
:
prompts
[
0
]
"text"
:
prompts
[
0
]
...
@@ -63,9 +59,6 @@ class LLMCallbackHandler(BaseCallbackHandler):
...
@@ -63,9 +59,6 @@ class LLMCallbackHandler(BaseCallbackHandler):
self
.
llm_message
.
prompt_tokens
=
self
.
model_instance
.
get_num_tokens
([
PromptMessage
(
content
=
prompts
[
0
])])
self
.
llm_message
.
prompt_tokens
=
self
.
model_instance
.
get_num_tokens
([
PromptMessage
(
content
=
prompts
[
0
])])
def
on_llm_end
(
self
,
response
:
LLMResult
,
**
kwargs
:
Any
)
->
None
:
def
on_llm_end
(
self
,
response
:
LLMResult
,
**
kwargs
:
Any
)
->
None
:
end_at
=
time
.
perf_counter
()
self
.
llm_message
.
latency
=
end_at
-
self
.
start_at
if
not
self
.
conversation_message_task
.
streaming
:
if
not
self
.
conversation_message_task
.
streaming
:
self
.
conversation_message_task
.
append_message_text
(
response
.
generations
[
0
][
0
]
.
text
)
self
.
conversation_message_task
.
append_message_text
(
response
.
generations
[
0
][
0
]
.
text
)
self
.
llm_message
.
completion
=
response
.
generations
[
0
][
0
]
.
text
self
.
llm_message
.
completion
=
response
.
generations
[
0
][
0
]
.
text
...
@@ -89,8 +82,6 @@ class LLMCallbackHandler(BaseCallbackHandler):
...
@@ -89,8 +82,6 @@ class LLMCallbackHandler(BaseCallbackHandler):
"""Do nothing."""
"""Do nothing."""
if
isinstance
(
error
,
ConversationTaskStoppedException
):
if
isinstance
(
error
,
ConversationTaskStoppedException
):
if
self
.
conversation_message_task
.
streaming
:
if
self
.
conversation_message_task
.
streaming
:
end_at
=
time
.
perf_counter
()
self
.
llm_message
.
latency
=
end_at
-
self
.
start_at
self
.
llm_message
.
completion_tokens
=
self
.
model_instance
.
get_num_tokens
(
self
.
llm_message
.
completion_tokens
=
self
.
model_instance
.
get_num_tokens
(
[
PromptMessage
(
content
=
self
.
llm_message
.
completion
)]
[
PromptMessage
(
content
=
self
.
llm_message
.
completion
)]
)
)
...
...
api/core/chain/sensitive_word_avoidance_chain.py
View file @
f9082104
import
enum
import
logging
from
typing
import
List
,
Dict
,
Optional
,
Any
from
typing
import
List
,
Dict
,
Optional
,
Any
import
openai
from
flask
import
current_app
from
langchain.callbacks.manager
import
CallbackManagerForChainRun
from
langchain.callbacks.manager
import
CallbackManagerForChainRun
from
langchain.chains.base
import
Chain
from
langchain.chains.base
import
Chain
from
openai
import
InvalidRequestError
from
openai.error
import
APIConnectionError
,
APIError
,
ServiceUnavailableError
,
Timeout
,
RateLimitError
,
\
AuthenticationError
,
OpenAIError
from
pydantic
import
BaseModel
from
core.model_providers.error
import
LLMBadRequestError
from
core.model_providers.model_factory
import
ModelFactory
from
core.model_providers.models.llm.base
import
BaseLLM
from
core.model_providers.models.moderation
import
openai_moderation
class
SensitiveWordAvoidanceRule
(
BaseModel
):
class
Type
(
enum
.
Enum
):
MODERATION
=
"moderation"
KEYWORDS
=
"keywords"
type
:
Type
canned_response
:
str
=
'Your content violates our usage policy. Please revise and try again.'
extra_params
:
dict
=
{}
class
SensitiveWordAvoidanceChain
(
Chain
):
class
SensitiveWordAvoidanceChain
(
Chain
):
input_key
:
str
=
"input"
#: :meta private:
input_key
:
str
=
"input"
#: :meta private:
output_key
:
str
=
"output"
#: :meta private:
output_key
:
str
=
"output"
#: :meta private:
sensitive_words
:
List
[
str
]
=
[]
model_instance
:
BaseLLM
canned_response
:
str
=
Non
e
sensitive_word_avoidance_rule
:
SensitiveWordAvoidanceRul
e
@
property
@
property
def
_chain_type
(
self
)
->
str
:
def
_chain_type
(
self
)
->
str
:
...
@@ -31,11 +54,24 @@ class SensitiveWordAvoidanceChain(Chain):
...
@@ -31,11 +54,24 @@ class SensitiveWordAvoidanceChain(Chain):
"""
"""
return
[
self
.
output_key
]
return
[
self
.
output_key
]
def
_check_sensitive_word
(
self
,
text
:
str
)
->
str
:
def
_check_sensitive_word
(
self
,
text
:
str
)
->
bool
:
for
word
in
self
.
sensitive_word
s
:
for
word
in
self
.
sensitive_word
_avoidance_rule
.
extra_params
.
get
(
'sensitive_words'
,
[])
:
if
word
in
text
:
if
word
in
text
:
return
self
.
canned_response
return
False
return
text
return
True
def
_check_moderation
(
self
,
text
:
str
)
->
bool
:
moderation_model_instance
=
ModelFactory
.
get_moderation_model
(
tenant_id
=
self
.
model_instance
.
model_provider
.
provider
.
tenant_id
,
model_provider_name
=
'openai'
,
model_name
=
openai_moderation
.
DEFAULT_MODEL
)
try
:
return
moderation_model_instance
.
run
(
text
=
text
)
except
Exception
as
ex
:
logging
.
exception
(
ex
)
raise
LLMBadRequestError
(
'Rate limit exceeded, please try again later.'
)
def
_call
(
def
_call
(
self
,
self
,
...
@@ -43,5 +79,13 @@ class SensitiveWordAvoidanceChain(Chain):
...
@@ -43,5 +79,13 @@ class SensitiveWordAvoidanceChain(Chain):
run_manager
:
Optional
[
CallbackManagerForChainRun
]
=
None
,
run_manager
:
Optional
[
CallbackManagerForChainRun
]
=
None
,
)
->
Dict
[
str
,
Any
]:
)
->
Dict
[
str
,
Any
]:
text
=
inputs
[
self
.
input_key
]
text
=
inputs
[
self
.
input_key
]
output
=
self
.
_check_sensitive_word
(
text
)
return
{
self
.
output_key
:
output
}
if
self
.
sensitive_word_avoidance_rule
.
type
==
SensitiveWordAvoidanceRule
.
Type
.
KEYWORDS
:
result
=
self
.
_check_sensitive_word
(
text
)
else
:
result
=
self
.
_check_moderation
(
text
)
if
not
result
:
raise
LLMBadRequestError
(
self
.
sensitive_word_avoidance_rule
.
canned_response
)
return
{
self
.
output_key
:
text
}
api/core/completion.py
View file @
f9082104
import
json
import
json
import
logging
import
logging
import
re
from
typing
import
Optional
,
List
,
Union
from
typing
import
Optional
,
List
,
Union
,
Tuple
from
langchain.schema
import
BaseMessage
from
requests.exceptions
import
ChunkedEncodingError
from
requests.exceptions
import
ChunkedEncodingError
from
core.agent.agent_executor
import
AgentExecuteResult
,
PlanningStrategy
from
core.agent.agent_executor
import
AgentExecuteResult
,
PlanningStrategy
...
@@ -14,11 +12,10 @@ from core.model_providers.error import LLMBadRequestError
...
@@ -14,11 +12,10 @@ from core.model_providers.error import LLMBadRequestError
from
core.memory.read_only_conversation_token_db_buffer_shared_memory
import
\
from
core.memory.read_only_conversation_token_db_buffer_shared_memory
import
\
ReadOnlyConversationTokenDBBufferSharedMemory
ReadOnlyConversationTokenDBBufferSharedMemory
from
core.model_providers.model_factory
import
ModelFactory
from
core.model_providers.model_factory
import
ModelFactory
from
core.model_providers.models.entity.message
import
PromptMessage
,
to_prompt_messages
from
core.model_providers.models.entity.message
import
PromptMessage
from
core.model_providers.models.llm.base
import
BaseLLM
from
core.model_providers.models.llm.base
import
BaseLLM
from
core.orchestrator_rule_parser
import
OrchestratorRuleParser
from
core.orchestrator_rule_parser
import
OrchestratorRuleParser
from
core.prompt.prompt_builder
import
PromptBuilder
from
core.prompt.prompt_builder
import
PromptBuilder
from
core.prompt.prompt_template
import
JinjaPromptTemplate
from
core.prompt.prompts
import
MORE_LIKE_THIS_GENERATE_PROMPT
from
core.prompt.prompts
import
MORE_LIKE_THIS_GENERATE_PROMPT
from
models.dataset
import
DocumentSegment
,
Dataset
,
Document
from
models.dataset
import
DocumentSegment
,
Dataset
,
Document
from
models.model
import
App
,
AppModelConfig
,
Account
,
Conversation
,
Message
,
EndUser
from
models.model
import
App
,
AppModelConfig
,
Account
,
Conversation
,
Message
,
EndUser
...
@@ -81,7 +78,7 @@ class Completion:
...
@@ -81,7 +78,7 @@ class Completion:
# parse sensitive_word_avoidance_chain
# parse sensitive_word_avoidance_chain
chain_callback
=
MainChainGatherCallbackHandler
(
conversation_message_task
)
chain_callback
=
MainChainGatherCallbackHandler
(
conversation_message_task
)
sensitive_word_avoidance_chain
=
orchestrator_rule_parser
.
to_sensitive_word_avoidance_chain
([
chain_callback
])
sensitive_word_avoidance_chain
=
orchestrator_rule_parser
.
to_sensitive_word_avoidance_chain
(
final_model_instance
,
[
chain_callback
])
if
sensitive_word_avoidance_chain
:
if
sensitive_word_avoidance_chain
:
query
=
sensitive_word_avoidance_chain
.
run
(
query
)
query
=
sensitive_word_avoidance_chain
.
run
(
query
)
...
...
api/core/conversation_message_task.py
View file @
f9082104
import
decimal
import
json
import
json
import
time
from
typing
import
Optional
,
Union
,
List
from
typing
import
Optional
,
Union
,
List
from
core.callback_handler.entity.agent_loop
import
AgentLoop
from
core.callback_handler.entity.agent_loop
import
AgentLoop
...
@@ -23,6 +23,8 @@ class ConversationMessageTask:
...
@@ -23,6 +23,8 @@ class ConversationMessageTask:
def
__init__
(
self
,
task_id
:
str
,
app
:
App
,
app_model_config
:
AppModelConfig
,
user
:
Account
,
def
__init__
(
self
,
task_id
:
str
,
app
:
App
,
app_model_config
:
AppModelConfig
,
user
:
Account
,
inputs
:
dict
,
query
:
str
,
streaming
:
bool
,
model_instance
:
BaseLLM
,
inputs
:
dict
,
query
:
str
,
streaming
:
bool
,
model_instance
:
BaseLLM
,
conversation
:
Optional
[
Conversation
]
=
None
,
is_override
:
bool
=
False
):
conversation
:
Optional
[
Conversation
]
=
None
,
is_override
:
bool
=
False
):
self
.
start_at
=
time
.
perf_counter
()
self
.
task_id
=
task_id
self
.
task_id
=
task_id
self
.
app
=
app
self
.
app
=
app
...
@@ -61,6 +63,7 @@ class ConversationMessageTask:
...
@@ -61,6 +63,7 @@ class ConversationMessageTask:
)
)
def
init
(
self
):
def
init
(
self
):
override_model_configs
=
None
override_model_configs
=
None
if
self
.
is_override
:
if
self
.
is_override
:
override_model_configs
=
self
.
app_model_config
.
to_dict
()
override_model_configs
=
self
.
app_model_config
.
to_dict
()
...
@@ -165,7 +168,7 @@ class ConversationMessageTask:
...
@@ -165,7 +168,7 @@ class ConversationMessageTask:
self
.
message
.
answer_tokens
=
answer_tokens
self
.
message
.
answer_tokens
=
answer_tokens
self
.
message
.
answer_unit_price
=
answer_unit_price
self
.
message
.
answer_unit_price
=
answer_unit_price
self
.
message
.
answer_price_unit
=
answer_price_unit
self
.
message
.
answer_price_unit
=
answer_price_unit
self
.
message
.
provider_response_latency
=
llm_message
.
latency
self
.
message
.
provider_response_latency
=
time
.
perf_counter
()
-
self
.
start_at
self
.
message
.
total_price
=
total_price
self
.
message
.
total_price
=
total_price
db
.
session
.
commit
()
db
.
session
.
commit
()
...
@@ -220,18 +223,18 @@ class ConversationMessageTask:
...
@@ -220,18 +223,18 @@ class ConversationMessageTask:
return
message_agent_thought
return
message_agent_thought
def
on_agent_end
(
self
,
message_agent_thought
:
MessageAgentThought
,
agent_model_instan
t
:
BaseLLM
,
def
on_agent_end
(
self
,
message_agent_thought
:
MessageAgentThought
,
agent_model_instan
ce
:
BaseLLM
,
agent_loop
:
AgentLoop
):
agent_loop
:
AgentLoop
):
agent_message_unit_price
=
agent_model_instan
t
.
get_tokens_unit_price
(
MessageType
.
HUMAN
)
agent_message_unit_price
=
agent_model_instan
ce
.
get_tokens_unit_price
(
MessageType
.
HUMAN
)
agent_message_price_unit
=
agent_model_instan
t
.
get_price_unit
(
MessageType
.
HUMAN
)
agent_message_price_unit
=
agent_model_instan
ce
.
get_price_unit
(
MessageType
.
HUMAN
)
agent_answer_unit_price
=
agent_model_instan
t
.
get_tokens_unit_price
(
MessageType
.
ASSISTANT
)
agent_answer_unit_price
=
agent_model_instan
ce
.
get_tokens_unit_price
(
MessageType
.
ASSISTANT
)
agent_answer_price_unit
=
agent_model_instan
t
.
get_price_unit
(
MessageType
.
ASSISTANT
)
agent_answer_price_unit
=
agent_model_instan
ce
.
get_price_unit
(
MessageType
.
ASSISTANT
)
loop_message_tokens
=
agent_loop
.
prompt_tokens
loop_message_tokens
=
agent_loop
.
prompt_tokens
loop_answer_tokens
=
agent_loop
.
completion_tokens
loop_answer_tokens
=
agent_loop
.
completion_tokens
loop_message_total_price
=
agent_model_instan
t
.
calc_tokens_price
(
loop_message_tokens
,
MessageType
.
HUMAN
)
loop_message_total_price
=
agent_model_instan
ce
.
calc_tokens_price
(
loop_message_tokens
,
MessageType
.
HUMAN
)
loop_answer_total_price
=
agent_model_instan
t
.
calc_tokens_price
(
loop_answer_tokens
,
MessageType
.
ASSISTANT
)
loop_answer_total_price
=
agent_model_instan
ce
.
calc_tokens_price
(
loop_answer_tokens
,
MessageType
.
ASSISTANT
)
loop_total_price
=
loop_message_total_price
+
loop_answer_total_price
loop_total_price
=
loop_message_total_price
+
loop_answer_total_price
message_agent_thought
.
observation
=
agent_loop
.
tool_output
message_agent_thought
.
observation
=
agent_loop
.
tool_output
...
@@ -245,7 +248,7 @@ class ConversationMessageTask:
...
@@ -245,7 +248,7 @@ class ConversationMessageTask:
message_agent_thought
.
latency
=
agent_loop
.
latency
message_agent_thought
.
latency
=
agent_loop
.
latency
message_agent_thought
.
tokens
=
agent_loop
.
prompt_tokens
+
agent_loop
.
completion_tokens
message_agent_thought
.
tokens
=
agent_loop
.
prompt_tokens
+
agent_loop
.
completion_tokens
message_agent_thought
.
total_price
=
loop_total_price
message_agent_thought
.
total_price
=
loop_total_price
message_agent_thought
.
currency
=
agent_model_instan
t
.
get_currency
()
message_agent_thought
.
currency
=
agent_model_instan
ce
.
get_currency
()
db
.
session
.
flush
()
db
.
session
.
flush
()
def
on_dataset_query_end
(
self
,
dataset_query_obj
:
DatasetQueryObj
):
def
on_dataset_query_end
(
self
,
dataset_query_obj
:
DatasetQueryObj
):
...
...
api/core/helper/moderation.py
0 → 100644
View file @
f9082104
import
logging
import
openai
from
flask
import
current_app
from
core.model_providers.error
import
LLMBadRequestError
from
core.model_providers.providers.base
import
BaseModelProvider
from
models.provider
import
ProviderType
def
check_moderation
(
model_provider
:
BaseModelProvider
,
text
:
str
)
->
bool
:
if
current_app
.
config
[
'HOSTED_MODERATION_ENABLED'
]
and
current_app
.
config
[
'HOSTED_MODERATION_PROVIDERS'
]:
moderation_providers
=
current_app
.
config
[
'HOSTED_MODERATION_PROVIDERS'
]
.
split
(
','
)
if
model_provider
.
provider
.
provider_type
==
ProviderType
.
SYSTEM
.
value
\
and
model_provider
.
provider_name
in
moderation_providers
:
# 2000 text per chunk
length
=
2000
chunks
=
[
text
[
i
:
i
+
length
]
for
i
in
range
(
0
,
len
(
text
),
length
)]
try
:
moderation_result
=
openai
.
Moderation
.
create
(
input
=
chunks
,
api_key
=
current_app
.
config
[
'HOSTED_OPENAI_API_KEY'
])
except
Exception
as
ex
:
logging
.
exception
(
ex
)
raise
LLMBadRequestError
(
'Rate limit exceeded, please try again later.'
)
for
result
in
moderation_result
.
results
:
if
result
[
'flagged'
]
is
True
:
return
False
return
True
api/core/model_providers/model_factory.py
View file @
f9082104
...
@@ -8,6 +8,7 @@ from core.model_providers.models.base import BaseProviderModel
...
@@ -8,6 +8,7 @@ from core.model_providers.models.base import BaseProviderModel
from
core.model_providers.models.embedding.base
import
BaseEmbedding
from
core.model_providers.models.embedding.base
import
BaseEmbedding
from
core.model_providers.models.entity.model_params
import
ModelKwargs
,
ModelType
from
core.model_providers.models.entity.model_params
import
ModelKwargs
,
ModelType
from
core.model_providers.models.llm.base
import
BaseLLM
from
core.model_providers.models.llm.base
import
BaseLLM
from
core.model_providers.models.moderation.base
import
BaseModeration
from
core.model_providers.models.speech2text.base
import
BaseSpeech2Text
from
core.model_providers.models.speech2text.base
import
BaseSpeech2Text
from
extensions.ext_database
import
db
from
extensions.ext_database
import
db
from
models.provider
import
TenantDefaultModel
from
models.provider
import
TenantDefaultModel
...
@@ -180,7 +181,7 @@ class ModelFactory:
...
@@ -180,7 +181,7 @@ class ModelFactory:
def
get_moderation_model
(
cls
,
def
get_moderation_model
(
cls
,
tenant_id
:
str
,
tenant_id
:
str
,
model_provider_name
:
str
,
model_provider_name
:
str
,
model_name
:
str
)
->
Optional
[
Base
ProviderModel
]:
model_name
:
str
)
->
Optional
[
Base
Moderation
]:
"""
"""
get moderation model.
get moderation model.
...
...
api/core/model_providers/models/llm/base.py
View file @
f9082104
...
@@ -10,6 +10,7 @@ from langchain.memory.chat_memory import BaseChatMemory
...
@@ -10,6 +10,7 @@ from langchain.memory.chat_memory import BaseChatMemory
from
langchain.schema
import
LLMResult
,
SystemMessage
,
AIMessage
,
HumanMessage
,
BaseMessage
,
ChatGeneration
from
langchain.schema
import
LLMResult
,
SystemMessage
,
AIMessage
,
HumanMessage
,
BaseMessage
,
ChatGeneration
from
core.callback_handler.std_out_callback_handler
import
DifyStreamingStdOutCallbackHandler
,
DifyStdOutCallbackHandler
from
core.callback_handler.std_out_callback_handler
import
DifyStreamingStdOutCallbackHandler
,
DifyStdOutCallbackHandler
from
core.helper
import
moderation
from
core.model_providers.models.base
import
BaseProviderModel
from
core.model_providers.models.base
import
BaseProviderModel
from
core.model_providers.models.entity.message
import
PromptMessage
,
MessageType
,
LLMRunResult
,
to_prompt_messages
from
core.model_providers.models.entity.message
import
PromptMessage
,
MessageType
,
LLMRunResult
,
to_prompt_messages
from
core.model_providers.models.entity.model_params
import
ModelType
,
ModelKwargs
,
ModelMode
,
ModelKwargsRules
from
core.model_providers.models.entity.model_params
import
ModelType
,
ModelKwargs
,
ModelMode
,
ModelKwargsRules
...
@@ -116,6 +117,15 @@ class BaseLLM(BaseProviderModel):
...
@@ -116,6 +117,15 @@ class BaseLLM(BaseProviderModel):
:param callbacks:
:param callbacks:
:return:
:return:
"""
"""
moderation_result
=
moderation
.
check_moderation
(
self
.
model_provider
,
"
\n
"
.
join
([
message
.
content
for
message
in
messages
])
)
if
not
moderation_result
:
kwargs
[
'fake_response'
]
=
"I apologize for any confusion, "
\
"but I'm an AI assistant to be helpful, harmless, and honest."
if
self
.
deduct_quota
:
if
self
.
deduct_quota
:
self
.
model_provider
.
check_quota_over_limit
()
self
.
model_provider
.
check_quota_over_limit
()
...
...
api/core/model_providers/models/moderation/base.py
0 → 100644
View file @
f9082104
from
abc
import
abstractmethod
from
typing
import
Any
from
core.model_providers.models.base
import
BaseProviderModel
from
core.model_providers.models.entity.model_params
import
ModelType
from
core.model_providers.providers.base
import
BaseModelProvider
class
BaseModeration
(
BaseProviderModel
):
name
:
str
type
:
ModelType
=
ModelType
.
MODERATION
def
__init__
(
self
,
model_provider
:
BaseModelProvider
,
client
:
Any
,
name
:
str
):
super
()
.
__init__
(
model_provider
,
client
)
self
.
name
=
name
def
run
(
self
,
text
:
str
)
->
bool
:
try
:
return
self
.
_run
(
text
)
except
Exception
as
ex
:
raise
self
.
handle_exceptions
(
ex
)
@
abstractmethod
def
_run
(
self
,
text
:
str
)
->
bool
:
raise
NotImplementedError
@
abstractmethod
def
handle_exceptions
(
self
,
ex
:
Exception
)
->
Exception
:
raise
NotImplementedError
api/core/model_providers/models/moderation/openai_moderation.py
View file @
f9082104
...
@@ -4,29 +4,35 @@ import openai
...
@@ -4,29 +4,35 @@ import openai
from
core.model_providers.error
import
LLMBadRequestError
,
LLMAPIConnectionError
,
LLMAPIUnavailableError
,
\
from
core.model_providers.error
import
LLMBadRequestError
,
LLMAPIConnectionError
,
LLMAPIUnavailableError
,
\
LLMRateLimitError
,
LLMAuthorizationError
LLMRateLimitError
,
LLMAuthorizationError
from
core.model_providers.models.base
import
BaseProviderModel
from
core.model_providers.models.moderation.base
import
BaseModeration
from
core.model_providers.models.entity.model_params
import
ModelType
from
core.model_providers.providers.base
import
BaseModelProvider
from
core.model_providers.providers.base
import
BaseModelProvider
DEFAULT_
AUDIO_
MODEL
=
'whisper-1'
DEFAULT_MODEL
=
'whisper-1'
class
OpenAIModeration
(
BaseProviderModel
):
class
OpenAIModeration
(
BaseModeration
):
type
:
ModelType
=
ModelType
.
MODERATION
def
__init__
(
self
,
model_provider
:
BaseModelProvider
,
name
:
str
):
def
__init__
(
self
,
model_provider
:
BaseModelProvider
,
name
:
str
):
super
()
.
__init__
(
model_provider
,
openai
.
Moderation
)
super
()
.
__init__
(
model_provider
,
openai
.
Moderation
,
name
)
def
run
(
self
,
text
)
:
def
_run
(
self
,
text
:
str
)
->
bool
:
credentials
=
self
.
model_provider
.
get_model_credentials
(
credentials
=
self
.
model_provider
.
get_model_credentials
(
model_name
=
DEFAULT_AUDIO_MODEL
,
model_name
=
self
.
name
,
model_type
=
self
.
type
model_type
=
self
.
type
)
)
try
:
# 2000 text per chunk
return
self
.
_client
.
create
(
input
=
text
,
api_key
=
credentials
[
'openai_api_key'
])
length
=
2000
except
Exception
as
ex
:
chunks
=
[
text
[
i
:
i
+
length
]
for
i
in
range
(
0
,
len
(
text
),
length
)]
raise
self
.
handle_exceptions
(
ex
)
moderation_result
=
self
.
_client
.
create
(
input
=
chunks
,
api_key
=
credentials
[
'openai_api_key'
])
for
result
in
moderation_result
.
results
:
if
result
[
'flagged'
]
is
True
:
return
False
return
True
def
handle_exceptions
(
self
,
ex
:
Exception
)
->
Exception
:
def
handle_exceptions
(
self
,
ex
:
Exception
)
->
Exception
:
if
isinstance
(
ex
,
openai
.
error
.
InvalidRequestError
):
if
isinstance
(
ex
,
openai
.
error
.
InvalidRequestError
):
...
...
api/core/orchestrator_rule_parser.py
View file @
f9082104
import
math
import
math
from
typing
import
Optional
from
typing
import
Optional
from
flask
import
current_app
from
langchain
import
WikipediaAPIWrapper
from
langchain
import
WikipediaAPIWrapper
from
langchain.callbacks.manager
import
Callbacks
from
langchain.callbacks.manager
import
Callbacks
from
langchain.memory.chat_memory
import
BaseChatMemory
from
langchain.memory.chat_memory
import
BaseChatMemory
...
@@ -12,7 +13,7 @@ from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGa
...
@@ -12,7 +13,7 @@ from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGa
from
core.callback_handler.dataset_tool_callback_handler
import
DatasetToolCallbackHandler
from
core.callback_handler.dataset_tool_callback_handler
import
DatasetToolCallbackHandler
from
core.callback_handler.main_chain_gather_callback_handler
import
MainChainGatherCallbackHandler
from
core.callback_handler.main_chain_gather_callback_handler
import
MainChainGatherCallbackHandler
from
core.callback_handler.std_out_callback_handler
import
DifyStdOutCallbackHandler
from
core.callback_handler.std_out_callback_handler
import
DifyStdOutCallbackHandler
from
core.chain.sensitive_word_avoidance_chain
import
SensitiveWordAvoidanceChain
from
core.chain.sensitive_word_avoidance_chain
import
SensitiveWordAvoidanceChain
,
SensitiveWordAvoidanceRule
from
core.conversation_message_task
import
ConversationMessageTask
from
core.conversation_message_task
import
ConversationMessageTask
from
core.model_providers.error
import
ProviderTokenNotInitError
from
core.model_providers.error
import
ProviderTokenNotInitError
from
core.model_providers.model_factory
import
ModelFactory
from
core.model_providers.model_factory
import
ModelFactory
...
@@ -26,6 +27,7 @@ from core.tool.web_reader_tool import WebReaderTool
...
@@ -26,6 +27,7 @@ from core.tool.web_reader_tool import WebReaderTool
from
extensions.ext_database
import
db
from
extensions.ext_database
import
db
from
models.dataset
import
Dataset
,
DatasetProcessRule
from
models.dataset
import
Dataset
,
DatasetProcessRule
from
models.model
import
AppModelConfig
from
models.model
import
AppModelConfig
from
models.provider
import
ProviderType
class
OrchestratorRuleParser
:
class
OrchestratorRuleParser
:
...
@@ -63,7 +65,7 @@ class OrchestratorRuleParser:
...
@@ -63,7 +65,7 @@ class OrchestratorRuleParser:
# add agent callback to record agent thoughts
# add agent callback to record agent thoughts
agent_callback
=
AgentLoopGatherCallbackHandler
(
agent_callback
=
AgentLoopGatherCallbackHandler
(
model_instan
t
=
agent_model_instance
,
model_instan
ce
=
agent_model_instance
,
conversation_message_task
=
conversation_message_task
conversation_message_task
=
conversation_message_task
)
)
...
@@ -123,23 +125,45 @@ class OrchestratorRuleParser:
...
@@ -123,23 +125,45 @@ class OrchestratorRuleParser:
return
chain
return
chain
def
to_sensitive_word_avoidance_chain
(
self
,
callbacks
:
Callbacks
=
None
,
**
kwargs
)
\
def
to_sensitive_word_avoidance_chain
(
self
,
model_instance
:
BaseLLM
,
callbacks
:
Callbacks
=
None
,
**
kwargs
)
\
->
Optional
[
SensitiveWordAvoidanceChain
]:
->
Optional
[
SensitiveWordAvoidanceChain
]:
"""
"""
Convert app sensitive word avoidance config to chain
Convert app sensitive word avoidance config to chain
:param model_instance: model instance
:param callbacks: callbacks for the chain
:param kwargs:
:param kwargs:
:return:
:return:
"""
"""
if
not
self
.
app_model_config
.
sensitive_word_avoidance_dict
:
sensitive_word_avoidance_rule
=
None
return
None
if
self
.
app_model_config
.
sensitive_word_avoidance_dict
:
sensitive_word_avoidance_config
=
self
.
app_model_config
.
sensitive_word_avoidance_dict
sensitive_word_avoidance_config
=
self
.
app_model_config
.
sensitive_word_avoidance_dict
if
sensitive_word_avoidance_config
.
get
(
"enabled"
,
False
):
if
sensitive_word_avoidance_config
.
get
(
'type'
)
==
'moderation'
:
sensitive_word_avoidance_rule
=
SensitiveWordAvoidanceRule
(
type
=
SensitiveWordAvoidanceRule
.
Type
.
MODERATION
,
canned_response
=
sensitive_word_avoidance_config
.
get
(
"canned_response"
)
if
sensitive_word_avoidance_config
.
get
(
"canned_response"
)
else
'Your content violates our usage policy. Please revise and try again.'
,
)
else
:
sensitive_words
=
sensitive_word_avoidance_config
.
get
(
"words"
,
""
)
sensitive_words
=
sensitive_word_avoidance_config
.
get
(
"words"
,
""
)
if
sensitive_word_avoidance_config
.
get
(
"enabled"
,
False
)
and
sensitive_words
:
if
sensitive_words
:
sensitive_word_avoidance_rule
=
SensitiveWordAvoidanceRule
(
type
=
SensitiveWordAvoidanceRule
.
Type
.
KEYWORDS
,
canned_response
=
sensitive_word_avoidance_config
.
get
(
"canned_response"
)
if
sensitive_word_avoidance_config
.
get
(
"canned_response"
)
else
'Your content violates our usage policy. Please revise and try again.'
,
extra_params
=
{
'sensitive_words'
:
sensitive_words
.
split
(
','
),
}
)
if
sensitive_word_avoidance_rule
:
return
SensitiveWordAvoidanceChain
(
return
SensitiveWordAvoidanceChain
(
sensitive_words
=
sensitive_words
.
split
(
","
)
,
model_instance
=
model_instance
,
canned_response
=
sensitive_word_avoidance_config
.
get
(
"canned_response"
,
''
)
,
sensitive_word_avoidance_rule
=
sensitive_word_avoidance_rule
,
output_key
=
"sensitive_word_avoidance_output"
,
output_key
=
"sensitive_word_avoidance_output"
,
callbacks
=
callbacks
,
callbacks
=
callbacks
,
**
kwargs
**
kwargs
...
...
api/tests/integration_tests/models/moderation/test_openai_moderation.py
View file @
f9082104
...
@@ -2,7 +2,7 @@ import json
...
@@ -2,7 +2,7 @@ import json
import
os
import
os
from
unittest.mock
import
patch
from
unittest.mock
import
patch
from
core.model_providers.models.moderation.openai_moderation
import
OpenAIModeration
,
DEFAULT_
AUDIO_
MODEL
from
core.model_providers.models.moderation.openai_moderation
import
OpenAIModeration
,
DEFAULT_MODEL
from
core.model_providers.providers.openai_provider
import
OpenAIProvider
from
core.model_providers.providers.openai_provider
import
OpenAIProvider
from
models.provider
import
Provider
,
ProviderType
from
models.provider
import
Provider
,
ProviderType
...
@@ -23,7 +23,7 @@ def get_mock_openai_moderation_model():
...
@@ -23,7 +23,7 @@ def get_mock_openai_moderation_model():
openai_provider
=
OpenAIProvider
(
provider
=
get_mock_provider
(
valid_openai_api_key
))
openai_provider
=
OpenAIProvider
(
provider
=
get_mock_provider
(
valid_openai_api_key
))
return
OpenAIModeration
(
return
OpenAIModeration
(
model_provider
=
openai_provider
,
model_provider
=
openai_provider
,
name
=
DEFAULT_
AUDIO_
MODEL
name
=
DEFAULT_MODEL
)
)
...
@@ -36,5 +36,4 @@ def test_run(mock_decrypt):
...
@@ -36,5 +36,4 @@ def test_run(mock_decrypt):
model
=
get_mock_openai_moderation_model
()
model
=
get_mock_openai_moderation_model
()
rst
=
model
.
run
(
'hello'
)
rst
=
model
.
run
(
'hello'
)
assert
isinstance
(
rst
,
dict
)
assert
rst
is
True
assert
'id'
in
rst
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment