From 8747b74aaf7cadbeca3dd4515e5a2dbf5fac6f1e Mon Sep 17 00:00:00 2001 From: Moises Hiraldo Date: Mon, 6 Feb 2017 16:07:00 +0000 Subject: [PATCH] Enable callbacks for the celery dispatcher #15 --- t/unit/dispatch/test_celery.py | 2 +- t/unit/test_request.py | 16 ++++++++++++---- t/unit/test_tasks.py | 9 ++++++--- thorn/dispatch/celery.py | 2 +- thorn/request.py | 3 +++ thorn/tasks.py | 4 ++-- 6 files changed, 25 insertions(+), 11 deletions(-) diff --git a/t/unit/dispatch/test_celery.py b/t/unit/dispatch/test_celery.py index d50debb..4edcd72 100644 --- a/t/unit/dispatch/test_celery.py +++ b/t/unit/dispatch/test_celery.py @@ -17,7 +17,7 @@ def test_send(self, patching): res = Dispatcher().send( event, payload, user, timeout=3.03, kw=9, context=context) send_event.s.assert_called_once_with( - event, payload, user.pk, 3.03, context, + event, payload, user.pk, 3.03, context, kw=9 ) send_event.s().apply_async.assert_called_once_with() assert res is send_event.s().apply_async() diff --git a/t/unit/test_request.py b/t/unit/test_request.py index 8d1a377..722f364 100644 --- a/t/unit/test_request.py +++ b/t/unit/test_request.py @@ -12,10 +12,15 @@ from conftest import DEFAULT_RECIPIENT_VALIDATORS +class PickableMock(Mock): + def __reduce__(self): + return (Mock, ()) + + def mock_req(event, url, **kwargs): - kwargs.setdefault('on_success', Mock(name='on_success')) - kwargs.setdefault('on_timeout', Mock(name='on_timeout')) - kwargs.setdefault('on_error', Mock(name='on_error')) + kwargs.setdefault('on_success', PickableMock(name='on_success')) + kwargs.setdefault('on_timeout', PickableMock(name='on_timeout')) + kwargs.setdefault('on_error', PickableMock(name='on_error')) subscriber = Mock(name='subscriber') subscriber.url = url subscriber.content_type = MIME_JSON @@ -176,6 +181,9 @@ def test_as_dict(self): 'retry_max': self.req.retry_max, 'recipient_validators': DEFAULT_RECIPIENT_VALIDATORS, 'allow_keepalive': self.req.allow_keepalive, + 'on_success': self.req.on_success, + 'on_error': self.req.on_error, + 'on_timeout': self.req.on_timeout, } def test_urlident(self): @@ -199,7 +207,7 @@ def as_dict(self): return {'value': 808} self.req._dispatcher = '' self.req.subscriber = Subscriber() - r2 = pickle.loads(pickle.dumps(self.req)) + r2 = pickle.loads(pickle.dumps(self.req, -1)) assert r2.app is self.app def test_repr(self): diff --git a/t/unit/test_tasks.py b/t/unit/test_tasks.py index 7e49f24..35f7a31 100644 --- a/t/unit/test_tasks.py +++ b/t/unit/test_tasks.py @@ -98,7 +98,8 @@ def test_success(self, app_or_default): id=self.req.id, timeout=self.req.timeout, retry=self.req.retry, retry_max=self.req.retry_max, retry_delay=self.req.retry_delay, recipient_validators=DEFAULT_RECIPIENT_VALIDATORS, - allow_keepalive=True, + allow_keepalive=True, on_error=None, on_success=None, + on_timeout=None ) _Request().dispatch.assert_called_once_with( session=self.session, propagate=_Request().retry) @@ -116,7 +117,8 @@ def test_when_keepalive_disabled(self, app_or_default): id=self.req.id, timeout=self.req.timeout, retry=self.req.retry, retry_max=self.req.retry_max, retry_delay=self.req.retry_delay, recipient_validators=DEFAULT_RECIPIENT_VALIDATORS, - allow_keepalive=False, + allow_keepalive=False, on_error=None, on_success=None, + on_timeout=None ) _Request().dispatch.assert_called_once_with( session=self.session, propagate=_Request().retry) @@ -133,7 +135,8 @@ def test_success__with_user(self, app_or_default): id=self.req2.id, timeout=self.req2.timeout, retry=self.req2.retry, retry_max=self.req2.retry_max, retry_delay=self.req2.retry_delay, recipient_validators=DEFAULT_RECIPIENT_VALIDATORS, - allow_keepalive=True, + allow_keepalive=True, on_error=None, on_success=None, + on_timeout=None ) _Request().dispatch.assert_called_once_with( session=self.session, propagate=_Request().retry) diff --git a/thorn/dispatch/celery.py b/thorn/dispatch/celery.py index 1ad53eb..a514d98 100644 --- a/thorn/dispatch/celery.py +++ b/thorn/dispatch/celery.py @@ -40,7 +40,7 @@ def send(self, event, payload, sender, timeout=None, context=None, **kwargs): return send_event.s( event, payload, - sender.pk if sender else sender, timeout, context, + sender.pk if sender else sender, timeout, context, **kwargs ).apply_async() def flush_buffer(self): diff --git a/thorn/request.py b/thorn/request.py index 7e5b0f7..611c402 100644 --- a/thorn/request.py +++ b/thorn/request.py @@ -207,6 +207,9 @@ def as_dict(self): self._recipient_validators, ), 'allow_keepalive': self.allow_keepalive, + 'on_success': self.on_success, + 'on_error': self.on_error, + 'on_timeout': self.on_timeout, } def annotate_headers(self, extra_headers): diff --git a/thorn/tasks.py b/thorn/tasks.py index 1455ebc..d37217e 100644 --- a/thorn/tasks.py +++ b/thorn/tasks.py @@ -17,7 +17,7 @@ def _worker_dispatcher(): @shared_task(ignore_result=True) -def send_event(event, payload, sender, timeout, context={}): +def send_event(event, payload, sender, timeout, context={}, **kwargs): # type: (str, Dict, Any, float, Dict) -> None """Task called by process dispatching the event. @@ -26,7 +26,7 @@ def send_event(event, payload, sender, timeout, context={}): HTTP requests in batches (``dispatch_requests -> dispatch_request``). """ _worker_dispatcher().send( - event, payload, sender, timeout=timeout, context=context) + event, payload, sender, timeout=timeout, context=context, **kwargs) @shared_task(ignore_result=True)