github twitter keybase instagram spotify

Testing asyncio code

Foreword: This is part 5 of a 7-part series titled “asyncio: We Did It Wrong.” Take a look at Part 1: True Concurrency, Part 2: Graceful Shutdowns, Part 3: Exception Handling, and Part 4: Synchronous & threaded code in asyncio for where we are in the tutorial now. Once done, follow along with Part 6: Debugging asyncio Code, and Part 7: Profiling asyncio Code.

Example code can be found on GitHub. All code on this post is licensed under MIT.


Mayhem Mandrill Recap

The goal for this 7-part series is to build a mock chaos monkey-like service called “Mayhem Mandrill”. This is an event-driven service that consumes from a pub/sub, and initiates a mock restart of a host. We could get thousands of messages in seconds, so as we get a message, we shouldn’t block the handling of the next message we receive.

For a more simplistic starting point, we’re going to test asyncio code that doesn’t have to deal with threading. Here’s the starting point of what we’re going to test:

# contents of mayhem.py
#!/usr/bin/env python3.7

"""
Notice! This requires:
 - attrs==19.1.0
"""

import asyncio
import functools
import logging
import random
import signal
import string
import uuid

import attr

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s,%(msecs)d %(levelname)s: %(message)s",
    datefmt="%H:%M:%S",
)

@attr.s
class PubSubMessage:
    instance_name = attr.ib()
    message_id    = attr.ib(repr=False)
    hostname      = attr.ib(repr=False, init=False)
    restarted     = attr.ib(repr=False, default=False)
    saved         = attr.ib(repr=False, default=False)
    acked         = attr.ib(repr=False, default=False)
    extended_cnt  = attr.ib(repr=False, default=0)

    def __attrs_post_init__(self):
        self.hostname = f"{self.instance_name}.example.net"

class RestartFailed(Exception):
    pass

async def publish(queue):
    choices = string.ascii_lowercase + string.digits

    while True:
        msg_id = str(uuid.uuid4())
        host_id = "".join(random.choices(choices, k=4))
        instance_name = f"cattle-{host_id}"
        msg = PubSubMessage(message_id=msg_id, instance_name=instance_name)
        asyncio.create_task(queue.put(msg))
        logging.debug(f"Published message {msg}")
        await asyncio.sleep(random.random())

async def restart_host(msg):
    await asyncio.sleep(random.random())
    if random.randrange(1, 5) == 3:
        raise RestartFailed(f"Could not restart {msg.hostname}")
    msg.restarted = True
    logging.info(f"Restarted {msg.hostname}")

async def save(msg):
    await asyncio.sleep(random.random())
    if random.randrange(1, 5) == 3:
        raise Exception(f"Could not save {msg}")
    msg.saved = True
    logging.info(f"Saved {msg} into database")

async def cleanup(msg, event):
    await event.wait()
    await asyncio.sleep(random.random())
    msg.acked = True
    logging.info(f"Done. Acked {msg}")

async def extend(msg, event):
    while not event.is_set():
        msg.extended_cnt += 1
        logging.info(f"Extended deadline by 3 seconds for {msg}")
        await asyncio.sleep(2)

def handle_results(results, msg):
    for result in results:
        if isinstance(result, RestartFailed):
            logging.error(f"Retrying for failure to restart: {msg.hostname}")
        elif isinstance(result, Exception):
            logging.error(f"Handling general error: {result}")

async def handle_message(msg):
    event = asyncio.Event()

    asyncio.create_task(extend(msg, event))
    asyncio.create_task(cleanup(msg, event))

    results = await asyncio.gather(
        save(msg), restart_host(msg), return_exceptions=True
    )
    handle_results(results, msg)
    event.set()

async def consume(queue):
    while True:
        msg = await queue.get()
        logging.info(f"Pulled {msg}")
        asyncio.create_task(handle_message(msg))

def handle_exception(loop, context):
    msg = context.get("exception", context["message"])
    logging.error(f"Caught exception: {msg}")
    logging.info("Shutting down...")
    asyncio.create_task(shutdown(loop))

async def shutdown(loop, signal=None):
    if signal:
        logging.info(f"Received exit signal {signal.name}...")
    logging.info("Closing database connections")
    logging.info("Nacking outstanding messages")
    tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]

    [task.cancel() for task in tasks]

    logging.info("Cancelling outstanding tasks")
    await asyncio.gather(*tasks, return_exceptions=True)
    logging.info(f"Flushing metrics")
    loop.stop()

def main():
    loop = asyncio.get_event_loop()
    signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT)
    for s in signals:
        loop.add_signal_handler(
            s, lambda s=s: asyncio.create_task(shutdown(loop, signal=s))
        )
    loop.set_exception_handler(handle_exception)
    queue = asyncio.Queue()

    try:
        loop.create_task(publish(queue))
        loop.create_task(consume(queue))
        loop.run_forever()
    finally:
        loop.close()
        logging.info("Successfully shutdown the Mayhem service.")


if __name__ == "__main__":
    main()

Simple Testing with pytest

We will be using pytest since I prefer writing simple assert statements for my tests.

We will start simple, and test the “happy path” of the save function (i.e. not the code path that raises an exception):

async def save(msg):
    # unhelpful simulation of i/o work
    await asyncio.sleep(random.random()) 
    msg.saved = True
    logging.info(f"Saved {msg} into database")

Since save is a coroutine function, we’ll need to run it on the loop:

# contents of test_mayhem.py
#!/usr/bin/env python3.7
"""
Notice! This requires pytest==4.3.1
"""

import asyncio
import pytest

import mayhem


@pytest.fixture
def message():
    return mayhem.PubSubMessage(message_id="1234", instance_name="mayhem_test")

def test_save(message):
    assert not message.saved  # sanity check
    asyncio.run(mayhem.save(message))
    assert message.saved

Running this via pytest:

$ pytest -v test_mayhem_1.py
test_mayhem_1.py::test_save PASSED                                                           [100%]

Sweet! However if you’re not on 3.7+ yet, you’ll have to construct and deconstruct the loop yourself, rather than making use of asyncio.run:

def test_save(message):
    assert not message.saved  # sanity check
    loop = asyncio.get_event_loop()
    loop.run_until_complete(mayhem.save(message))
    loop.close()
    assert message.saved

This can get annoying, especially when you have many coroutine functions to test. Thankfully, there is a plugin for pytest called pytest-asyncio. This plugin allows you to define your tests themselves as coroutine functions, and manages the event loop for you:

#!/usr/bin/env python3.7
"""
Notice! This requires:
- pytest==4.3.1
- pytest-asyncio==0.10.0
"""
# <-- snip -->

@pytest.mark.asyncio
async def test_save(message):  # <-- now a coroutine!
    assert not message.saved  # sanity check
    await mayhem.save(message)
    assert message.saved

Much cleaner! Using pytest-asyncio can get you pretty far.

Mocking Coroutines

When writing unit tests, you’ll often need to mock out coroutine functions that are called within your tested function.

For instance, our save coroutine function calls another coroutine function, asyncio.sleep (a stand-in for a network I/O call to a database):

async def save(msg):
    # unhelpful simulation of i/o work
    await asyncio.sleep(random.random())  # <-- let's mock this out
    msg.saved = True
    logging.info(f"Saved {msg} into database")

You don’t actually want to wait for asyncio.sleep to complete while running your tests, nor do you want an actual call to a database to happen.

Both the unittest.mock and pytest-mock libraries do not support asynchronous mocks, so we’ll have to work around this.

First, in making use of the pytest-mock library, we’ll create a pytest fixture that will return a function:

#!/usr/bin/env python3.7
"""
Notice! This requires:
- pytest==4.3.1
- pytest-asyncio==0.10.0
- pytest-mock==1.10.3
"""
# <-- snip -->

@pytest.fixture
def create_mock_coro(mocker, monkeypatch):
    def _create_mock_patch_coro(to_patch=None):
        mock = mocker.Mock()

        async def _coro(*args, **kwargs):
            return mock(*args, **kwargs)

        if to_patch:  # <-- may not need/want to patch anything
            monkeypatch.setattr(to_patch, _coro)
        return mock, _coro

    return _create_mock_patch_coro

Then, we’ll create another pytest fixture that will use the create_mock_coro fixture to mock and patch asyncio.sleep:

@pytest.fixture
def mock_sleep(create_mock_coro):
    # won't need the returned coroutine here
    mock, _ = create_mock_coro(to_patch="mayhem.asyncio.sleep")
    return mock

Now let’s use the mock_sleep fixture in our test_save:

@pytest.mark.asyncio
async def test_save(mock_sleep, message):
    assert not message.saved  # sanity check
    await mayhem.save(message)
    assert message.saved
    assert 1 == mock_sleep.call_count

What we’ve done here is basically patched asyncio.sleep in our mayhem module with a coroutine function that returns a mocked object. Then, we assert that the mocked asyncio.sleep object is called once when mayhem.save is called. Because we now have a mock object instead of the actual coroutine, we can now do anything that’s supported with unittest.mock.Mock objects, i.e. our_mocked_object.assert_called_once_with(...), our_mocked_object.return_value = "foo", etc.

Testing create.task

For testing coroutine functions that have calls to create.task, we can’t simply use the create_mock_coro fixture. For instance, let’s try to test our consume coroutine function:

async def consume(queue):
    while True:
        msg = await queue.get()
        logging.info(f"Pulled {msg}")
        asyncio.create_task(handle_message(msg))

I have the following fixtures for the asyncio.queue:

@pytest.fixture
def mock_queue(mocker, monkeypatch):
    queue = mocker.Mock()
    monkeypatch.setattr(mayhem.asyncio, "Queue", queue)
    return queue.return_value


@pytest.fixture
def mock_get(mock_queue, create_mock_coro):
    mock_get, coro_get = create_mock_coro()
    mock_queue.get = coro_get
    return mock_get

So let’s try to use create_mock_coro to mock and match the call to handle_message coroutine via create_task.

Note: we’re setting mock_get.side_effect with one “real” value, and one Exception to make sure we’re not permanently stuck within the while True loop that consume has.

@pytest.mark.asyncio
async def test_consume(mock_get, mock_queue, message, create_mock_coro):
    mock_get.side_effect = [message, Exception("break while loop")]
    mock_handle_message, _ = create_mock_coro("mayhem.handle_message")

    with pytest.raises(Exception, match="break while loop"):
        await mayhem.consume(mock_queue)

    mock_handle_message.assert_called_once_with(message)

When running this, we see that mock_handle_message does not actually get called, like we’re expecting:

$ pytest -v test_mayhem_4.py
test_mayhem_4.py::test_consume FAILED                                                        [100%]

=========================================== FAILURES ============================================
_________________________________________ test_consume __________________________________________

mock_get = <Mock id='4477488824'>, mock_queue = <Mock name='mock()' id='4477488880'>
message = Message(instance_name='cattle-1234')
create_mock_coro = <function create_mock_coro.<locals>._create_mock_patch_coro at 0x10add9840>

    @pytest.mark.asyncio
    async def test_consume(mock_get, mock_queue, message, create_mock_coro):
        mock_get.side_effect = [message, Exception("break while loop")]
        mock_handle_message = create_mock_coro("mayhem.handle_message")

        with pytest.raises(Exception, match="break while loop"):
            await mayhem.consume(mock_queue)

>       mock_handle_message.assert_called_once_with(message)
E       AssertionError: Expected 'mock' to be called once. Called 0 times.

test_mayhem_4.py:230: AssertionError
------------------------------------- Captured stderr call --------------------------------------
15:30:37,721 INFO: Pulled Message(instance_name='cattle-1234')
============================== 1 failed, 1 passed in 0.10 seconds ===============================

This is because the scheduled tasks are only scheduled and pending at this point; we need to nudge them along. We do this by collecting all running tasks (that’s not the test itself), and running them explicitly:

@pytest.mark.asyncio
async def test_consume(mock_get, mock_queue, message, create_mock_coro):
    mock_get.side_effect = [message, Exception("break while loop")]
    mock_handle_message = create_mock_coro("mandrill.handle_message")

    with pytest.raises(Exception, match="break while loop"):
        await mayhem.consume(mock_queue)

    ret_tasks = [
        t for t in asyncio.all_tasks() if t is not asyncio.current_task()
    ]
    # should be 1 per side effect minus the Exception (i.e. messages consumed)
    assert 1 == len(ret_tasks)
    mock_handle_message.assert_not_called()  # <-- sanity check

    # explicitly await tasks scheduled by `asyncio.create_task`
    await asyncio.gather(*ret_tasks)

    mock_handle_message.assert_called_once_with(message)

Now pytest is happy:

$ pytest -v test_mayhem_5.py
test_mayhem_5.py::test_consume PASSED                                                       [100%]

Non-async testing of the event loop

I hear you wanting to get to 100% test coverage, which may seem difficult for our main function. We’ll make use of pytest-asyncio’s event_loop fixture, with a slight modification.

First, we’ll create our own fixture by inheriting from pytest-asyncio’s event_loop fixture:

@pytest.fixture
def event_loop(event_loop, mocker):
    new_loop = asyncio.get_event_loop_policy().new_event_loop()
    asyncio.set_event_loop(new_loop)
    new_loop._close = new_loop.close
    new_loop.close = mocker.Mock()

    yield new_loop

    new_loop._close()

We’re essentially setting a different event loop that pytest-asyncio will use when it injects it into the tested code. We want to update the testing event loop so we can override the close() behavior, which gets called in our main function. If we close the loop during the test, we’ll lose access to the signal handlers that we setup within the main function. We can replace the close() method with a mock object to still assert that it has been called.

So now, we’ll write a test_main function that actually borders on an integration or functional test. We want to make sure – in addition to the expected calls to publish and consume – that shutdown gets called when expected.

We can’t exactly mock out shutdown with create_mock_coro since it will patch it with just another coroutine and therefore run the mocked coroutine each time it receives a signal. Instead, we’ll mock out the asyncio.gather within the shutdown coroutine. Instead, we’ll just mock out the coroutine that shutdown calls (the asyncio.gather).

And finally, in order to see if the loop actually responds to signals, we need to send a signal to it. We do this by starting a separate thread from which we’ll send a signal to the process itself.

# <-- snip -->
import os
import signal
import time
import threading

# <-- snip -->
def test_main(create_mock_coro, event_loop, mock_queue):
    mock_consume, _ = create_mock_coro("mayhem.consume")
    mock_publish, _ = create_mock_coro("mayhem.publish")
    # mock out `asyncio.gather` that `shutdown` calls instead 
    # of `shutdown` itself
    mock_asyncio_gather, _ = create_mock_coro("mayhem.asyncio.gather")

    def _send_signal():
        # allow the loop to start and work a little bit...
        time.sleep(0.1)
        # ...then send a signal
        os.kill(os.getpid(), signal.SIGTERM)

    thread = threading.Thread(target=_send_signal, daemon=True)
    thread.start()

    mayhem.main()

    assert signal.SIGTERM in event_loop._signal_handlers
    assert mayhem.handle_exception == event_loop.get_exception_handler()

    mock_asyncio_gather.assert_called_once_with(return_exceptions=True)
    mock_consume.assert_called_once_with(mock_queue)
    mock_publish.assert_called_once_with(mock_queue)

    # asserting the loop is stopped but not closed
    assert not event_loop.is_running()
    assert not event_loop.is_closed()
    event_loop.close.assert_called_once_with()
$ pytest -v test_mayhem_6.py
test_mayhem_6.py::test_main PASSED                                                          [100%]

We can further parametrize this, as well as test for behavior when the loop receives a signal other than SIGINT, SIGTERM, and SIGHUP. This requires us to add another signal since we want to make sure that other signals do not invoke the defined shutdown behavior. We’ll make use of SIGUSR1 and add a different shutdown mock to the test event loop:

@pytest.mark.parametrize(
    "tested_signal", ("SIGINT","SIGTERM", "SIGHUP", "SIGUSR1")
)
def test_main(tested_signal, create_mock_coro, event_loop, mock_queue, mocker):
    tested_signal = getattr(signal, tested_signal)
    mock_asyncio_gather, _ = create_mock_coro("mayhem.asyncio.gather")
    mock_consume, _ = create_mock_coro("mayhem.consume")
    mock_publish, _ = create_mock_coro("mayhem.publish")

    mock_shutdown = mocker.Mock()
    def _shutdown():
        mock_shutdown()
        event_loop.stop()

    event_loop.add_signal_handler(signal.SIGUSR1, _shutdown)

    def _send_signal():
        time.sleep(0.1)
        os.kill(os.getpid(), tested_signal)

    thread = threading.Thread(target=_send_signal, daemon=True)
    thread.start()

    mayhem.main()

    assert tested_signal in event_loop._signal_handlers
    assert mayhem.handle_exception == event_loop.get_exception_handler()

    mock_consume.assert_called_once_with(mock_queue)
    mock_publish.assert_called_once_with(mock_queue)

    if tested_signal is not signal.SIGUSR1:
        mock_asyncio_gather.assert_called_once_with(return_exceptions=True)
        mock_shutdown.assert_not_called()
    else:
        mock_asyncio_gather.assert_not_called()
        mock_shutdown.assert_called_once_with()

    # asserting the loop is stopped but not closed
    assert not event_loop.is_running()
    assert not event_loop.is_closed()
    event_loop.close.assert_called_once_with()

Look at those happy tests:

pytest -v part-5/test_mayhem_7.py

test_mayhem_7.py::test_main[SIGINT] PASSED                                           [ 25%]
test_mayhem_7.py::test_main[SIGTERM] PASSED                                          [ 50%]
test_mayhem_7.py::test_main[SIGHUP] PASSED                                           [ 75%]
test_mayhem_7.py::test_main[SIGUSR1] PASSED                                          [100%]

To see what near-100% test coverage looks like for mayhem.py, check out part-5/test_mayhem_full.py.

Third-Party Libraries

  • The aforementioned pytest-asyncio has other helpful things too like the event_loop , unused_tcp_port, and unused_tcp_port_factory fixtures; and the ability to create your own asynchronous fixtures.
  • asynctest that has a lot of helpful tooling, including coroutine mocks and exhausting callbacks so we don’t have to manually await tasks made by create_task. It does require the use of the unittest and define your tests as asynctest.TestCase. I’m not much of a fan of the unittest style, so perhaps someday someone will create asynctest for pytest :).
  • aiohttp has some really nice built-in test utilities supporting both pytest and unittest.

Recap

Basically, by using pytest-asyncio, testing asyncio code isn’t too much different than non-asynchronous code. There is still the clunkiness of needing to manually mock coroutine functions and exhausting the event loop when testing code that uses create_task (an open source contribution opportunity, maybe??).


Follow the next parts of this series for debugging and profiling asyncio code.




Has this article been helpful for you? Consider expressing your gratitude!
Need some help? I'm available for tutoring, mentoring, and interview prep!


comments powered by Disqus