Unverified Commit f9082104 authored by takatost's avatar takatost Committed by GitHub

feat: add hosted moderation (#1158)

parent 983834cd
...@@ -61,6 +61,8 @@ DEFAULTS = { ...@@ -61,6 +61,8 @@ DEFAULTS = {
'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1000000, 'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1000000,
'HOSTED_ANTHROPIC_PAID_MIN_QUANTITY': 20, 'HOSTED_ANTHROPIC_PAID_MIN_QUANTITY': 20,
'HOSTED_ANTHROPIC_PAID_MAX_QUANTITY': 100, 'HOSTED_ANTHROPIC_PAID_MAX_QUANTITY': 100,
'HOSTED_MODERATION_ENABLED': 'False',
'HOSTED_MODERATION_PROVIDERS': '',
'TENANT_DOCUMENT_COUNT': 100, 'TENANT_DOCUMENT_COUNT': 100,
'CLEAN_DAY_SETTING': 30, 'CLEAN_DAY_SETTING': 30,
'UPLOAD_FILE_SIZE_LIMIT': 15, 'UPLOAD_FILE_SIZE_LIMIT': 15,
...@@ -230,6 +232,9 @@ class Config: ...@@ -230,6 +232,9 @@ class Config:
self.HOSTED_ANTHROPIC_PAID_MIN_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MIN_QUANTITY')) self.HOSTED_ANTHROPIC_PAID_MIN_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MIN_QUANTITY'))
self.HOSTED_ANTHROPIC_PAID_MAX_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MAX_QUANTITY')) self.HOSTED_ANTHROPIC_PAID_MAX_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MAX_QUANTITY'))
self.HOSTED_MODERATION_ENABLED = get_bool_env('HOSTED_MODERATION_ENABLED')
self.HOSTED_MODERATION_PROVIDERS = get_env('HOSTED_MODERATION_PROVIDERS')
self.STRIPE_API_KEY = get_env('STRIPE_API_KEY') self.STRIPE_API_KEY = get_env('STRIPE_API_KEY')
self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET') self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET')
......
...@@ -16,6 +16,8 @@ from core.agent.agent.structed_multi_dataset_router_agent import StructuredMulti ...@@ -16,6 +16,8 @@ from core.agent.agent.structed_multi_dataset_router_agent import StructuredMulti
from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
from langchain.agents import AgentExecutor as LCAgentExecutor from langchain.agents import AgentExecutor as LCAgentExecutor
from core.helper import moderation
from core.model_providers.error import LLMError
from core.model_providers.models.llm.base import BaseLLM from core.model_providers.models.llm.base import BaseLLM
from core.tool.dataset_retriever_tool import DatasetRetrieverTool from core.tool.dataset_retriever_tool import DatasetRetrieverTool
...@@ -116,6 +118,18 @@ class AgentExecutor: ...@@ -116,6 +118,18 @@ class AgentExecutor:
return self.agent.should_use_agent(query) return self.agent.should_use_agent(query)
def run(self, query: str) -> AgentExecuteResult: def run(self, query: str) -> AgentExecuteResult:
moderation_result = moderation.check_moderation(
self.configuration.model_instance.model_provider,
query
)
if not moderation_result:
return AgentExecuteResult(
output="I apologize for any confusion, but I'm an AI assistant to be helpful, harmless, and honest.",
strategy=self.configuration.strategy,
configuration=self.configuration
)
agent_executor = LCAgentExecutor.from_agent_and_tools( agent_executor = LCAgentExecutor.from_agent_and_tools(
agent=self.agent, agent=self.agent,
tools=self.configuration.tools, tools=self.configuration.tools,
...@@ -128,7 +142,9 @@ class AgentExecutor: ...@@ -128,7 +142,9 @@ class AgentExecutor:
try: try:
output = agent_executor.run(query) output = agent_executor.run(query)
except Exception: except LLMError as ex:
raise ex
except Exception as ex:
logging.exception("agent_executor run failed") logging.exception("agent_executor run failed")
output = None output = None
......
...@@ -6,7 +6,7 @@ from typing import Any, Dict, List, Union, Optional ...@@ -6,7 +6,7 @@ from typing import Any, Dict, List, Union, Optional
from langchain.agents import openai_functions_agent, openai_functions_multi_agent from langchain.agents import openai_functions_agent, openai_functions_multi_agent
from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration, BaseMessage
from core.callback_handler.entity.agent_loop import AgentLoop from core.callback_handler.entity.agent_loop import AgentLoop
from core.conversation_message_task import ConversationMessageTask from core.conversation_message_task import ConversationMessageTask
...@@ -18,9 +18,9 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): ...@@ -18,9 +18,9 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
"""Callback Handler that prints to std out.""" """Callback Handler that prints to std out."""
raise_error: bool = True raise_error: bool = True
def __init__(self, model_instant: BaseLLM, conversation_message_task: ConversationMessageTask) -> None: def __init__(self, model_instance: BaseLLM, conversation_message_task: ConversationMessageTask) -> None:
"""Initialize callback handler.""" """Initialize callback handler."""
self.model_instant = model_instant self.model_instance = model_instance
self.conversation_message_task = conversation_message_task self.conversation_message_task = conversation_message_task
self._agent_loops = [] self._agent_loops = []
self._current_loop = None self._current_loop = None
...@@ -46,6 +46,21 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): ...@@ -46,6 +46,21 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
"""Whether to ignore chain callbacks.""" """Whether to ignore chain callbacks."""
return True return True
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
**kwargs: Any
) -> Any:
if not self._current_loop:
# Agent start with a LLM query
self._current_loop = AgentLoop(
position=len(self._agent_loops) + 1,
prompt="\n".join([message.content for message in messages[0]]),
status='llm_started',
started_at=time.perf_counter()
)
def on_llm_start( def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None: ) -> None:
...@@ -70,7 +85,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): ...@@ -70,7 +85,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
if response.llm_output: if response.llm_output:
self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens'] self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
else: else:
self._current_loop.prompt_tokens = self.model_instant.get_num_tokens( self._current_loop.prompt_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self._current_loop.prompt)] [PromptMessage(content=self._current_loop.prompt)]
) )
completion_generation = response.generations[0][0] completion_generation = response.generations[0][0]
...@@ -87,7 +102,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): ...@@ -87,7 +102,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
if response.llm_output: if response.llm_output:
self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens'] self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
else: else:
self._current_loop.completion_tokens = self.model_instant.get_num_tokens( self._current_loop.completion_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self._current_loop.completion)] [PromptMessage(content=self._current_loop.completion)]
) )
...@@ -162,7 +177,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): ...@@ -162,7 +177,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
self.conversation_message_task.on_agent_end( self.conversation_message_task.on_agent_end(
self._message_agent_thought, self.model_instant, self._current_loop self._message_agent_thought, self.model_instance, self._current_loop
) )
self._agent_loops.append(self._current_loop) self._agent_loops.append(self._current_loop)
...@@ -193,7 +208,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): ...@@ -193,7 +208,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
) )
self.conversation_message_task.on_agent_end( self.conversation_message_task.on_agent_end(
self._message_agent_thought, self.model_instant, self._current_loop self._message_agent_thought, self.model_instance, self._current_loop
) )
self._agent_loops.append(self._current_loop) self._agent_loops.append(self._current_loop)
......
...@@ -6,4 +6,3 @@ class LLMMessage(BaseModel): ...@@ -6,4 +6,3 @@ class LLMMessage(BaseModel):
prompt_tokens: int = 0 prompt_tokens: int = 0
completion: str = '' completion: str = ''
completion_tokens: int = 0 completion_tokens: int = 0
latency: float = 0.0
import logging import logging
import time
from typing import Any, Dict, List, Union from typing import Any, Dict, List, Union
from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.base import BaseCallbackHandler
...@@ -32,7 +31,6 @@ class LLMCallbackHandler(BaseCallbackHandler): ...@@ -32,7 +31,6 @@ class LLMCallbackHandler(BaseCallbackHandler):
messages: List[List[BaseMessage]], messages: List[List[BaseMessage]],
**kwargs: Any **kwargs: Any
) -> Any: ) -> Any:
self.start_at = time.perf_counter()
real_prompts = [] real_prompts = []
for message in messages[0]: for message in messages[0]:
if message.type == 'human': if message.type == 'human':
...@@ -53,8 +51,6 @@ class LLMCallbackHandler(BaseCallbackHandler): ...@@ -53,8 +51,6 @@ class LLMCallbackHandler(BaseCallbackHandler):
def on_llm_start( def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None: ) -> None:
self.start_at = time.perf_counter()
self.llm_message.prompt = [{ self.llm_message.prompt = [{
"role": 'user', "role": 'user',
"text": prompts[0] "text": prompts[0]
...@@ -63,9 +59,6 @@ class LLMCallbackHandler(BaseCallbackHandler): ...@@ -63,9 +59,6 @@ class LLMCallbackHandler(BaseCallbackHandler):
self.llm_message.prompt_tokens = self.model_instance.get_num_tokens([PromptMessage(content=prompts[0])]) self.llm_message.prompt_tokens = self.model_instance.get_num_tokens([PromptMessage(content=prompts[0])])
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
end_at = time.perf_counter()
self.llm_message.latency = end_at - self.start_at
if not self.conversation_message_task.streaming: if not self.conversation_message_task.streaming:
self.conversation_message_task.append_message_text(response.generations[0][0].text) self.conversation_message_task.append_message_text(response.generations[0][0].text)
self.llm_message.completion = response.generations[0][0].text self.llm_message.completion = response.generations[0][0].text
...@@ -89,8 +82,6 @@ class LLMCallbackHandler(BaseCallbackHandler): ...@@ -89,8 +82,6 @@ class LLMCallbackHandler(BaseCallbackHandler):
"""Do nothing.""" """Do nothing."""
if isinstance(error, ConversationTaskStoppedException): if isinstance(error, ConversationTaskStoppedException):
if self.conversation_message_task.streaming: if self.conversation_message_task.streaming:
end_at = time.perf_counter()
self.llm_message.latency = end_at - self.start_at
self.llm_message.completion_tokens = self.model_instance.get_num_tokens( self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self.llm_message.completion)] [PromptMessage(content=self.llm_message.completion)]
) )
......
import enum
import logging
from typing import List, Dict, Optional, Any from typing import List, Dict, Optional, Any
import openai
from flask import current_app
from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain from langchain.chains.base import Chain
from openai import InvalidRequestError
from openai.error import APIConnectionError, APIError, ServiceUnavailableError, Timeout, RateLimitError, \
AuthenticationError, OpenAIError
from pydantic import BaseModel
from core.model_providers.error import LLMBadRequestError
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.moderation import openai_moderation
class SensitiveWordAvoidanceRule(BaseModel):
class Type(enum.Enum):
MODERATION = "moderation"
KEYWORDS = "keywords"
type: Type
canned_response: str = 'Your content violates our usage policy. Please revise and try again.'
extra_params: dict = {}
class SensitiveWordAvoidanceChain(Chain): class SensitiveWordAvoidanceChain(Chain):
input_key: str = "input" #: :meta private: input_key: str = "input" #: :meta private:
output_key: str = "output" #: :meta private: output_key: str = "output" #: :meta private:
sensitive_words: List[str] = [] model_instance: BaseLLM
canned_response: str = None sensitive_word_avoidance_rule: SensitiveWordAvoidanceRule
@property @property
def _chain_type(self) -> str: def _chain_type(self) -> str:
...@@ -31,11 +54,24 @@ class SensitiveWordAvoidanceChain(Chain): ...@@ -31,11 +54,24 @@ class SensitiveWordAvoidanceChain(Chain):
""" """
return [self.output_key] return [self.output_key]
def _check_sensitive_word(self, text: str) -> str: def _check_sensitive_word(self, text: str) -> bool:
for word in self.sensitive_words: for word in self.sensitive_word_avoidance_rule.extra_params.get('sensitive_words', []):
if word in text: if word in text:
return self.canned_response return False
return text return True
def _check_moderation(self, text: str) -> bool:
moderation_model_instance = ModelFactory.get_moderation_model(
tenant_id=self.model_instance.model_provider.provider.tenant_id,
model_provider_name='openai',
model_name=openai_moderation.DEFAULT_MODEL
)
try:
return moderation_model_instance.run(text=text)
except Exception as ex:
logging.exception(ex)
raise LLMBadRequestError('Rate limit exceeded, please try again later.')
def _call( def _call(
self, self,
...@@ -43,5 +79,13 @@ class SensitiveWordAvoidanceChain(Chain): ...@@ -43,5 +79,13 @@ class SensitiveWordAvoidanceChain(Chain):
run_manager: Optional[CallbackManagerForChainRun] = None, run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
text = inputs[self.input_key] text = inputs[self.input_key]
output = self._check_sensitive_word(text)
return {self.output_key: output} if self.sensitive_word_avoidance_rule.type == SensitiveWordAvoidanceRule.Type.KEYWORDS:
result = self._check_sensitive_word(text)
else:
result = self._check_moderation(text)
if not result:
raise LLMBadRequestError(self.sensitive_word_avoidance_rule.canned_response)
return {self.output_key: text}
import json import json
import logging import logging
import re from typing import Optional, List, Union
from typing import Optional, List, Union, Tuple
from langchain.schema import BaseMessage
from requests.exceptions import ChunkedEncodingError from requests.exceptions import ChunkedEncodingError
from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy
...@@ -14,11 +12,10 @@ from core.model_providers.error import LLMBadRequestError ...@@ -14,11 +12,10 @@ from core.model_providers.error import LLMBadRequestError
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \ from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
ReadOnlyConversationTokenDBBufferSharedMemory ReadOnlyConversationTokenDBBufferSharedMemory
from core.model_providers.model_factory import ModelFactory from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.message import PromptMessage, to_prompt_messages from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.llm.base import BaseLLM from core.model_providers.models.llm.base import BaseLLM
from core.orchestrator_rule_parser import OrchestratorRuleParser from core.orchestrator_rule_parser import OrchestratorRuleParser
from core.prompt.prompt_builder import PromptBuilder from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import JinjaPromptTemplate
from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
from models.dataset import DocumentSegment, Dataset, Document from models.dataset import DocumentSegment, Dataset, Document
from models.model import App, AppModelConfig, Account, Conversation, Message, EndUser from models.model import App, AppModelConfig, Account, Conversation, Message, EndUser
...@@ -81,7 +78,7 @@ class Completion: ...@@ -81,7 +78,7 @@ class Completion:
# parse sensitive_word_avoidance_chain # parse sensitive_word_avoidance_chain
chain_callback = MainChainGatherCallbackHandler(conversation_message_task) chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain([chain_callback]) sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain(final_model_instance, [chain_callback])
if sensitive_word_avoidance_chain: if sensitive_word_avoidance_chain:
query = sensitive_word_avoidance_chain.run(query) query = sensitive_word_avoidance_chain.run(query)
......
import decimal
import json import json
import time
from typing import Optional, Union, List from typing import Optional, Union, List
from core.callback_handler.entity.agent_loop import AgentLoop from core.callback_handler.entity.agent_loop import AgentLoop
...@@ -23,6 +23,8 @@ class ConversationMessageTask: ...@@ -23,6 +23,8 @@ class ConversationMessageTask:
def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account, def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account,
inputs: dict, query: str, streaming: bool, model_instance: BaseLLM, inputs: dict, query: str, streaming: bool, model_instance: BaseLLM,
conversation: Optional[Conversation] = None, is_override: bool = False): conversation: Optional[Conversation] = None, is_override: bool = False):
self.start_at = time.perf_counter()
self.task_id = task_id self.task_id = task_id
self.app = app self.app = app
...@@ -61,6 +63,7 @@ class ConversationMessageTask: ...@@ -61,6 +63,7 @@ class ConversationMessageTask:
) )
def init(self): def init(self):
override_model_configs = None override_model_configs = None
if self.is_override: if self.is_override:
override_model_configs = self.app_model_config.to_dict() override_model_configs = self.app_model_config.to_dict()
...@@ -165,7 +168,7 @@ class ConversationMessageTask: ...@@ -165,7 +168,7 @@ class ConversationMessageTask:
self.message.answer_tokens = answer_tokens self.message.answer_tokens = answer_tokens
self.message.answer_unit_price = answer_unit_price self.message.answer_unit_price = answer_unit_price
self.message.answer_price_unit = answer_price_unit self.message.answer_price_unit = answer_price_unit
self.message.provider_response_latency = llm_message.latency self.message.provider_response_latency = time.perf_counter() - self.start_at
self.message.total_price = total_price self.message.total_price = total_price
db.session.commit() db.session.commit()
...@@ -220,18 +223,18 @@ class ConversationMessageTask: ...@@ -220,18 +223,18 @@ class ConversationMessageTask:
return message_agent_thought return message_agent_thought
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instant: BaseLLM, def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instance: BaseLLM,
agent_loop: AgentLoop): agent_loop: AgentLoop):
agent_message_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.HUMAN) agent_message_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.HUMAN)
agent_message_price_unit = agent_model_instant.get_price_unit(MessageType.HUMAN) agent_message_price_unit = agent_model_instance.get_price_unit(MessageType.HUMAN)
agent_answer_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.ASSISTANT) agent_answer_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
agent_answer_price_unit = agent_model_instant.get_price_unit(MessageType.ASSISTANT) agent_answer_price_unit = agent_model_instance.get_price_unit(MessageType.ASSISTANT)
loop_message_tokens = agent_loop.prompt_tokens loop_message_tokens = agent_loop.prompt_tokens
loop_answer_tokens = agent_loop.completion_tokens loop_answer_tokens = agent_loop.completion_tokens
loop_message_total_price = agent_model_instant.calc_tokens_price(loop_message_tokens, MessageType.HUMAN) loop_message_total_price = agent_model_instance.calc_tokens_price(loop_message_tokens, MessageType.HUMAN)
loop_answer_total_price = agent_model_instant.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT) loop_answer_total_price = agent_model_instance.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT)
loop_total_price = loop_message_total_price + loop_answer_total_price loop_total_price = loop_message_total_price + loop_answer_total_price
message_agent_thought.observation = agent_loop.tool_output message_agent_thought.observation = agent_loop.tool_output
...@@ -245,7 +248,7 @@ class ConversationMessageTask: ...@@ -245,7 +248,7 @@ class ConversationMessageTask:
message_agent_thought.latency = agent_loop.latency message_agent_thought.latency = agent_loop.latency
message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens
message_agent_thought.total_price = loop_total_price message_agent_thought.total_price = loop_total_price
message_agent_thought.currency = agent_model_instant.get_currency() message_agent_thought.currency = agent_model_instance.get_currency()
db.session.flush() db.session.flush()
def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj): def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj):
......
import logging
import openai
from flask import current_app
from core.model_providers.error import LLMBadRequestError
from core.model_providers.providers.base import BaseModelProvider
from models.provider import ProviderType
def check_moderation(model_provider: BaseModelProvider, text: str) -> bool:
if current_app.config['HOSTED_MODERATION_ENABLED'] and current_app.config['HOSTED_MODERATION_PROVIDERS']:
moderation_providers = current_app.config['HOSTED_MODERATION_PROVIDERS'].split(',')
if model_provider.provider.provider_type == ProviderType.SYSTEM.value \
and model_provider.provider_name in moderation_providers:
# 2000 text per chunk
length = 2000
chunks = [text[i:i + length] for i in range(0, len(text), length)]
try:
moderation_result = openai.Moderation.create(input=chunks,
api_key=current_app.config['HOSTED_OPENAI_API_KEY'])
except Exception as ex:
logging.exception(ex)
raise LLMBadRequestError('Rate limit exceeded, please try again later.')
for result in moderation_result.results:
if result['flagged'] is True:
return False
return True
...@@ -8,6 +8,7 @@ from core.model_providers.models.base import BaseProviderModel ...@@ -8,6 +8,7 @@ from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.base import BaseEmbedding from core.model_providers.models.embedding.base import BaseEmbedding
from core.model_providers.models.entity.model_params import ModelKwargs, ModelType from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
from core.model_providers.models.llm.base import BaseLLM from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.moderation.base import BaseModeration
from core.model_providers.models.speech2text.base import BaseSpeech2Text from core.model_providers.models.speech2text.base import BaseSpeech2Text
from extensions.ext_database import db from extensions.ext_database import db
from models.provider import TenantDefaultModel from models.provider import TenantDefaultModel
...@@ -180,7 +181,7 @@ class ModelFactory: ...@@ -180,7 +181,7 @@ class ModelFactory:
def get_moderation_model(cls, def get_moderation_model(cls,
tenant_id: str, tenant_id: str,
model_provider_name: str, model_provider_name: str,
model_name: str) -> Optional[BaseProviderModel]: model_name: str) -> Optional[BaseModeration]:
""" """
get moderation model. get moderation model.
......
...@@ -10,6 +10,7 @@ from langchain.memory.chat_memory import BaseChatMemory ...@@ -10,6 +10,7 @@ from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler
from core.helper import moderation
from core.model_providers.models.base import BaseProviderModel from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
...@@ -116,6 +117,15 @@ class BaseLLM(BaseProviderModel): ...@@ -116,6 +117,15 @@ class BaseLLM(BaseProviderModel):
:param callbacks: :param callbacks:
:return: :return:
""" """
moderation_result = moderation.check_moderation(
self.model_provider,
"\n".join([message.content for message in messages])
)
if not moderation_result:
kwargs['fake_response'] = "I apologize for any confusion, " \
"but I'm an AI assistant to be helpful, harmless, and honest."
if self.deduct_quota: if self.deduct_quota:
self.model_provider.check_quota_over_limit() self.model_provider.check_quota_over_limit()
......
from abc import abstractmethod
from typing import Any
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.base import BaseModelProvider
class BaseModeration(BaseProviderModel):
name: str
type: ModelType = ModelType.MODERATION
def __init__(self, model_provider: BaseModelProvider, client: Any, name: str):
super().__init__(model_provider, client)
self.name = name
def run(self, text: str) -> bool:
try:
return self._run(text)
except Exception as ex:
raise self.handle_exceptions(ex)
@abstractmethod
def _run(self, text: str) -> bool:
raise NotImplementedError
@abstractmethod
def handle_exceptions(self, ex: Exception) -> Exception:
raise NotImplementedError
...@@ -4,29 +4,35 @@ import openai ...@@ -4,29 +4,35 @@ import openai
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \ from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
LLMRateLimitError, LLMAuthorizationError LLMRateLimitError, LLMAuthorizationError
from core.model_providers.models.base import BaseProviderModel from core.model_providers.models.moderation.base import BaseModeration
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.base import BaseModelProvider from core.model_providers.providers.base import BaseModelProvider
DEFAULT_AUDIO_MODEL = 'whisper-1' DEFAULT_MODEL = 'whisper-1'
class OpenAIModeration(BaseProviderModel): class OpenAIModeration(BaseModeration):
type: ModelType = ModelType.MODERATION
def __init__(self, model_provider: BaseModelProvider, name: str): def __init__(self, model_provider: BaseModelProvider, name: str):
super().__init__(model_provider, openai.Moderation) super().__init__(model_provider, openai.Moderation, name)
def run(self, text): def _run(self, text: str) -> bool:
credentials = self.model_provider.get_model_credentials( credentials = self.model_provider.get_model_credentials(
model_name=DEFAULT_AUDIO_MODEL, model_name=self.name,
model_type=self.type model_type=self.type
) )
try: # 2000 text per chunk
return self._client.create(input=text, api_key=credentials['openai_api_key']) length = 2000
except Exception as ex: chunks = [text[i:i + length] for i in range(0, len(text), length)]
raise self.handle_exceptions(ex)
moderation_result = self._client.create(input=chunks,
api_key=credentials['openai_api_key'])
for result in moderation_result.results:
if result['flagged'] is True:
return False
return True
def handle_exceptions(self, ex: Exception) -> Exception: def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, openai.error.InvalidRequestError): if isinstance(ex, openai.error.InvalidRequestError):
......
import math import math
from typing import Optional from typing import Optional
from flask import current_app
from langchain import WikipediaAPIWrapper from langchain import WikipediaAPIWrapper
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.chat_memory import BaseChatMemory
...@@ -12,7 +13,7 @@ from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGa ...@@ -12,7 +13,7 @@ from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGa
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain, SensitiveWordAvoidanceRule
from core.conversation_message_task import ConversationMessageTask from core.conversation_message_task import ConversationMessageTask
from core.model_providers.error import ProviderTokenNotInitError from core.model_providers.error import ProviderTokenNotInitError
from core.model_providers.model_factory import ModelFactory from core.model_providers.model_factory import ModelFactory
...@@ -26,6 +27,7 @@ from core.tool.web_reader_tool import WebReaderTool ...@@ -26,6 +27,7 @@ from core.tool.web_reader_tool import WebReaderTool
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset, DatasetProcessRule from models.dataset import Dataset, DatasetProcessRule
from models.model import AppModelConfig from models.model import AppModelConfig
from models.provider import ProviderType
class OrchestratorRuleParser: class OrchestratorRuleParser:
...@@ -63,7 +65,7 @@ class OrchestratorRuleParser: ...@@ -63,7 +65,7 @@ class OrchestratorRuleParser:
# add agent callback to record agent thoughts # add agent callback to record agent thoughts
agent_callback = AgentLoopGatherCallbackHandler( agent_callback = AgentLoopGatherCallbackHandler(
model_instant=agent_model_instance, model_instance=agent_model_instance,
conversation_message_task=conversation_message_task conversation_message_task=conversation_message_task
) )
...@@ -123,23 +125,45 @@ class OrchestratorRuleParser: ...@@ -123,23 +125,45 @@ class OrchestratorRuleParser:
return chain return chain
def to_sensitive_word_avoidance_chain(self, callbacks: Callbacks = None, **kwargs) \ def to_sensitive_word_avoidance_chain(self, model_instance: BaseLLM, callbacks: Callbacks = None, **kwargs) \
-> Optional[SensitiveWordAvoidanceChain]: -> Optional[SensitiveWordAvoidanceChain]:
""" """
Convert app sensitive word avoidance config to chain Convert app sensitive word avoidance config to chain
:param model_instance: model instance
:param callbacks: callbacks for the chain
:param kwargs: :param kwargs:
:return: :return:
""" """
if not self.app_model_config.sensitive_word_avoidance_dict: sensitive_word_avoidance_rule = None
return None
if self.app_model_config.sensitive_word_avoidance_dict:
sensitive_word_avoidance_config = self.app_model_config.sensitive_word_avoidance_dict sensitive_word_avoidance_config = self.app_model_config.sensitive_word_avoidance_dict
if sensitive_word_avoidance_config.get("enabled", False):
if sensitive_word_avoidance_config.get('type') == 'moderation':
sensitive_word_avoidance_rule = SensitiveWordAvoidanceRule(
type=SensitiveWordAvoidanceRule.Type.MODERATION,
canned_response=sensitive_word_avoidance_config.get("canned_response")
if sensitive_word_avoidance_config.get("canned_response")
else 'Your content violates our usage policy. Please revise and try again.',
)
else:
sensitive_words = sensitive_word_avoidance_config.get("words", "") sensitive_words = sensitive_word_avoidance_config.get("words", "")
if sensitive_word_avoidance_config.get("enabled", False) and sensitive_words: if sensitive_words:
sensitive_word_avoidance_rule = SensitiveWordAvoidanceRule(
type=SensitiveWordAvoidanceRule.Type.KEYWORDS,
canned_response=sensitive_word_avoidance_config.get("canned_response")
if sensitive_word_avoidance_config.get("canned_response")
else 'Your content violates our usage policy. Please revise and try again.',
extra_params={
'sensitive_words': sensitive_words.split(','),
}
)
if sensitive_word_avoidance_rule:
return SensitiveWordAvoidanceChain( return SensitiveWordAvoidanceChain(
sensitive_words=sensitive_words.split(","), model_instance=model_instance,
canned_response=sensitive_word_avoidance_config.get("canned_response", ''), sensitive_word_avoidance_rule=sensitive_word_avoidance_rule,
output_key="sensitive_word_avoidance_output", output_key="sensitive_word_avoidance_output",
callbacks=callbacks, callbacks=callbacks,
**kwargs **kwargs
......
...@@ -2,7 +2,7 @@ import json ...@@ -2,7 +2,7 @@ import json
import os import os
from unittest.mock import patch from unittest.mock import patch
from core.model_providers.models.moderation.openai_moderation import OpenAIModeration, DEFAULT_AUDIO_MODEL from core.model_providers.models.moderation.openai_moderation import OpenAIModeration, DEFAULT_MODEL
from core.model_providers.providers.openai_provider import OpenAIProvider from core.model_providers.providers.openai_provider import OpenAIProvider
from models.provider import Provider, ProviderType from models.provider import Provider, ProviderType
...@@ -23,7 +23,7 @@ def get_mock_openai_moderation_model(): ...@@ -23,7 +23,7 @@ def get_mock_openai_moderation_model():
openai_provider = OpenAIProvider(provider=get_mock_provider(valid_openai_api_key)) openai_provider = OpenAIProvider(provider=get_mock_provider(valid_openai_api_key))
return OpenAIModeration( return OpenAIModeration(
model_provider=openai_provider, model_provider=openai_provider,
name=DEFAULT_AUDIO_MODEL name=DEFAULT_MODEL
) )
...@@ -36,5 +36,4 @@ def test_run(mock_decrypt): ...@@ -36,5 +36,4 @@ def test_run(mock_decrypt):
model = get_mock_openai_moderation_model() model = get_mock_openai_moderation_model()
rst = model.run('hello') rst = model.run('hello')
assert isinstance(rst, dict) assert rst is True
assert 'id' in rst
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment