# pylint: disable=too-many-instance-attributes
import asyncio
import collections
import json
import logging
import re
import socket
from functools import partial
import paho.mqtt.client as mqtt
from caspia.meadow import errors
from caspia.toolbox import monitor
logger = logging.getLogger(__name__)
[docs]class Subscription:
def __init__(self):
self.subscribers = set()
self.subscribed = False
self.subscribed_evt = asyncio.Event()
self.subscribe_task = None
self.unsubscribe_task = None
@property
def should_subscribe(self):
return self.subscribers and not self.subscribed and not self.subscribe_task
@property
def should_unsubscribe(self):
return not self.subscribers and self.subscribed and not self.unsubscribe_task
[docs]class Connection:
def __init__(self, broker_url, name=None, loop=None):
if broker_url.startswith('mqtt://'):
broker_url = broker_url[len('mqtt://'):]
self.broker_url = broker_url
self.name = name
self.loop = loop or asyncio.get_event_loop()
self._subscriptions = collections.defaultdict(Subscription)
self._prevent_duplicates = {'gateway/#'}
self._topic_last_hash = dict()
self._connect_called = False
self._client: mqtt.Client = self._create_client()
self._connect_lock = asyncio.Lock(loop=self.loop)
self._connected = asyncio.Event(loop=self.loop)
self._connected.clear()
self._disconnected = asyncio.Event(loop=self.loop)
self._disconnected.set()
self._message_futures = dict()
self._inflight_semaphore = asyncio.Semaphore(20)
@property
def connected(self):
return self._connected.is_set()
def _create_client(self):
client = mqtt.Client(self.name, clean_session=True)
client.on_connect = partial(self.loop.call_soon_threadsafe, self.on_connect)
client.on_disconnect = partial(self.loop.call_soon_threadsafe, self.on_disconnect)
client.on_message = partial(self.loop.call_soon_threadsafe, self.on_message)
client.on_publish = partial(self.loop.call_soon_threadsafe, self.on_publish)
client.on_subscribe = partial(self.loop.call_soon_threadsafe, self.on_subscribe)
client.on_unsubscribe = partial(self.loop.call_soon_threadsafe, self.on_unsubscribe)
client.enable_logger(logger)
self._helper = AsyncioHelper(self.loop, client, self.name)
return client
[docs] def on_connect(self, client, userdata, flags, rc):
self._cancel_all_message_futures()
if rc == mqtt.MQTT_ERR_SUCCESS:
self._connected.set()
self._disconnected.clear()
logger.info('Client %s successfully connected to %s', self.name, self.broker_url)
else:
self._connected.clear()
self._disconnected.set()
reason = mqtt.error_string(rc)
logger.error('Client %s failed to connect to %s, reason: %r', self.name,
self.broker_url, reason)
[docs] def on_disconnect(self, client, userdata, rc):
self._connected.clear()
self._disconnected.set()
reason = mqtt.error_string(rc)
args = 'disconnected from %s [%s]', self.broker_url, reason
if rc == mqtt.MQTT_ERR_SUCCESS:
logger.info(*args)
else:
logger.error(*args)
self._cancel_all_message_futures()
def _cancel_all_message_futures(self):
while self._message_futures:
_, future = self._message_futures.popitem()
future.set_exception(errors.CommunicationError('connection interrupted'))
[docs] def on_message(self, client, userdata, message):
asyncio.ensure_future(self._process_message(message), loop=self.loop)
[docs] def on_publish(self, client, userdata, mid):
future = self._message_futures.pop(mid, None)
if future:
future.set_result(None)
[docs] def on_subscribe(self, client, userdata, mid, granted_qos):
future = self._message_futures.pop(mid, None)
if future:
future.set_result(None)
[docs] def on_unsubscribe(self, client, userdata, mid):
future = self._message_futures.pop(mid, None)
if future:
future.set_result(None)
async def _do_connect(self):
"""Connect the mqtt client."""
def _connect():
if ':' in self.broker_url:
host, port = self.broker_url.split(':')
else:
host, port = self.broker_url, 1883
self._client.connect(host, int(port))
self._client.socket().setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 2048)
with await self._connect_lock:
if self._disconnected.is_set():
await self.loop.run_in_executor(None, _connect)
await self._connected.wait()
async def _do_subscribe(self, topic, qos=2):
"""Subscribe to a mqtt topic."""
async with self._inflight_semaphore:
result, mid = self._client.subscribe(topic, qos=qos)
if result == mqtt.MQTT_ERR_SUCCESS:
await self._wait_for_completion(mid)
else:
raise errors.CommunicationError.from_mqtt(result)
async def _do_unsubscribe(self, topic):
async with self._inflight_semaphore:
result, mid = self._client.unsubscribe(topic)
if result == mqtt.MQTT_ERR_SUCCESS:
await self._wait_for_completion(mid)
else:
raise errors.CommunicationError.from_mqtt(result)
async def _publish(self, topic, message, retain=False):
await self.ensure_connected()
async with self._inflight_semaphore:
logger.debug('Publishing %s to %s', message, topic)
payload_data = json.dumps(message).encode('utf-8')
result, mid = self._client.publish(topic, payload_data, qos=2, retain=retain)
if result == mqtt.MQTT_ERR_SUCCESS:
await self._wait_for_completion(mid)
else:
raise errors.CommunicationError.from_mqtt(result)
async def _wait_for_completion(self, mid, timeout=5.0):
future = self.loop.create_future()
self._message_futures[mid] = future
try:
return await asyncio.wait_for(future, timeout)
finally:
self._message_futures.pop(mid, None)
def _maintain_topic_subscription_status(self, topic: str, subs: Subscription):
if subs.should_subscribe:
subs.subscribe_task = asyncio.ensure_future(self._do_subscribe(topic))
@subs.subscribe_task.add_done_callback
def cb(task): # pylint: disable=unused-variable
subs.subscribe_task = None
if not task.cancelled():
if task.exception():
logger.error('Client %s failed to subscribe to %s', self.name, topic)
else:
subs.subscribed = True
subs.subscribed_evt.set()
elif subs.should_unsubscribe:
subs.unsubscribe_task = asyncio.ensure_future(self._do_unsubscribe(topic))
@subs.unsubscribe_task.add_done_callback
def cb(task): # pylint: disable=unused-variable
subs.unsubscribe_task = None
if task.cancelled():
return
if task.exception():
logger.error('Client %s failed to unsubscribe from %s', self.name, topic)
else:
subs.subscribed = False
subs.subscribed_evt.clear()
if not subs.should_subscribe:
del self._subscriptions[topic]
def _ignore_as_duplicate(self, topic, payload_str):
current_hash = hash(payload_str)
if any(self._matches(topic, a_filter) for a_filter in self._prevent_duplicates):
last_hash = self._topic_last_hash.get(topic)
if last_hash == current_hash:
return True
else:
self._topic_last_hash[topic] = current_hash
return False
else:
return False
async def _process_message(self, message: mqtt.MQTTMessage):
"""Process single mqtt message. Usually called from run_forever()."""
topic = message.topic
payload_str = message.payload.decode('utf-8')
payload = json.loads(payload_str)
logger.debug("Received %s => %s bytes", topic, len(message.payload))
if self._ignore_as_duplicate(topic, payload_str):
logger.debug('Ignoring as duplicate')
return
if topic.startswith('notification/'):
monitor.record_metric('meadow-connection:notifications-in', 1)
subscriptions = set()
for key in (k for k in self._subscriptions if self._matches(topic, k)):
subscriptions |= self._subscriptions[key].subscribers
if subscriptions:
await asyncio.wait([s(payload, topic) for s in subscriptions], loop=self.loop)
[docs] async def ensure_connected(self):
if not self.connected:
await self._do_connect()
[docs] async def run_forever(self):
"""Never return. Main runloop of the client managing its connection."""
async def monitor_subscriptions(after_connect=True):
while True:
issues = 0
for topic, subscription in self._subscriptions.items():
if subscription.should_subscribe:
if not after_connect:
logger.error('Client %s is not subscribed to %s (but should be!).',
self.name, topic)
self._maintain_topic_subscription_status(topic, subscription)
issues += 1
elif subscription.should_unsubscribe:
if not after_connect:
logger.error('Client %s is subscribed to %s (but should not be!).',
self.name, topic)
self._maintain_topic_subscription_status(topic, subscription)
issues += 1
if not issues:
logger.debug('Client %s has all subscriptions OK', self.name)
await asyncio.sleep(1.0)
after_connect = False
try:
subscriptions_monitor = None
while True:
await self._disconnected.wait()
# set all subscriptions to `not subscribed` state
for subscription in self._subscriptions.values():
if subscription.subscribe_task:
subscription.subscribe_task.cancel()
Subscription.subscribe_task = None
if subscription.unsubscribe_task:
subscription.unsubscribe_task.cancel()
Subscription.unsubscribe_task = None
subscription.subscribed = False
# cancel subscription monitor
if subscriptions_monitor and not subscriptions_monitor.done():
subscriptions_monitor.cancel()
# cancel all awaited requests
error = errors.CommunicationError.from_mqtt(mqtt.MQTT_ERR_CONN_LOST)
while self._message_futures:
_, future = self._message_futures.popitem()
future.set_exception(error)
# reconnect
logger.debug('Client %s is not connected. Conneting ...', self.name)
try:
await self._do_connect()
except Exception as e:
logger.error('Client %s failed to connect: %s', self.name, repr(e))
await asyncio.sleep(3)
else:
logger.debug('Client %s connected.', self.name)
# start subscription maintainer
subscriptions_monitor = asyncio.ensure_future(monitor_subscriptions())
except asyncio.CancelledError:
self._client.disconnect()
[docs] def attach_subscriber(self, subscriber):
"""Attach the subscriber and return immediatelly.
The subscriber will be ready sometime in future. To wait for that,
use `await self.subscriber_ready(s)`.
"""
subscriber.client = self
topic = subscriber.topic
logger.debug('Attaching subscriber for %s', topic)
subscription = self._subscriptions[topic]
subscription.subscribers.add(subscriber)
if self.connected:
self._maintain_topic_subscription_status(topic, subscription)
[docs] async def subscriber_ready(self, subscriber):
"""Await the `subscriber` to be ready."""
await self.ensure_connected()
subscription = self._subscriptions[subscriber.topic]
await subscription.subscribed_evt.wait()
[docs] def detach_subscriber(self, subscriber):
"""Detach the subscriber (with immediate effect)."""
topic = subscriber.topic
logger.debug('Detaching subscriber for %s', topic)
subscription = self._subscriptions[topic]
subscription.subscribers.remove(subscriber)
if self.connected:
self._maintain_topic_subscription_status(topic, subscription)
#
# Topics
#
def _notification_topic(self, service, characteristic):
return '/'.join(['notification', service, characteristic])
def _response_topic(self, service, characteristic, req_id):
return '/'.join(['response', service, characteristic, req_id])
def _read_topic(self, service, characteristic):
return '/'.join(['read', service, characteristic])
def _write_topic(self, service, characteristic):
return '/'.join(['write', service, characteristic])
def _matches(self, topic, a_filter):
if "#" not in a_filter and "+" not in a_filter:
# if filter doesn't contain wildcard, return exact match
return a_filter == topic
else:
# else use regex
match_pattern = re.compile(
a_filter.replace('#', '.*').replace('$', r'\$').replace('+', '[^/]+'))
return match_pattern.fullmatch(topic)
helper_logger = logger.getChild('helper')
[docs]class AsyncioHelper:
def __init__(self, loop, client: mqtt.Client, name):
self.loop = loop
self.name = name
self.client = client
self.client.on_socket_open = self.on_socket_open
self.client.on_socket_close = self.on_socket_close
self.client.on_socket_register_write = self.on_socket_register_write
self.client.on_socket_unregister_write = self.on_socket_unregister_write
[docs] def on_socket_open(self, client, userdata, sock):
helper_logger.debug('Socket of client %s did opened', self.name)
def cb():
helper_logger.debug('Socket of client %s is readable, calling loop_read', self.name)
client.loop_read()
self.loop.call_soon_threadsafe(self.loop.add_reader, sock.fileno(), cb)
self.misc = asyncio.run_coroutine_threadsafe(self.misc_loop(), self.loop)
[docs] def on_socket_close(self, client, userdata, sock):
helper_logger.debug('Socket of client %s did close', self.name)
def cb(sockfileno):
self.loop.remove_reader(sockfileno)
self.misc.cancel()
self.loop.call_soon_threadsafe(cb, sock.fileno())
[docs] def on_socket_register_write(self, client, userdata, sock):
helper_logger.debug('Watching socket of client %s for writability', self.name)
def cb():
helper_logger.debug('Socket of client %s is writable, calling loop_write', self.name)
client.loop_write()
self.loop.call_soon_threadsafe(self.loop.add_writer, sock.fileno(), cb)
[docs] def on_socket_unregister_write(self, client, userdata, sock):
helper_logger.debug('Stopping to watch socket of client %s for writability', self.name)
self.loop.call_soon_threadsafe(self.loop.remove_writer, sock.fileno())
[docs] async def misc_loop(self):
# pylint: disable=protected-access
helper_logger.debug("misc_loop started")
self.client._sockpairR.setblocking(False)
while self.client.loop_misc() == mqtt.MQTT_ERR_SUCCESS:
try:
await asyncio.sleep(1)
try:
while len(self.client._sockpairR.recv(1)):
pass
except BlockingIOError:
pass
except asyncio.CancelledError:
break
helper_logger.debug("misc_loop finished")