'''
This module contails all the celery tasks for Script Audit Module.
'''

from datetime import datetime
import logging

from celery import states
from celery.exceptions import SoftTimeLimitExceeded
from celery.result import AsyncResult

from MNF import MNF_celery, settings
from scriptAudit.mnf_script_audit import NeutralAudit
from scriptAudit.models import ScriptAuditModel
from scriptAudit import exceptions as na_exc
from utils.abs_classes import GenericCeleryTask
from centralisedFileSystem.models import Script


logger = logging.getLogger(__name__)


class NeutralAuditTask(GenericCeleryTask):
    '''
    A celery task for running Script Audit.
    '''

    name = "ScriptAudit.tasks.NeutralAuditTask"
    soft_time_limit = 300

    throws = (na_exc.ScriptAuditTaskException,)

    def __init__(self) -> None:
        self.script_id : str
        self.audit_model : ScriptAuditModel

    def to_run(self) -> bool:

        if self.audit_model.status == "SUCCESS":
            raise na_exc.AlreadyAudited(self.audit_model.script.id)

        audit_started = bool(
            ScriptAuditModel.objects.filter(
                script=self.audit_model.script,
                status="STARTED",
            ).exclude(
                celery_id=self.request.id,
            ).count()
        )

        if audit_started:
            raise na_exc.AlreadyRunning(self.audit_model.script.id)

        print(f"Script {self.audit_model.script.id} exists but not Audited... ")
        return True

    def after_return(self, status, retval, task_id, args, kwargs, einfo) -> None:

        try:
            del self.script_id
            del self.audit_model

        except AttributeError as att_err:
            logger.info("`%s` not deleted as it was not decleared.", att_err)

        finally:
            logger.info("%s : %s exiting...", self.name, task_id)

    def on_failure(self, exc, task_id, args, kwargs, einfo) -> None:

        if self.is_expected_exceptions(exc):
            self.update_state(task_id, states.REJECTED)
            logger.warning(
            "`%s` failed without logging in `ScriptAuditModel` with : %s",
            task_id,
            exc,
        )

        try:
            self.audit_model.status = AsyncResult(task_id).status
            self.audit_model.results = exc
            self.audit_model.save()

        except AttributeError as att_err:
            logger.info("`%s` Model does not exist! It is not an error.", att_err)

        if settings.DEBUG:
            logger.debug(einfo)

    def on_success(self, retval, task_id, args, kwargs) -> None:

        self.audit_model.status = AsyncResult(task_id).status
        self.audit_model.results = retval

        self.audit_model.save()

        self.audit_model.script.last_audited_on = datetime.now()
        self.audit_model.script.save()

        logger.info(
            "ScriptAuditTask %s executed successfuly for script %s.",
            task_id,
            self.script_id,
        )

    def on_retry(self, exc, task_id, args, kwargs, einfo) -> None:

        self.audit_model.status = AsyncResult(task_id).status
        self.audit_model.retries += 1
        self.audit_model.results = exc
        self.audit_model.save()

        if settings.DEBUG:
            logger.debug(einfo)

    def run(self, *args, **kwargs) -> None:

        self.script_id = kwargs.get("script_id", None)

        if not self.script_id:
            raise na_exc.ScriptIdNotFound()

        try:
            self.audit_model = ScriptAuditModel.objects.update_or_create(
                script = Script.objects.get(
                            id=self.script_id,
                        ),
                celery_id = self.request.id,
                status = AsyncResult(self.request.id).status,
            )[0]

        except Script.DoesNotExist as dne:
            raise na_exc.ScriptDoesNotExist(self.script_id) from dne

        if not self.to_run():
            return

        logger.info("Auditing Script %s", self.script_id)

        try:
            audit = NeutralAudit(*args, **kwargs)
            audit.audit()

        except SoftTimeLimitExceeded as stle:
            raise na_exc.TimeLimitExeded(self.soft_time_limit) from stle


MNF_celery.app.register_task(NeutralAuditTask)