from typing import Optional

from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.queue_entities import (
    QueueNodeFailedEvent,
    QueueNodeStartedEvent,
    QueueNodeSucceededEvent,
    QueueTextChunkEvent,
    QueueWorkflowFailedEvent,
    QueueWorkflowStartedEvent,
    QueueWorkflowSucceededEvent,
)
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
from models.workflow import Workflow


class WorkflowEventTriggerCallback(BaseWorkflowCallback):

    def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
        self._queue_manager = queue_manager
        self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph_dict)

    def on_workflow_run_started(self) -> None:
        """
        Workflow run started
        """
        self._queue_manager.publish(
            QueueWorkflowStartedEvent(),
            PublishFrom.APPLICATION_MANAGER
        )

    def on_workflow_run_succeeded(self) -> None:
        """
        Workflow run succeeded
        """
        self._queue_manager.publish(
            QueueWorkflowSucceededEvent(),
            PublishFrom.APPLICATION_MANAGER
        )

    def on_workflow_run_failed(self, error: str) -> None:
        """
        Workflow run failed
        """
        self._queue_manager.publish(
            QueueWorkflowFailedEvent(
                error=error
            ),
            PublishFrom.APPLICATION_MANAGER
        )

    def on_workflow_node_execute_started(self, node_id: str,
                                         node_type: NodeType,
                                         node_data: BaseNodeData,
                                         node_run_index: int = 1,
                                         predecessor_node_id: Optional[str] = None) -> None:
        """
        Workflow node execute started
        """
        self._queue_manager.publish(
            QueueNodeStartedEvent(
                node_id=node_id,
                node_type=node_type,
                node_data=node_data,
                node_run_index=node_run_index,
                predecessor_node_id=predecessor_node_id
            ),
            PublishFrom.APPLICATION_MANAGER
        )

    def on_workflow_node_execute_succeeded(self, node_id: str,
                                           node_type: NodeType,
                                           node_data: BaseNodeData,
                                           inputs: Optional[dict] = None,
                                           process_data: Optional[dict] = None,
                                           outputs: Optional[dict] = None,
                                           execution_metadata: Optional[dict] = None) -> None:
        """
        Workflow node execute succeeded
        """
        self._queue_manager.publish(
            QueueNodeSucceededEvent(
                node_id=node_id,
                node_type=node_type,
                node_data=node_data,
                inputs=inputs,
                process_data=process_data,
                outputs=outputs,
                execution_metadata=execution_metadata
            ),
            PublishFrom.APPLICATION_MANAGER
        )

    def on_workflow_node_execute_failed(self, node_id: str,
                                        node_type: NodeType,
                                        node_data: BaseNodeData,
                                        error: str) -> None:
        """
        Workflow node execute failed
        """
        self._queue_manager.publish(
            QueueNodeFailedEvent(
                node_id=node_id,
                node_type=node_type,
                node_data=node_data,
                error=error
            ),
            PublishFrom.APPLICATION_MANAGER
        )

    def on_node_text_chunk(self, node_id: str, text: str) -> None:
        """
        Publish text chunk
        """
        if node_id in self._streamable_node_ids:
            self._queue_manager.publish(
                QueueTextChunkEvent(
                    text=text
                ), PublishFrom.APPLICATION_MANAGER
            )

    def _fetch_streamable_node_ids(self, graph: dict) -> list[str]:
        """
        Fetch streamable node ids
        When the Workflow type is chat, only the nodes before END Node are LLM or Direct Answer can be streamed output
        When the Workflow type is workflow, only the nodes before END Node (only Plain Text mode) are LLM can be streamed output

        :param graph: workflow graph
        :return:
        """
        streamable_node_ids = []
        end_node_ids = []
        for node_config in graph.get('nodes'):
            if node_config.get('data', {}).get('type') == NodeType.END.value:
                if node_config.get('data', {}).get('outputs', {}).get('type', '') == 'plain-text':
                    end_node_ids.append(node_config.get('id'))

        for edge_config in graph.get('edges'):
            if edge_config.get('target') in end_node_ids:
                streamable_node_ids.append(edge_config.get('source'))

        return streamable_node_ids
