#!/usr/bin/env python3


# SAT: a jabber client
# Copyright (C) 2009-2021 Jérôme Poisson (goffi@goffi.org)

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.

# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.


import dataclasses
from functools import partial
from pathlib import Path
from twisted.spread import jelly, pb
from twisted.internet import reactor
from libervia.backend.core.log import getLogger
from libervia.backend.tools import config

log = getLogger(__name__)


## jelly hack
# we monkey patch jelly to handle namedtuple
ori_jelly = jelly._Jellier.jelly


def fixed_jelly(self, obj):
    """this method fix handling of namedtuple"""
    if isinstance(obj, tuple) and not obj is tuple:
        obj = tuple(obj)
    return ori_jelly(self, obj)


jelly._Jellier.jelly = fixed_jelly


@dataclasses.dataclass(eq=False)
class HandlerWrapper:
    # we use a wrapper to keep signals handlers because RemoteReference doesn't support
    # comparison (other than equality), making it unusable with a list
    handler: pb.RemoteReference


class PBRoot(pb.Root):
    def __init__(self):
        self.signals_handlers = []

    def remote_init_bridge(self, signals_handler):
        self.signals_handlers.append(HandlerWrapper(signals_handler))
        log.info("registered signal handler")

    def send_signal_eb(self, failure_, signal_name):
        if not failure_.check(pb.PBConnectionLost):
            log.error(
                f"Error while sending signal {signal_name}: {failure_}",
            )

    def send_signal(self, name, args, kwargs):
        to_remove = []
        for wrapper in self.signals_handlers:
            handler = wrapper.handler
            try:
                d = handler.callRemote(name, *args, **kwargs)
            except pb.DeadReferenceError:
                to_remove.append(wrapper)
            else:
                d.addErrback(self.send_signal_eb, name)
        if to_remove:
            for wrapper in to_remove:
                log.debug("Removing signal handler for dead frontend")
                self.signals_handlers.remove(wrapper)

    def _bridge_deactivate_signals(self):
        if hasattr(self, "signals_paused"):
            log.warning("bridge signals already deactivated")
            if self.signals_handler:
                self.signals_paused.extend(self.signals_handler)
        else:
            self.signals_paused = self.signals_handlers
        self.signals_handlers = []
        log.debug("bridge signals have been deactivated")

    def _bridge_reactivate_signals(self):
        try:
            self.signals_handlers = self.signals_paused
        except AttributeError:
            log.debug("signals were already activated")
        else:
            del self.signals_paused
            log.debug("bridge signals have been reactivated")

##METHODS_PART##


class bridge(object):
    def __init__(self):
        log.info("Init Perspective Broker...")
        self.root = PBRoot()
        conf = config.parse_main_conf()
        get_conf = partial(config.get_conf, conf, "bridge_pb", "")
        conn_type = get_conf("connection_type", "unix_socket")
        if conn_type == "unix_socket":
            local_dir = Path(config.config_get(conf, "", "local_dir")).resolve()
            socket_path = local_dir / "bridge_pb"
            log.info(f"using UNIX Socket at {socket_path}")
            reactor.listenUNIX(
                str(socket_path), pb.PBServerFactory(self.root), mode=0o600
            )
        elif conn_type == "socket":
            port = int(get_conf("port", 8789))
            log.info(f"using TCP Socket at port {port}")
            reactor.listenTCP(port, pb.PBServerFactory(self.root))
        else:
            raise ValueError(f"Unknown pb connection type: {conn_type!r}")

    def send_signal(self, name, *args, **kwargs):
        self.root.send_signal(name, args, kwargs)

    def remote_init_bridge(self, signals_handler):
        self.signals_handlers.append(signals_handler)
        log.info("registered signal handler")

    def register_method(self, name, callback):
        log.debug("registering PB bridge method [%s]" % name)
        setattr(self.root, "remote_" + name, callback)
        #  self.root.register_method(name, callback)

    def add_method(
            self, name, int_suffix, in_sign, out_sign, method, async_=False, doc={}
    ):
        """Dynamically add a method to PB bridge"""
        # FIXME: doc parameter is kept only temporary, the time to remove it from calls
        log.debug("Adding method {name} to PB bridge".format(name=name))
        self.register_method(name, method)

    def add_signal(self, name, int_suffix, signature, doc={}):
        log.debug("Adding signal {name} to PB bridge".format(name=name))
        setattr(
            self, name, lambda *args, **kwargs: self.send_signal(name, *args, **kwargs)
        )

    def bridge_deactivate_signals(self):
        """Stop sending signals to bridge

        Mainly used for mobile frontends, when the frontend is paused
        """
        self.root._bridge_deactivate_signals()

    def bridge_reactivate_signals(self):
        """Send again signals to bridge

        Should only be used after bridge_deactivate_signals has been called
        """
        self.root._bridge_reactivate_signals()

    def _debug(self, action, params, profile):
        self.send_signal("_debug", action, params, profile)

    def action_new(self, action_data, id, security_limit, profile):
        self.send_signal("action_new", action_data, id, security_limit, profile)

    def connected(self, jid_s, profile):
        self.send_signal("connected", jid_s, profile)

    def contact_deleted(self, entity_jid, profile):
        self.send_signal("contact_deleted", entity_jid, profile)

    def contact_new(self, contact_jid, attributes, groups, profile):
        self.send_signal("contact_new", contact_jid, attributes, groups, profile)

    def disconnected(self, profile):
        self.send_signal("disconnected", profile)

    def entity_data_updated(self, jid, name, value, profile):
        self.send_signal("entity_data_updated", jid, name, value, profile)

    def message_encryption_started(self, to_jid, encryption_data, profile_key):
        self.send_signal("message_encryption_started", to_jid, encryption_data, profile_key)

    def message_encryption_stopped(self, to_jid, encryption_data, profile_key):
        self.send_signal("message_encryption_stopped", to_jid, encryption_data, profile_key)

    def message_new(self, uid, timestamp, from_jid, to_jid, message, subject, mess_type, extra, profile):
        self.send_signal("message_new", uid, timestamp, from_jid, to_jid, message, subject, mess_type, extra, profile)

    def message_update(self, uid, message_type, message_data, profile):
        self.send_signal("message_update", uid, message_type, message_data, profile)

    def notification_deleted(self, id, profile):
        self.send_signal("notification_deleted", id, profile)

    def notification_new(self, id, timestamp, type, body_plain, body_rich, title, requires_action, priority, expire_at, extra, profile):
        self.send_signal("notification_new", id, timestamp, type, body_plain, body_rich, title, requires_action, priority, expire_at, extra, profile)

    def param_update(self, name, value, category, profile):
        self.send_signal("param_update", name, value, category, profile)

    def presence_update(self, entity_jid, show, priority, statuses, profile):
        self.send_signal("presence_update", entity_jid, show, priority, statuses, profile)

    def progress_error(self, id, error, profile):
        self.send_signal("progress_error", id, error, profile)

    def progress_finished(self, id, metadata, profile):
        self.send_signal("progress_finished", id, metadata, profile)

    def progress_started(self, id, metadata, profile):
        self.send_signal("progress_started", id, metadata, profile)

    def subscribe(self, sub_type, entity_jid, profile):
        self.send_signal("subscribe", sub_type, entity_jid, profile)
