from collections.abc import Generator
from typing import Optional, cast

from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.model_entities import ModelStatus
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.file.file_obj import FileObj
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.message_entities import PromptMessage
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.llm.entities import LLMNodeData
from extensions.ext_database import db
from models.model import Conversation
from models.workflow import WorkflowNodeExecutionStatus


class LLMNode(BaseNode):
    _node_data_cls = LLMNodeData
    node_type = NodeType.LLM

    def _run(self, variable_pool: VariablePool) -> NodeRunResult:
        """
        Run node
        :param variable_pool: variable pool
        :return:
        """
        node_data = self.node_data
        node_data = cast(self._node_data_cls, node_data)

        node_inputs = None
        process_data = None

        try:
            # fetch variables and fetch values from variable pool
            inputs = self._fetch_inputs(node_data, variable_pool)

            node_inputs = {
                **inputs
            }

            # fetch files
            files: list[FileObj] = self._fetch_files(node_data, variable_pool)

            if files:
                node_inputs['#files#'] = [{
                    'type': file.type.value,
                    'transfer_method': file.transfer_method.value,
                    'url': file.url,
                    'upload_file_id': file.upload_file_id,
                } for file in files]

            # fetch context value
            context = self._fetch_context(node_data, variable_pool)

            if context:
                node_inputs['#context#'] = context

            # fetch model config
            model_instance, model_config = self._fetch_model_config(node_data)

            # fetch memory
            memory = self._fetch_memory(node_data, variable_pool, model_instance)

            # fetch prompt messages
            prompt_messages, stop = self._fetch_prompt_messages(
                node_data=node_data,
                inputs=inputs,
                files=files,
                context=context,
                memory=memory,
                model_config=model_config
            )

            process_data = {
                'model_mode': model_config.mode,
                'prompts': PromptMessageUtil.prompt_messages_to_prompt_for_saving(
                    model_mode=model_config.mode,
                    prompt_messages=prompt_messages
                )
            }

            # handle invoke result
            result_text, usage = self._invoke_llm(
                node_data=node_data,
                model_instance=model_instance,
                prompt_messages=prompt_messages,
                stop=stop
            )
        except Exception as e:
            return NodeRunResult(
                status=WorkflowNodeExecutionStatus.FAILED,
                error=str(e),
                inputs=node_inputs,
                process_data=process_data
            )

        outputs = {
            'text': result_text,
            'usage': jsonable_encoder(usage)
        }

        return NodeRunResult(
            status=WorkflowNodeExecutionStatus.SUCCEEDED,
            inputs=node_inputs,
            process_data=process_data,
            outputs=outputs,
            metadata={
                NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
                NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
                NodeRunMetadataKey.CURRENCY: usage.currency
            }
        )

    def _invoke_llm(self, node_data: LLMNodeData,
                    model_instance: ModelInstance,
                    prompt_messages: list[PromptMessage],
                    stop: list[str]) -> tuple[str, LLMUsage]:
        """
        Invoke large language model
        :param node_data: node data
        :param model_instance: model instance
        :param prompt_messages: prompt messages
        :param stop: stop
        :return:
        """
        db.session.close()

        invoke_result = model_instance.invoke_llm(
            prompt_messages=prompt_messages,
            model_parameters=node_data.model.completion_params,
            stop=stop,
            stream=True,
            user=self.user_id,
        )

        # handle invoke result
        return self._handle_invoke_result(
            invoke_result=invoke_result
        )

    def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]:
        """
        Handle invoke result
        :param invoke_result: invoke result
        :return:
        """
        model = None
        prompt_messages = []
        full_text = ''
        usage = None
        for result in invoke_result:
            text = result.delta.message.content
            full_text += text

            self.publish_text_chunk(text=text)

            if not model:
                model = result.model

            if not prompt_messages:
                prompt_messages = result.prompt_messages

            if not usage and result.delta.usage:
                usage = result.delta.usage

        if not usage:
            usage = LLMUsage.empty_usage()

        return full_text, usage

    def _fetch_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]:
        """
        Fetch inputs
        :param node_data: node data
        :param variable_pool: variable pool
        :return:
        """
        inputs = {}
        for variable_selector in node_data.variables:
            variable_value = variable_pool.get_variable_value(variable_selector.value_selector)
            if variable_value is None:
                raise ValueError(f'Variable {variable_selector.value_selector} not found')

            inputs[variable_selector.variable] = variable_value

        return inputs

    def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list[FileObj]:
        """
        Fetch files
        :param node_data: node data
        :param variable_pool: variable pool
        :return:
        """
        if not node_data.vision.enabled:
            return []

        files = variable_pool.get_variable_value(['sys', SystemVariable.FILES.value])
        if not files:
            return []

        return files

    def _fetch_context(self, node_data: LLMNodeData, variable_pool: VariablePool) -> Optional[str]:
        """
        Fetch context
        :param node_data: node data
        :param variable_pool: variable pool
        :return:
        """
        if not node_data.context.enabled:
            return None

        context_value = variable_pool.get_variable_value(node_data.context.variable_selector)
        if context_value:
            if isinstance(context_value, str):
                return context_value
            elif isinstance(context_value, list):
                context_str = ''
                for item in context_value:
                    if 'content' not in item:
                        raise ValueError(f'Invalid context structure: {item}')

                    context_str += item['content'] + '\n'

                return context_str.strip()

        return None

    def _fetch_model_config(self, node_data: LLMNodeData) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
        """
        Fetch model config
        :param node_data: node data
        :return:
        """
        model_name = node_data.model.name
        provider_name = node_data.model.provider

        model_manager = ModelManager()
        model_instance = model_manager.get_model_instance(
            tenant_id=self.tenant_id,
            model_type=ModelType.LLM,
            provider=provider_name,
            model=model_name
        )

        provider_model_bundle = model_instance.provider_model_bundle
        model_type_instance = model_instance.model_type_instance
        model_type_instance = cast(LargeLanguageModel, model_type_instance)

        model_credentials = model_instance.credentials

        # check model
        provider_model = provider_model_bundle.configuration.get_provider_model(
            model=model_name,
            model_type=ModelType.LLM
        )

        if provider_model is None:
            raise ValueError(f"Model {model_name} not exist.")

        if provider_model.status == ModelStatus.NO_CONFIGURE:
            raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
        elif provider_model.status == ModelStatus.NO_PERMISSION:
            raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
        elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
            raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")

        # model config
        completion_params = node_data.model.completion_params
        stop = []
        if 'stop' in completion_params:
            stop = completion_params['stop']
            del completion_params['stop']

        # get model mode
        model_mode = node_data.model.mode
        if not model_mode:
            raise ValueError("LLM mode is required.")

        model_schema = model_type_instance.get_model_schema(
            model_name,
            model_credentials
        )

        if not model_schema:
            raise ValueError(f"Model {model_name} not exist.")

        return model_instance, ModelConfigWithCredentialsEntity(
            provider=provider_name,
            model=model_name,
            model_schema=model_schema,
            mode=model_mode,
            provider_model_bundle=provider_model_bundle,
            credentials=model_credentials,
            parameters=completion_params,
            stop=stop,
        )

    def _fetch_memory(self, node_data: LLMNodeData,
                      variable_pool: VariablePool,
                      model_instance: ModelInstance) -> Optional[TokenBufferMemory]:
        """
        Fetch memory
        :param node_data: node data
        :param variable_pool: variable pool
        :return:
        """
        if not node_data.memory:
            return None

        # get conversation id
        conversation_id = variable_pool.get_variable_value(['sys', SystemVariable.CONVERSATION])
        if conversation_id is None:
            return None

        # get conversation
        conversation = db.session.query(Conversation).filter(
            Conversation.tenant_id == self.tenant_id,
            Conversation.app_id == self.app_id,
            Conversation.id == conversation_id
        ).first()

        if not conversation:
            return None

        memory = TokenBufferMemory(
            conversation=conversation,
            model_instance=model_instance
        )

        return memory

    def _fetch_prompt_messages(self, node_data: LLMNodeData,
                               inputs: dict[str, str],
                               files: list[FileObj],
                               context: Optional[str],
                               memory: Optional[TokenBufferMemory],
                               model_config: ModelConfigWithCredentialsEntity) \
            -> tuple[list[PromptMessage], Optional[list[str]]]:
        """
        Fetch prompt messages
        :param node_data: node data
        :param inputs: inputs
        :param files: files
        :param context: context
        :param memory: memory
        :param model_config: model config
        :return:
        """
        prompt_transform = AdvancedPromptTransform()
        prompt_messages = prompt_transform.get_prompt(
            prompt_template=node_data.prompt_template,
            inputs=inputs,
            query='',
            files=files,
            context=context,
            memory_config=node_data.memory,
            memory=memory,
            model_config=model_config
        )
        stop = model_config.stop

        return prompt_messages, stop

    @classmethod
    def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
        """
        Extract variable selector to variable mapping
        :param node_data: node data
        :return:
        """
        node_data = node_data
        node_data = cast(cls._node_data_cls, node_data)

        variable_mapping = {}
        for variable_selector in node_data.variables:
            variable_mapping[variable_selector.variable] = variable_selector.value_selector

        if node_data.context.enabled:
            variable_mapping['#context#'] = node_data.context.variable_selector

        if node_data.vision.enabled:
            variable_mapping['#files#'] = ['sys', SystemVariable.FILES.value]

        return variable_mapping

    @classmethod
    def get_default_config(cls, filters: Optional[dict] = None) -> dict:
        """
        Get default config of node.
        :param filters: filter by node config parameters.
        :return:
        """
        return {
            "type": "llm",
            "config": {
                "prompt_templates": {
                    "chat_model": {
                        "prompts": [
                            {
                                "role": "system",
                                "text": "You are a helpful AI assistant."
                            }
                        ]
                    },
                    "completion_model": {
                        "conversation_histories_role": {
                            "user_prefix": "Human",
                            "assistant_prefix": "Assistant"
                        },
                        "prompt": {
                            "text": "Here is the chat histories between human and assistant, inside "
                                    "<histories></histories> XML tags.\n\n<histories>\n{{"
                                    "#histories#}}\n</histories>\n\n\nHuman: {{#query#}}\n\nAssistant:"
                        },
                        "stop": ["Human:"]
                    }
                }
            }
        }
