How to add context to Celery WorkerLostError logs

For a while now, we've had semi-frequent WorkerLostErrors on one of our Celery queues. These can happen for a few reasons, but one of the most common causes is a Celery task using too much memory and being killed by a parent process. If you're logging errors across your stack, you might see logs that look something like this: WorkerLostError: Worker exited prematurely: signal 9 (SIGKILL).

This is only moderately useful, as it doesn't tell you which task was consuming excess memory, making it difficult to track down which task was causing the error. In our case, we had a task_failure signal handler that looked something like this:

@task_failure.connect
def handle_task_failure(**kwargs):
    exception = kwargs.get("exception", None)
    # do some logging

Problem was, this did nothing for WorkerLostErrors because the task is not killed gracefully, so it wasn't around to fire its task_failure signal. After quite a bit of digging around in Celery documentation, I found a solution. Turns out that the Celery Request class has on_failure handler methods of its own. From the docs:

When using the pre-forking worker, the methods on_timeout() and on_failure() are executed in the main worker process. An application may leverage such facility to detect failures which are not detected using celery.app.task.Task.on_failure().

Perfect! So the only question is how to implement it. I did something like this:

from celery import Celery, Task
from celery.exceptions import Retry
from celery.worker.request import Request

class FailureLoggingRequest(Request):
    def on_failure(self, exc_info, send_failed_event=True, return_ok=False):
        super().on_failure(
            exc_info, send_failed_event=send_failed_event, return_ok=return_ok
        ) # very important to call super here!!
        if not isinstance(exception, Retry):
            # we only want to log if the error is not fixed by retrying
            logger.error(
                "Celery task failed",
                context={
                    # these are just a few of the task context variables available
                    "task": self.task.name,
                    "task_id": self.id,
                    "args": self.args,
                    "kwargs": self.kwargs,
                    "exception": exc_info.exception,
                    "traceback": exc_info.traceback,
                },
            )


class FailureLoggingTask(Task):
    Request = FailureLoggingRequest

app = Celery(
    task_cls=FailureLoggingTask
    # other config
)

And just like that, we had logging around WorkerLostErrors. We could see which task caused the error, the specific task ID that died, the args and kwargs that the task was called with, etc. Our visibility into Celery issues improved dramatically.

There are a lot of other interesting methods on the Celery Request class, and diving into the source was very instructive in learning a bit more about how Celery works and how I might solve other Celery issues in the future. I recommend diving in yourself if, like me, you've often found Celery a little opaque and mysterious.