Source code for caspia.meadow.client.connection.connection

# pylint: disable=too-many-instance-attributes
import asyncio
import collections
import json
import logging
import re
import socket
from functools import partial

import paho.mqtt.client as mqtt

from caspia.meadow import errors
from caspia.toolbox import monitor

logger = logging.getLogger(__name__)


[docs]class Subscription: def __init__(self): self.subscribers = set() self.subscribed = False self.subscribed_evt = asyncio.Event() self.subscribe_task = None self.unsubscribe_task = None @property def should_subscribe(self): return self.subscribers and not self.subscribed and not self.subscribe_task @property def should_unsubscribe(self): return not self.subscribers and self.subscribed and not self.unsubscribe_task
[docs]class Connection: def __init__(self, broker_url, name=None, loop=None): if broker_url.startswith('mqtt://'): broker_url = broker_url[len('mqtt://'):] self.broker_url = broker_url self.name = name self.loop = loop or asyncio.get_event_loop() self._subscriptions = collections.defaultdict(Subscription) self._prevent_duplicates = {'gateway/#'} self._topic_last_hash = dict() self._connect_called = False self._client: mqtt.Client = self._create_client() self._connect_lock = asyncio.Lock(loop=self.loop) self._connected = asyncio.Event(loop=self.loop) self._connected.clear() self._disconnected = asyncio.Event(loop=self.loop) self._disconnected.set() self._message_futures = dict() self._inflight_semaphore = asyncio.Semaphore(20) @property def connected(self): return self._connected.is_set() def _create_client(self): client = mqtt.Client(self.name, clean_session=True) client.on_connect = partial(self.loop.call_soon_threadsafe, self.on_connect) client.on_disconnect = partial(self.loop.call_soon_threadsafe, self.on_disconnect) client.on_message = partial(self.loop.call_soon_threadsafe, self.on_message) client.on_publish = partial(self.loop.call_soon_threadsafe, self.on_publish) client.on_subscribe = partial(self.loop.call_soon_threadsafe, self.on_subscribe) client.on_unsubscribe = partial(self.loop.call_soon_threadsafe, self.on_unsubscribe) client.enable_logger(logger) self._helper = AsyncioHelper(self.loop, client, self.name) return client
[docs] def on_connect(self, client, userdata, flags, rc): self._cancel_all_message_futures() if rc == mqtt.MQTT_ERR_SUCCESS: self._connected.set() self._disconnected.clear() logger.info('Client %s successfully connected to %s', self.name, self.broker_url) else: self._connected.clear() self._disconnected.set() reason = mqtt.error_string(rc) logger.error('Client %s failed to connect to %s, reason: %r', self.name, self.broker_url, reason)
[docs] def on_disconnect(self, client, userdata, rc): self._connected.clear() self._disconnected.set() reason = mqtt.error_string(rc) args = 'disconnected from %s [%s]', self.broker_url, reason if rc == mqtt.MQTT_ERR_SUCCESS: logger.info(*args) else: logger.error(*args) self._cancel_all_message_futures()
def _cancel_all_message_futures(self): while self._message_futures: _, future = self._message_futures.popitem() future.set_exception(errors.CommunicationError('connection interrupted'))
[docs] def on_message(self, client, userdata, message): asyncio.ensure_future(self._process_message(message), loop=self.loop)
[docs] def on_publish(self, client, userdata, mid): future = self._message_futures.pop(mid, None) if future: future.set_result(None)
[docs] def on_subscribe(self, client, userdata, mid, granted_qos): future = self._message_futures.pop(mid, None) if future: future.set_result(None)
[docs] def on_unsubscribe(self, client, userdata, mid): future = self._message_futures.pop(mid, None) if future: future.set_result(None)
async def _do_connect(self): """Connect the mqtt client.""" def _connect(): if ':' in self.broker_url: host, port = self.broker_url.split(':') else: host, port = self.broker_url, 1883 self._client.connect(host, int(port)) self._client.socket().setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 2048) with await self._connect_lock: if self._disconnected.is_set(): await self.loop.run_in_executor(None, _connect) await self._connected.wait() async def _do_subscribe(self, topic, qos=2): """Subscribe to a mqtt topic.""" async with self._inflight_semaphore: result, mid = self._client.subscribe(topic, qos=qos) if result == mqtt.MQTT_ERR_SUCCESS: await self._wait_for_completion(mid) else: raise errors.CommunicationError.from_mqtt(result) async def _do_unsubscribe(self, topic): async with self._inflight_semaphore: result, mid = self._client.unsubscribe(topic) if result == mqtt.MQTT_ERR_SUCCESS: await self._wait_for_completion(mid) else: raise errors.CommunicationError.from_mqtt(result) async def _publish(self, topic, message, retain=False): await self.ensure_connected() async with self._inflight_semaphore: logger.debug('Publishing %s to %s', message, topic) payload_data = json.dumps(message).encode('utf-8') result, mid = self._client.publish(topic, payload_data, qos=2, retain=retain) if result == mqtt.MQTT_ERR_SUCCESS: await self._wait_for_completion(mid) else: raise errors.CommunicationError.from_mqtt(result) async def _wait_for_completion(self, mid, timeout=5.0): future = self.loop.create_future() self._message_futures[mid] = future try: return await asyncio.wait_for(future, timeout) finally: self._message_futures.pop(mid, None) def _maintain_topic_subscription_status(self, topic: str, subs: Subscription): if subs.should_subscribe: subs.subscribe_task = asyncio.ensure_future(self._do_subscribe(topic)) @subs.subscribe_task.add_done_callback def cb(task): # pylint: disable=unused-variable subs.subscribe_task = None if not task.cancelled(): if task.exception(): logger.error('Client %s failed to subscribe to %s', self.name, topic) else: subs.subscribed = True subs.subscribed_evt.set() elif subs.should_unsubscribe: subs.unsubscribe_task = asyncio.ensure_future(self._do_unsubscribe(topic)) @subs.unsubscribe_task.add_done_callback def cb(task): # pylint: disable=unused-variable subs.unsubscribe_task = None if task.cancelled(): return if task.exception(): logger.error('Client %s failed to unsubscribe from %s', self.name, topic) else: subs.subscribed = False subs.subscribed_evt.clear() if not subs.should_subscribe: del self._subscriptions[topic] def _ignore_as_duplicate(self, topic, payload_str): current_hash = hash(payload_str) if any(self._matches(topic, a_filter) for a_filter in self._prevent_duplicates): last_hash = self._topic_last_hash.get(topic) if last_hash == current_hash: return True else: self._topic_last_hash[topic] = current_hash return False else: return False async def _process_message(self, message: mqtt.MQTTMessage): """Process single mqtt message. Usually called from run_forever().""" topic = message.topic payload_str = message.payload.decode('utf-8') payload = json.loads(payload_str) logger.debug("Received %s => %s bytes", topic, len(message.payload)) if self._ignore_as_duplicate(topic, payload_str): logger.debug('Ignoring as duplicate') return if topic.startswith('notification/'): monitor.record_metric('meadow-connection:notifications-in', 1) subscriptions = set() for key in (k for k in self._subscriptions if self._matches(topic, k)): subscriptions |= self._subscriptions[key].subscribers if subscriptions: await asyncio.wait([s(payload, topic) for s in subscriptions], loop=self.loop)
[docs] async def ensure_connected(self): if not self.connected: await self._do_connect()
[docs] async def run_forever(self): """Never return. Main runloop of the client managing its connection.""" async def monitor_subscriptions(after_connect=True): while True: issues = 0 for topic, subscription in self._subscriptions.items(): if subscription.should_subscribe: if not after_connect: logger.error('Client %s is not subscribed to %s (but should be!).', self.name, topic) self._maintain_topic_subscription_status(topic, subscription) issues += 1 elif subscription.should_unsubscribe: if not after_connect: logger.error('Client %s is subscribed to %s (but should not be!).', self.name, topic) self._maintain_topic_subscription_status(topic, subscription) issues += 1 if not issues: logger.debug('Client %s has all subscriptions OK', self.name) await asyncio.sleep(1.0) after_connect = False try: subscriptions_monitor = None while True: await self._disconnected.wait() # set all subscriptions to `not subscribed` state for subscription in self._subscriptions.values(): if subscription.subscribe_task: subscription.subscribe_task.cancel() Subscription.subscribe_task = None if subscription.unsubscribe_task: subscription.unsubscribe_task.cancel() Subscription.unsubscribe_task = None subscription.subscribed = False # cancel subscription monitor if subscriptions_monitor and not subscriptions_monitor.done(): subscriptions_monitor.cancel() # cancel all awaited requests error = errors.CommunicationError.from_mqtt(mqtt.MQTT_ERR_CONN_LOST) while self._message_futures: _, future = self._message_futures.popitem() future.set_exception(error) # reconnect logger.debug('Client %s is not connected. Conneting ...', self.name) try: await self._do_connect() except Exception as e: logger.error('Client %s failed to connect: %s', self.name, repr(e)) await asyncio.sleep(3) else: logger.debug('Client %s connected.', self.name) # start subscription maintainer subscriptions_monitor = asyncio.ensure_future(monitor_subscriptions()) except asyncio.CancelledError: self._client.disconnect()
[docs] def attach_subscriber(self, subscriber): """Attach the subscriber and return immediatelly. The subscriber will be ready sometime in future. To wait for that, use `await self.subscriber_ready(s)`. """ subscriber.client = self topic = subscriber.topic logger.debug('Attaching subscriber for %s', topic) subscription = self._subscriptions[topic] subscription.subscribers.add(subscriber) if self.connected: self._maintain_topic_subscription_status(topic, subscription)
[docs] async def subscriber_ready(self, subscriber): """Await the `subscriber` to be ready.""" await self.ensure_connected() subscription = self._subscriptions[subscriber.topic] await subscription.subscribed_evt.wait()
[docs] def detach_subscriber(self, subscriber): """Detach the subscriber (with immediate effect).""" topic = subscriber.topic logger.debug('Detaching subscriber for %s', topic) subscription = self._subscriptions[topic] subscription.subscribers.remove(subscriber) if self.connected: self._maintain_topic_subscription_status(topic, subscription)
# # Topics # def _notification_topic(self, service, characteristic): return '/'.join(['notification', service, characteristic]) def _response_topic(self, service, characteristic, req_id): return '/'.join(['response', service, characteristic, req_id]) def _read_topic(self, service, characteristic): return '/'.join(['read', service, characteristic]) def _write_topic(self, service, characteristic): return '/'.join(['write', service, characteristic]) def _matches(self, topic, a_filter): if "#" not in a_filter and "+" not in a_filter: # if filter doesn't contain wildcard, return exact match return a_filter == topic else: # else use regex match_pattern = re.compile( a_filter.replace('#', '.*').replace('$', r'\$').replace('+', '[^/]+')) return match_pattern.fullmatch(topic)
helper_logger = logger.getChild('helper')
[docs]class AsyncioHelper: def __init__(self, loop, client: mqtt.Client, name): self.loop = loop self.name = name self.client = client self.client.on_socket_open = self.on_socket_open self.client.on_socket_close = self.on_socket_close self.client.on_socket_register_write = self.on_socket_register_write self.client.on_socket_unregister_write = self.on_socket_unregister_write
[docs] def on_socket_open(self, client, userdata, sock): helper_logger.debug('Socket of client %s did opened', self.name) def cb(): helper_logger.debug('Socket of client %s is readable, calling loop_read', self.name) client.loop_read() self.loop.call_soon_threadsafe(self.loop.add_reader, sock.fileno(), cb) self.misc = asyncio.run_coroutine_threadsafe(self.misc_loop(), self.loop)
[docs] def on_socket_close(self, client, userdata, sock): helper_logger.debug('Socket of client %s did close', self.name) def cb(sockfileno): self.loop.remove_reader(sockfileno) self.misc.cancel() self.loop.call_soon_threadsafe(cb, sock.fileno())
[docs] def on_socket_register_write(self, client, userdata, sock): helper_logger.debug('Watching socket of client %s for writability', self.name) def cb(): helper_logger.debug('Socket of client %s is writable, calling loop_write', self.name) client.loop_write() self.loop.call_soon_threadsafe(self.loop.add_writer, sock.fileno(), cb)
[docs] def on_socket_unregister_write(self, client, userdata, sock): helper_logger.debug('Stopping to watch socket of client %s for writability', self.name) self.loop.call_soon_threadsafe(self.loop.remove_writer, sock.fileno())
[docs] async def misc_loop(self): # pylint: disable=protected-access helper_logger.debug("misc_loop started") self.client._sockpairR.setblocking(False) while self.client.loop_misc() == mqtt.MQTT_ERR_SUCCESS: try: await asyncio.sleep(1) try: while len(self.client._sockpairR.recv(1)): pass except BlockingIOError: pass except asyncio.CancelledError: break helper_logger.debug("misc_loop finished")