From 3669652413703dbaed6bd6a47bb6527a9e2d6d8c Mon Sep 17 00:00:00 2001 From: Jarek Potiuk Date: Sun, 1 Dec 2019 16:41:28 +0100 Subject: [PATCH] [AIRFLOW-6140] Add missing types for some core classes --- .pre-commit-config.yaml | 2 +- airflow/__init__.py | 3 +- .../airflow_local_settings.py | 72 +++++---- airflow/configuration.py | 3 +- airflow/macros/__init__.py | 4 +- airflow/models/baseoperator.py | 146 ++++++++++-------- airflow/serialization/serialized_objects.py | 38 +++-- tests/plugins/test_plugins_manager.py | 1 - tests/test_local_settings.py | 11 +- tests/ti_deps/deps/test_trigger_rule_dep.py | 1 - 10 files changed, 148 insertions(+), 133 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 22643454c519b..8c198af9a44c7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -262,7 +262,7 @@ repos: language: system entry: "./scripts/ci/pre_commit_pylint_main.sh" files: \.py$ - exclude: ^tests/.*\.py$|^airflow/_vendor/.*$ + exclude: ^tests/.*\.py$|^airflow/_vendor/.*|^scripts/.*\.py$ pass_filenames: true require_serial: true # Pylint tests should be run in one chunk to detect all cycles - id: pylint-tests diff --git a/airflow/__init__.py b/airflow/__init__.py index 862c9bcdc51ab..abb9588778988 100644 --- a/airflow/__init__.py +++ b/airflow/__init__.py @@ -31,14 +31,13 @@ # pylint: disable=wrong-import-position from typing import Callable, Optional -# noinspection PyUnresolvedReferences from airflow import utils from airflow import settings from airflow import version from airflow.utils.log.logging_mixin import LoggingMixin from airflow.configuration import conf from airflow.exceptions import AirflowException -from airflow.models import DAG +from airflow.models.dag import DAG __version__ = version.version diff --git a/airflow/config_templates/airflow_local_settings.py b/airflow/config_templates/airflow_local_settings.py index ae8df7977b120..f65dec8a43b68 100644 --- a/airflow/config_templates/airflow_local_settings.py +++ b/airflow/config_templates/airflow_local_settings.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file @@ -19,43 +18,42 @@ """Airflow logging settings""" import os -from typing import Any, Dict +from typing import Any, Dict, Union -from airflow import AirflowException -from airflow.configuration import conf +from airflow import AirflowException, conf from airflow.utils.file import mkdirs # TODO: Logging format and level should be configured # in this file instead of from airflow.cfg. Currently # there are other log format and level configurations in # settings.py and cli.py. Please see AIRFLOW-1455. -LOG_LEVEL = conf.get('core', 'LOGGING_LEVEL').upper() +LOG_LEVEL: str = conf.get('core', 'LOGGING_LEVEL').upper() # Flask appbuilder's info level log is very verbose, # so it's set to 'WARN' by default. -FAB_LOG_LEVEL = conf.get('core', 'FAB_LOGGING_LEVEL').upper() +FAB_LOG_LEVEL: str = conf.get('core', 'FAB_LOGGING_LEVEL').upper() -LOG_FORMAT = conf.get('core', 'LOG_FORMAT') +LOG_FORMAT: str = conf.get('core', 'LOG_FORMAT') -COLORED_LOG_FORMAT = conf.get('core', 'COLORED_LOG_FORMAT') +COLORED_LOG_FORMAT: str = conf.get('core', 'COLORED_LOG_FORMAT') -COLORED_LOG = conf.getboolean('core', 'COLORED_CONSOLE_LOG') +COLORED_LOG: bool = conf.getboolean('core', 'COLORED_CONSOLE_LOG') -COLORED_FORMATTER_CLASS = conf.get('core', 'COLORED_FORMATTER_CLASS') +COLORED_FORMATTER_CLASS: str = conf.get('core', 'COLORED_FORMATTER_CLASS') -BASE_LOG_FOLDER = conf.get('core', 'BASE_LOG_FOLDER') +BASE_LOG_FOLDER: str = conf.get('core', 'BASE_LOG_FOLDER') -PROCESSOR_LOG_FOLDER = conf.get('scheduler', 'CHILD_PROCESS_LOG_DIRECTORY') +PROCESSOR_LOG_FOLDER: str = conf.get('scheduler', 'CHILD_PROCESS_LOG_DIRECTORY') -DAG_PROCESSOR_MANAGER_LOG_LOCATION = \ +DAG_PROCESSOR_MANAGER_LOG_LOCATION: str = \ conf.get('core', 'DAG_PROCESSOR_MANAGER_LOG_LOCATION') -FILENAME_TEMPLATE = conf.get('core', 'LOG_FILENAME_TEMPLATE') +FILENAME_TEMPLATE: str = conf.get('core', 'LOG_FILENAME_TEMPLATE') -PROCESSOR_FILENAME_TEMPLATE = conf.get('core', 'LOG_PROCESSOR_FILENAME_TEMPLATE') +PROCESSOR_FILENAME_TEMPLATE: str = conf.get('core', 'LOG_PROCESSOR_FILENAME_TEMPLATE') -DEFAULT_LOGGING_CONFIG = { +DEFAULT_LOGGING_CONFIG: Dict[str, Any] = { 'version': 1, 'disable_existing_loggers': False, 'formatters': { @@ -107,9 +105,9 @@ 'handlers': ['console'], 'level': LOG_LEVEL, } -} # type: Dict[str, Any] +} -DEFAULT_DAG_PARSING_LOGGING_CONFIG = { +DEFAULT_DAG_PARSING_LOGGING_CONFIG: Dict[str, Dict[str, Dict[str, Any]]] = { 'handlers': { 'processor_manager': { 'class': 'logging.handlers.RotatingFileHandler', @@ -140,34 +138,34 @@ # Manually create log directory for processor_manager handler as RotatingFileHandler # will only create file but not the directory. - processor_manager_handler_config = DEFAULT_DAG_PARSING_LOGGING_CONFIG['handlers'][ - 'processor_manager'] - directory = os.path.dirname(processor_manager_handler_config['filename']) + processor_manager_handler_config: Dict[str, Any] = \ + DEFAULT_DAG_PARSING_LOGGING_CONFIG['handlers']['processor_manager'] + directory: str = os.path.dirname(processor_manager_handler_config['filename']) mkdirs(directory, 0o755) ################## # Remote logging # ################## -REMOTE_LOGGING = conf.getboolean('core', 'remote_logging') +REMOTE_LOGGING: bool = conf.getboolean('core', 'remote_logging') if REMOTE_LOGGING: - ELASTICSEARCH_HOST = conf.get('elasticsearch', 'HOST') + ELASTICSEARCH_HOST: str = conf.get('elasticsearch', 'HOST') # Storage bucket URL for remote logging # S3 buckets should start with "s3://" # GCS buckets should start with "gs://" # WASB buckets should start with "wasb" # just to help Airflow select correct handler - REMOTE_BASE_LOG_FOLDER = conf.get('core', 'REMOTE_BASE_LOG_FOLDER') + REMOTE_BASE_LOG_FOLDER: str = conf.get('core', 'REMOTE_BASE_LOG_FOLDER') if REMOTE_BASE_LOG_FOLDER.startswith('s3://'): - S3_REMOTE_HANDLERS = { + S3_REMOTE_HANDLERS: Dict[str, Dict[str, str]] = { 'task': { 'class': 'airflow.utils.log.s3_task_handler.S3TaskHandler', 'formatter': 'airflow', - 'base_log_folder': os.path.expanduser(BASE_LOG_FOLDER), + 'base_log_folder': str(os.path.expanduser(BASE_LOG_FOLDER)), 's3_log_folder': REMOTE_BASE_LOG_FOLDER, 'filename_template': FILENAME_TEMPLATE, }, @@ -175,11 +173,11 @@ DEFAULT_LOGGING_CONFIG['handlers'].update(S3_REMOTE_HANDLERS) elif REMOTE_BASE_LOG_FOLDER.startswith('gs://'): - GCS_REMOTE_HANDLERS = { + GCS_REMOTE_HANDLERS: Dict[str, Dict[str, str]] = { 'task': { 'class': 'airflow.utils.log.gcs_task_handler.GCSTaskHandler', 'formatter': 'airflow', - 'base_log_folder': os.path.expanduser(BASE_LOG_FOLDER), + 'base_log_folder': str(os.path.expanduser(BASE_LOG_FOLDER)), 'gcs_log_folder': REMOTE_BASE_LOG_FOLDER, 'filename_template': FILENAME_TEMPLATE, }, @@ -187,11 +185,11 @@ DEFAULT_LOGGING_CONFIG['handlers'].update(GCS_REMOTE_HANDLERS) elif REMOTE_BASE_LOG_FOLDER.startswith('wasb'): - WASB_REMOTE_HANDLERS = { + WASB_REMOTE_HANDLERS: Dict[str, Dict[str, Union[str, bool]]] = { 'task': { 'class': 'airflow.utils.log.wasb_task_handler.WasbTaskHandler', 'formatter': 'airflow', - 'base_log_folder': os.path.expanduser(BASE_LOG_FOLDER), + 'base_log_folder': str(os.path.expanduser(BASE_LOG_FOLDER)), 'wasb_log_folder': REMOTE_BASE_LOG_FOLDER, 'wasb_container': 'airflow-logs', 'filename_template': FILENAME_TEMPLATE, @@ -201,17 +199,17 @@ DEFAULT_LOGGING_CONFIG['handlers'].update(WASB_REMOTE_HANDLERS) elif ELASTICSEARCH_HOST: - ELASTICSEARCH_LOG_ID_TEMPLATE = conf.get('elasticsearch', 'LOG_ID_TEMPLATE') - ELASTICSEARCH_END_OF_LOG_MARK = conf.get('elasticsearch', 'END_OF_LOG_MARK') - ELASTICSEARCH_WRITE_STDOUT = conf.get('elasticsearch', 'WRITE_STDOUT') - ELASTICSEARCH_JSON_FORMAT = conf.get('elasticsearch', 'JSON_FORMAT') - ELASTICSEARCH_JSON_FIELDS = conf.get('elasticsearch', 'JSON_FIELDS') + ELASTICSEARCH_LOG_ID_TEMPLATE: str = conf.get('elasticsearch', 'LOG_ID_TEMPLATE') + ELASTICSEARCH_END_OF_LOG_MARK: str = conf.get('elasticsearch', 'END_OF_LOG_MARK') + ELASTICSEARCH_WRITE_STDOUT: str = conf.get('elasticsearch', 'WRITE_STDOUT') + ELASTICSEARCH_JSON_FORMAT: str = conf.get('elasticsearch', 'JSON_FORMAT') + ELASTICSEARCH_JSON_FIELDS: str = conf.get('elasticsearch', 'JSON_FIELDS') - ELASTIC_REMOTE_HANDLERS = { + ELASTIC_REMOTE_HANDLERS: Dict[str, Dict[str, str]] = { 'task': { 'class': 'airflow.utils.log.es_task_handler.ElasticsearchTaskHandler', 'formatter': 'airflow', - 'base_log_folder': os.path.expanduser(BASE_LOG_FOLDER), + 'base_log_folder': str(os.path.expanduser(BASE_LOG_FOLDER)), 'log_id_template': ELASTICSEARCH_LOG_ID_TEMPLATE, 'filename_template': FILENAME_TEMPLATE, 'end_of_log_mark': ELASTICSEARCH_END_OF_LOG_MARK, diff --git a/airflow/configuration.py b/airflow/configuration.py index a8528cc7eb96b..817841a07b53f 100644 --- a/airflow/configuration.py +++ b/airflow/configuration.py @@ -1,5 +1,4 @@ -# -*- coding: utf-8 -*- -# + # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information diff --git a/airflow/macros/__init__.py b/airflow/macros/__init__.py index 6a9fa7e8fd3f6..abc2ab7d1a205 100644 --- a/airflow/macros/__init__.py +++ b/airflow/macros/__init__.py @@ -16,9 +16,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - -# pylint: disable=missing-docstring - +"""Macros.""" import time # noqa import uuid # noqa from datetime import datetime, timedelta diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index dd36d5d31e3e1..69a708f506e41 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -31,6 +31,7 @@ import jinja2 from cached_property import cached_property from dateutil.relativedelta import relativedelta +from sqlalchemy.orm import Session from airflow.configuration import conf from airflow.exceptions import AirflowException, DuplicateTaskIdFound @@ -40,6 +41,7 @@ # noinspection PyPep8Naming from airflow.models.taskinstance import TaskInstance, clear_task_instances from airflow.models.xcom import XCOM_RETURN_KEY +from airflow.ti_deps.deps.base_ti_dep import BaseTIDep from airflow.ti_deps.deps.not_in_retry_period_dep import NotInRetryPeriodDep from airflow.ti_deps.deps.prev_dagrun_dep import PrevDagrunDep from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep @@ -500,7 +502,7 @@ def __rlshift__(self, other): # /Composing Operators --------------------------------------------- @property - def dag(self): + def dag(self) -> Any: """ Returns the Operator's DAG if set, otherwise raises an error """ @@ -511,7 +513,7 @@ def dag(self): 'Operator {} has not been assigned to a DAG yet'.format(self)) @dag.setter - def dag(self, dag): + def dag(self, dag: Any): """ Operators can be assigned to one DAG, one time. Repeat assignments to that same DAG are ok. @@ -541,7 +543,7 @@ def has_dag(self): return getattr(self, '_dag', None) is not None @property - def dag_id(self): + def dag_id(self) -> str: """Returns dag id if it has one or an adhoc + owner""" if self.has_dag(): return self.dag.dag_id @@ -549,9 +551,9 @@ def dag_id(self): return 'adhoc_' + self.owner @property - def deps(self): + def deps(self) -> Set[BaseTIDep]: """ - Returns the list of dependencies for the operator. These differ from execution + Returns the set of dependencies for the operator. These differ from execution context dependencies in that they are specific to tasks and can be extended/overridden by subclasses. """ @@ -562,7 +564,7 @@ def deps(self): } @property - def priority_weight_total(self): + def priority_weight_total(self) -> int: """ Total priority weight for the task. It might include all upstream or downstream tasks. depending on the weight rule. @@ -581,17 +583,21 @@ def priority_weight_total(self): else: upstream = False + if not self._dag: + return self.priority_weight + from airflow import DAG + dag: DAG = self._dag return self.priority_weight + sum( - map(lambda task_id: self._dag.task_dict[task_id].priority_weight, + map(lambda task_id: dag.task_dict[task_id].priority_weight, self.get_flat_relative_ids(upstream=upstream)) ) @cached_property - def operator_extra_link_dict(self): + def operator_extra_link_dict(self) -> Dict[str, Any]: """Returns dictionary of all extra links for the operator""" - from airflow.plugins_manager import operator_extra_links - op_extra_links_from_plugin = {} + op_extra_links_from_plugin: Dict[str, Any] = {} + from airflow.plugins_manager import operator_extra_links for ope in operator_extra_links: if ope.operators and self.__class__ in ope.operators: op_extra_links_from_plugin.update({ope.name: ope}) @@ -605,18 +611,18 @@ def operator_extra_link_dict(self): return operator_extra_links_all @cached_property - def global_operator_extra_link_dict(self): + def global_operator_extra_link_dict(self) -> Dict[str, Any]: """Returns dictionary of all global extra links""" from airflow.plugins_manager import global_operator_extra_links return {link.name: link for link in global_operator_extra_links} @prepare_lineage - def pre_execute(self, context): + def pre_execute(self, context: Any): """ This hook is triggered right before self.execute() is called. """ - def execute(self, context): + def execute(self, context: Any): """ This is the main method to derive when creating an operator. Context is the same dictionary used as when rendering jinja templates. @@ -626,14 +632,14 @@ def execute(self, context): raise NotImplementedError() @apply_lineage - def post_execute(self, context, result=None): + def post_execute(self, context: Any, result: Any = None): """ This hook is triggered right after self.execute() is called. It is passed the execution context and any results returned by the operator. """ - def on_kill(self): + def on_kill(self) -> None: """ Override this method to cleanup subprocesses when a task instance gets killed. Any use of the threading, subprocess or multiprocessing @@ -768,7 +774,7 @@ def get_template_env(self) -> jinja2.Environment: """Fetch a Jinja template environment from the DAG or instantiate empty environment if no DAG.""" return self.dag.get_template_env() if self.has_dag() else jinja2.Environment(cache_size=0) - def prepare_template(self): + def prepare_template(self) -> None: """ Hook that is triggered after the templated fields get replaced by their content. If you need your operator to alter the @@ -776,7 +782,7 @@ def prepare_template(self): it should override this method to do so. """ - def resolve_template_files(self): + def resolve_template_files(self) -> None: """Getting the content of files for template_field / template_ext""" if self.template_ext: # pylint: disable=too-many-nested-blocks for attr in self.template_fields: @@ -802,32 +808,32 @@ def resolve_template_files(self): self.prepare_template() @property - def upstream_list(self): + def upstream_list(self) -> List[str]: """@property: list of tasks directly upstream""" return [self.dag.get_task(tid) for tid in self._upstream_task_ids] @property - def upstream_task_ids(self): - """@property: list of ids of tasks directly upstream""" + def upstream_task_ids(self) -> Set[str]: + """@property: set of ids of tasks directly upstream""" return self._upstream_task_ids @property - def downstream_list(self): + def downstream_list(self) -> List[str]: """@property: list of tasks directly downstream""" return [self.dag.get_task(tid) for tid in self._downstream_task_ids] @property - def downstream_task_ids(self): - """@property: list of ids of tasks directly downstream""" + def downstream_task_ids(self) -> Set[str]: + """@property: set of ids of tasks directly downstream""" return self._downstream_task_ids @provide_session def clear(self, - start_date=None, - end_date=None, - upstream=False, - downstream=False, - session=None): + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, + upstream: bool = False, + downstream: bool = False, + session: Session = None): """ Clears the state of task instances associated with the task, following the parameters specified. @@ -860,7 +866,9 @@ def clear(self, return count @provide_session - def get_task_instances(self, start_date=None, end_date=None, session=None): + def get_task_instances(self, start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, + session: Session = None) -> List[TaskInstance]: """ Get a set of task instance related to this task for a specific date range. @@ -874,11 +882,16 @@ def get_task_instances(self, start_date=None, end_date=None, session=None): .order_by(TaskInstance.execution_date)\ .all() - def get_flat_relative_ids(self, upstream=False, found_descendants=None): + def get_flat_relative_ids(self, + upstream: bool = False, + found_descendants: Optional[Set[str]] = None) -> Set[str]: """ - Get a flat list of relatives' ids, either upstream or downstream. + Get a flat set of relatives' ids, either upstream or downstream. """ + if not self._dag: + return set() + if not found_descendants: found_descendants = set() relative_ids = self.get_direct_relative_ids(upstream) @@ -892,20 +905,24 @@ def get_flat_relative_ids(self, upstream=False, found_descendants=None): return found_descendants - def get_flat_relatives(self, upstream=False): + def get_flat_relatives(self, upstream: bool = False): """ Get a flat list of relatives, either upstream or downstream. """ - return list(map(lambda task_id: self._dag.task_dict[task_id], + if not self._dag: + return set() + from airflow import DAG + dag: DAG = self._dag + return list(map(lambda task_id: dag.task_dict[task_id], self.get_flat_relative_ids(upstream))) def run( self, - start_date=None, - end_date=None, - ignore_first_depends_on_past=False, - ignore_ti_state=False, - mark_success=False): + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, + ignore_first_depends_on_past: bool = False, + ignore_ti_state: bool = False, + mark_success: bool = False) -> None: """ Run a set of task instances for a date range. """ @@ -919,7 +936,7 @@ def run( execution_date == start_date and ignore_first_depends_on_past), ignore_ti_state=ignore_ti_state) - def dry_run(self): + def dry_run(self) -> None: """Performs dry run for the operator - just render template fields.""" self.log.info('Dry run') for attr in self.template_fields: @@ -928,9 +945,9 @@ def dry_run(self): self.log.info('Rendering template for %s', attr) self.log.info(content) - def get_direct_relative_ids(self, upstream=False): + def get_direct_relative_ids(self, upstream: bool = False) -> Set[str]: """ - Get the direct relative ids to the current task, upstream or + Get set of the direct relative ids to the current task, upstream or downstream. """ if upstream: @@ -938,9 +955,9 @@ def get_direct_relative_ids(self, upstream=False): else: return self._downstream_task_ids - def get_direct_relatives(self, upstream=False): + def get_direct_relatives(self, upstream: bool = False) -> List[str]: """ - Get the direct relatives to the current task, upstream or + Get list of the direct relatives to the current task, upstream or downstream. """ if upstream: @@ -953,11 +970,11 @@ def __repr__(self): self=self) @property - def task_type(self): + def task_type(self) -> str: """@property: type of the task""" return self.__class__.__name__ - def add_only_new(self, item_set, item): + def add_only_new(self, item_set: Set[str], item: str) -> None: """Adds only new items to item set""" if item in item_set: self.log.warning( @@ -965,12 +982,14 @@ def add_only_new(self, item_set, item): else: item_set.add(item) - def _set_relatives(self, task_or_task_list, upstream=False): - """Sets relatives for the task.""" + def _set_relatives(self, + task_or_task_list: Union['BaseOperator', List['BaseOperator']], + upstream: bool = False) -> None: + """Sets relatives for the task or task list.""" try: - task_list = list(task_or_task_list) + task_list = list(task_or_task_list) # type: ignore except TypeError: - task_list = [task_or_task_list] + task_list = [task_or_task_list] # type: ignore for task in task_list: if not isinstance(task, BaseOperator): @@ -980,8 +999,9 @@ def _set_relatives(self, task_or_task_list, upstream=False): # relationships can only be set if the tasks share a single DAG. Tasks # without a DAG are assigned to that DAG. + # noinspection PyProtectedMember dags = { - task._dag.dag_id: task._dag # pylint: disable=protected-access + task._dag.dag_id: task._dag # type: ignore # pylint: disable=protected-access for task in [self] + task_list if task.has_dag()} if len(dags) > 1: @@ -1009,14 +1029,14 @@ def _set_relatives(self, task_or_task_list, upstream=False): self.add_only_new(self._downstream_task_ids, task.task_id) task.add_only_new(task.get_direct_relative_ids(upstream=True), self.task_id) - def set_downstream(self, task_or_task_list): + def set_downstream(self, task_or_task_list: Union['BaseOperator', List['BaseOperator']]) -> None: """ Set a task or a task list to be directly downstream from the current task. """ self._set_relatives(task_or_task_list, upstream=False) - def set_upstream(self, task_or_task_list): + def set_upstream(self, task_or_task_list: Union['BaseOperator', List['BaseOperator']]) -> None: """ Set a task or a task list to be directly upstream from the current task. @@ -1025,10 +1045,10 @@ def set_upstream(self, task_or_task_list): @staticmethod def xcom_push( - context, - key, - value, - execution_date=None): + context: Any, + key: str, + value: Any, + execution_date: Optional[datetime] = None) -> None: """ See TaskInstance.xcom_push() """ @@ -1039,11 +1059,11 @@ def xcom_push( @staticmethod def xcom_pull( - context, - task_ids=None, - dag_id=None, - key=XCOM_RETURN_KEY, - include_prior_dates=None): + context: Any, + task_ids: Optional[List[str]] = None, + dag_id: Optional[str] = None, + key: str = XCOM_RETURN_KEY, + include_prior_dates: Optional[bool] = None) -> Any: """ See TaskInstance.xcom_pull() """ @@ -1054,12 +1074,12 @@ def xcom_pull( include_prior_dates=include_prior_dates) @cached_property - def extra_links(self) -> Iterable[str]: + def extra_links(self) -> List[str]: """@property: extra links for the task. """ return list(set(self.operator_extra_link_dict.keys()) .union(self.global_operator_extra_link_dict.keys())) - def get_extra_links(self, dttm, link_name): + def get_extra_links(self, dttm: datetime, link_name: str) -> Optional[Dict[str, Any]]: """ For an operator, gets the URL that the external links specified in `extra_links` should point to. diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 5f07021288ef7..ce08a8c44bc16 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -20,7 +20,7 @@ import enum import logging from inspect import Parameter, signature -from typing import Dict, Optional, Set, Union +from typing import Any, Dict, Optional, Set, Union import pendulum from dateutil import relativedelta @@ -74,13 +74,14 @@ def from_json(cls, serialized_obj: str) -> Union['BaseSerialization', dict, list return cls.from_dict(json.loads(serialized_obj)) @classmethod - def from_dict(cls, serialized_obj: dict) -> Union['BaseSerialization', dict, list, set, tuple]: + def from_dict(cls, serialized_obj: Dict[Encoding, Any]) -> \ + Union['BaseSerialization', dict, list, set, tuple]: """Deserializes a python dict stored with type decorators and reconstructs all DAGs and operators it contains.""" return cls._deserialize(serialized_obj) @classmethod - def validate_schema(cls, serialized_obj: Union[str, dict]): + def validate_schema(cls, serialized_obj: Union[str, dict]) -> None: """Validate serialized_obj satisfies JSON schema.""" if cls._json_schema is None: raise AirflowException('JSON schema of {:s} is not set.'.format(cls.__name__)) @@ -93,17 +94,17 @@ def validate_schema(cls, serialized_obj: Union[str, dict]): raise TypeError("Invalid type: Only dict and str are supported.") @staticmethod - def _encode(x, type_): + def _encode(x: Any, type_: Any) -> Dict[Encoding, Any]: """Encode data by a JSON dict.""" return {Encoding.VAR: x, Encoding.TYPE: type_} @classmethod - def _is_primitive(cls, var): + def _is_primitive(cls, var: Any) -> bool: """Primitive types.""" return var is None or isinstance(var, cls._primitive_types) @classmethod - def _is_excluded(cls, var, attrname, instance): + def _is_excluded(cls, var: Any, attrname: str, instance: Any) -> bool: """Types excluded from serialization.""" # pylint: disable=unused-argument return ( @@ -113,9 +114,10 @@ def _is_excluded(cls, var, attrname, instance): ) @classmethod - def serialize_to_json(cls, object_to_serialize: Union[BaseOperator, DAG], decorated_fields: Set): + def serialize_to_json(cls, object_to_serialize: Union[BaseOperator, DAG], decorated_fields: Set) \ + -> Dict[str, Any]: """Serializes an object to json""" - serialized_object = {} + serialized_object: Dict[str, Any] = {} keys_to_serialize = object_to_serialize.get_serialized_fields() for key in keys_to_serialize: # None is ignored in serialized form and is added back in deserialization. @@ -132,8 +134,9 @@ def serialize_to_json(cls, object_to_serialize: Union[BaseOperator, DAG], decora serialized_object[key] = value return serialized_object + # pylint: disable=too-many-return-statements @classmethod - def _serialize(cls, var): # pylint: disable=too-many-return-statements + def _serialize(cls, var: Any) -> Any: # Unfortunately there is no support for recursive types in mypy """Helper function of depth first search for serialization. The serialization protocol is: @@ -191,9 +194,10 @@ def _serialize(cls, var): # pylint: disable=too-many-return-statements except Exception: # pylint: disable=broad-except LOG.warning('Failed to stringify.', exc_info=True) return FAILED + # pylint: enable=too-many-return-statements @classmethod - def _deserialize(cls, encoded_var): # pylint: disable=too-many-return-statements + def _deserialize(cls, encoded_var: Any) -> Any: # pylint: disable=too-many-return-statements """Helper function of depth first search for deserialization.""" # JSON primitives (except for dict) are not encoded. if cls._is_primitive(encoded_var): @@ -219,7 +223,7 @@ def _deserialize(cls, encoded_var): # pylint: disable=too-many-return-statement return pendulum.timezone(var) elif type_ == DAT.RELATIVEDELTA: if 'weekday' in var: - var['weekday'] = relativedelta.weekday(*var['weekday']) + var['weekday'] = relativedelta.weekday(*var['weekday']) # type: ignore return relativedelta.relativedelta(**var) elif type_ == DAT.SET: return {cls._deserialize(v) for v in var} @@ -232,11 +236,11 @@ def _deserialize(cls, encoded_var): # pylint: disable=too-many-return-statement _deserialize_timezone = pendulum.timezone @classmethod - def _deserialize_timedelta(cls, seconds): + def _deserialize_timedelta(cls, seconds: int) -> datetime.timedelta: return datetime.timedelta(seconds=seconds) @classmethod - def _value_is_hardcoded_default(cls, attrname, value): + def _value_is_hardcoded_default(cls, attrname: str, value: Any) -> bool: """ Return true if ``value`` is the hard-coded default for the given attribute. @@ -298,7 +302,7 @@ def serialize_operator(cls, op: BaseOperator) -> dict: return serialize_op @classmethod - def deserialize_operator(cls, encoded_op: dict) -> BaseOperator: + def deserialize_operator(cls, encoded_op: Dict[str, Any]) -> BaseOperator: """Deserializes an operator from a JSON object. """ from airflow.plugins_manager import operator_extra_links @@ -338,7 +342,7 @@ def deserialize_operator(cls, encoded_op: dict) -> BaseOperator: return op @classmethod - def _is_excluded(cls, var, attrname, op): + def _is_excluded(cls, var: Any, attrname: str, op: BaseOperator): if var is not None and op.has_dag() and attrname.endswith("_date"): # If this date is the same as the matching field in the dag, then # don't store it again at the task level. @@ -393,7 +397,7 @@ def serialize_dag(cls, dag: DAG) -> dict: return serialize_dag @classmethod - def deserialize_dag(cls, encoded_dag: dict) -> 'SerializedDAG': + def deserialize_dag(cls, encoded_dag: Dict[str, Any]) -> 'SerializedDAG': """Deserializes a DAG from a JSON object. """ dag = SerializedDAG(dag_id=encoded_dag['_dag_id']) @@ -443,7 +447,7 @@ def deserialize_dag(cls, encoded_dag: dict) -> 'SerializedDAG': return dag @classmethod - def to_dict(cls, var) -> dict: + def to_dict(cls, var: Any) -> dict: """Stringifies DAGs and operators contained by var and returns a dict of var. """ json_dict = { diff --git a/tests/plugins/test_plugins_manager.py b/tests/plugins/test_plugins_manager.py index 975977566dd19..e0c337ddc8d24 100644 --- a/tests/plugins/test_plugins_manager.py +++ b/tests/plugins/test_plugins_manager.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file diff --git a/tests/test_local_settings.py b/tests/test_local_settings.py index ea774c20edc44..2ac2c7be5cd91 100644 --- a/tests/test_local_settings.py +++ b/tests/test_local_settings.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file @@ -94,10 +93,10 @@ def test_import_with_dunder_all_not_specified(self): """ with SettingsContext(SETTINGS_FILE_POLICY_WITH_DUNDER_ALL, "airflow_local_settings"): from airflow import settings - settings.import_local_settings() # pylint: ignore + settings.import_local_settings() with self.assertRaises(AttributeError): - settings.not_policy() + settings.not_policy() # pylint: disable=no-member def test_import_with_dunder_all(self): """ @@ -106,7 +105,7 @@ def test_import_with_dunder_all(self): """ with SettingsContext(SETTINGS_FILE_POLICY_WITH_DUNDER_ALL, "airflow_local_settings"): from airflow import settings - settings.import_local_settings() # pylint: ignore + settings.import_local_settings() task_instance = MagicMock() settings.policy(task_instance) @@ -130,7 +129,7 @@ def test_policy_function(self): """ with SettingsContext(SETTINGS_FILE_POLICY, "airflow_local_settings"): from airflow import settings - settings.import_local_settings() # pylint: ignore + settings.import_local_settings() task_instance = MagicMock() settings.policy(task_instance) @@ -144,7 +143,7 @@ def test_pod_mutation_hook(self): """ with SettingsContext(SETTINGS_FILE_POD_MUTATION_HOOK, "airflow_local_settings"): from airflow import settings - settings.import_local_settings() # pylint: ignore + settings.import_local_settings() pod = MagicMock() settings.pod_mutation_hook(pod) diff --git a/tests/ti_deps/deps/test_trigger_rule_dep.py b/tests/ti_deps/deps/test_trigger_rule_dep.py index 37386d6b18fef..fbbafd5984ab5 100644 --- a/tests/ti_deps/deps/test_trigger_rule_dep.py +++ b/tests/ti_deps/deps/test_trigger_rule_dep.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file