Unverified Commit f5b2271c authored by John Wang's avatar John Wang Committed by GitHub

fix: import wrong user (#32)

parent a8155cba
...@@ -2,8 +2,6 @@ import decimal ...@@ -2,8 +2,6 @@ import decimal
import json import json
from typing import Optional, Union from typing import Optional, Union
from gunicorn.config import User
from core.callback_handler.entity.agent_loop import AgentLoop from core.callback_handler.entity.agent_loop import AgentLoop
from core.callback_handler.entity.dataset_query import DatasetQueryObj from core.callback_handler.entity.dataset_query import DatasetQueryObj
from core.callback_handler.entity.llm_message import LLMMessage from core.callback_handler.entity.llm_message import LLMMessage
...@@ -269,7 +267,7 @@ class ConversationMessageTask: ...@@ -269,7 +267,7 @@ class ConversationMessageTask:
class PubHandler: class PubHandler:
def __init__(self, user: Union[Account | User], task_id: str, def __init__(self, user: Union[Account | EndUser], task_id: str,
message: Message, conversation: Conversation, message: Message, conversation: Conversation,
chain_pub: bool = False, agent_thought_pub: bool = False): chain_pub: bool = False, agent_thought_pub: bool = False):
self._channel = PubHandler.generate_channel_name(user, task_id) self._channel = PubHandler.generate_channel_name(user, task_id)
...@@ -282,12 +280,12 @@ class PubHandler: ...@@ -282,12 +280,12 @@ class PubHandler:
self._agent_thought_pub = agent_thought_pub self._agent_thought_pub = agent_thought_pub
@classmethod @classmethod
def generate_channel_name(cls, user: Union[Account | User], task_id: str): def generate_channel_name(cls, user: Union[Account | EndUser], task_id: str):
user_str = 'account-' + user.id if isinstance(user, Account) else 'end-user-' + user.id user_str = 'account-' + user.id if isinstance(user, Account) else 'end-user-' + user.id
return "generate_result:{}-{}".format(user_str, task_id) return "generate_result:{}-{}".format(user_str, task_id)
@classmethod @classmethod
def generate_stopped_cache_key(cls, user: Union[Account | User], task_id: str): def generate_stopped_cache_key(cls, user: Union[Account | EndUser], task_id: str):
user_str = 'account-' + user.id if isinstance(user, Account) else 'end-user-' + user.id user_str = 'account-' + user.id if isinstance(user, Account) else 'end-user-' + user.id
return "generate_result_stopped:{}-{}".format(user_str, task_id) return "generate_result_stopped:{}-{}".format(user_str, task_id)
...@@ -366,7 +364,7 @@ class PubHandler: ...@@ -366,7 +364,7 @@ class PubHandler:
redis_client.publish(self._channel, json.dumps(content)) redis_client.publish(self._channel, json.dumps(content))
@classmethod @classmethod
def pub_error(cls, user: Union[Account | User], task_id: str, e): def pub_error(cls, user: Union[Account | EndUser], task_id: str, e):
content = { content = {
'error': type(e).__name__, 'error': type(e).__name__,
'description': e.description if getattr(e, 'description', None) is not None else str(e) 'description': e.description if getattr(e, 'description', None) is not None else str(e)
...@@ -379,7 +377,7 @@ class PubHandler: ...@@ -379,7 +377,7 @@ class PubHandler:
return redis_client.get(self._stopped_cache_key) is not None return redis_client.get(self._stopped_cache_key) is not None
@classmethod @classmethod
def stop(cls, user: Union[Account | User], task_id: str): def stop(cls, user: Union[Account | EndUser], task_id: str):
stopped_cache_key = cls.generate_stopped_cache_key(user, task_id) stopped_cache_key = cls.generate_stopped_cache_key(user, task_id)
redis_client.setex(stopped_cache_key, 600, 1) redis_client.setex(stopped_cache_key, 600, 1)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment