diff --git a/.gitignore b/.gitignore index 114236d467885..089d7b1e2f7e3 100644 --- a/.gitignore +++ b/.gitignore @@ -152,3 +152,8 @@ airflow-*.err airflow-*.out airflow-*.log airflow-*.pid + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json diff --git a/.travis.yml b/.travis.yml index fed12b95268bf..adf19a63a2364 100644 --- a/.travis.yml +++ b/.travis.yml @@ -45,6 +45,10 @@ jobs: stage: pre-test install: pip install flake8 script: flake8 + - name: mypy + stage: pre-test + install: pip install mypy + script: mypy airflow - name: Check license header stage: pre-test install: skip diff --git a/airflow/__init__.py b/airflow/__init__.py index 6f0da74e1bc73..c9b0bdd678ddd 100644 --- a/airflow/__init__.py +++ b/airflow/__init__.py @@ -38,7 +38,7 @@ from airflow.exceptions import AirflowException if settings.DAGS_FOLDER not in sys.path: - sys.path.append(settings.DAGS_FOLDER) + sys.path.append(settings.DAGS_FOLDER) # type: ignore login = None diff --git a/airflow/api/__init__.py b/airflow/api/__init__.py index b4a2f8f5bc36f..67cb51d199e63 100644 --- a/airflow/api/__init__.py +++ b/airflow/api/__init__.py @@ -19,13 +19,15 @@ from __future__ import print_function +from typing import Any + from airflow.exceptions import AirflowException from airflow import configuration as conf from importlib import import_module from airflow.utils.log.logging_mixin import LoggingMixin -api_auth = None +api_auth = None # type: Any log = LoggingMixin().log diff --git a/airflow/api/auth/backend/kerberos_auth.py b/airflow/api/auth/backend/kerberos_auth.py index 50a88106fce29..f95f4d257191c 100644 --- a/airflow/api/auth/backend/kerberos_auth.py +++ b/airflow/api/auth/backend/kerberos_auth.py @@ -34,7 +34,7 @@ from airflow import configuration as conf from flask import Response -from flask import _request_ctx_stack as stack +from flask import _request_ctx_stack as stack # type: ignore from flask import make_response from flask import request from flask import g diff --git a/airflow/bin/cli.py b/airflow/bin/cli.py index fae3cb7a27489..9b1b8dbb4f6c2 100644 --- a/airflow/bin/cli.py +++ b/airflow/bin/cli.py @@ -33,7 +33,6 @@ import argparse from argparse import RawTextHelpFormatter from builtins import input -from collections import namedtuple from airflow.utils.timezone import parse as parsedate import json @@ -49,6 +48,7 @@ import psutil import re from urllib.parse import urlunparse +from typing import Any import airflow from airflow import api @@ -69,7 +69,7 @@ from sqlalchemy.orm import exc api.load_auth() -api_module = import_module(conf.get('cli', 'api_client')) +api_module = import_module(conf.get('cli', 'api_client')) # type: Any api_client = api_module.Client(api_base_url=conf.get('cli', 'endpoint_url'), auth=api.api_auth.client_auth) @@ -1664,9 +1664,17 @@ def sync_perm(args): dag.access_control) -Arg = namedtuple( - 'Arg', ['flags', 'help', 'action', 'default', 'nargs', 'type', 'choices', 'metavar']) -Arg.__new__.__defaults__ = (None, None, None, None, None, None, None) +class Arg(object): + def __init__(self, flags=None, help=None, action=None, default=None, nargs=None, + type=None, choices=None, metavar=None): + self.flags = flags + self.help = help + self.action = action + self.default = default + self.nargs = nargs + self.type = type + self.choices = choices + self.metavar = metavar class CLIFactory(object): @@ -2380,8 +2388,8 @@ def get_parser(cls, dag_parser=False): continue arg = cls.args[arg] kwargs = { - f: getattr(arg, f) - for f in arg._fields if f != 'flags' and getattr(arg, f)} + f: v + for f, v in vars(arg).items() if f != 'flags' and v} sp.add_argument(*arg.flags, **kwargs) sp.set_defaults(func=sub['func']) return parser diff --git a/airflow/config_templates/airflow_local_settings.py b/airflow/config_templates/airflow_local_settings.py index b82755c939120..a25f56d076647 100644 --- a/airflow/config_templates/airflow_local_settings.py +++ b/airflow/config_templates/airflow_local_settings.py @@ -18,6 +18,7 @@ # under the License. import os +from typing import Dict, Any from airflow import configuration as conf from airflow.utils.file import mkdirs @@ -107,7 +108,7 @@ 'handlers': ['console'], 'level': LOG_LEVEL, } -} +} # type: Dict[str, Any] DEFAULT_DAG_PARSING_LOGGING_CONFIG = { 'handlers': { diff --git a/airflow/contrib/example_dags/example_gcs_to_bq_operator.py b/airflow/contrib/example_dags/example_gcs_to_bq_operator.py index ee9fe093913e2..b76d05eec138e 100644 --- a/airflow/contrib/example_dags/example_gcs_to_bq_operator.py +++ b/airflow/contrib/example_dags/example_gcs_to_bq_operator.py @@ -17,13 +17,17 @@ # specific language governing permissions and limitations # under the License. +from typing import Any + import airflow +from airflow import models +from airflow.operators import bash_operator + +gcs_to_bq = None # type: Any try: from airflow.contrib.operators import gcs_to_bq except ImportError: - gcs_to_bq = None -from airflow import models -from airflow.operators import bash_operator + pass if gcs_to_bq is not None: diff --git a/airflow/contrib/example_dags/example_winrm_operator.py b/airflow/contrib/example_dags/example_winrm_operator.py index 195bf5d98d03c..83e1844ebb042 100644 --- a/airflow/contrib/example_dags/example_winrm_operator.py +++ b/airflow/contrib/example_dags/example_winrm_operator.py @@ -31,7 +31,7 @@ from airflow.models import DAG from datetime import timedelta -from airflow.contrib.hooks import WinRMHook +from airflow.contrib.hooks.winrm_hook import WinRMHook from airflow.contrib.operators.winrm_operator import WinRMOperator diff --git a/airflow/contrib/hooks/databricks_hook.py b/airflow/contrib/hooks/databricks_hook.py index 4bca95d44ad0a..a3cd31617a4c0 100644 --- a/airflow/contrib/hooks/databricks_hook.py +++ b/airflow/contrib/hooks/databricks_hook.py @@ -25,11 +25,7 @@ from requests import exceptions as requests_exceptions from requests.auth import AuthBase from time import sleep - -try: - from urllib import parse as urlparse -except ImportError: - import urlparse +from six.moves.urllib import parse as urlparse RESTART_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/restart") START_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/start") diff --git a/airflow/contrib/hooks/gcp_api_base_hook.py b/airflow/contrib/hooks/gcp_api_base_hook.py index f24ad48c28415..853bf4010817d 100644 --- a/airflow/contrib/hooks/gcp_api_base_hook.py +++ b/airflow/contrib/hooks/gcp_api_base_hook.py @@ -159,6 +159,7 @@ def _get_field(self, f, default=None): def project_id(self): return self._get_field('project') + @staticmethod def fallback_to_default_project_id(func): """ Decorator that provides fallback for Google Cloud Platform project id. If @@ -186,8 +187,6 @@ def inner_wrapper(self, *args, **kwargs): return func(self, *args, **kwargs) return inner_wrapper - fallback_to_default_project_id = staticmethod(fallback_to_default_project_id) - def _get_project_id(self, project_id): """ In case project_id is None, overrides it with default project_id from diff --git a/airflow/contrib/kubernetes/pod_launcher.py b/airflow/contrib/kubernetes/pod_launcher.py index 2704fd9d32715..7a9b92c591f3a 100644 --- a/airflow/contrib/kubernetes/pod_launcher.py +++ b/airflow/contrib/kubernetes/pod_launcher.py @@ -17,6 +17,7 @@ import json import time +from typing import Tuple, Optional from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.state import State from datetime import datetime as dt @@ -70,7 +71,7 @@ def delete_pod(self, pod): raise def run_pod(self, pod, startup_timeout=120, get_logs=True): - # type: (Pod, int, bool) -> (State, str) + # type: (Pod, int, bool) -> Tuple[State, Optional[str]] """ Launches the pod synchronously and waits for completion. Args: @@ -91,7 +92,7 @@ def run_pod(self, pod, startup_timeout=120, get_logs=True): return self._monitor_pod(pod, get_logs) def _monitor_pod(self, pod, get_logs): - # type: (Pod, bool) -> (State, str) + # type: (Pod, bool) -> Tuple[State, Optional[str]] if get_logs: logs = self._client.read_namespaced_pod_log( diff --git a/airflow/contrib/operators/adls_list_operator.py b/airflow/contrib/operators/adls_list_operator.py index 7d03e86b176b1..33c99064aa27b 100644 --- a/airflow/contrib/operators/adls_list_operator.py +++ b/airflow/contrib/operators/adls_list_operator.py @@ -17,6 +17,8 @@ # specific language governing permissions and limitations # under the License. +from typing import Iterable + from airflow.contrib.hooks.azure_data_lake_hook import AzureDataLakeHook from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults @@ -46,7 +48,7 @@ class AzureDataLakeStorageListOperator(BaseOperator): azure_data_lake_conn_id='azure_data_lake_default' ) """ - template_fields = ('path',) + template_fields = ('path',) # type: Iterable[str] ui_color = '#901dd2' @apply_defaults diff --git a/airflow/contrib/operators/azure_container_instances_operator.py b/airflow/contrib/operators/azure_container_instances_operator.py index 7ec6ca6ab8d5e..2d01e13f65d1b 100644 --- a/airflow/contrib/operators/azure_container_instances_operator.py +++ b/airflow/contrib/operators/azure_container_instances_operator.py @@ -17,7 +17,9 @@ # specific language governing permissions and limitations # under the License. +from collections import namedtuple from time import sleep +from typing import Dict, Sequence from airflow.contrib.hooks.azure_container_instance_hook import AzureContainerInstanceHook from airflow.contrib.hooks.azure_container_registry_hook import AzureContainerRegistryHook @@ -36,8 +38,13 @@ from msrestazure.azure_exceptions import CloudError -DEFAULT_ENVIRONMENT_VARIABLES = {} -DEFAULT_VOLUMES = [] +Volume = namedtuple( + 'Volume', + ['conn_id', 'account_name', 'share_name', 'mount_path', 'read_only'], +) + +DEFAULT_ENVIRONMENT_VARIABLES = {} # type: Dict[str, str] +DEFAULT_VOLUMES = [] # type: Sequence[Volume] DEFAULT_MEMORY_IN_GB = 2.0 DEFAULT_CPU = 1.0 @@ -98,7 +105,6 @@ class AzureContainerInstancesOperator(BaseOperator): """ template_fields = ('name', 'environment_variables') - template_ext = tuple() @apply_defaults def __init__(self, ci_conn_id, registry_conn_id, resource_group, name, image, region, diff --git a/airflow/contrib/operators/gcp_bigtable_operator.py b/airflow/contrib/operators/gcp_bigtable_operator.py index 48fd632de024d..d99746798dc28 100644 --- a/airflow/contrib/operators/gcp_bigtable_operator.py +++ b/airflow/contrib/operators/gcp_bigtable_operator.py @@ -17,6 +17,7 @@ # specific language governing permissions and limitations # under the License. +from typing import Iterable import google.api_core.exceptions from airflow import AirflowException @@ -33,7 +34,7 @@ class BigtableValidationMixin(object): Common class for Cloud Bigtable operators for validating required fields. """ - REQUIRED_ATTRIBUTES = [] + REQUIRED_ATTRIBUTES = [] # type: Iterable[str] def _validate_inputs(self): for attr_name in self.REQUIRED_ATTRIBUTES: diff --git a/airflow/contrib/operators/gcs_list_operator.py b/airflow/contrib/operators/gcs_list_operator.py index 7b37b269a6035..056b349394016 100644 --- a/airflow/contrib/operators/gcs_list_operator.py +++ b/airflow/contrib/operators/gcs_list_operator.py @@ -17,6 +17,8 @@ # specific language governing permissions and limitations # under the License. +from typing import Iterable + from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults @@ -58,7 +60,7 @@ class GoogleCloudStorageListOperator(BaseOperator): google_cloud_storage_conn_id=google_cloud_conn_id ) """ - template_fields = ('bucket', 'prefix', 'delimiter') + template_fields = ('bucket', 'prefix', 'delimiter') # type: Iterable[str] ui_color = '#f0eee4' @apply_defaults diff --git a/airflow/contrib/operators/jenkins_job_trigger_operator.py b/airflow/contrib/operators/jenkins_job_trigger_operator.py index 8bb81c90c7dbd..3af87f3f3bed7 100644 --- a/airflow/contrib/operators/jenkins_job_trigger_operator.py +++ b/airflow/contrib/operators/jenkins_job_trigger_operator.py @@ -27,13 +27,9 @@ import jenkins from jenkins import JenkinsException from requests import Request +import six from six.moves.urllib.error import HTTPError, URLError -try: - basestring -except NameError: - basestring = str # For python3 compatibility - def jenkins_request_with_headers(jenkins_server, req): """ @@ -138,7 +134,7 @@ def build_job(self, jenkins_server): """ # Warning if the parameter is too long, the URL can be longer than # the maximum allowed size - if self.parameters and isinstance(self.parameters, basestring): + if self.parameters and isinstance(self.parameters, six.string_types): import ast self.parameters = ast.literal_eval(self.parameters) diff --git a/airflow/contrib/operators/kubernetes_pod_operator.py b/airflow/contrib/operators/kubernetes_pod_operator.py index b8ec35fb639a6..848d010b8cba6 100644 --- a/airflow/contrib/operators/kubernetes_pod_operator.py +++ b/airflow/contrib/operators/kubernetes_pod_operator.py @@ -25,10 +25,6 @@ from airflow.contrib.kubernetes.volume import Volume # noqa from airflow.contrib.kubernetes.secret import Secret # noqa -template_fields = ('templates_dict',) -template_ext = tuple() -ui_color = '#ffefeb' - class KubernetesPodOperator(BaseOperator): """ diff --git a/airflow/contrib/operators/qubole_operator.py b/airflow/contrib/operators/qubole_operator.py index 2f4642ba65dab..71f84a293005b 100755 --- a/airflow/contrib/operators/qubole_operator.py +++ b/airflow/contrib/operators/qubole_operator.py @@ -17,6 +17,8 @@ # specific language governing permissions and limitations # under the License. +from typing import Iterable + from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults from airflow.contrib.hooks.qubole_hook import QuboleHook, COMMAND_ARGS, HYPHEN_ARGS, \ @@ -141,9 +143,9 @@ class QuboleOperator(BaseOperator): 'extract_query', 'boundary_query', 'macros', 'name', 'parameters', 'dbtap_id', 'hive_table', 'db_table', 'split_column', 'note_id', 'db_update_keys', 'export_dir', 'partition_spec', 'qubole_conn_id', - 'arguments', 'user_program_arguments', 'cluster_label') + 'arguments', 'user_program_arguments', 'cluster_label') # type: Iterable[str] - template_ext = ('.txt',) + template_ext = ('.txt',) # type: Iterable[str] ui_color = '#3064A1' ui_fgcolor = '#fff' qubole_hook_allowed_args_list = ['command_type', 'qubole_conn_id', 'fetch_logs'] diff --git a/airflow/contrib/operators/s3_list_operator.py b/airflow/contrib/operators/s3_list_operator.py index 9c67c2fa3b78e..d5cccab88ec8f 100644 --- a/airflow/contrib/operators/s3_list_operator.py +++ b/airflow/contrib/operators/s3_list_operator.py @@ -17,6 +17,8 @@ # specific language governing permissions and limitations # under the License. +from typing import Iterable + from airflow.hooks.S3_hook import S3Hook from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults @@ -64,7 +66,7 @@ class S3ListOperator(BaseOperator): aws_conn_id='aws_customers_conn' ) """ - template_fields = ('bucket', 'prefix', 'delimiter') + template_fields = ('bucket', 'prefix', 'delimiter') # type: Iterable[str] ui_color = '#ffd700' @apply_defaults diff --git a/airflow/contrib/operators/sagemaker_base_operator.py b/airflow/contrib/operators/sagemaker_base_operator.py index 08d6d0eb6a83c..d6717fd6b41fc 100644 --- a/airflow/contrib/operators/sagemaker_base_operator.py +++ b/airflow/contrib/operators/sagemaker_base_operator.py @@ -19,6 +19,8 @@ import json +from typing import Iterable + from airflow.contrib.hooks.sagemaker_hook import SageMakerHook from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults @@ -38,7 +40,7 @@ class SageMakerBaseOperator(BaseOperator): template_ext = () ui_color = '#ededed' - integer_fields = [] + integer_fields = [] # type: Iterable[Iterable[str]] @apply_defaults def __init__(self, diff --git a/airflow/contrib/sensors/python_sensor.py b/airflow/contrib/sensors/python_sensor.py index ecdc7e6c8ea40..44d98e695b38d 100644 --- a/airflow/contrib/sensors/python_sensor.py +++ b/airflow/contrib/sensors/python_sensor.py @@ -52,7 +52,6 @@ class PythonSensor(BaseSensorOperator): """ template_fields = ('templates_dict',) - template_ext = tuple() @apply_defaults def __init__( diff --git a/airflow/contrib/utils/gcp_field_sanitizer.py b/airflow/contrib/utils/gcp_field_sanitizer.py index bf9f9348f56a7..6621bb4ee1e39 100644 --- a/airflow/contrib/utils/gcp_field_sanitizer.py +++ b/airflow/contrib/utils/gcp_field_sanitizer.py @@ -98,6 +98,8 @@ components in all elements of the array. """ +from typing import List + from airflow import LoggingMixin, AirflowException @@ -118,7 +120,7 @@ class GcpBodyFieldSanitizer(LoggingMixin): """ def __init__(self, sanitize_specs): - # type: ([str]) -> None + # type: (List[str]) -> None super(GcpBodyFieldSanitizer, self).__init__() self._sanitize_specs = sanitize_specs diff --git a/airflow/contrib/utils/gcp_field_validator.py b/airflow/contrib/utils/gcp_field_validator.py index 0337be1422454..d2a01ac7fa229 100644 --- a/airflow/contrib/utils/gcp_field_validator.py +++ b/airflow/contrib/utils/gcp_field_validator.py @@ -132,8 +132,8 @@ """ import re +from typing import Sequence, Dict, Callable -from typing import Callable from airflow import LoggingMixin, AirflowException COMPOSITE_FIELD_TYPES = ['union', 'dict', 'list'] @@ -195,7 +195,7 @@ class GcpBodyFieldValidator(LoggingMixin): """ def __init__(self, validation_specs, api_version): - # type: ([dict], str) -> None + # type: (Sequence[Dict], str) -> None super(GcpBodyFieldValidator, self).__init__() self._validation_specs = validation_specs self._api_version = api_version diff --git a/airflow/hooks/dbapi_hook.py b/airflow/hooks/dbapi_hook.py index 4fce22786997e..2b42325897811 100644 --- a/airflow/hooks/dbapi_hook.py +++ b/airflow/hooks/dbapi_hook.py @@ -22,6 +22,7 @@ from datetime import datetime from contextlib import closing import sys +from typing import Optional from sqlalchemy import create_engine @@ -34,7 +35,7 @@ class DbApiHook(BaseHook): Abstract base class for sql hooks. """ # Override to provide the connection name. - conn_name_attr = None + conn_name_attr = None # type: Optional[str] # Override to have a default connection id for a particular dbHook default_conn_name = 'default_conn_id' # Override if this db supports autocommit. diff --git a/airflow/jobs.py b/airflow/jobs.py index fbb8c558c5340..6b351988d4c38 100644 --- a/airflow/jobs.py +++ b/airflow/jobs.py @@ -32,6 +32,7 @@ import time from collections import defaultdict, OrderedDict from time import sleep +from typing import Any import six from past.builtins import basestring @@ -63,7 +64,7 @@ from airflow.utils.sqlalchemy import UtcDateTime from airflow.utils.state import State -Base = models.base.Base +Base = models.base.Base # type: Any ID_LEN = models.base.ID_LEN diff --git a/airflow/lineage/datasets.py b/airflow/lineage/datasets.py index 49dd492bb775a..260277065b6f0 100644 --- a/airflow/lineage/datasets.py +++ b/airflow/lineage/datasets.py @@ -18,6 +18,7 @@ # under the License. import six +from typing import List from jinja2 import Environment @@ -28,7 +29,7 @@ def _inherited(cls): class DataSet(object): - attributes = [] + attributes = [] # type: List[str] type_name = "dataSet" def __init__(self, qualified_name=None, data=None, **kwargs): diff --git a/airflow/models/__init__.py b/airflow/models/__init__.py index c35b98ceb0ff9..02d0f706666a6 100755 --- a/airflow/models/__init__.py +++ b/airflow/models/__init__.py @@ -22,9 +22,10 @@ from __future__ import print_function from __future__ import unicode_literals -import copy from builtins import ImportError as BuiltinImportError, bytes, object, str from collections import defaultdict, namedtuple, OrderedDict +import copy +from typing import Iterable from future.standard_library import install_aliases @@ -316,6 +317,10 @@ def size(self): """ return len(self.dags) + @property + def dag_ids(self): + return self.dags.keys() + def get_dag(self, dag_id): """ Gets the DAG out of the dictionary, and refreshes it if expired @@ -1022,11 +1027,8 @@ def are_dependents_done(self, session=None): count = ti[0][0] return count == len(task.downstream_task_ids) - @property @provide_session - def previous_ti(self, session=None): - """ The task instance for the task that ran before this task instance """ - + def _get_previous_ti(self, session=None): dag = self.task.dag if dag: dr = self.get_dagrun(session=session) @@ -1052,6 +1054,11 @@ def previous_ti(self, session=None): return None + @property + def previous_ti(self): + """The task instance for the task that ran before this task instance.""" + return self._get_previous_ti() + @provide_session def are_dependencies_met( self, @@ -2034,9 +2041,9 @@ class derived from this one results in the creation of a task object, """ # For derived classes to define which fields will get jinjaified - template_fields = [] + template_fields = [] # type: Iterable[str] # Defines which files extensions to look for in the templated fields - template_ext = [] + template_ext = [] # type: Iterable[str] # Defines the color in the UI ui_color = '#fff' ui_fgcolor = '#000' @@ -2048,7 +2055,7 @@ class derived from this one results in the creation of a task object, '_log',) # each operator should override this class attr for shallow copy attrs. - shallow_copy_attrs = () + shallow_copy_attrs = () # type: Iterable[str] @apply_defaults def __init__( @@ -2984,6 +2991,7 @@ def __init__( orientation=configuration.conf.get('webserver', 'dag_orientation'), catchup=configuration.conf.getboolean('scheduler', 'catchup_by_default'), on_success_callback=None, on_failure_callback=None, + doc_md=None, params=None, access_control=None): @@ -3060,6 +3068,7 @@ def __init__( self.partial = False self.on_success_callback = on_success_callback self.on_failure_callback = on_failure_callback + self.doc_md = doc_md self._old_context_manager_dags = [] self._access_control = access_control @@ -3345,13 +3354,8 @@ def folder(self): def owner(self): return ", ".join(list(set([t.owner for t in self.tasks]))) - @property @provide_session - def concurrency_reached(self, session=None): - """ - Returns a boolean indicating whether the concurrency limit for this DAG - has been reached - """ + def _get_concurrency_reached(self, session=None): TI = TaskInstance qry = session.query(func.count(TI.task_id)).filter( TI.dag_id == self.dag_id, @@ -3360,15 +3364,26 @@ def concurrency_reached(self, session=None): return qry.scalar() >= self.concurrency @property - @provide_session - def is_paused(self, session=None): + def concurrency_reached(self): """ - Returns a boolean indicating whether this DAG is paused + Returns a boolean indicating whether the concurrency limit for this DAG + has been reached """ + return self._get_concurrency_reached() + + @provide_session + def _get_is_paused(self, session=None): qry = session.query(DagModel).filter( DagModel.dag_id == self.dag_id) return qry.value('is_paused') + @property + def is_paused(self): + """ + Returns a boolean indicating whether this DAG is paused + """ + return self._get_is_paused() + @provide_session def handle_callback(self, dagrun, success=True, reason=None, session=None): """ @@ -3450,16 +3465,18 @@ def get_dagrun(self, execution_date, session=None): return dagrun - @property @provide_session - def latest_execution_date(self, session=None): + def _get_latest_execution_date(self, session=None): + return session.query(func.max(DagRun.execution_date)).filter( + DagRun.dag_id == self.dag_id + ).scalar() + + @property + def latest_execution_date(self): """ Returns the latest date for which at least one dag run exists """ - execution_date = session.query(func.max(DagRun.execution_date)).filter( - DagRun.dag_id == self.dag_id - ).scalar() - return execution_date + return self._get_latest_execution_date() @property def subdags(self): diff --git a/airflow/models/base.py b/airflow/models/base.py index 3f5eb80154b73..97c6b777984d8 100644 --- a/airflow/models/base.py +++ b/airflow/models/base.py @@ -17,6 +17,7 @@ # specific language governing permissions and limitations # under the License. +from typing import Any from sqlalchemy import MetaData from sqlalchemy.ext.declarative import declarative_base @@ -24,9 +25,11 @@ SQL_ALCHEMY_SCHEMA = airflow.configuration.get("core", "SQL_ALCHEMY_SCHEMA") -if not SQL_ALCHEMY_SCHEMA or SQL_ALCHEMY_SCHEMA.isspace(): - Base = declarative_base() -else: - Base = declarative_base(metadata=MetaData(schema=SQL_ALCHEMY_SCHEMA)) +metadata = ( + None + if not SQL_ALCHEMY_SCHEMA or SQL_ALCHEMY_SCHEMA.isspace() + else MetaData(schema=SQL_ALCHEMY_SCHEMA) +) +Base = declarative_base(metadata=metadata) # type: Any ID_LEN = 250 diff --git a/airflow/operators/check_operator.py b/airflow/operators/check_operator.py index 404964dad6cc8..dab0ead350661 100644 --- a/airflow/operators/check_operator.py +++ b/airflow/operators/check_operator.py @@ -19,6 +19,7 @@ from builtins import zip from builtins import str +from typing import Iterable from airflow.exceptions import AirflowException from airflow.hooks.base_hook import BaseHook @@ -62,8 +63,8 @@ class CheckOperator(BaseOperator): :type sql: str """ - template_fields = ('sql',) - template_ext = ('.hql', '.sql',) + template_fields = ('sql',) # type: Iterable[str] + template_ext = ('.hql', '.sql',) # type: Iterable[str] ui_color = '#fff7e6' @apply_defaults @@ -120,8 +121,8 @@ class ValueCheckOperator(BaseOperator): __mapper_args__ = { 'polymorphic_identity': 'ValueCheckOperator' } - template_fields = ('sql', 'pass_value',) - template_ext = ('.hql', '.sql',) + template_fields = ('sql', 'pass_value',) # type: Iterable[str] + template_ext = ('.hql', '.sql',) # type: Iterable[str] ui_color = '#fff7e6' @apply_defaults @@ -196,8 +197,8 @@ class IntervalCheckOperator(BaseOperator): __mapper_args__ = { 'polymorphic_identity': 'IntervalCheckOperator' } - template_fields = ('sql1', 'sql2') - template_ext = ('.hql', '.sql',) + template_fields = ('sql1', 'sql2') # type: Iterable[str] + template_ext = ('.hql', '.sql',) # type: Iterable[str] ui_color = '#fff7e6' @apply_defaults diff --git a/airflow/operators/dagrun_operator.py b/airflow/operators/dagrun_operator.py index 7000c6c90bc18..1bc1979a1afb4 100644 --- a/airflow/operators/dagrun_operator.py +++ b/airflow/operators/dagrun_operator.py @@ -53,7 +53,6 @@ class TriggerDagRunOperator(BaseOperator): :type execution_date: str or datetime.datetime """ template_fields = ('trigger_dag_id', 'execution_date') - template_ext = tuple() ui_color = '#ffefeb' @apply_defaults diff --git a/airflow/operators/dummy_operator.py b/airflow/operators/dummy_operator.py index 025a242fd1b94..222c853b1d1fb 100644 --- a/airflow/operators/dummy_operator.py +++ b/airflow/operators/dummy_operator.py @@ -27,7 +27,6 @@ class DummyOperator(BaseOperator): DAG. """ - template_fields = tuple() ui_color = '#e8f7e4' @apply_defaults diff --git a/airflow/operators/python_operator.py b/airflow/operators/python_operator.py index 96b7e6ee2d077..dd8298b28a7aa 100644 --- a/airflow/operators/python_operator.py +++ b/airflow/operators/python_operator.py @@ -69,7 +69,6 @@ class PythonOperator(BaseOperator): :type templates_exts: list[str] """ template_fields = ('templates_dict', 'op_args', 'op_kwargs') - template_ext = tuple() ui_color = '#ffefeb' # since we won't mutate the arguments, we should just do the shallow copy diff --git a/airflow/operators/subdag_operator.py b/airflow/operators/subdag_operator.py index bdee51548d874..812f486e16428 100644 --- a/airflow/operators/subdag_operator.py +++ b/airflow/operators/subdag_operator.py @@ -38,7 +38,6 @@ class SubDagOperator(BaseOperator): :type executor: airflow.executors.base_executor.BaseExecutor """ - template_fields = tuple() ui_color = '#555' ui_fgcolor = '#fff' diff --git a/airflow/plugins_manager.py b/airflow/plugins_manager.py index b3966fad098d2..1c9a2b3613746 100644 --- a/airflow/plugins_manager.py +++ b/airflow/plugins_manager.py @@ -29,6 +29,7 @@ import re import sys import pkg_resources +from typing import List, Any from airflow import configuration from airflow.utils.log.logging_mixin import LoggingMixin @@ -43,17 +44,17 @@ class AirflowPluginException(Exception): class AirflowPlugin(object): - name = None - operators = [] - sensors = [] - hooks = [] - executors = [] - macros = [] - admin_views = [] - flask_blueprints = [] - menu_links = [] - appbuilder_views = [] - appbuilder_menu_items = [] + name = None # type: str + operators = [] # type: List[Any] + sensors = [] # type: List[Any] + hooks = [] # type: List[Any] + executors = [] # type: List[Any] + macros = [] # type: List[Any] + admin_views = [] # type: List[Any] + flask_blueprints = [] # type: List[Any] + menu_links = [] # type: List[Any] + appbuilder_views = [] # type: List[Any] + appbuilder_menu_items = [] # type: List[Any] @classmethod def validate(cls): @@ -122,7 +123,7 @@ def is_valid_plugin(plugin_obj, existing_plugins): if plugins_folder not in sys.path: sys.path.append(plugins_folder) -plugins = [] +plugins = [] # type: List[AirflowPlugin] norm_pattern = re.compile(r'[/|.]') @@ -176,11 +177,11 @@ def make_module(name, objects): macros_modules = [] # Plugin components to integrate directly -admin_views = [] -flask_blueprints = [] -menu_links = [] -flask_appbuilder_views = [] -flask_appbuilder_menu_links = [] +admin_views = [] # type: List[Any] +flask_blueprints = [] # type: List[Any] +menu_links = [] # type: List[Any] +flask_appbuilder_views = [] # type: List[Any] +flask_appbuilder_menu_links = [] # type: List[Any] for p in plugins: operators_modules.append( diff --git a/airflow/sensors/sql_sensor.py b/airflow/sensors/sql_sensor.py index 6f942c4f3b8b6..b1cbe1861efe2 100644 --- a/airflow/sensors/sql_sensor.py +++ b/airflow/sensors/sql_sensor.py @@ -18,6 +18,7 @@ # under the License. from builtins import str +from typing import Iterable from airflow.exceptions import AirflowException from airflow.hooks.base_hook import BaseHook @@ -38,8 +39,8 @@ class SqlSensor(BaseSensorOperator): :param parameters: The parameters to render the SQL query with (optional). :type parameters: mapping or iterable """ - template_fields = ('sql',) - template_ext = ('.hql', '.sql',) + template_fields = ('sql',) # type: Iterable[str] + template_ext = ('.hql', '.sql',) # type: Iterable[str] ui_color = '#7c7287' @apply_defaults diff --git a/airflow/utils/cli_action_loggers.py b/airflow/utils/cli_action_loggers.py index 21304936f3ff5..658fbd2d1e3e0 100644 --- a/airflow/utils/cli_action_loggers.py +++ b/airflow/utils/cli_action_loggers.py @@ -24,6 +24,7 @@ from __future__ import absolute_import import logging +from typing import List, Callable from airflow.utils.db import create_session @@ -98,8 +99,8 @@ def default_action_log(log, **_): session.add(log) -__pre_exec_callbacks = [] -__post_exec_callbacks = [] +__pre_exec_callbacks = [] # type: List[Callable] +__post_exec_callbacks = [] # type: List[Callable] # By default, register default action log into pre-execution callback register_pre_exec_callback(default_action_log) diff --git a/airflow/utils/dag_processing.py b/airflow/utils/dag_processing.py index c449c3fac9745..0ef4d01dcc62c 100644 --- a/airflow/utils/dag_processing.py +++ b/airflow/utils/dag_processing.py @@ -35,6 +35,7 @@ from collections import namedtuple from datetime import timedelta from importlib import import_module +import enum import psutil from six.moves import range, reload_module @@ -436,10 +437,11 @@ def file_path(self): 'all_files_processed', 'result_count']) -DagParsingSignal = namedtuple( - 'DagParsingSignal', - ['AGENT_HEARTBEAT', 'MANAGER_DONE', 'TERMINATE_MANAGER', 'END_MANAGER'])( - 'agent_heartbeat', 'manager_done', 'terminate_manager', 'end_manager') +class DagParsingSignal(enum.Enum): + AGENT_HEARTBEAT = 'agent_heartbeat' + MANAGER_DONE = 'manager_done' + TERMINATE_MANAGER = 'terminate_manager' + END_MANAGER = 'end_manager' class DagFileProcessorAgent(LoggingMixin): diff --git a/airflow/utils/trigger_rule.py b/airflow/utils/trigger_rule.py index 4f7db65f7bae5..81fad6b7bfe76 100644 --- a/airflow/utils/trigger_rule.py +++ b/airflow/utils/trigger_rule.py @@ -20,6 +20,7 @@ from __future__ import unicode_literals from builtins import object +from typing import Set class TriggerRule(object): @@ -31,7 +32,7 @@ class TriggerRule(object): DUMMY = 'dummy' NONE_FAILED = 'none_failed' - _ALL_TRIGGER_RULES = {} + _ALL_TRIGGER_RULES = set() # type: Set[str] @classmethod def is_valid(cls, trigger_rule): diff --git a/airflow/utils/weight_rule.py b/airflow/utils/weight_rule.py index f34856be8311e..f7f85c0734ac6 100644 --- a/airflow/utils/weight_rule.py +++ b/airflow/utils/weight_rule.py @@ -20,6 +20,7 @@ from __future__ import unicode_literals from builtins import object +from typing import Set class WeightRule(object): @@ -27,7 +28,7 @@ class WeightRule(object): UPSTREAM = 'upstream' ABSOLUTE = 'absolute' - _ALL_WEIGHT_RULES = {} + _ALL_WEIGHT_RULES = set() # type: Set[str] @classmethod def is_valid(cls, weight_rule): diff --git a/airflow/www/app.py b/airflow/www/app.py index ca82175bc6aab..a92f0613cd623 100644 --- a/airflow/www/app.py +++ b/airflow/www/app.py @@ -19,8 +19,9 @@ # import logging import socket -import six +from typing import Any +import six from flask import Flask from flask_appbuilder import AppBuilder, SQLA from flask_caching import Cache @@ -34,7 +35,7 @@ from airflow.logging_config import configure_logging from airflow.www.static_config import configure_manifest_files -app = None +app = None # type: Any appbuilder = None csrf = CSRFProtect() diff --git a/airflow/www/static_config.py b/airflow/www/static_config.py index 278c499eee5e1..4be0e38a19517 100644 --- a/airflow/www/static_config.py +++ b/airflow/www/static_config.py @@ -18,10 +18,11 @@ # under the License. from __future__ import print_function -import json import os +import json +from typing import Dict -manifest = dict() +manifest = dict() # type: Dict[str, str] def configure_manifest_files(app): diff --git a/airflow/www/views.py b/airflow/www/views.py index 6c556bf411d8a..9486678ba4f02 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -84,7 +84,7 @@ if os.environ.get('SKIP_DAGS_PARSING') != 'True': dagbag = models.DagBag(settings.DAGS_FOLDER) else: - dagbag = models.DagBag + dagbag = models.DagBag(os.devnull, include_examples=False) def get_date_time_num_runs_dag_runs_form_data(request, session, dag): diff --git a/setup.cfg b/setup.cfg index f29267e269727..35785162aaf45 100644 --- a/setup.cfg +++ b/setup.cfg @@ -35,3 +35,15 @@ all_files = 1 upload-dir = docs/_build/html [easy_install] + +[mypy] +ignore_missing_imports = True + +[mypy-airflow.migrations.*] +ignore_errors = True + +[mypy-airflow._vendor.*] +ignore_errors = True + +[mypy-airflow.contrib.auth.*] +ignore_errors = True diff --git a/setup.py b/setup.py index fb09f95279a8b..5586b6901a429 100644 --- a/setup.py +++ b/setup.py @@ -251,7 +251,6 @@ def write_version(filename=os.path.join(*['airflow', 'rednose', 'requests_mock', 'flake8>=3.6.0', - 'typing', ] if not PY3: @@ -320,6 +319,7 @@ def do_setup(): 'tabulate>=0.7.5, <0.9', 'tenacity==4.12.0', 'text-unidecode==1.2', + 'typing;python_version<"3.5"', 'thrift>=0.9.2', 'tzlocal>=1.4', 'unicodecsv>=0.14.1',