import os
import base64
import logging
import pickle
import uuid
import import_string
from datetime import timedelta
from django.core.management import call_command, get_commands
from django.core.exceptions import ImproperlyConfigured
from django.core.cache import caches
from django.db import close_old_connections, transaction
from django.db.utils import InterfaceError, OperationalError
from django.utils.timezone import now, localtime
try:
from celery import Task, shared_task, current_app
from celery.result import AsyncResult
from celery.exceptions import CeleryError, TimeoutError
from celery.utils.time import maybe_iso8601
from kombu.utils import uuid as task_uuid
from kombu import serialization
except ImportError:
raise ImproperlyConfigured('Missing celery library, please install it')
from .celery import CeleryQueueEnum
from .config import settings
logger = logging.getLogger(__name__)
def default_unique_key_generator(task, prefix, task_args, task_kwargs):
task_args = task_args or ()
task_kwargs = task_kwargs or {}
_, _, data = serialization.dumps(
(list(task_args), task_kwargs), task._get_app().conf.task_serializer,
)
return str(uuid.uuid5(uuid.NAMESPACE_DNS, ':'.join((prefix, task.name, data))))
[docs]class NotTriggeredCeleryError(CeleryError):
pass
class IgnoredResult:
state = 'IGNORED'
id = None
def get(self, *args, **kwargs):
return None
def successful(self):
return False
def failed(self):
return False
@property
def task_id(self):
return self.id
class ResultWrapper:
def __init__(self, invocation_id, task, args, kwargs, options, result=None):
self._invocation_id = invocation_id
self._result = result
self._task = task
self._args = args
self._kwargs = kwargs
self._options = options
def on_apply(self):
pass
def on_trigger(self):
pass
def on_unique(self):
pass
def on_ignored(self):
pass
def on_timeout(self):
pass
def set_result(self, result):
self._result = result
def set_options(self, options):
self._options = options
def then(self, *args, **kwargs):
return self._result.then(*args, **kwargs)
def get(self, *args, **kwargs):
if self._result is None:
raise NotTriggeredCeleryError('Celery task has not been triggered yet')
else:
try:
return self._result.get(*args, **kwargs)
except TimeoutError as ex:
self.timeout(ex)
raise ex
def timeout(self, ex):
self._task.on_invocation_timeout(
self._invocation_id, self._args, self._kwargs, self.task_id, ex, self._options, self
)
@property
def state(self):
return 'WAITING' if not self._result else self._result.state
def successful(self):
return self._result is not None and self._result.successful()
def failed(self):
return self._result is not None and self._result.failed()
@property
def id(self):
return None if self._result is None else self._result.task_id
@property
def task_id(self):
return self.id
[docs]class DjangoTask(Task):
abstract = True
# Support set retry delay in list. Retry countdown value is get from list where index is attempt
# number (request.retries)
default_retry_delays = None
# Unique task if task with same input already exists no extra task is created and old task result is returned
unique = False
unique_key_generator = default_unique_key_generator
_stackprotected = True
ignore_task_after_success = False
ignore_task_timedelta = None
result_wrapper_class = ResultWrapper
max_queue_waiting_time = None
stale_time_limit = None
def __new__(cls, *args, **kwargs):
queue = getattr(cls, 'queue', None)
if isinstance(queue, CeleryQueueEnum):
cls.queue = queue.queue_name
for k, v in queue.default_task_kwargs.items():
if getattr(cls, k, None) is None:
setattr(cls, k, v)
return super().__new__(cls, *args, **kwargs)
[docs] def on_invocation_apply(self, invocation_id, args, kwargs, options, result):
"""
Method is called when task was applied with the requester.
:param invocation_id: UUID of the requester invocation
:param args: input task args
:param kwargs: input task kwargs
:param options: input task options
:param result: result which will be finally returned
"""
result.on_apply()
[docs] def on_invocation_trigger(self, invocation_id, args, kwargs, task_id, options, result):
"""
Task has been triggered and placed in the queue.
:param invocation_id: UUID of the requester invocation
:param args: input task args
:param kwargs: input task kwargs
:param task_id: UUID of the celery task
:param options: input task options
:param result: result which will be finally returned
"""
result.on_trigger()
[docs] def on_invocation_unique(self, invocation_id, args, kwargs, task_id, options, result):
"""
Task has been triggered but the same task is already active.
Therefore only pointer to the active task is returned.
:param invocation_id: UUID of the requester invocation
:param args: input task args
:param kwargs: input task kwargs
:param task_id: UUID of the celery task
:param options: input task options
:param result: result which will be finally returned
"""
result.on_unique()
[docs] def on_invocation_ignored(self, invocation_id, args, kwargs, task_id, options, result):
"""
Task has been triggered but the task has set ignore_task_timedelta
and task was sucessfully completed in this timeout.
Therefore no new task is invoked.
:param invocation_id: UUID of the requester invocation
:param args: input task args
:param kwargs: input task kwargs
:param task_id: UUID of the celery task
:param options: input task options
:param result: result which will be finally returned
"""
result.on_ignored()
[docs] def on_invocation_timeout(self, invocation_id, args, kwargs, task_id, ex, options, result):
"""
Task has been joined to another unique async result.
:param invocation_id: UUID of the requester invocation
:param args: input task args
:param kwargs: input task kwargs
:param task_id: UUID of the celery task
:param ex: celery TimeoutError
:param options: input task options
:param result: result which will be finally returned
"""
result.on_timeout()
[docs] def on_task_start(self, task_id, args, kwargs):
"""
Task has been started with worker.
:param task_id: UUID of the celery task
:param args: input task args
:param kwargs: input task kwargs
"""
pass
[docs] def on_task_retry(self, task_id, args, kwargs, exc, eta):
"""
Task failed but will be retried.
:param task_id: UUID of the celery task
:param args: task args
:param kwargs: task kwargs
:param exc: raised exception which caused retry
:param eta: time to next retry
"""
pass
[docs] def on_task_failure(self, task_id, args, kwargs, exc, einfo):
"""
Task failed and will not be retried.
:param task_id: UUID of the celery task
:param args: task args
:param kwargs: task kwargs
:param exc: raised exception
:param einfo: exception traceback
"""
pass
[docs] def on_task_success(self, task_id, args, kwargs, retval):
"""
Task was successful.
:param task_id: UUID of the celery task
:param args: task args
:param kwargs: task kwargs
:param retval: task result
"""
pass
[docs] def on_failure(self, exc, task_id, args, kwargs, einfo):
super().on_failure(exc, task_id, args, kwargs, einfo)
self.on_task_failure(task_id, args, kwargs, exc, einfo)
self._clear_unique_key(args, kwargs)
[docs] def on_success(self, retval, task_id, args, kwargs):
super().on_success(retval, task_id, args, kwargs)
self.on_task_success(task_id, args, kwargs, retval)
self._clear_unique_key(args, kwargs)
if self.ignore_task_after_success:
assert self.ignore_task_timedelta is not None, (
'ignore_task_timedelta must be set for ignore_task_after_success'
)
self._set_ignore_task_after_success(args, kwargs)
def __call__(self, *args, **kwargs):
"""
Overrides parent which works with thread stack. We didn't want to allow change context which was generated in
one of apply methods. Call task directly is now disallowed.
"""
req = self.request_stack.top
if not req or req.called_directly:
raise CeleryError(
'Task cannot be called directly. Please use apply, apply_async or apply_async_on_commit methods'
)
if req._protected:
raise CeleryError('Request is protected')
# request is protected (no usage in celery but get from function _install_stack_protection in
# celery library)
req._protected = 1
# Every set attr is sent here
self.on_task_start(req.id, args, kwargs)
return self._start(*args, **kwargs)
def _start(self, *args, **kwargs):
return self.run(*args, **kwargs)
def _get_unique_key(self, task_args, task_kwargs):
return self.unique_key_generator(
settings.UNIQUE_TASK_KEY_PREFIX, task_args, task_kwargs
) if self.unique else None
def _get_ignore_task_key(self, task_args, task_kwargs):
return self.unique_key_generator(
settings.IGNORE_TASK_KEY_PREFIX, task_args, task_kwargs
)
def is_processing(self, args=None, kwargs=None):
unique_key = self._get_unique_key(args, kwargs)
if unique_key is None:
raise CeleryError('Process check can be performed for only unique tasks')
return caches[settings.CACHE_NAME].get(unique_key) is not None
def _ignore_task_after_success(self, key):
return key and caches[settings.CACHE_NAME].get(key)
def _set_ignore_task_after_success(self, task_args, task_kwargs, ignore_task_timedelta=None):
ignore_task_timedelta = self.ignore_task_timedelta if ignore_task_timedelta is None else ignore_task_timedelta
if ignore_task_timedelta:
ignore_task_key = self._get_ignore_task_key(task_args, task_kwargs)
if ignore_task_key:
current_time = localtime()
caches[settings.CACHE_NAME].add(
ignore_task_key,
True,
(current_time + ignore_task_timedelta - current_time).total_seconds()
)
def _clear_unique_key(self, task_args, task_kwargs):
unique_key = self._get_unique_key(task_args, task_kwargs)
if unique_key:
caches[settings.CACHE_NAME].delete(unique_key)
def _get_unique_task_id(self, unique_key, task_id, stale_time_limit):
if unique_key and not stale_time_limit:
raise CeleryError('For unique tasks is require set task stale_time_limit')
if unique_key and not self._get_app().conf.task_always_eager:
if caches[settings.CACHE_NAME].add(unique_key, task_id, stale_time_limit):
return task_id
else:
unique_task_id = caches[settings.CACHE_NAME].get(unique_key)
return (
unique_task_id if unique_task_id
else self._get_unique_task_id(unique_key, task_id, stale_time_limit)
)
else:
return task_id
def _compute_eta(self, eta, countdown, trigger_time):
if countdown is not None:
return trigger_time + timedelta(seconds=countdown)
elif eta:
return eta
else:
return trigger_time
def _compute_expires(self, expires, time_limit, stale_time_limit, trigger_time):
expires = self.expires if expires is None else expires
if expires is not None:
return trigger_time + timedelta(seconds=expires) if isinstance(expires, int) else expires
elif stale_time_limit is not None and time_limit is not None:
return trigger_time + timedelta(seconds=stale_time_limit - time_limit)
else:
return None
def _get_time_limit(self, time_limit):
if time_limit is not None:
return time_limit
elif self.time_limit is not None:
return self.time_limit
else:
return self._get_app().conf.task_time_limit
def _get_soft_time_limit(self, soft_time_limit):
if soft_time_limit is not None:
return soft_time_limit
elif self.soft_time_limit is not None:
return self.soft_time_limit
else:
return self._get_app().conf.task_soft_time_limit
def _get_stale_time_limit(self, expires, time_limit, stale_time_limit, trigger_time):
max_queue_waiting_time = self.max_queue_waiting_time or settings.DEFAULT_TASK_MAX_QUEUE_WAITING_TIME
if stale_time_limit is not None:
return stale_time_limit
elif self.stale_time_limit is not None:
return self.stale_time_limit
elif settings.DEFAULT_TASK_STALE_TIME_LIMIT is not None:
return settings.DEFAULT_TASK_STALE_TIME_LIMIT
elif time_limit is not None and max_queue_waiting_time:
autoretry_for = getattr(self, 'autoretry_for', None)
if autoretry_for and self.default_retry_delays:
return (
(time_limit + max_queue_waiting_time) * (len(self.default_retry_delays) + 1)
+ sum(self.default_retry_delays)
)
elif autoretry_for:
return (
(time_limit + max_queue_waiting_time + self.default_retry_delay) * self.max_retries
+ time_limit + max_queue_waiting_time
)
else:
return time_limit + max_queue_waiting_time
else:
return None
def _apply_and_get_result(self, result, args, kwargs, invocation_id, is_async=False, **options):
if is_async:
return self._call_super_apply_async(
args=args, kwargs=kwargs, is_async=is_async, invocation_id=invocation_id, **options
)
else:
return super().apply(
args=args, kwargs=kwargs, is_async=is_async, invocation_id=invocation_id, **options
)
def _trigger(self, result, args, kwargs, invocation_id, task_id=None, eta=None, countdown=None, expires=None,
time_limit=None, soft_time_limit=None, stale_time_limit=None, ignore_task_after_success=None,
is_async=True, headers=None, **options):
headers = {} if headers is None else headers
app = self._get_app()
task_id = task_id or task_uuid()
time_limit = self._get_time_limit(time_limit)
trigger_time = now()
eta = self._compute_eta(eta, countdown, trigger_time)
countdown = None
stale_time_limit = self._get_stale_time_limit(expires, time_limit, stale_time_limit, trigger_time)
expires = self._compute_expires(expires, time_limit, stale_time_limit, trigger_time)
options.update(dict(
invocation_id=invocation_id,
task_id=task_id,
trigger_time=trigger_time,
time_limit=time_limit,
soft_time_limit=self._get_soft_time_limit(soft_time_limit),
eta=eta,
countdown=countdown,
expires=expires,
is_async=is_async,
stale_time_limit=stale_time_limit
))
headers.update(dict(
apply_time=options['apply_time'].isoformat(),
trigger_time=trigger_time.isoformat(),
stale_time_limit=stale_time_limit
))
ignore_task_after_success = (
ignore_task_after_success if ignore_task_after_success is not None
else self.ignore_task_timedelta is not None
)
ignore_task_after_success_key = (
self._get_ignore_task_key(args, kwargs) if ignore_task_after_success else None
)
if self._ignore_task_after_success(ignore_task_after_success_key):
result.set_options(options)
result.set_result(IgnoredResult())
self.on_invocation_ignored(invocation_id, args, kwargs, task_id, options, result)
else:
unique_key = self._get_unique_key(args, kwargs)
unique_task_id = self._get_unique_task_id(unique_key, task_id, stale_time_limit)
if is_async and unique_task_id != task_id:
options['task_id'] = unique_task_id
result.set_options(options)
self.on_invocation_unique(invocation_id, args, kwargs, unique_task_id, options, result)
result.set_result(AsyncResult(unique_task_id, app=app))
else:
result.set_options(options)
self.on_invocation_trigger(invocation_id, args, kwargs, task_id, options, result)
result.set_result(self._apply_and_get_result(result, args, kwargs, headers=headers, **options))
def _first_apply(self, args=None, kwargs=None, invocation_id=None, is_async=True, is_on_commit=False, using=None,
**options):
invocation_id = invocation_id or task_uuid()
apply_time = now()
app = self._get_app()
queue = str(options.get('queue', getattr(self, 'queue', app.conf.task_default_queue)))
options.update(dict(
queue=queue,
is_async=is_async,
invocation_id=invocation_id,
apply_time=apply_time,
is_on_commit=is_on_commit,
using=using,
))
result = self.result_wrapper_class(invocation_id, self, args, kwargs, options)
self.on_invocation_apply(invocation_id, args, kwargs, options, result)
if is_on_commit:
self_inst = self
def _apply_on_commit():
self_inst._trigger(result, args=args, kwargs=kwargs, **options)
transaction.on_commit(_apply_on_commit, using=using)
else:
self._trigger(result, args=args, kwargs=kwargs, **options)
return result
def apply_async_on_commit(self, args=None, kwargs=None, using=None, **options):
return self._first_apply(args=args, kwargs=kwargs, is_async=True, is_on_commit=True, using=using, **options)
[docs] def apply(self, args=None, kwargs=None, **options):
if 'retries' in options or 'is_async' in options:
return super().apply(args=args, kwargs=kwargs, **options)
else:
return self._first_apply(args=args, kwargs=kwargs, is_async=False, **options)
def _call_super_apply_async(self, args=None, kwargs=None, task_id=None, **options):
"""
Apply async can be called from two sources. By hand from executor or automatically via retry function.
If retry function is used id is get from request. But we sometimes we need to change options before it
(some options is not transfered to the worker for example properties, therefore not all changed options in
_first_apply method is available in retry method)
"""
task_id = task_id or self.request.id or uuid()
if settings.AUTO_SQS_MESSAGE_GROUP_ID:
if 'MessageGroupId' not in options:
options['MessageGroupId'] = task_id
return super().apply_async(args=args, kwargs=kwargs, task_id=task_id, **options)
[docs] def apply_async(self, args=None, kwargs=None, **options):
try:
if self.request.id:
return self._call_super_apply_async(args=args, kwargs=kwargs, **options)
else:
return self._first_apply(
args=args, kwargs=kwargs, is_async=True, **options
)
except (InterfaceError, OperationalError) as ex:
logger.warning('Closing old database connections, following exception thrown: %s', str(ex))
close_old_connections()
raise ex
def delay_on_commit(self, *args, **kwargs):
options = kwargs.pop('options', {})
self.apply_async_on_commit(args, kwargs, **options)
def _get_header_from_request(self, name):
return getattr(self.request, name) if hasattr(self.request, name) else self.request.headers.get(name)
[docs] def retry(self, args=None, kwargs=None, exc=None, throw=True,
eta=None, countdown=None, max_retries=None, default_retry_delays=None, headers=None, **options):
headers = {} if headers is None else headers
trigger_time = now()
headers.update(dict(
apply_time=self._get_header_from_request('apply_time'),
trigger_time=trigger_time.isoformat(),
stale_time_limit=self._get_header_from_request('stale_time_limit')
))
max_retries = max_retries or self.max_retries
if default_retry_delays or (eta is None and countdown is None and self.default_retry_delays):
default_retry_delays = self.default_retry_delays if default_retry_delays is None else default_retry_delays
max_retries = len(default_retry_delays)
countdown = default_retry_delays[self.request.retries] if self.request.retries < max_retries else None
if not eta and countdown is None:
countdown = self.default_retry_delay
if not eta:
eta = trigger_time + timedelta(seconds=countdown)
if max_retries is None or self.request.retries < max_retries:
# In the opposite way task will be failed
self.on_task_retry(self.request.id, args, kwargs, exc, eta)
# Fix bug in celery==5.2.*
request_execution_options = self.request.as_execution_options()
if 'expires' in request_execution_options:
options.setdefault('expires', maybe_iso8601(request_execution_options['expires']))
return super().retry(
args=args, kwargs=kwargs, exc=exc, throw=throw,
eta=eta, max_retries=max_retries, headers=headers, **options
)
[docs] def apply_async_and_get_result(self, args=None, kwargs=None, timeout=None, propagate=True, **options):
"""
Apply task in an asynchronous way, wait defined timeout and get AsyncResult or TimeoutError
:param args: task args
:param kwargs: task kwargs
:param timeout: timout in seconds to wait for result
:param propagate: propagate or not exceptions from celery task
:param options: apply_async method options
:return: AsyncResult or TimeoutError
"""
result = self.apply_async(args=args, kwargs=kwargs, **options)
if timeout is None or timeout > 0:
return result.get(timeout=timeout, propagate=propagate)
else:
ex = TimeoutError('The operation timed out.')
result.timeout(ex)
raise ex
def get_command_kwargs(self):
return {}
def set_ignore_task(self, ignore_task_timedelta=None):
assert ignore_task_timedelta is not None or self.ignore_task_timedelta is not None, (
'ignore_task_timedelta must be set'
)
self._set_ignore_task_after_success(self.request.args, self.request.kwargs, ignore_task_timedelta)
def obj_to_string(obj):
return base64.encodebytes(pickle.dumps(obj)).decode('utf8')
def string_to_obj(obj_string):
return pickle.loads(base64.decodebytes(obj_string.encode('utf8')))
def get_command_task_name(command_name):
if command_name not in get_commands():
raise ImproperlyConfigured(f'Cannot generate celery task from command "{command_name}", command not found')
app_name = get_commands()[command_name]
return 'command.{}.{}'.format(app_name, command_name)
def get_django_command_task(command_name):
command_task_name = get_command_task_name(command_name)
if command_task_name not in current_app.tasks:
raise ImproperlyConfigured(
'Command was not found please check DJANGO_CELERY_EXTENSIONS_AUTO_GENERATE_TASKS_DJANGO_COMMANDS setting'
)
return current_app.tasks[command_task_name]
def auto_convert_commands_to_tasks():
from django.core.management.base import BaseCommand
BaseCommand.stealth_options = tuple(BaseCommand.stealth_options) + ('celery_task',)
for name in settings.AUTO_GENERATE_TASKS_DJANGO_COMMANDS:
task_name = get_command_task_name(name)
def generate_command_task(command_name, task_name):
shared_task_kwargs = {
**dict(
bind=True,
name=task_name,
ignore_result=True,
),
**(settings.AUTO_GENERATE_TASKS_DEFAULT_CELERY_KWARGS or {}),
**(settings.AUTO_GENERATE_TASKS_DJANGO_COMMANDS[command_name] or {})
}
if 'autoretry_for' in shared_task_kwargs:
shared_task_kwargs['autoretry_for'] = [
import_string(exception_class) if isinstance(exception_class, str) else exception_class
for exception_class in shared_task_kwargs['autoretry_for']
]
@shared_task(
**shared_task_kwargs
)
def command_task(self, command_args=None, **kwargs):
command_args = [] if command_args is None else command_args
call_command(
command_name,
settings=os.environ.get('DJANGO_SETTINGS_MODULE'),
celery_task=self,
*command_args,
**self.get_command_kwargs()
)
generate_command_task(name, task_name)
def init_celery_app():
if settings.CELERY_SETTINGS:
import_string(settings.CELERY_SETTINGS)
if settings.CELERY_AUTODISCOVER:
from common.celery import app
app.autodiscover_tasks(force=True)