diff --git a/.codecov.yml b/.codecov.yml deleted file mode 100644 index 4094f35dccf0d..0000000000000 --- a/.codecov.yml +++ /dev/null @@ -1,19 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# keep default diff --git a/.readthedocs.yml b/.readthedocs.yml index c6a4da8d690c8..87fb227d81cb5 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -19,7 +19,14 @@ python: pip_install: true extra_requirements: + - all_dbs + - databricks - doc - docker - - gcp_api - emr + - gcp_api + - s3 + - salesforce + - sendgrid + - ssh + - slack diff --git a/.travis.yml b/.travis.yml index 9c7cfd02084db..5bd750453a563 100644 --- a/.travis.yml +++ b/.travis.yml @@ -42,7 +42,7 @@ cache: before_install: - sudo ls -lh $HOME/.cache/pip/ - sudo rm -rf $HOME/.cache/pip/* $HOME/.wheelhouse/* - - sudo chown -R travis.travis $HOME/.cache/pip + - sudo chown -R travis:travis $HOME/.cache/pip install: # Use recent docker-compose version - sudo rm /usr/local/bin/docker-compose @@ -50,9 +50,5 @@ install: - chmod +x docker-compose - sudo mv docker-compose /usr/local/bin - pip install --upgrade pip - - pip install codecov script: - docker-compose --log-level ERROR -f scripts/ci/docker-compose.yml run airflow-testing /app/scripts/ci/run-ci.sh -after_success: - - sudo chown -R travis.travis . - - codecov diff --git a/CHANGELOG.txt b/CHANGELOG.txt index 06d5aed3f1011..b4ee1755b467a 100644 --- a/CHANGELOG.txt +++ b/CHANGELOG.txt @@ -780,7 +780,6 @@ AIRFLOW 1.10.0, 2018-08-03 [AIRFLOW-1609] Fix gitignore to ignore all venvs [AIRFLOW-1601] Add configurable task cleanup time ->>>>>>> 862ad8b9... [AIRFLOW-XXX] Update changelog for 1.10 AIRFLOW 1.9.0, 2018-01-02 ------------------------- diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index beaf609b5a836..152d5d9aabf16 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -88,7 +88,7 @@ There are three ways to setup an Apache Airflow development environment. 1. Using tools and libraries installed directly on your system. Install Python (2.7.x or 3.4.x), MySQL, and libxml by using system-level package - managers like yum, apt-get for Linux, or Homebrew for Mac OS at first. Refer to the [base CI Dockerfile](https://github.com/apache/incubator-airflow-ci/blob/master/Dockerfile.base) for + managers like yum, apt-get for Linux, or Homebrew for Mac OS at first. Refer to the [base CI Dockerfile](https://github.com/apache/incubator-airflow-ci/blob/master/Dockerfile) for a comprehensive list of required packages. Then install python development requirements. It is usually best to work in a virtualenv: diff --git a/LICENSE b/LICENSE index f2c490e99b0b3..68ca6b6117c19 100644 --- a/LICENSE +++ b/LICENSE @@ -220,6 +220,7 @@ at licenses/LICENSE-[project].txt. (ALv2 License) hue (https://github.com/cloudera/hue/) (ALv2 License) jqclock (https://github.com/JohnRDOrazio/jQuery-Clock-Plugin) (ALv2 License) bootstrap3-typeahead (https://github.com/bassjobsen/Bootstrap-3-Typeahead) + (ALv2 License) airflow.contrib.auth.backends.github_enterprise_auth ======================================================================== MIT licenses diff --git a/NOTICE b/NOTICE index 544c0807667fb..99069f7a40bce 100644 --- a/NOTICE +++ b/NOTICE @@ -13,6 +13,11 @@ is subject to the terms and conditions of their respective licenses. See the LICENSE file for a list of subcomponents and dependencies and their respective licenses. +airflow.contrib.auth.backends.github_enterprise_auth: +----------------------------------------------------- + +* Copyright 2015 Matthew Pelland (matt@pelland.io) + hue: ----- This product contains a modified portion of 'Hue' developed by Cloudera, Inc. diff --git a/README.md b/README.md index c4ae6c08fd1fe..211d9844d1890 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,7 @@ [![Coverage Status](https://img.shields.io/codecov/c/github/apache/incubator-airflow/master.svg)](https://codecov.io/github/apache/incubator-airflow?branch=master) [![Documentation Status](https://readthedocs.org/projects/airflow/badge/?version=latest)](https://airflow.readthedocs.io/en/latest/?badge=latest) [![License](http://img.shields.io/:license-Apache%202-blue.svg)](http://www.apache.org/licenses/LICENSE-2.0.txt) +[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/apache-airflow.svg)](https://pypi.org/project/apache-airflow/) [![Join the chat at https://gitter.im/apache/incubator-airflow](https://badges.gitter.im/apache/incubator-airflow.svg)](https://gitter.im/apache/incubator-airflow?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) _NOTE: The transition from 1.8.0 (or before) to 1.8.1 (or after) requires uninstalling Airflow before installing the new version. The package name was changed from `airflow` to `apache-airflow` as of version 1.8.1._ @@ -93,6 +94,7 @@ if you may. Currently **officially** using Airflow: 1. [6play](https://www.6play.fr) [[@lemourA](https://github.com/lemoura), [@achaussende](https://github.com/achaussende), [@d-nguyen](https://github.com/d-nguyen), [@julien-gm](https://github.com/julien-gm)] +1. [8fit](https://8fit.com/) [[@nicor88](https://github.com/nicor88), [@frnzska](https://github.com/frnzska)] 1. [AdBOOST](https://www.adboost.sk) [[AdBOOST](https://github.com/AdBOOST)] 1. [Agari](https://github.com/agaridata) [[@r39132](https://github.com/r39132)] 1. [Airbnb](http://airbnb.io/) [[@mistercrunch](https://github.com/mistercrunch), [@artwr](https://github.com/artwr)] @@ -152,6 +154,7 @@ Currently **officially** using Airflow: 1. [eRevalue](https://www.datamaran.com) [[@hamedhsn](https://github.com/hamedhsn)] 1. [evo.company](https://evo.company/) [[@orhideous](https://github.com/orhideous)] 1. [Flipp](https://www.flipp.com) [[@sethwilsonwishabi](https://github.com/sethwilsonwishabi)] +1. [Format](https://www.format.com) [[@format](https://github.com/4ormat) & [@jasonicarter](https://github.com/jasonicarter)] 1. [FreshBooks](https://github.com/freshbooks) [[@DinoCow](https://github.com/DinoCow)] 1. [Fundera](https://fundera.com) [[@andyxhadji](https://github.com/andyxhadji)] 1. [G Adventures](https://gadventures.com) [[@samuelmullin](https://github.com/samuelmullin)] @@ -250,6 +253,7 @@ Currently **officially** using Airflow: 1. [Stripe](https://stripe.com) [[@jbalogh](https://github.com/jbalogh)] 1. [Strongmind](https://www.strongmind.com) [[@tomchapin](https://github.com/tomchapin) & [@wongstein](https://github.com/wongstein)] 1. [Tails.com](https://tails.com/) [[@alanmcruickshank](https://github.com/alanmcruickshank)] +1. [THE ICONIC](https://www.theiconic.com.au/) [[@revathijay](https://github.com/revathijay)] [[@ilikedata](https://github.com/ilikedata)] 1. [Thinking Machines](https://thinkingmachin.es) [[@marksteve](https://github.com/marksteve)] 1. [Thinknear](https://www.thinknear.com/) [[@d3cay1](https://github.com/d3cay1), [@ccson](https://github.com/ccson), & [@ababian](https://github.com/ababian)] 1. [Thumbtack](https://www.thumbtack.com/) [[@natekupp](https://github.com/natekupp)] diff --git a/airflow/api/common/experimental/mark_tasks.py b/airflow/api/common/experimental/mark_tasks.py index 88c5275f5a4e8..2fac1254cd698 100644 --- a/airflow/api/common/experimental/mark_tasks.py +++ b/airflow/api/common/experimental/mark_tasks.py @@ -208,6 +208,7 @@ def _set_dag_run_state(dag_id, execution_date, state, session=None): dr.state = state if state == State.RUNNING: dr.start_date = timezone.utcnow() + dr.end_date = None else: dr.end_date = timezone.utcnow() session.commit() diff --git a/airflow/configuration.py b/airflow/configuration.py index 9e80648c74ec0..33376285be483 100644 --- a/airflow/configuration.py +++ b/airflow/configuration.py @@ -139,6 +139,7 @@ class AirflowConfigParser(ConfigParser): 'celery': { # Remove these keys in Airflow 1.11 'worker_concurrency': 'celeryd_concurrency', + 'result_backend': 'celery_result_backend', 'broker_url': 'celery_broker_url', 'ssl_active': 'celery_ssl_active', 'ssl_cert': 'celery_ssl_cert', diff --git a/airflow/contrib/auth/backends/github_enterprise_auth.py b/airflow/contrib/auth/backends/github_enterprise_auth.py index 08fa0d7929d2c..a0f23935b5434 100644 --- a/airflow/contrib/auth/backends/github_enterprise_auth.py +++ b/airflow/contrib/auth/backends/github_enterprise_auth.py @@ -1,21 +1,19 @@ # -*- coding: utf-8 -*- # -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at +# See the NOTICE file distributed with this work for additional information +# regarding copyright ownership. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import flask_login # Need to expose these downstream diff --git a/airflow/contrib/hooks/aws_hook.py b/airflow/contrib/hooks/aws_hook.py index 8ca1f3d744a53..448de63ffe989 100644 --- a/airflow/contrib/hooks/aws_hook.py +++ b/airflow/contrib/hooks/aws_hook.py @@ -84,8 +84,9 @@ class AwsHook(BaseHook): This class is a thin wrapper around the boto3 python library. """ - def __init__(self, aws_conn_id='aws_default'): + def __init__(self, aws_conn_id='aws_default', verify=None): self.aws_conn_id = aws_conn_id + self.verify = verify def _get_credentials(self, region_name): aws_access_key_id = None @@ -162,12 +163,14 @@ def _get_credentials(self, region_name): def get_client_type(self, client_type, region_name=None): session, endpoint_url = self._get_credentials(region_name) - return session.client(client_type, endpoint_url=endpoint_url) + return session.client(client_type, endpoint_url=endpoint_url, + verify=self.verify) def get_resource_type(self, resource_type, region_name=None): session, endpoint_url = self._get_credentials(region_name) - return session.resource(resource_type, endpoint_url=endpoint_url) + return session.resource(resource_type, endpoint_url=endpoint_url, + verify=self.verify) def get_session(self, region_name=None): """Get the underlying boto3.session.""" diff --git a/airflow/contrib/hooks/bigquery_hook.py b/airflow/contrib/hooks/bigquery_hook.py index e4c0653bfe278..44ecd49e9edcd 100644 --- a/airflow/contrib/hooks/bigquery_hook.py +++ b/airflow/contrib/hooks/bigquery_hook.py @@ -24,6 +24,7 @@ import time from builtins import range +from copy import deepcopy from past.builtins import basestring @@ -195,10 +196,19 @@ class BigQueryBaseCursor(LoggingMixin): PEP 249 cursor isn't needed. """ - def __init__(self, service, project_id, use_legacy_sql=True): + def __init__(self, + service, + project_id, + use_legacy_sql=True, + api_resource_configs=None): + self.service = service self.project_id = project_id self.use_legacy_sql = use_legacy_sql + if api_resource_configs: + _validate_value("api_resource_configs", api_resource_configs, dict) + self.api_resource_configs = api_resource_configs \ + if api_resource_configs else {} self.running_job_id = None def create_empty_table(self, @@ -238,8 +248,7 @@ def create_empty_table(self, :return: """ - if time_partitioning is None: - time_partitioning = dict() + project_id = project_id if project_id is not None else self.project_id table_resource = { @@ -473,11 +482,11 @@ def create_external_table(self, def run_query(self, bql=None, sql=None, - destination_dataset_table=False, + destination_dataset_table=None, write_disposition='WRITE_EMPTY', allow_large_results=False, - flatten_results=False, - udf_config=False, + flatten_results=None, + udf_config=None, use_legacy_sql=None, maximum_billing_tier=None, maximum_bytes_billed=None, @@ -486,7 +495,8 @@ def run_query(self, labels=None, schema_update_options=(), priority='INTERACTIVE', - time_partitioning=None): + time_partitioning=None, + api_resource_configs=None): """ Executes a BigQuery SQL query. Optionally persists results in a BigQuery table. See here: @@ -518,6 +528,13 @@ def run_query(self, :param use_legacy_sql: Whether to use legacy SQL (true) or standard SQL (false). If `None`, defaults to `self.use_legacy_sql`. :type use_legacy_sql: boolean + :param api_resource_configs: a dictionary that contain params + 'configuration' applied for Google BigQuery Jobs API: + https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs + for example, {'query': {'useQueryCache': False}}. You could use it + if you need to provide some params that are not supported by the + BigQueryHook like args. + :type api_resource_configs: dict :type udf_config: list :param maximum_billing_tier: Positive integer that serves as a multiplier of the basic price. @@ -550,12 +567,22 @@ def run_query(self, :type time_partitioning: dict """ + if not api_resource_configs: + api_resource_configs = self.api_resource_configs + else: + _validate_value('api_resource_configs', + api_resource_configs, dict) + configuration = deepcopy(api_resource_configs) + if 'query' not in configuration: + configuration['query'] = {} + + else: + _validate_value("api_resource_configs['query']", + configuration['query'], dict) - # TODO remove `bql` in Airflow 2.0 - Jira: [AIRFLOW-2513] - if time_partitioning is None: - time_partitioning = {} sql = bql if sql is None else sql + # TODO remove `bql` in Airflow 2.0 - Jira: [AIRFLOW-2513] if bql: import warnings warnings.warn('Deprecated parameter `bql` used in ' @@ -566,95 +593,109 @@ def run_query(self, 'Airflow.', category=DeprecationWarning) - if sql is None: - raise TypeError('`BigQueryBaseCursor.run_query` missing 1 required ' - 'positional argument: `sql`') + if sql is None and not configuration['query'].get('query', None): + raise TypeError('`BigQueryBaseCursor.run_query` ' + 'missing 1 required positional argument: `sql`') # BigQuery also allows you to define how you want a table's schema to change # as a side effect of a query job # for more details: # https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.query.schemaUpdateOptions + allowed_schema_update_options = [ 'ALLOW_FIELD_ADDITION', "ALLOW_FIELD_RELAXATION" ] - if not set(allowed_schema_update_options).issuperset( - set(schema_update_options)): - raise ValueError( - "{0} contains invalid schema update options. " - "Please only use one or more of the following options: {1}" - .format(schema_update_options, allowed_schema_update_options)) - if use_legacy_sql is None: - use_legacy_sql = self.use_legacy_sql + if not set(allowed_schema_update_options + ).issuperset(set(schema_update_options)): + raise ValueError("{0} contains invalid schema update options. " + "Please only use one or more of the following " + "options: {1}" + .format(schema_update_options, + allowed_schema_update_options)) - configuration = { - 'query': { - 'query': sql, - 'useLegacySql': use_legacy_sql, - 'maximumBillingTier': maximum_billing_tier, - 'maximumBytesBilled': maximum_bytes_billed, - 'priority': priority - } - } + if schema_update_options: + if write_disposition not in ["WRITE_APPEND", "WRITE_TRUNCATE"]: + raise ValueError("schema_update_options is only " + "allowed if write_disposition is " + "'WRITE_APPEND' or 'WRITE_TRUNCATE'.") if destination_dataset_table: - if '.' not in destination_dataset_table: - raise ValueError( - 'Expected destination_dataset_table name in the format of ' - '.. Got: {}'.format( - destination_dataset_table)) destination_project, destination_dataset, destination_table = \ _split_tablename(table_input=destination_dataset_table, default_project_id=self.project_id) - configuration['query'].update({ - 'allowLargeResults': allow_large_results, - 'flattenResults': flatten_results, - 'writeDisposition': write_disposition, - 'createDisposition': create_disposition, - 'destinationTable': { - 'projectId': destination_project, - 'datasetId': destination_dataset, - 'tableId': destination_table, - } - }) - if udf_config: - if not isinstance(udf_config, list): - raise TypeError("udf_config argument must have a type 'list'" - " not {}".format(type(udf_config))) - configuration['query'].update({ - 'userDefinedFunctionResources': udf_config - }) - if query_params: - if self.use_legacy_sql: - raise ValueError("Query parameters are not allowed when using " - "legacy SQL") - else: - configuration['query']['queryParameters'] = query_params + destination_dataset_table = { + 'projectId': destination_project, + 'datasetId': destination_dataset, + 'tableId': destination_table, + } - if labels: - configuration['labels'] = labels + query_param_list = [ + (sql, 'query', None, str), + (priority, 'priority', 'INTERACTIVE', str), + (use_legacy_sql, 'useLegacySql', self.use_legacy_sql, bool), + (query_params, 'queryParameters', None, dict), + (udf_config, 'userDefinedFunctionResources', None, list), + (maximum_billing_tier, 'maximumBillingTier', None, int), + (maximum_bytes_billed, 'maximumBytesBilled', None, float), + (time_partitioning, 'timePartitioning', {}, dict), + (schema_update_options, 'schemaUpdateOptions', None, tuple), + (destination_dataset_table, 'destinationTable', None, dict) + ] - time_partitioning = _cleanse_time_partitioning( - destination_dataset_table, - time_partitioning - ) - if time_partitioning: - configuration['query'].update({ - 'timePartitioning': time_partitioning - }) + for param_tuple in query_param_list: - if schema_update_options: - if write_disposition not in ["WRITE_APPEND", "WRITE_TRUNCATE"]: - raise ValueError("schema_update_options is only " - "allowed if write_disposition is " - "'WRITE_APPEND' or 'WRITE_TRUNCATE'.") - else: - self.log.info( - "Adding experimental " - "'schemaUpdateOptions': {0}".format(schema_update_options)) - configuration['query'][ - 'schemaUpdateOptions'] = schema_update_options + param, param_name, param_default, param_type = param_tuple + + if param_name not in configuration['query'] and param in [None, {}, ()]: + if param_name == 'timePartitioning': + param_default = _cleanse_time_partitioning( + destination_dataset_table, time_partitioning) + param = param_default + + if param not in [None, {}, ()]: + _api_resource_configs_duplication_check( + param_name, param, configuration['query']) + + configuration['query'][param_name] = param + + # check valid type of provided param, + # it last step because we can get param from 2 sources, + # and first of all need to find it + + _validate_value(param_name, configuration['query'][param_name], + param_type) + + if param_name == 'schemaUpdateOptions' and param: + self.log.info("Adding experimental 'schemaUpdateOptions': " + "{0}".format(schema_update_options)) + + if param_name == 'destinationTable': + for key in ['projectId', 'datasetId', 'tableId']: + if key not in configuration['query']['destinationTable']: + raise ValueError( + "Not correct 'destinationTable' in " + "api_resource_configs. 'destinationTable' " + "must be a dict with {'projectId':'', " + "'datasetId':'', 'tableId':''}") + + configuration['query'].update({ + 'allowLargeResults': allow_large_results, + 'flattenResults': flatten_results, + 'writeDisposition': write_disposition, + 'createDisposition': create_disposition, + }) + + if 'useLegacySql' in configuration['query'] and \ + 'queryParameters' in configuration['query']: + raise ValueError("Query parameters are not allowed " + "when using legacy SQL") + + if labels: + _api_resource_configs_duplication_check( + 'labels', labels, configuration) + configuration['labels'] = labels return self.run_with_configuration(configuration) @@ -888,8 +929,7 @@ def run_load(self, # https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.query.tableDefinitions.(key).sourceFormat if src_fmt_configs is None: src_fmt_configs = {} - if time_partitioning is None: - time_partitioning = {} + source_format = source_format.upper() allowed_formats = [ "CSV", "NEWLINE_DELIMITED_JSON", "AVRO", "GOOGLE_SHEETS", @@ -1167,10 +1207,6 @@ def run_table_delete(self, deletion_dataset_table, :type ignore_if_missing: boolean :return: """ - if '.' not in deletion_dataset_table: - raise ValueError( - 'Expected deletion_dataset_table name in the format of ' - '.
. Got: {}'.format(deletion_dataset_table)) deletion_project, deletion_dataset, deletion_table = \ _split_tablename(table_input=deletion_dataset_table, default_project_id=self.project_id) @@ -1536,6 +1572,12 @@ def _bq_cast(string_field, bq_type): def _split_tablename(table_input, default_project_id, var_name=None): + + if '.' not in table_input: + raise ValueError( + 'Expected deletion_dataset_table name in the format of ' + '.
. Got: {}'.format(table_input)) + if not default_project_id: raise ValueError("INTERNAL: No default project is specified") @@ -1597,6 +1639,10 @@ def var_print(var_name): def _cleanse_time_partitioning(destination_dataset_table, time_partitioning_in): # if it is a partitioned table ($ is in the table name) add partition load option + + if time_partitioning_in is None: + time_partitioning_in = {} + time_partitioning_out = {} if destination_dataset_table and '$' in destination_dataset_table: if time_partitioning_in.get('field'): @@ -1607,3 +1653,20 @@ def _cleanse_time_partitioning(destination_dataset_table, time_partitioning_in): time_partitioning_out.update(time_partitioning_in) return time_partitioning_out + + +def _validate_value(key, value, expected_type): + """ function to check expected type and raise + error if type is not correct """ + if not isinstance(value, expected_type): + raise TypeError("{} argument must have a type {} not {}".format( + key, expected_type, type(value))) + + +def _api_resource_configs_duplication_check(key, value, config_dict): + if key in config_dict and value != config_dict[key]: + raise ValueError("Values of {param_name} param are duplicated. " + "`api_resource_configs` contained {param_name} param " + "in `query` config and {param_name} was also provided " + "with arg to run_query() method. Please remove duplicates." + .format(param_name=key)) diff --git a/airflow/contrib/hooks/databricks_hook.py b/airflow/contrib/hooks/databricks_hook.py index 54f00e00907c0..cb2ba9bd00fd7 100644 --- a/airflow/contrib/hooks/databricks_hook.py +++ b/airflow/contrib/hooks/databricks_hook.py @@ -24,6 +24,7 @@ from airflow.hooks.base_hook import BaseHook from requests import exceptions as requests_exceptions from requests.auth import AuthBase +from time import sleep from airflow.utils.log.logging_mixin import LoggingMixin @@ -32,6 +33,9 @@ except ImportError: import urlparse +RESTART_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/restart") +START_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/start") +TERMINATE_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/delete") SUBMIT_RUN_ENDPOINT = ('POST', 'api/2.0/jobs/runs/submit') GET_RUN_ENDPOINT = ('GET', 'api/2.0/jobs/runs/get') @@ -47,7 +51,8 @@ def __init__( self, databricks_conn_id='databricks_default', timeout_seconds=180, - retry_limit=3): + retry_limit=3, + retry_delay=1.0): """ :param databricks_conn_id: The name of the databricks connection to use. :type databricks_conn_id: string @@ -57,6 +62,9 @@ def __init__( :param retry_limit: The number of times to retry the connection in case of service outages. :type retry_limit: int + :param retry_delay: The number of seconds to wait between retries (it + might be a floating point number). + :type retry_delay: float """ self.databricks_conn_id = databricks_conn_id self.databricks_conn = self.get_connection(databricks_conn_id) @@ -64,6 +72,7 @@ def __init__( if retry_limit < 1: raise ValueError('Retry limit must be greater than equal to 1') self.retry_limit = retry_limit + self.retry_delay = retry_delay @staticmethod def _parse_host(host): @@ -119,7 +128,8 @@ def _do_api_call(self, endpoint_info, json): else: raise AirflowException('Unexpected HTTP Method: ' + method) - for attempt_num in range(1, self.retry_limit + 1): + attempt_num = 1 + while True: try: response = request_func( url, @@ -127,21 +137,29 @@ def _do_api_call(self, endpoint_info, json): auth=auth, headers=USER_AGENT_HEADER, timeout=self.timeout_seconds) - if response.status_code == requests.codes.ok: - return response.json() - else: + response.raise_for_status() + return response.json() + except requests_exceptions.RequestException as e: + if not _retryable_error(e): # In this case, the user probably made a mistake. # Don't retry. raise AirflowException('Response: {0}, Status Code: {1}'.format( - response.content, response.status_code)) - except (requests_exceptions.ConnectionError, - requests_exceptions.Timeout) as e: - self.log.error( - 'Attempt %s API Request to Databricks failed with reason: %s', - attempt_num, e - ) - raise AirflowException(('API requests to Databricks failed {} times. ' + - 'Giving up.').format(self.retry_limit)) + e.response.content, e.response.status_code)) + + self._log_request_error(attempt_num, e) + + if attempt_num == self.retry_limit: + raise AirflowException(('API requests to Databricks failed {} times. ' + + 'Giving up.').format(self.retry_limit)) + + attempt_num += 1 + sleep(self.retry_delay) + + def _log_request_error(self, attempt_num, error): + self.log.error( + 'Attempt %s API Request to Databricks failed with reason: %s', + attempt_num, error + ) def submit_run(self, json): """ @@ -174,6 +192,21 @@ def cancel_run(self, run_id): json = {'run_id': run_id} self._do_api_call(CANCEL_RUN_ENDPOINT, json) + def restart_cluster(self, json): + self._do_api_call(RESTART_CLUSTER_ENDPOINT, json) + + def start_cluster(self, json): + self._do_api_call(START_CLUSTER_ENDPOINT, json) + + def terminate_cluster(self, json): + self._do_api_call(TERMINATE_CLUSTER_ENDPOINT, json) + + +def _retryable_error(exception): + return isinstance(exception, requests_exceptions.ConnectionError) \ + or isinstance(exception, requests_exceptions.Timeout) \ + or exception.response is not None and exception.response.status_code >= 500 + RUN_LIFE_CYCLE_STATES = [ 'PENDING', diff --git a/airflow/contrib/hooks/gcp_container_hook.py b/airflow/contrib/hooks/gcp_container_hook.py index e5fbda138e0fa..0047b8dbebf2f 100644 --- a/airflow/contrib/hooks/gcp_container_hook.py +++ b/airflow/contrib/hooks/gcp_container_hook.py @@ -48,6 +48,7 @@ def __init__(self, project_id, location): def _dict_to_proto(py_dict, proto): """ Converts a python dictionary to the proto supplied + :param py_dict: The dictionary to convert :type py_dict: dict :param proto: The proto object to merge with dictionary @@ -63,6 +64,7 @@ def wait_for_operation(self, operation): """ Given an operation, continuously fetches the status from Google Cloud until either completion or an error occurring + :param operation: The Operation to wait for :type operation: A google.cloud.container_V1.gapic.enums.Operator :return: A new, updated operation fetched from Google Cloud @@ -83,6 +85,7 @@ def wait_for_operation(self, operation): def get_operation(self, operation_name): """ Fetches the operation from Google Cloud + :param operation_name: Name of operation to fetch :type operation_name: str :return: The new, updated operation from Google Cloud @@ -196,6 +199,7 @@ def create_cluster(self, cluster, retry=DEFAULT, timeout=DEFAULT): def get_cluster(self, name, retry=DEFAULT, timeout=DEFAULT): """ Gets details of specified cluster + :param name: The name of the cluster to retrieve :type name: str :param retry: A retry object used to retry requests. If None is specified, diff --git a/airflow/contrib/hooks/gcp_dataflow_hook.py b/airflow/contrib/hooks/gcp_dataflow_hook.py index 8c9b7423e0e6d..a9b7e71a5ee06 100644 --- a/airflow/contrib/hooks/gcp_dataflow_hook.py +++ b/airflow/contrib/hooks/gcp_dataflow_hook.py @@ -250,7 +250,7 @@ def _build_dataflow_job_name(task_id, append_job_name=True): 'letter and ending with a letter or number '.format(task_id)) if append_job_name: - job_name = task_id + "-" + str(uuid.uuid1())[:8] + job_name = task_id + "-" + str(uuid.uuid4())[:8] else: job_name = task_id diff --git a/airflow/contrib/hooks/gcp_dataproc_hook.py b/airflow/contrib/hooks/gcp_dataproc_hook.py index 57c48bde59328..f9e7a9050989b 100644 --- a/airflow/contrib/hooks/gcp_dataproc_hook.py +++ b/airflow/contrib/hooks/gcp_dataproc_hook.py @@ -81,7 +81,7 @@ def get(self): class _DataProcJobBuilder: def __init__(self, project_id, task_id, cluster_name, job_type, properties): - name = task_id + "_" + str(uuid.uuid1())[:8] + name = task_id + "_" + str(uuid.uuid4())[:8] self.job_type = job_type self.job = { "job": { @@ -141,7 +141,7 @@ def set_python_main(self, main): self.job["job"][self.job_type]["mainPythonFileUri"] = main def set_job_name(self, name): - self.job["job"]["reference"]["jobId"] = name + "_" + str(uuid.uuid1())[:8] + self.job["job"]["reference"]["jobId"] = name + "_" + str(uuid.uuid4())[:8] def build(self): return self.job diff --git a/airflow/contrib/kubernetes/kubernetes_request_factory/kubernetes_request_factory.py b/airflow/contrib/kubernetes/kubernetes_request_factory/kubernetes_request_factory.py index 27e0ebd29c0a3..97bcdf2abc9ea 100644 --- a/airflow/contrib/kubernetes/kubernetes_request_factory/kubernetes_request_factory.py +++ b/airflow/contrib/kubernetes/kubernetes_request_factory/kubernetes_request_factory.py @@ -175,6 +175,11 @@ def extract_service_account_name(pod, req): if pod.service_account_name: req['spec']['serviceAccountName'] = pod.service_account_name + @staticmethod + def extract_hostnetwork(pod, req): + if pod.hostnetwork: + req['spec']['hostNetwork'] = pod.hostnetwork + @staticmethod def extract_image_pull_secrets(pod, req): if pod.image_pull_secrets: diff --git a/airflow/contrib/kubernetes/kubernetes_request_factory/pod_request_factory.py b/airflow/contrib/kubernetes/kubernetes_request_factory/pod_request_factory.py index 95d6c829dec59..877d7aafe2b2e 100644 --- a/airflow/contrib/kubernetes/kubernetes_request_factory/pod_request_factory.py +++ b/airflow/contrib/kubernetes/kubernetes_request_factory/pod_request_factory.py @@ -59,6 +59,7 @@ def create(self, pod): self.extract_image_pull_secrets(pod, req) self.extract_annotations(pod, req) self.extract_affinity(pod, req) + self.extract_hostnetwork(pod, req) return req @@ -116,4 +117,5 @@ def create(self, pod): self.extract_image_pull_secrets(pod, req) self.extract_annotations(pod, req) self.extract_affinity(pod, req) + self.extract_hostnetwork(pod, req) return req diff --git a/airflow/contrib/kubernetes/pod.py b/airflow/contrib/kubernetes/pod.py index 6fcf354459b80..221c8f4180ef8 100644 --- a/airflow/contrib/kubernetes/pod.py +++ b/airflow/contrib/kubernetes/pod.py @@ -77,7 +77,8 @@ def __init__( service_account_name=None, resources=None, annotations=None, - affinity=None + affinity=None, + hostnetwork=False ): self.image = image self.envs = envs or {} @@ -98,3 +99,4 @@ def __init__( self.resources = resources or Resources() self.annotations = annotations or {} self.affinity = affinity or {} + self.hostnetwork = hostnetwork or False diff --git a/airflow/contrib/kubernetes/pod_generator.py b/airflow/contrib/kubernetes/pod_generator.py index 6d8d83ef054a9..bee7f5b9572a0 100644 --- a/airflow/contrib/kubernetes/pod_generator.py +++ b/airflow/contrib/kubernetes/pod_generator.py @@ -149,7 +149,7 @@ def make_pod(self, namespace, image, pod_id, cmds, arguments, labels): return Pod( namespace=namespace, - name=pod_id + "-" + str(uuid.uuid1())[:8], + name=pod_id + "-" + str(uuid.uuid4())[:8], image=image, cmds=cmds, args=arguments, diff --git a/airflow/contrib/kubernetes/pod_launcher.py b/airflow/contrib/kubernetes/pod_launcher.py index 42f2bfea8adec..8c8d949107494 100644 --- a/airflow/contrib/kubernetes/pod_launcher.py +++ b/airflow/contrib/kubernetes/pod_launcher.py @@ -22,7 +22,7 @@ from datetime import datetime as dt from airflow.contrib.kubernetes.kubernetes_request_factory import \ pod_request_factory as pod_factory -from kubernetes import watch +from kubernetes import watch, client from kubernetes.client.rest import ApiException from kubernetes.stream import stream as kubernetes_stream from airflow import AirflowException @@ -59,6 +59,15 @@ def run_pod_async(self, pod): raise return resp + def delete_pod(self, pod): + try: + self._client.delete_namespaced_pod( + pod.name, pod.namespace, body=client.V1DeleteOptions()) + except ApiException as e: + # If the pod is already deleted + if e.status != 404: + raise + def run_pod(self, pod, startup_timeout=120, get_logs=True): # type: (Pod) -> (State, result) """ diff --git a/airflow/contrib/operators/awsbatch_operator.py b/airflow/contrib/operators/awsbatch_operator.py index 50c6c2c319717..4008c90c47bbe 100644 --- a/airflow/contrib/operators/awsbatch_operator.py +++ b/airflow/contrib/operators/awsbatch_operator.py @@ -42,18 +42,20 @@ class AWSBatchOperator(BaseOperator): :type job_definition: str :param job_queue: the queue name on AWS Batch :type job_queue: str - :param: overrides: the same parameter that boto3 will receive on - containerOverrides (templated): - http://boto3.readthedocs.io/en/latest/reference/services/batch.html#submit_job - :type: overrides: dict - :param max_retries: exponential backoff retries while waiter is not merged, 4200 = 48 hours + :param overrides: the same parameter that boto3 will receive on + containerOverrides (templated). + http://boto3.readthedocs.io/en/latest/reference/services/batch.html#submit_job + :type overrides: dict + :param max_retries: exponential backoff retries while waiter is not merged, + 4200 = 48 hours :type max_retries: int :param aws_conn_id: connection id of AWS credentials / region name. If None, - credential boto3 strategy will be used - (http://boto3.readthedocs.io/en/latest/guide/configuration.html). + credential boto3 strategy will be used + (http://boto3.readthedocs.io/en/latest/guide/configuration.html). :type aws_conn_id: str :param region_name: region name to use in AWS Hook. Override the region_name in connection (if provided) + :type region_name: str """ ui_color = '#c3dae0' diff --git a/airflow/contrib/operators/bigquery_check_operator.py b/airflow/contrib/operators/bigquery_check_operator.py index a9c493f4fd418..3eba0771db91b 100644 --- a/airflow/contrib/operators/bigquery_check_operator.py +++ b/airflow/contrib/operators/bigquery_check_operator.py @@ -56,7 +56,7 @@ class BigQueryCheckOperator(CheckOperator): :param bigquery_conn_id: reference to the BigQuery database :type bigquery_conn_id: string :param use_legacy_sql: Whether to use legacy SQL (true) - or standard SQL (false). + or standard SQL (false). :type use_legacy_sql: boolean """ @@ -83,7 +83,7 @@ class BigQueryValueCheckOperator(ValueCheckOperator): :param sql: the sql to be executed :type sql: string :param use_legacy_sql: Whether to use legacy SQL (true) - or standard SQL (false). + or standard SQL (false). :type use_legacy_sql: boolean """ @@ -125,7 +125,7 @@ class BigQueryIntervalCheckOperator(IntervalCheckOperator): between the current day, and the prior days_back. :type metrics_threshold: dict :param use_legacy_sql: Whether to use legacy SQL (true) - or standard SQL (false). + or standard SQL (false). :type use_legacy_sql: boolean """ diff --git a/airflow/contrib/operators/bigquery_operator.py b/airflow/contrib/operators/bigquery_operator.py index 09f442ae7b1a6..b0c0ce2d6e31b 100644 --- a/airflow/contrib/operators/bigquery_operator.py +++ b/airflow/contrib/operators/bigquery_operator.py @@ -75,6 +75,13 @@ class BigQueryOperator(BaseOperator): (without incurring a charge). If unspecified, this will be set to your project default. :type maximum_bytes_billed: float + :param api_resource_configs: a dictionary that contain params + 'configuration' applied for Google BigQuery Jobs API: + https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs + for example, {'query': {'useQueryCache': False}}. You could use it + if you need to provide some params that are not supported by BigQueryOperator + like args. + :type api_resource_configs: dict :param schema_update_options: Allows the schema of the destination table to be updated as a side effect of the load job. :type schema_update_options: tuple @@ -106,7 +113,7 @@ def __init__(self, destination_dataset_table=False, write_disposition='WRITE_EMPTY', allow_large_results=False, - flatten_results=False, + flatten_results=None, bigquery_conn_id='bigquery_default', delegate_to=None, udf_config=False, @@ -118,7 +125,8 @@ def __init__(self, query_params=None, labels=None, priority='INTERACTIVE', - time_partitioning={}, + time_partitioning=None, + api_resource_configs=None, *args, **kwargs): super(BigQueryOperator, self).__init__(*args, **kwargs) @@ -140,7 +148,10 @@ def __init__(self, self.labels = labels self.bq_cursor = None self.priority = priority - self.time_partitioning = time_partitioning + if time_partitioning is None: + self.time_partitioning = {} + if api_resource_configs is None: + self.api_resource_configs = {} # TODO remove `bql` in Airflow 2.0 if self.bql: @@ -179,7 +190,8 @@ def execute(self, context): labels=self.labels, schema_update_options=self.schema_update_options, priority=self.priority, - time_partitioning=self.time_partitioning + time_partitioning=self.time_partitioning, + api_resource_configs=self.api_resource_configs, ) def on_kill(self): @@ -234,49 +246,49 @@ class BigQueryCreateEmptyTableOperator(BaseOperator): work, the service account making the request must have domain-wide delegation enabled. :type delegate_to: string - :param labels a dictionary containing labels for the table, passed to BigQuery + :param labels: a dictionary containing labels for the table, passed to BigQuery + + **Example (with schema JSON in GCS)**: :: + + CreateTable = BigQueryCreateEmptyTableOperator( + task_id='BigQueryCreateEmptyTableOperator_task', + dataset_id='ODS', + table_id='Employees', + project_id='internal-gcp-project', + gcs_schema_object='gs://schema-bucket/employee_schema.json', + bigquery_conn_id='airflow-service-account', + google_cloud_storage_conn_id='airflow-service-account' + ) + + **Corresponding Schema file** (``employee_schema.json``): :: + + [ + { + "mode": "NULLABLE", + "name": "emp_name", + "type": "STRING" + }, + { + "mode": "REQUIRED", + "name": "salary", + "type": "INTEGER" + } + ] + + **Example (with schema in the DAG)**: :: + + CreateTable = BigQueryCreateEmptyTableOperator( + task_id='BigQueryCreateEmptyTableOperator_task', + dataset_id='ODS', + table_id='Employees', + project_id='internal-gcp-project', + schema_fields=[{"name": "emp_name", "type": "STRING", "mode": "REQUIRED"}, + {"name": "salary", "type": "INTEGER", "mode": "NULLABLE"}], + bigquery_conn_id='airflow-service-account', + google_cloud_storage_conn_id='airflow-service-account' + ) :type labels: dict - **Example (with schema JSON in GCS)**: :: - - CreateTable = BigQueryCreateEmptyTableOperator( - task_id='BigQueryCreateEmptyTableOperator_task', - dataset_id='ODS', - table_id='Employees', - project_id='internal-gcp-project', - gcs_schema_object='gs://schema-bucket/employee_schema.json', - bigquery_conn_id='airflow-service-account', - google_cloud_storage_conn_id='airflow-service-account' - ) - - **Corresponding Schema file** (``employee_schema.json``): :: - - [ - { - "mode": "NULLABLE", - "name": "emp_name", - "type": "STRING" - }, - { - "mode": "REQUIRED", - "name": "salary", - "type": "INTEGER" - } - ] - - **Example (with schema in the DAG)**: :: - - CreateTable = BigQueryCreateEmptyTableOperator( - task_id='BigQueryCreateEmptyTableOperator_task', - dataset_id='ODS', - table_id='Employees', - project_id='internal-gcp-project', - schema_fields=[{"name": "emp_name", "type": "STRING", "mode": "REQUIRED"}, - {"name": "salary", "type": "INTEGER", "mode": "NULLABLE"}], - bigquery_conn_id='airflow-service-account', - google_cloud_storage_conn_id='airflow-service-account' - ) - """ template_fields = ('dataset_id', 'table_id', 'project_id', 'gcs_schema_object', 'labels') diff --git a/airflow/contrib/operators/databricks_operator.py b/airflow/contrib/operators/databricks_operator.py index 7b8d522dba85b..3245a99256502 100644 --- a/airflow/contrib/operators/databricks_operator.py +++ b/airflow/contrib/operators/databricks_operator.py @@ -146,6 +146,9 @@ class DatabricksSubmitRunOperator(BaseOperator): :param databricks_retry_limit: Amount of times retry if the Databricks backend is unreachable. Its value must be greater than or equal to 1. :type databricks_retry_limit: int + :param databricks_retry_delay: Number of seconds to wait between retries (it + might be a floating point number). + :type databricks_retry_delay: float :param do_xcom_push: Whether we should push run_id and run_page_url to xcom. :type do_xcom_push: boolean """ @@ -168,6 +171,7 @@ def __init__( databricks_conn_id='databricks_default', polling_period_seconds=30, databricks_retry_limit=3, + databricks_retry_delay=1, do_xcom_push=False, **kwargs): """ @@ -178,6 +182,7 @@ def __init__( self.databricks_conn_id = databricks_conn_id self.polling_period_seconds = polling_period_seconds self.databricks_retry_limit = databricks_retry_limit + self.databricks_retry_delay = databricks_retry_delay if spark_jar_task is not None: self.json['spark_jar_task'] = spark_jar_task if notebook_task is not None: @@ -232,7 +237,8 @@ def _log_run_page_url(self, url): def get_hook(self): return DatabricksHook( self.databricks_conn_id, - retry_limit=self.databricks_retry_limit) + retry_limit=self.databricks_retry_limit, + retry_delay=self.databricks_retry_delay) def execute(self, context): hook = self.get_hook() diff --git a/airflow/contrib/operators/dataflow_operator.py b/airflow/contrib/operators/dataflow_operator.py index 3a6980cefaeae..3c69fb759abd3 100644 --- a/airflow/contrib/operators/dataflow_operator.py +++ b/airflow/contrib/operators/dataflow_operator.py @@ -16,7 +16,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +import os import re import uuid import copy @@ -252,6 +252,38 @@ def execute(self, context): class DataFlowPythonOperator(BaseOperator): + """ + Create a new DataFlowPythonOperator. Note that both + dataflow_default_options and options will be merged to specify pipeline + execution parameter, and dataflow_default_options is expected to save + high-level options, for instances, project and zone information, which + apply to all dataflow operators in the DAG. + + .. seealso:: + For more detail on job submission have a look at the reference: + https://cloud.google.com/dataflow/pipelines/specifying-exec-params + + :param py_file: Reference to the python dataflow pipleline file.py, e.g., + /some/local/file/path/to/your/python/pipeline/file. + :type py_file: string + :param py_options: Additional python options. + :type pyt_options: list of strings, e.g., ["-m", "-v"]. + :param dataflow_default_options: Map of default job options. + :type dataflow_default_options: dict + :param options: Map of job specific options. + :type options: dict + :param gcp_conn_id: The connection ID to use connecting to Google Cloud + Platform. + :type gcp_conn_id: string + :param delegate_to: The account to impersonate, if any. + For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: string + :param poll_sleep: The time in seconds to sleep between polling Google + Cloud Platform for the dataflow job status while the job is in the + JOB_STATE_RUNNING state. + :type poll_sleep: int + """ template_fields = ['options', 'dataflow_default_options'] @@ -267,38 +299,7 @@ def __init__( poll_sleep=10, *args, **kwargs): - """ - Create a new DataFlowPythonOperator. Note that both - dataflow_default_options and options will be merged to specify pipeline - execution parameter, and dataflow_default_options is expected to save - high-level options, for instances, project and zone information, which - apply to all dataflow operators in the DAG. - - .. seealso:: - For more detail on job submission have a look at the reference: - https://cloud.google.com/dataflow/pipelines/specifying-exec-params - :param py_file: Reference to the python dataflow pipleline file.py, e.g., - /some/local/file/path/to/your/python/pipeline/file. - :type py_file: string - :param py_options: Additional python options. - :type pyt_options: list of strings, e.g., ["-m", "-v"]. - :param dataflow_default_options: Map of default job options. - :type dataflow_default_options: dict - :param options: Map of job specific options. - :type options: dict - :param gcp_conn_id: The connection ID to use connecting to Google Cloud - Platform. - :type gcp_conn_id: string - :param delegate_to: The account to impersonate, if any. - For this to work, the service account making the request must have - domain-wide delegation enabled. - :type delegate_to: string - :param poll_sleep: The time in seconds to sleep between polling Google - Cloud Platform for the dataflow job status while the job is in the - JOB_STATE_RUNNING state. - :type poll_sleep: int - """ super(DataFlowPythonOperator, self).__init__(*args, **kwargs) self.py_file = py_file @@ -358,18 +359,18 @@ def google_cloud_to_local(self, file_name): # Extracts bucket_id and object_id by first removing 'gs://' prefix and # then split the remaining by path delimiter '/'. path_components = file_name[self.GCS_PREFIX_LENGTH:].split('/') - if path_components < 2: + if len(path_components) < 2: raise Exception( 'Invalid Google Cloud Storage (GCS) object path: {}.' .format(file_name)) bucket_id = path_components[0] object_id = '/'.join(path_components[1:]) - local_file = '/tmp/dataflow{}-{}'.format(str(uuid.uuid1())[:8], + local_file = '/tmp/dataflow{}-{}'.format(str(uuid.uuid4())[:8], path_components[-1]) file_size = self._gcs_hook.download(bucket_id, object_id, local_file) - if file_size > 0: + if os.stat(file_size).st_size > 0: return local_file raise Exception( 'Failed to download Google Cloud Storage GCS object: {}' diff --git a/airflow/contrib/operators/dataproc_operator.py b/airflow/contrib/operators/dataproc_operator.py index 6dfa2da095e38..4b0cd899f02dc 100644 --- a/airflow/contrib/operators/dataproc_operator.py +++ b/airflow/contrib/operators/dataproc_operator.py @@ -75,10 +75,20 @@ class DataprocClusterCreateOperator(BaseOperator): :type properties: dict :param master_machine_type: Compute engine machine type to use for the master node :type master_machine_type: string + :param master_disk_type: Type of the boot disk for the master node + (default is ``pd-standard``). + Valid values: ``pd-ssd`` (Persistent Disk Solid State Drive) or + ``pd-standard`` (Persistent Disk Hard Disk Drive). + :type master_disk_type: string :param master_disk_size: Disk size for the master node :type master_disk_size: int :param worker_machine_type: Compute engine machine type to use for the worker nodes :type worker_machine_type: string + :param worker_disk_type: Type of the boot disk for the worker node + (default is ``pd-standard``). + Valid values: ``pd-ssd`` (Persistent Disk Solid State Drive) or + ``pd-standard`` (Persistent Disk Hard Disk Drive). + :type worker_disk_type: string :param worker_disk_size: Disk size for the worker nodes :type worker_disk_size: int :param num_preemptible_workers: The # of preemptible worker nodes to spin up @@ -141,8 +151,10 @@ def __init__(self, image_version=None, properties=None, master_machine_type='n1-standard-4', + master_disk_type='pd-standard', master_disk_size=500, worker_machine_type='n1-standard-4', + worker_disk_type='pd-standard', worker_disk_size=500, num_preemptible_workers=0, labels=None, @@ -171,8 +183,10 @@ def __init__(self, self.image_version = image_version self.properties = properties self.master_machine_type = master_machine_type + self.master_disk_type = master_disk_type self.master_disk_size = master_disk_size self.worker_machine_type = worker_machine_type + self.worker_disk_type = worker_disk_type self.worker_disk_size = worker_disk_size self.labels = labels self.zone = zone @@ -272,6 +286,7 @@ def _build_cluster_data(self): 'numInstances': 1, 'machineTypeUri': master_type_uri, 'diskConfig': { + 'bootDiskType': self.master_disk_type, 'bootDiskSizeGb': self.master_disk_size } }, @@ -279,6 +294,7 @@ def _build_cluster_data(self): 'numInstances': self.num_workers, 'machineTypeUri': worker_type_uri, 'diskConfig': { + 'bootDiskType': self.worker_disk_type, 'bootDiskSizeGb': self.worker_disk_size } }, @@ -292,6 +308,7 @@ def _build_cluster_data(self): 'numInstances': self.num_preemptible_workers, 'machineTypeUri': worker_type_uri, 'diskConfig': { + 'bootDiskType': self.worker_disk_type, 'bootDiskSizeGb': self.worker_disk_size }, 'isPreemptible': True @@ -395,14 +412,14 @@ class DataprocClusterScaleOperator(BaseOperator): **Example**: :: - t1 = DataprocClusterScaleOperator( - task_id='dataproc_scale', - project_id='my-project', - cluster_name='cluster-1', - num_workers=10, - num_preemptible_workers=10, - graceful_decommission_timeout='1h' - dag=dag) + t1 = DataprocClusterScaleOperator( + task_id='dataproc_scale', + project_id='my-project', + cluster_name='cluster-1', + num_workers=10, + num_preemptible_workers=10, + graceful_decommission_timeout='1h', + dag=dag) .. seealso:: For more detail on about scaling clusters have a look at the reference: @@ -1158,7 +1175,7 @@ class DataProcPySparkOperator(BaseOperator): @staticmethod def _generate_temp_filename(filename): dt = time.strftime('%Y%m%d%H%M%S') - return "{}_{}_{}".format(dt, str(uuid.uuid1())[:8], ntpath.basename(filename)) + return "{}_{}_{}".format(dt, str(uuid.uuid4())[:8], ntpath.basename(filename)) """ Upload a local file to a Google Cloud Storage bucket @@ -1312,7 +1329,7 @@ def start(self): .instantiate( name=('projects/%s/regions/%s/workflowTemplates/%s' % (self.project_id, self.region, self.template_id)), - body={'instanceId': str(uuid.uuid1())}) + body={'instanceId': str(uuid.uuid4())}) .execute()) @@ -1355,6 +1372,6 @@ def start(self): self.hook.get_conn().projects().regions().workflowTemplates() .instantiateInline( parent='projects/%s/regions/%s' % (self.project_id, self.region), - instanceId=str(uuid.uuid1()), + instanceId=str(uuid.uuid4()), body=self.template) .execute()) diff --git a/airflow/contrib/operators/ecs_operator.py b/airflow/contrib/operators/ecs_operator.py index 319441d297af5..c85ae15b771ec 100644 --- a/airflow/contrib/operators/ecs_operator.py +++ b/airflow/contrib/operators/ecs_operator.py @@ -33,17 +33,18 @@ class ECSOperator(BaseOperator): :type task_definition: str :param cluster: the cluster name on EC2 Container Service :type cluster: str - :param: overrides: the same parameter that boto3 will receive (templated): - http://boto3.readthedocs.org/en/latest/reference/services/ecs.html#ECS.Client.run_task - :type: overrides: dict + :param overrides: the same parameter that boto3 will receive (templated): + http://boto3.readthedocs.org/en/latest/reference/services/ecs.html#ECS.Client.run_task + :type overrides: dict :param aws_conn_id: connection id of AWS credentials / region name. If None, - credential boto3 strategy will be used - (http://boto3.readthedocs.io/en/latest/guide/configuration.html). + credential boto3 strategy will be used + (http://boto3.readthedocs.io/en/latest/guide/configuration.html). :type aws_conn_id: str :param region_name: region name to use in AWS Hook. Override the region_name in connection (if provided) + :type region_name: str :param launch_type: the launch type on which to run your task ('EC2' or 'FARGATE') - :type: launch_type: str + :type launch_type: str """ ui_color = '#f0ede4' diff --git a/airflow/contrib/operators/gcs_to_bq.py b/airflow/contrib/operators/gcs_to_bq.py index fb27e8f205661..69acb616594d9 100644 --- a/airflow/contrib/operators/gcs_to_bq.py +++ b/airflow/contrib/operators/gcs_to_bq.py @@ -38,7 +38,7 @@ class GoogleCloudStorageToBigQueryOperator(BaseOperator): :type bucket: string :param source_objects: List of Google cloud storage URIs to load from. (templated) If source_format is 'DATASTORE_BACKUP', the list must only contain a single URI. - :type object: list + :type source_objects: list of str :param destination_project_dataset_table: The dotted (.).
BigQuery table to load data into. If is not included, project will be the project defined in the connection json. (templated) diff --git a/airflow/contrib/operators/gcs_to_s3.py b/airflow/contrib/operators/gcs_to_s3.py index a87aa3af5c531..0df6170eab377 100644 --- a/airflow/contrib/operators/gcs_to_s3.py +++ b/airflow/contrib/operators/gcs_to_s3.py @@ -47,6 +47,16 @@ class GoogleCloudStorageToS3Operator(GoogleCloudStorageListOperator): :type dest_aws_conn_id: str :param dest_s3_key: The base S3 key to be used to store the files. (templated) :type dest_s3_key: str + :parame dest_verify: Whether or not to verify SSL certificates for S3 connection. + By default SSL certificates are verified. + You can provide the following values: + - False: do not validate SSL certificates. SSL will still be used + (unless use_ssl is False), but SSL certificates will not be + verified. + - path/to/cert/bundle.pem: A filename of the CA cert bundle to uses. + You can specify this argument if you want to use a different + CA cert bundle than the one used by botocore. + :type dest_verify: bool or str """ template_fields = ('bucket', 'prefix', 'delimiter', 'dest_s3_key') ui_color = '#f0eee4' @@ -60,6 +70,7 @@ def __init__(self, delegate_to=None, dest_aws_conn_id=None, dest_s3_key=None, + dest_verify=None, replace=False, *args, **kwargs): @@ -75,12 +86,13 @@ def __init__(self, ) self.dest_aws_conn_id = dest_aws_conn_id self.dest_s3_key = dest_s3_key + self.dest_verify = dest_verify self.replace = replace def execute(self, context): # use the super to list all files in an Google Cloud Storage bucket files = super(GoogleCloudStorageToS3Operator, self).execute(context) - s3_hook = S3Hook(aws_conn_id=self.dest_aws_conn_id) + s3_hook = S3Hook(aws_conn_id=self.dest_aws_conn_id, verify=self.dest_verify) if not self.replace: # if we are not replacing -> list all files in the S3 bucket diff --git a/airflow/contrib/operators/kubernetes_pod_operator.py b/airflow/contrib/operators/kubernetes_pod_operator.py index fb905622d8a47..bb4bf7fca1e33 100644 --- a/airflow/contrib/operators/kubernetes_pod_operator.py +++ b/airflow/contrib/operators/kubernetes_pod_operator.py @@ -102,6 +102,7 @@ def execute(self, context): labels=self.labels, ) + pod.service_account_name = self.service_account_name pod.secrets = self.secrets pod.envs = self.env_vars pod.image_pull_policy = self.image_pull_policy @@ -109,6 +110,7 @@ def execute(self, context): pod.resources = self.resources pod.affinity = self.affinity pod.node_selectors = self.node_selectors + pod.hostnetwork = self.hostnetwork launcher = pod_launcher.PodLauncher(kube_client=client, extract_xcom=self.xcom_push) @@ -116,6 +118,10 @@ def execute(self, context): pod, startup_timeout=self.startup_timeout_seconds, get_logs=self.get_logs) + + if self.is_delete_operator_pod: + launcher.delete_pod(pod) + if final_state != State.SUCCESS: raise AirflowException( 'Pod returned a failure: {state}'.format(state=final_state) @@ -148,6 +154,10 @@ def __init__(self, config_file=None, xcom_push=False, node_selectors=None, + image_pull_secrets=None, + service_account_name="default", + is_delete_operator_pod=False, + hostnetwork=False, *args, **kwargs): super(KubernetesPodOperator, self).__init__(*args, **kwargs) @@ -172,3 +182,7 @@ def __init__(self, self.xcom_push = xcom_push self.resources = resources or Resources() self.config_file = config_file + self.image_pull_secrets = image_pull_secrets + self.service_account_name = service_account_name + self.is_delete_operator_pod = is_delete_operator_pod + self.hostnetwork = hostnetwork diff --git a/airflow/contrib/operators/qubole_check_operator.py b/airflow/contrib/operators/qubole_check_operator.py index 235af08ca753e..8b6b5d351cd86 100644 --- a/airflow/contrib/operators/qubole_check_operator.py +++ b/airflow/contrib/operators/qubole_check_operator.py @@ -215,11 +215,11 @@ def get_sql_from_qbol_cmd(params): def handle_airflow_exception(airflow_exception, hook): cmd = hook.cmd if cmd is not None: - if cmd.is_success: + if cmd.is_success(cmd.status): qubole_command_results = hook.get_query_results() qubole_command_id = cmd.id exception_message = '\nQubole Command Id: {qubole_command_id}' \ '\nQubole Command Results:' \ '\n{qubole_command_results}'.format(**locals()) raise AirflowException(str(airflow_exception) + exception_message) - raise AirflowException(airflow_exception.message) + raise AirflowException(str(airflow_exception)) diff --git a/airflow/contrib/operators/s3_list_operator.py b/airflow/contrib/operators/s3_list_operator.py index b85691b005fb9..a9e005eed3f65 100644 --- a/airflow/contrib/operators/s3_list_operator.py +++ b/airflow/contrib/operators/s3_list_operator.py @@ -38,6 +38,16 @@ class S3ListOperator(BaseOperator): :type delimiter: string :param aws_conn_id: The connection ID to use when connecting to S3 storage. :type aws_conn_id: string + :parame verify: Whether or not to verify SSL certificates for S3 connection. + By default SSL certificates are verified. + You can provide the following values: + - False: do not validate SSL certificates. SSL will still be used + (unless use_ssl is False), but SSL certificates will not be + verified. + - path/to/cert/bundle.pem: A filename of the CA cert bundle to uses. + You can specify this argument if you want to use a different + CA cert bundle than the one used by botocore. + :type verify: bool or str **Example**: The following operator would list all the files @@ -61,6 +71,7 @@ def __init__(self, prefix='', delimiter='', aws_conn_id='aws_default', + verify=None, *args, **kwargs): super(S3ListOperator, self).__init__(*args, **kwargs) @@ -68,9 +79,10 @@ def __init__(self, self.prefix = prefix self.delimiter = delimiter self.aws_conn_id = aws_conn_id + self.verify = verify def execute(self, context): - hook = S3Hook(aws_conn_id=self.aws_conn_id) + hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) self.log.info( 'Getting the list of files from bucket: {0} in prefix: {1} (Delimiter {2})'. diff --git a/airflow/contrib/operators/s3_to_gcs_operator.py b/airflow/contrib/operators/s3_to_gcs_operator.py index 64d7dc7cab976..81c48a9e157bc 100644 --- a/airflow/contrib/operators/s3_to_gcs_operator.py +++ b/airflow/contrib/operators/s3_to_gcs_operator.py @@ -41,6 +41,16 @@ class S3ToGoogleCloudStorageOperator(S3ListOperator): :type delimiter: string :param aws_conn_id: The source S3 connection :type aws_conn_id: string + :parame verify: Whether or not to verify SSL certificates for S3 connection. + By default SSL certificates are verified. + You can provide the following values: + - False: do not validate SSL certificates. SSL will still be used + (unless use_ssl is False), but SSL certificates will not be + verified. + - path/to/cert/bundle.pem: A filename of the CA cert bundle to uses. + You can specify this argument if you want to use a different + CA cert bundle than the one used by botocore. + :type verify: bool or str :param dest_gcs_conn_id: The destination connection ID to use when connecting to Google Cloud Storage. :type dest_gcs_conn_id: string @@ -80,6 +90,7 @@ def __init__(self, prefix='', delimiter='', aws_conn_id='aws_default', + verify=None, dest_gcs_conn_id=None, dest_gcs=None, delegate_to=None, @@ -98,6 +109,7 @@ def __init__(self, self.dest_gcs = dest_gcs self.delegate_to = delegate_to self.replace = replace + self.verify = verify if dest_gcs and not self._gcs_object_is_directory(self.dest_gcs): self.log.info( @@ -146,7 +158,7 @@ def execute(self, context): 'There are no new files to sync. Have a nice day!') if files: - hook = S3Hook(aws_conn_id=self.aws_conn_id) + hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) for file in files: # GCS hook builds its own in-memory file so we have to create diff --git a/airflow/contrib/operators/s3_to_sftp_operator.py b/airflow/contrib/operators/s3_to_sftp_operator.py new file mode 100644 index 0000000000000..2d63042216f7e --- /dev/null +++ b/airflow/contrib/operators/s3_to_sftp_operator.py @@ -0,0 +1,79 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from airflow.models import BaseOperator +from airflow.hooks.S3_hook import S3Hook +from airflow.contrib.hooks.ssh_hook import SSHHook +from tempfile import NamedTemporaryFile +from urllib.parse import urlparse + + +def get_s3_key(s3_key): + """This parses the correct format for S3 keys + regardless of how the S3 url is passed. """ + + parsed_s3_key = urlparse(s3_key) + return parsed_s3_key.path.lstrip('/') + + +class S3ToSFTPOperator(BaseOperator): + """ + S3 To SFTP Operator + :param sftp_conn_id: The sftp connection id. + :type sftp_conn_id: string + :param sftp_path: The sftp remote path. + :type sftp_path: string + :param s3_conn_id: The s3 connnection id. + :type s3_conn_id: string + :param s3_bucket: The targeted s3 bucket. + :type s3_bucket: string + :param s3_key: The targeted s3 key. + :type s3_key: string + """ + + template_fields = ('s3_key', 'sftp_path') + + def __init__(self, + sftp_conn_id=None, + s3_conn_id=None, + s3_bucket=None, + s3_key=None, + sftp_path=None, + *args, + **kwargs): + super(S3ToSFTPOperator, self).__init__(*args, **kwargs) + self.sftp_conn_id = sftp_conn_id + self.sftp_path = sftp_path + self.s3_bucket = s3_bucket + self.s3_key = s3_key + self.s3_conn_id = s3_conn_id + + self.ssh_hook = SSHHook(ssh_conn_id=self.sftp_conn_id) + self.s3_hook = S3Hook(self.s3_conn_id) + + def execute(self, context): + self.s3_key = get_s3_key(self.s3_key) + + s3_client = self.s3_hook.get_conn() + ssh_client = self.ssh_hook.get_conn() + sftp_client = ssh_client.open_sftp() + + with NamedTemporaryFile("w") as f: + s3_client.download_file(self.s3_bucket, self.s3_key, f.name) + sftp_client.put(f.name, self.sftp_path) diff --git a/airflow/contrib/operators/sftp_operator.py b/airflow/contrib/operators/sftp_operator.py index 3c736c8b95101..a3b5c1f24492b 100644 --- a/airflow/contrib/operators/sftp_operator.py +++ b/airflow/contrib/operators/sftp_operator.py @@ -33,11 +33,15 @@ class SFTPOperator(BaseOperator): This operator uses ssh_hook to open sftp trasport channel that serve as basis for file transfer. - :param ssh_hook: predefined ssh_hook to use for remote execution + :param ssh_hook: predefined ssh_hook to use for remote execution. + Either `ssh_hook` or `ssh_conn_id` needs to be provided. :type ssh_hook: :class:`SSHHook` - :param ssh_conn_id: connection id from airflow Connections + :param ssh_conn_id: connection id from airflow Connections. + `ssh_conn_id` will be ingored if `ssh_hook` is provided. :type ssh_conn_id: str :param remote_host: remote host to connect (templated) + Nullable. If provided, it will replace the `remote_host` which was + defined in `ssh_hook` or predefined in the connection of `ssh_conn_id`. :type remote_host: str :param local_filepath: local file path to get or put. (templated) :type local_filepath: str @@ -77,13 +81,21 @@ def __init__(self, def execute(self, context): file_msg = None try: - if self.ssh_conn_id and not self.ssh_hook: - self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id) + if self.ssh_conn_id: + if self.ssh_hook and isinstance(self.ssh_hook, SSHHook): + self.log.info("ssh_conn_id is ignored when ssh_hook is provided.") + else: + self.log.info("ssh_hook is not provided or invalid. " + + "Trying ssh_conn_id to create SSHHook.") + self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id) if not self.ssh_hook: - raise AirflowException("can not operate without ssh_hook or ssh_conn_id") + raise AirflowException("Cannot operate without ssh_hook or ssh_conn_id.") if self.remote_host is not None: + self.log.info("remote_host is provided explicitly. " + + "It will replace the remote_host which was defined " + + "in ssh_hook or predefined in connection of ssh_conn_id.") self.ssh_hook.remote_host = self.remote_host with self.ssh_hook.get_conn() as ssh_client: diff --git a/airflow/contrib/operators/sftp_to_s3_operator.py b/airflow/contrib/operators/sftp_to_s3_operator.py new file mode 100644 index 0000000000000..fcb63388126be --- /dev/null +++ b/airflow/contrib/operators/sftp_to_s3_operator.py @@ -0,0 +1,84 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from airflow.models import BaseOperator +from airflow.hooks.S3_hook import S3Hook +from airflow.contrib.hooks.ssh_hook import SSHHook +from tempfile import NamedTemporaryFile +from urllib.parse import urlparse + + +def get_s3_key(s3_key): + """This parses the correct format for S3 keys + regardless of how the S3 url is passed. """ + + parsed_s3_key = urlparse(s3_key) + return parsed_s3_key.path.lstrip('/') + + +class SFTPToS3Operator(BaseOperator): + """ + S3 To SFTP Operator + :param sftp_conn_id: The sftp connection id. + :type sftp_conn_id: string + :param sftp_path: The sftp remote path. + :type sftp_path: string + :param s3_conn_id: The s3 connnection id. + :type s3_conn_id: string + :param s3_bucket: The targeted s3 bucket. + :type s3_bucket: string + :param s3_key: The targeted s3 key. + :type s3_key: string + """ + + template_fields = ('s3_key', 'sftp_path') + + def __init__(self, + sftp_conn_id=None, + s3_conn_id=None, + s3_bucket=None, + s3_key=None, + sftp_path=None, + *args, + **kwargs): + super(SFTPToS3Operator, self).__init__(*args, **kwargs) + self.sftp_conn_id = sftp_conn_id + self.sftp_path = sftp_path + self.s3_bucket = s3_bucket + self.s3_key = s3_key + self.s3_conn_id = s3_conn_id + + self.ssh_hook = SSHHook(ssh_conn_id=self.sftp_conn_id) + self.s3_hook = S3Hook(self.s3_conn_id) + + def execute(self, context): + self.s3_key = get_s3_key(self.s3_key) + + ssh_client = self.ssh_hook.get_conn() + sftp_client = ssh_client.open_sftp() + + with NamedTemporaryFile("w") as f: + sftp_client.get(self.sftp_path, f.name) + + self.s3_hook.load_file( + filename=f.name, + key=self.s3_key, + bucket_name=self.s3_bucket, + replace=True + ) diff --git a/airflow/contrib/operators/ssh_operator.py b/airflow/contrib/operators/ssh_operator.py index c0e8953d2c344..2bf342935d60c 100644 --- a/airflow/contrib/operators/ssh_operator.py +++ b/airflow/contrib/operators/ssh_operator.py @@ -31,11 +31,15 @@ class SSHOperator(BaseOperator): """ SSHOperator to execute commands on given remote host using the ssh_hook. - :param ssh_hook: predefined ssh_hook to use for remote execution + :param ssh_hook: predefined ssh_hook to use for remote execution. + Either `ssh_hook` or `ssh_conn_id` needs to be provided. :type ssh_hook: :class:`SSHHook` - :param ssh_conn_id: connection id from airflow Connections + :param ssh_conn_id: connection id from airflow Connections. + `ssh_conn_id` will be ingored if `ssh_hook` is provided. :type ssh_conn_id: str :param remote_host: remote host to connect (templated) + Nullable. If provided, it will replace the `remote_host` which was + defined in `ssh_hook` or predefined in the connection of `ssh_conn_id`. :type remote_host: str :param command: command to execute on remote host. (templated) :type command: str @@ -68,14 +72,22 @@ def __init__(self, def execute(self, context): try: - if self.ssh_conn_id and not self.ssh_hook: - self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id, - timeout=self.timeout) + if self.ssh_conn_id: + if self.ssh_hook and isinstance(self.ssh_hook, SSHHook): + self.log.info("ssh_conn_id is ignored when ssh_hook is provided.") + else: + self.log.info("ssh_hook is not provided or invalid. " + + "Trying ssh_conn_id to create SSHHook.") + self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id, + timeout=self.timeout) if not self.ssh_hook: raise AirflowException("Cannot operate without ssh_hook or ssh_conn_id.") if self.remote_host is not None: + self.log.info("remote_host is provided explicitly. " + + "It will replace the remote_host which was defined " + + "in ssh_hook or predefined in connection of ssh_conn_id.") self.ssh_hook.remote_host = self.remote_host if not self.command: diff --git a/airflow/contrib/task_runner/cgroup_task_runner.py b/airflow/contrib/task_runner/cgroup_task_runner.py index 78a240f2db4e2..4662b0fe82f5a 100644 --- a/airflow/contrib/task_runner/cgroup_task_runner.py +++ b/airflow/contrib/task_runner/cgroup_task_runner.py @@ -123,7 +123,7 @@ def start(self): # Create a unique cgroup name cgroup_name = "airflow/{}/{}".format(datetime.datetime.utcnow(). strftime("%Y-%m-%d"), - str(uuid.uuid1())) + str(uuid.uuid4())) self.mem_cgroup_name = "memory/{}".format(cgroup_name) self.cpu_cgroup_name = "cpu/{}".format(cgroup_name) diff --git a/airflow/example_dags/example_kubernetes_operator.py b/airflow/example_dags/example_kubernetes_operator.py index e8d35c4c5bf66..4b3f54bd04ef9 100644 --- a/airflow/example_dags/example_kubernetes_operator.py +++ b/airflow/example_dags/example_kubernetes_operator.py @@ -48,7 +48,8 @@ in_cluster=False, task_id="task", get_logs=True, - dag=dag) + dag=dag, + is_delete_operator_pod=False) except ImportError as e: log.warn("Could not import KubernetesPodOperator: " + str(e)) diff --git a/airflow/executors/celery_executor.py b/airflow/executors/celery_executor.py index 2128ae7b09228..61bbc667160d0 100644 --- a/airflow/executors/celery_executor.py +++ b/airflow/executors/celery_executor.py @@ -82,7 +82,7 @@ def execute_async(self, key, command, self.log.info("[celery] queuing {key} through celery, " "queue={queue}".format(**locals())) self.tasks[key] = execute_command.apply_async( - args=command, queue=queue) + args=[command], queue=queue) self.last_state[key] = celery_states.PENDING def sync(self): diff --git a/airflow/models.py b/airflow/models.py index 55badf4828541..3f8f6c6736cf2 100755 --- a/airflow/models.py +++ b/airflow/models.py @@ -337,7 +337,8 @@ def process_file(self, filepath, only_if_updated=True, safe_mode=True): return found_dags mods = [] - if not zipfile.is_zipfile(filepath): + is_zipfile = zipfile.is_zipfile(filepath) + if not is_zipfile: if safe_mode and os.path.isfile(filepath): with open(filepath, 'rb') as f: content = f.read() @@ -409,7 +410,7 @@ def process_file(self, filepath, only_if_updated=True, safe_mode=True): if isinstance(dag, DAG): if not dag.full_filepath: dag.full_filepath = filepath - if dag.fileloc != filepath: + if dag.fileloc != filepath and not is_zipfile: dag.fileloc = filepath try: dag.is_subdag = False @@ -1815,12 +1816,16 @@ def get_template_context(self, session=None): next_execution_date = task.dag.following_schedule(self.execution_date) next_ds = None + next_ds_nodash = None if next_execution_date: next_ds = next_execution_date.strftime('%Y-%m-%d') + next_ds_nodash = next_ds.replace('-', '') prev_ds = None + prev_ds_nodash = None if prev_execution_date: prev_ds = prev_execution_date.strftime('%Y-%m-%d') + prev_ds_nodash = prev_ds.replace('-', '') ds_nodash = ds.replace('-', '') ts_nodash = ts.replace('-', '').replace(':', '') @@ -1887,7 +1892,9 @@ def __repr__(self): 'dag': task.dag, 'ds': ds, 'next_ds': next_ds, + 'next_ds_nodash': next_ds_nodash, 'prev_ds': prev_ds, + 'prev_ds_nodash': prev_ds_nodash, 'ds_nodash': ds_nodash, 'ts': ts, 'ts_nodash': ts_nodash, @@ -2328,14 +2335,17 @@ class derived from this one results in the creation of a task object, :param executor_config: Additional task-level configuration parameters that are interpreted by a specific executor. Parameters are namespaced by the name of executor. - ``example: to run this task in a specific docker container through - the KubernetesExecutor - MyOperator(..., - executor_config={ - "KubernetesExecutor": - {"image": "myCustomDockerImage"} - } - )`` + + **Example**: to run this task in a specific docker container through + the KubernetesExecutor :: + + MyOperator(..., + executor_config={ + "KubernetesExecutor": + {"image": "myCustomDockerImage"} + } + ) + :type executor_config: dict """ @@ -2413,10 +2423,17 @@ def __init__( self.email = email self.email_on_retry = email_on_retry self.email_on_failure = email_on_failure + self.start_date = start_date if start_date and not isinstance(start_date, datetime): self.log.warning("start_date for %s isn't datetime.datetime", self) + elif start_date: + self.start_date = timezone.convert_to_utc(start_date) + self.end_date = end_date + if end_date: + self.end_date = timezone.convert_to_utc(end_date) + if not TriggerRule.is_valid(trigger_rule): raise AirflowException( "The trigger_rule must be one of {all_triggers}," @@ -4840,6 +4857,8 @@ def get_state(self): def set_state(self, state): if self._state != state: self._state = state + self.end_date = timezone.utcnow() if self._state in State.finished() else None + if self.dag_id is not None: # FIXME: Due to the scoped_session factor we we don't get a clean # session here, so something really weird goes on: @@ -5063,7 +5082,7 @@ def update_state(self, session=None): if (not unfinished_tasks and any(r.state in (State.FAILED, State.UPSTREAM_FAILED) for r in roots)): self.log.info('Marking run %s failed', self) - self.state = State.FAILED + self.set_state(State.FAILED) dag.handle_callback(self, success=False, reason='task_failure', session=session) @@ -5071,20 +5090,20 @@ def update_state(self, session=None): elif not unfinished_tasks and all(r.state in (State.SUCCESS, State.SKIPPED) for r in roots): self.log.info('Marking run %s successful', self) - self.state = State.SUCCESS + self.set_state(State.SUCCESS) dag.handle_callback(self, success=True, reason='success', session=session) # if *all tasks* are deadlocked, the run failed elif (unfinished_tasks and none_depends_on_past and none_task_concurrency and no_dependencies_met): self.log.info('Deadlock; marking run %s failed', self) - self.state = State.FAILED + self.set_state(State.FAILED) dag.handle_callback(self, success=False, reason='all_tasks_deadlocked', session=session) # finally, if the roots aren't done, the dag is still running else: - self.state = State.RUNNING + self.set_state(State.RUNNING) # todo: determine we want to use with_for_update to make sure to lock the run session.merge(self) diff --git a/airflow/operators/__init__.py b/airflow/operators/__init__.py index 4e2060baaec1c..f732d8f8a85b1 100644 --- a/airflow/operators/__init__.py +++ b/airflow/operators/__init__.py @@ -54,21 +54,6 @@ 'PrestoValueCheckOperator', 'PrestoIntervalCheckOperator', ], - 'sensors': [ - 'BaseSensorOperator', - 'ExternalTaskSensor', - 'HdfsSensor', - 'HivePartitionSensor', - 'HttpSensor', - 'MetastorePartitionSensor', - 'NamedHivePartitionSensor', - 'S3KeySensor', - 'S3PrefixSensor', - 'SqlSensor', - 'TimeDeltaSensor', - 'TimeSensor', - 'WebHdfsSensor', - ], 'dagrun_operator': ['TriggerDagRunOperator'], 'dummy_operator': ['DummyOperator'], 'email_operator': ['EmailOperator'], diff --git a/airflow/operators/redshift_to_s3_operator.py b/airflow/operators/redshift_to_s3_operator.py index 9c1b621dae965..e6682c78df3f7 100644 --- a/airflow/operators/redshift_to_s3_operator.py +++ b/airflow/operators/redshift_to_s3_operator.py @@ -39,6 +39,16 @@ class RedshiftToS3Transfer(BaseOperator): :type redshift_conn_id: string :param aws_conn_id: reference to a specific S3 connection :type aws_conn_id: string + :parame verify: Whether or not to verify SSL certificates for S3 connection. + By default SSL certificates are verified. + You can provide the following values: + - False: do not validate SSL certificates. SSL will still be used + (unless use_ssl is False), but SSL certificates will not be + verified. + - path/to/cert/bundle.pem: A filename of the CA cert bundle to uses. + You can specify this argument if you want to use a different + CA cert bundle than the one used by botocore. + :type verify: bool or str :param unload_options: reference to a list of UNLOAD options :type unload_options: list """ @@ -56,6 +66,7 @@ def __init__( s3_key, redshift_conn_id='redshift_default', aws_conn_id='aws_default', + verify=None, unload_options=tuple(), autocommit=False, parameters=None, @@ -68,6 +79,7 @@ def __init__( self.s3_key = s3_key self.redshift_conn_id = redshift_conn_id self.aws_conn_id = aws_conn_id + self.verify = verify self.unload_options = unload_options self.autocommit = autocommit self.parameters = parameters @@ -79,7 +91,7 @@ def __init__( def execute(self, context): self.hook = PostgresHook(postgres_conn_id=self.redshift_conn_id) - self.s3 = S3Hook(aws_conn_id=self.aws_conn_id) + self.s3 = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) credentials = self.s3.get_credentials() unload_options = '\n\t\t\t'.join(self.unload_options) diff --git a/airflow/operators/s3_file_transform_operator.py b/airflow/operators/s3_file_transform_operator.py index 84a6eda0c8101..da82fa952cc1f 100644 --- a/airflow/operators/s3_file_transform_operator.py +++ b/airflow/operators/s3_file_transform_operator.py @@ -47,6 +47,18 @@ class S3FileTransformOperator(BaseOperator): :type source_s3_key: str :param source_aws_conn_id: source s3 connection :type source_aws_conn_id: str + :param source_verify: Whether or not to verify SSL certificates for S3 connetion. + By default SSL certificates are verified. + You can provide the following values: + + - ``False``: do not validate SSL certificates. SSL will still be used + (unless use_ssl is False), but SSL certificates will not be + verified. + - ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses. + You can specify this argument if you want to use a different + CA cert bundle than the one used by botocore. + This is also applicable to ``dest_verify``. + :type source_verify: bool or str :param dest_s3_key: The key to be written from S3. (templated) :type dest_s3_key: str :param dest_aws_conn_id: destination s3 connection @@ -71,14 +83,18 @@ def __init__( transform_script=None, select_expression=None, source_aws_conn_id='aws_default', + source_verify=None, dest_aws_conn_id='aws_default', + dest_verify=None, replace=False, *args, **kwargs): super(S3FileTransformOperator, self).__init__(*args, **kwargs) self.source_s3_key = source_s3_key self.source_aws_conn_id = source_aws_conn_id + self.source_verify = source_verify self.dest_s3_key = dest_s3_key self.dest_aws_conn_id = dest_aws_conn_id + self.dest_verify = dest_verify self.replace = replace self.transform_script = transform_script self.select_expression = select_expression @@ -88,8 +104,10 @@ def execute(self, context): raise AirflowException( "Either transform_script or select_expression must be specified") - source_s3 = S3Hook(aws_conn_id=self.source_aws_conn_id) - dest_s3 = S3Hook(aws_conn_id=self.dest_aws_conn_id) + source_s3 = S3Hook(aws_conn_id=self.source_aws_conn_id, + verify=self.source_verify) + dest_s3 = S3Hook(aws_conn_id=self.dest_aws_conn_id, + verify=self.dest_verify) self.log.info("Downloading source S3 file %s", self.source_s3_key) if not source_s3.check_for_key(self.source_s3_key): diff --git a/airflow/operators/s3_to_hive_operator.py b/airflow/operators/s3_to_hive_operator.py index b82ebce6fa295..85f05325f65be 100644 --- a/airflow/operators/s3_to_hive_operator.py +++ b/airflow/operators/s3_to_hive_operator.py @@ -78,6 +78,16 @@ class S3ToHiveTransfer(BaseOperator): :type delimiter: str :param aws_conn_id: source s3 connection :type aws_conn_id: str + :parame verify: Whether or not to verify SSL certificates for S3 connection. + By default SSL certificates are verified. + You can provide the following values: + - False: do not validate SSL certificates. SSL will still be used + (unless use_ssl is False), but SSL certificates will not be + verified. + - path/to/cert/bundle.pem: A filename of the CA cert bundle to uses. + You can specify this argument if you want to use a different + CA cert bundle than the one used by botocore. + :type verify: bool or str :param hive_cli_conn_id: destination hive connection :type hive_cli_conn_id: str :param input_compressed: Boolean to determine if file decompression is @@ -107,6 +117,7 @@ def __init__( check_headers=False, wildcard_match=False, aws_conn_id='aws_default', + verify=None, hive_cli_conn_id='hive_cli_default', input_compressed=False, tblproperties=None, @@ -125,6 +136,7 @@ def __init__( self.wildcard_match = wildcard_match self.hive_cli_conn_id = hive_cli_conn_id self.aws_conn_id = aws_conn_id + self.verify = verify self.input_compressed = input_compressed self.tblproperties = tblproperties self.select_expression = select_expression @@ -136,7 +148,7 @@ def __init__( def execute(self, context): # Downloading file from S3 - self.s3 = S3Hook(aws_conn_id=self.aws_conn_id) + self.s3 = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) self.hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id) self.log.info("Downloading S3 file") diff --git a/airflow/operators/s3_to_redshift_operator.py b/airflow/operators/s3_to_redshift_operator.py index 0d7921e9ed0f1..8c83f4437267f 100644 --- a/airflow/operators/s3_to_redshift_operator.py +++ b/airflow/operators/s3_to_redshift_operator.py @@ -39,6 +39,16 @@ class S3ToRedshiftTransfer(BaseOperator): :type redshift_conn_id: string :param aws_conn_id: reference to a specific S3 connection :type aws_conn_id: string + :parame verify: Whether or not to verify SSL certificates for S3 connection. + By default SSL certificates are verified. + You can provide the following values: + - False: do not validate SSL certificates. SSL will still be used + (unless use_ssl is False), but SSL certificates will not be + verified. + - path/to/cert/bundle.pem: A filename of the CA cert bundle to uses. + You can specify this argument if you want to use a different + CA cert bundle than the one used by botocore. + :type verify: bool or str :param copy_options: reference to a list of COPY options :type copy_options: list """ @@ -56,6 +66,7 @@ def __init__( s3_key, redshift_conn_id='redshift_default', aws_conn_id='aws_default', + verify=None, copy_options=tuple(), autocommit=False, parameters=None, @@ -67,13 +78,14 @@ def __init__( self.s3_key = s3_key self.redshift_conn_id = redshift_conn_id self.aws_conn_id = aws_conn_id + self.verify = verify self.copy_options = copy_options self.autocommit = autocommit self.parameters = parameters def execute(self, context): self.hook = PostgresHook(postgres_conn_id=self.redshift_conn_id) - self.s3 = S3Hook(aws_conn_id=self.aws_conn_id) + self.s3 = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) credentials = self.s3.get_credentials() copy_options = '\n\t\t\t'.join(self.copy_options) diff --git a/airflow/sensors/http_sensor.py b/airflow/sensors/http_sensor.py index 33f8531368ec1..f665737e96225 100644 --- a/airflow/sensors/http_sensor.py +++ b/airflow/sensors/http_sensor.py @@ -28,8 +28,11 @@ class HttpSensor(BaseSensorOperator): """ - Executes a HTTP get statement and returns False on failure: - 404 not found or response_check function returned False + Executes a HTTP GET statement and returns False on failure caused by + 404 Not Found or `response_check` returning False. + + HTTP Error codes other than 404 (like 403) or Connection Refused Error + would fail the sensor itself directly (no more poking). :param http_conn_id: The connection to run the sensor against :type http_conn_id: string diff --git a/airflow/sensors/s3_key_sensor.py b/airflow/sensors/s3_key_sensor.py index 246b4c3e73641..a743ba2e0cdd3 100644 --- a/airflow/sensors/s3_key_sensor.py +++ b/airflow/sensors/s3_key_sensor.py @@ -43,6 +43,16 @@ class S3KeySensor(BaseSensorOperator): :type wildcard_match: bool :param aws_conn_id: a reference to the s3 connection :type aws_conn_id: str + :param verify: Whether or not to verify SSL certificates for S3 connection. + By default SSL certificates are verified. + You can provide the following values: + - False: do not validate SSL certificates. SSL will still be used + (unless use_ssl is False), but SSL certificates will not be + verified. + - path/to/cert/bundle.pem: A filename of the CA cert bundle to uses. + You can specify this argument if you want to use a different + CA cert bundle than the one used by botocore. + :type verify: bool or str """ template_fields = ('bucket_key', 'bucket_name') @@ -52,6 +62,7 @@ def __init__(self, bucket_name=None, wildcard_match=False, aws_conn_id='aws_default', + verify=None, *args, **kwargs): super(S3KeySensor, self).__init__(*args, **kwargs) @@ -76,10 +87,11 @@ def __init__(self, self.bucket_key = bucket_key self.wildcard_match = wildcard_match self.aws_conn_id = aws_conn_id + self.verify = verify def poke(self, context): from airflow.hooks.S3_hook import S3Hook - hook = S3Hook(aws_conn_id=self.aws_conn_id) + hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) full_url = "s3://" + self.bucket_name + "/" + self.bucket_key self.log.info('Poking for key : {full_url}'.format(**locals())) if self.wildcard_match: diff --git a/airflow/sensors/s3_prefix_sensor.py b/airflow/sensors/s3_prefix_sensor.py index 917dd46e26c28..4617c97cf360e 100644 --- a/airflow/sensors/s3_prefix_sensor.py +++ b/airflow/sensors/s3_prefix_sensor.py @@ -38,6 +38,18 @@ class S3PrefixSensor(BaseSensorOperator): :param delimiter: The delimiter intended to show hierarchy. Defaults to '/'. :type delimiter: str + :param aws_conn_id: a reference to the s3 connection + :type aws_conn_id: str + :param verify: Whether or not to verify SSL certificates for S3 connection. + By default SSL certificates are verified. + You can provide the following values: + - False: do not validate SSL certificates. SSL will still be used + (unless use_ssl is False), but SSL certificates will not be + verified. + - path/to/cert/bundle.pem: A filename of the CA cert bundle to uses. + You can specify this argument if you want to use a different + CA cert bundle than the one used by botocore. + :type verify: bool or str """ template_fields = ('prefix', 'bucket_name') @@ -47,6 +59,7 @@ def __init__(self, prefix, delimiter='/', aws_conn_id='aws_default', + verify=None, *args, **kwargs): super(S3PrefixSensor, self).__init__(*args, **kwargs) @@ -56,12 +69,13 @@ def __init__(self, self.delimiter = delimiter self.full_url = "s3://" + bucket_name + '/' + prefix self.aws_conn_id = aws_conn_id + self.verify = verify def poke(self, context): self.log.info('Poking for prefix : {self.prefix}\n' 'in bucket s3://{self.bucket_name}'.format(**locals())) from airflow.hooks.S3_hook import S3Hook - hook = S3Hook(aws_conn_id=self.aws_conn_id) + hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) return hook.check_for_prefix( prefix=self.prefix, delimiter=self.delimiter, diff --git a/airflow/utils/state.py b/airflow/utils/state.py index e95f58c3505dd..a351df07b9654 100644 --- a/airflow/utils/state.py +++ b/airflow/utils/state.py @@ -101,7 +101,6 @@ def finished(cls): """ return [ cls.SUCCESS, - cls.SHUTDOWN, cls.FAILED, cls.SKIPPED, ] @@ -117,5 +116,6 @@ def unfinished(cls): cls.SCHEDULED, cls.QUEUED, cls.RUNNING, + cls.SHUTDOWN, cls.UP_FOR_RETRY ] diff --git a/airflow/www/static/main.css b/airflow/www/static/main.css index 57164b94e5ccf..147695c4a9591 100644 --- a/airflow/www/static/main.css +++ b/airflow/www/static/main.css @@ -262,3 +262,4 @@ div.square { .sc { color: #BA2121 } /* Literal.String.Char */ .sd { color: #BA2121; font-style: italic } /* Literal.String.Doc */ .s2 { color: #BA2121 } /* Literal.String.Double */ +.s1 { color: #BA2121 } /* Literal.String.Single */ diff --git a/airflow/www/utils.py b/airflow/www/utils.py index 9ce114d5eda3e..e85bc5909ae27 100644 --- a/airflow/www/utils.py +++ b/airflow/www/utils.py @@ -20,17 +20,21 @@ # flake8: noqa: E402 import inspect from future import standard_library -standard_library.install_aliases() +standard_library.install_aliases() # noqa: E402 from builtins import str, object from cgi import escape from io import BytesIO as IO import functools import gzip +import io import json +import os +import re import time import wtforms from wtforms.compat import text_type +import zipfile from flask import after_this_request, request, Response from flask_admin.model import filters @@ -372,6 +376,22 @@ def zipper(response): return view_func +def open_maybe_zipped(f, mode='r'): + """ + Opens the given file. If the path contains a folder with a .zip suffix, then + the folder is treated as a zip archive, opening the file inside the archive. + + :return: a file object, as in `open`, or as in `ZipFile.open`. + """ + + _, archive, filename = re.search( + r'((.*\.zip){})?(.*)'.format(re.escape(os.sep)), f).groups() + if archive and zipfile.is_zipfile(archive): + return zipfile.ZipFile(archive, mode=mode).open(filename) + else: + return io.open(f, mode=mode) + + def make_cache_key(*args, **kwargs): """ Used by cache to get a unique key per URL diff --git a/airflow/www/views.py b/airflow/www/views.py index e1a7caa8bb78c..aa2530e45827a 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -661,7 +661,7 @@ def code(self): dag = dagbag.get_dag(dag_id) title = dag_id try: - with open(dag.fileloc, 'r') as f: + with wwwutils.open_maybe_zipped(dag.fileloc, 'r') as f: code = f.read() html_code = highlight( code, lexers.PythonLexer(), HtmlFormatter(linenos=True)) diff --git a/airflow/www_rbac/static/css/main.css b/airflow/www_rbac/static/css/main.css index ac6189938cb3b..d3d198356e8db 100644 --- a/airflow/www_rbac/static/css/main.css +++ b/airflow/www_rbac/static/css/main.css @@ -265,3 +265,4 @@ div.square { .sc { color: #BA2121 } /* Literal.String.Char */ .sd { color: #BA2121; font-style: italic } /* Literal.String.Doc */ .s2 { color: #BA2121 } /* Literal.String.Double */ +.s1 { color: #BA2121 } /* Literal.String.Single */ diff --git a/airflow/www_rbac/static/js/clock.js b/airflow/www_rbac/static/js/base.js similarity index 84% rename from airflow/www_rbac/static/js/clock.js rename to airflow/www_rbac/static/js/base.js index afde7c2545f37..ea065792a533a 100644 --- a/airflow/www_rbac/static/js/clock.js +++ b/airflow/www_rbac/static/js/base.js @@ -33,4 +33,11 @@ function displayTime() { $(document).ready(function () { displayTime(); $('span').tooltip(); + $.ajaxSetup({ + beforeSend: function(xhr, settings) { + if (!/^(GET|HEAD|OPTIONS|TRACE)$/i.test(settings.type) && !this.crossDomain) { + xhr.setRequestHeader("X-CSRFToken", csrfToken); + } + } + }); }); diff --git a/airflow/www_rbac/templates/appbuilder/baselayout.html b/airflow/www_rbac/templates/appbuilder/baselayout.html index 1653a909d02b2..89e32cae783ec 100644 --- a/airflow/www_rbac/templates/appbuilder/baselayout.html +++ b/airflow/www_rbac/templates/appbuilder/baselayout.html @@ -67,9 +67,9 @@ {% block tail_js %} {{ super() }} - + {% endblock %} diff --git a/airflow/www_rbac/utils.py b/airflow/www_rbac/utils.py index a0e9258eae30e..0176a5312c373 100644 --- a/airflow/www_rbac/utils.py +++ b/airflow/www_rbac/utils.py @@ -26,6 +26,10 @@ import wtforms import bleach import markdown +import re +import zipfile +import os +import io from builtins import str from past.builtins import basestring @@ -202,6 +206,22 @@ def json_response(obj): mimetype="application/json") +def open_maybe_zipped(f, mode='r'): + """ + Opens the given file. If the path contains a folder with a .zip suffix, then + the folder is treated as a zip archive, opening the file inside the archive. + + :return: a file object, as in `open`, or as in `ZipFile.open`. + """ + + _, archive, filename = re.search( + r'((.*\.zip){})?(.*)'.format(re.escape(os.sep)), f).groups() + if archive and zipfile.is_zipfile(archive): + return zipfile.ZipFile(archive, mode=mode).open(filename) + else: + return io.open(f, mode=mode) + + def make_cache_key(*args, **kwargs): """ Used by cache to get a unique key per URL diff --git a/airflow/www_rbac/views.py b/airflow/www_rbac/views.py index d011724cc6811..3dc3400968812 100644 --- a/airflow/www_rbac/views.py +++ b/airflow/www_rbac/views.py @@ -400,7 +400,7 @@ def code(self): dag = dagbag.get_dag(dag_id) title = dag_id try: - with open(dag.fileloc, 'r') as f: + with wwwutils.open_maybe_zipped(dag.fileloc, 'r') as f: code = f.read() html_code = highlight( code, lexers.PythonLexer(), HtmlFormatter(linenos=True)) diff --git a/airflow/www_rbac/webpack.config.js b/airflow/www_rbac/webpack.config.js index 29b99f408f20e..16f394cc616b9 100644 --- a/airflow/www_rbac/webpack.config.js +++ b/airflow/www_rbac/webpack.config.js @@ -35,7 +35,7 @@ const BUILD_DIR = path.resolve(__dirname, './static/dist'); const config = { entry: { connectionForm: `${STATIC_DIR}/js/connection_form.js`, - clock: `${STATIC_DIR}/js/clock.js`, + base: `${STATIC_DIR}/js/base.js`, graph: `${STATIC_DIR}/js/graph.js`, ganttChartD3v2: `${STATIC_DIR}/js/gantt-chart-d3v2.js`, main: `${STATIC_DIR}/css/main.css`, diff --git a/docs/code.rst b/docs/code.rst index 80ec76193fcf3..c9e9b3d431452 100644 --- a/docs/code.rst +++ b/docs/code.rst @@ -242,12 +242,14 @@ Variable Description ================================= ==================================== ``{{ ds }}`` the execution date as ``YYYY-MM-DD`` ``{{ ds_nodash }}`` the execution date as ``YYYYMMDD`` -``{{ prev_ds }}`` the previous execution date as ``YYYY-MM-DD``. +``{{ prev_ds }}`` the previous execution date as ``YYYY-MM-DD`` if ``{{ ds }}`` is ``2016-01-08`` and ``schedule_interval`` is ``@weekly``, - ``{{ prev_ds }}`` will be ``2016-01-01``. -``{{ next_ds }}`` the next execution date as ``YYYY-MM-DD``. + ``{{ prev_ds }}`` will be ``2016-01-01`` +``{{ prev_ds_nodash }}`` the previous execution date as ``YYYYMMDD`` if exists, else ``None` +``{{ next_ds }}`` the next execution date as ``YYYY-MM-DD`` if ``{{ ds }}`` is ``2016-01-01`` and ``schedule_interval`` is ``@weekly``, - ``{{ prev_ds }}`` will be ``2016-01-08``. + ``{{ prev_ds }}`` will be ``2016-01-08`` +``{{ next_ds_nodash }}`` the next execution date as ``YYYYMMDD`` if exists, else ``None` ``{{ yesterday_ds }}`` yesterday's date as ``YYYY-MM-DD`` ``{{ yesterday_ds_nodash }}`` yesterday's date as ``YYYYMMDD`` ``{{ tomorrow_ds }}`` tomorrow's date as ``YYYY-MM-DD`` diff --git a/docs/faq.rst b/docs/faq.rst index 46212084c58fd..61c1ba9ce1c96 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -141,15 +141,15 @@ What are all the ``airflow run`` commands in my process list? There are many layers of ``airflow run`` commands, meaning it can call itself. - Basic ``airflow run``: fires up an executor, and tell it to run an - ``airflow run --local`` command. if using Celery, this means it puts a - command in the queue for it to run remote, on the worker. If using + ``airflow run --local`` command. If using Celery, this means it puts a + command in the queue for it to run remotely on the worker. If using LocalExecutor, that translates into running it in a subprocess pool. - Local ``airflow run --local``: starts an ``airflow run --raw`` command (described below) as a subprocess and is in charge of emitting heartbeats, listening for external kill signals - and ensures some cleanup takes place if the subprocess fails + and ensures some cleanup takes place if the subprocess fails. - Raw ``airflow run --raw`` runs the actual operator's execute method and - performs the actual work + performs the actual work. How can my airflow dag run faster? diff --git a/docs/kubernetes.rst b/docs/kubernetes.rst index a4916858fc15e..4f6eeb14b2298 100644 --- a/docs/kubernetes.rst +++ b/docs/kubernetes.rst @@ -91,7 +91,9 @@ Kubernetes Operator volume_mounts=[volume_mount] name="test", task_id="task", - affinity=affinity + affinity=affinity, + is_delete_operator_pod=True, + hostnetwork=False ) diff --git a/docs/timezone.rst b/docs/timezone.rst index 9e8598e2ed464..fe44ecfbb9f62 100644 --- a/docs/timezone.rst +++ b/docs/timezone.rst @@ -2,23 +2,23 @@ Time zones ========== Support for time zones is enabled by default. Airflow stores datetime information in UTC internally and in the database. -It allows you to run your DAGs with time zone dependent schedules. At the moment Airflow does not convert them to the -end user’s time zone in the user interface. There it will always be displayed in UTC. Also templates used in Operators +It allows you to run your DAGs with time zone dependent schedules. At the moment Airflow does not convert them to the +end user’s time zone in the user interface. There it will always be displayed in UTC. Also templates used in Operators are not converted. Time zone information is exposed and it is up to the writer of DAG what do with it. -This is handy if your users live in more than one time zone and you want to display datetime information according to +This is handy if your users live in more than one time zone and you want to display datetime information according to each user’s wall clock. -Even if you are running Airflow in only one time zone it is still good practice to store data in UTC in your database -(also before Airflow became time zone aware this was also to recommended or even required setup). The main reason is -Daylight Saving Time (DST). Many countries have a system of DST, where clocks are moved forward in spring and backward -in autumn. If you’re working in local time, you’re likely to encounter errors twice a year, when the transitions -happen. (The pendulum and pytz documentation discusses these issues in greater detail.) This probably doesn’t matter -for a simple DAG, but it’s a problem if you are in, for example, financial services where you have end of day -deadlines to meet. +Even if you are running Airflow in only one time zone it is still good practice to store data in UTC in your database +(also before Airflow became time zone aware this was also to recommended or even required setup). The main reason is +Daylight Saving Time (DST). Many countries have a system of DST, where clocks are moved forward in spring and backward +in autumn. If you’re working in local time, you’re likely to encounter errors twice a year, when the transitions +happen. (The pendulum and pytz documentation discusses these issues in greater detail.) This probably doesn’t matter +for a simple DAG, but it’s a problem if you are in, for example, financial services where you have end of day +deadlines to meet. -The time zone is set in `airflow.cfg`. By default it is set to utc, but you change it to use the system’s settings or -an arbitrary IANA time zone, e.g. `Europe/Amsterdam`. It is dependent on `pendulum`, which is more accurate than `pytz`. +The time zone is set in `airflow.cfg`. By default it is set to utc, but you change it to use the system’s settings or +an arbitrary IANA time zone, e.g. `Europe/Amsterdam`. It is dependent on `pendulum`, which is more accurate than `pytz`. Pendulum is installed when you install Airflow. Please note that the Web UI currently only runs in UTC. @@ -28,8 +28,8 @@ Concepts Naïve and aware datetime objects '''''''''''''''''''''''''''''''' -Python’s datetime.datetime objects have a tzinfo attribute that can be used to store time zone information, -represented as an instance of a subclass of datetime.tzinfo. When this attribute is set and describes an offset, +Python’s datetime.datetime objects have a tzinfo attribute that can be used to store time zone information, +represented as an instance of a subclass of datetime.tzinfo. When this attribute is set and describes an offset, a datetime object is aware. Otherwise, it’s naive. You can use timezone.is_aware() and timezone.is_naive() to determine whether datetimes are aware or naive. @@ -39,7 +39,7 @@ Because Airflow uses time-zone-aware datetime objects. If your code creates date .. code:: python from airflow.utils import timezone - + now = timezone.utcnow() a_date = timezone.datetime(2017,1,1) @@ -49,9 +49,9 @@ Interpretation of naive datetime objects Although Airflow operates fully time zone aware, it still accepts naive date time objects for `start_dates` and `end_dates` in your DAG definitions. This is mostly in order to preserve backwards compatibility. In -case a naive `start_date` or `end_date` is encountered the default time zone is applied. It is applied +case a naive `start_date` or `end_date` is encountered the default time zone is applied. It is applied in such a way that it is assumed that the naive date time is already in the default time zone. In other -words if you have a default time zone setting of `Europe/Amsterdam` and create a naive datetime `start_date` of +words if you have a default time zone setting of `Europe/Amsterdam` and create a naive datetime `start_date` of `datetime(2017,1,1)` it is assumed to be a `start_date` of Jan 1, 2017 Amsterdam time. .. code:: python @@ -65,16 +65,16 @@ words if you have a default time zone setting of `Europe/Amsterdam` and create a op = DummyOperator(task_id='dummy', dag=dag) print(op.owner) # Airflow -Unfortunately, during DST transitions, some datetimes don’t exist or are ambiguous. -In such situations, pendulum raises an exception. That’s why you should always create aware +Unfortunately, during DST transitions, some datetimes don’t exist or are ambiguous. +In such situations, pendulum raises an exception. That’s why you should always create aware datetime objects when time zone support is enabled. -In practice, this is rarely an issue. Airflow gives you aware datetime objects in the models and DAGs, and most often, -new datetime objects are created from existing ones through timedelta arithmetic. The only datetime that’s often +In practice, this is rarely an issue. Airflow gives you aware datetime objects in the models and DAGs, and most often, +new datetime objects are created from existing ones through timedelta arithmetic. The only datetime that’s often created in application code is the current time, and timezone.utcnow() automatically does the right thing. -Default time zone +Default time zone ''''''''''''''''' The default time zone is the time zone defined by the `default_timezone` setting under `[core]`. If @@ -92,15 +92,15 @@ it is therefore important to make sure this setting is equal on all Airflow node Time zone aware DAGs -------------------- -Creating a time zone aware DAG is quite simple. Just make sure to supply a time zone aware `start_date`. It is +Creating a time zone aware DAG is quite simple. Just make sure to supply a time zone aware `start_date`. It is recommended to use `pendulum` for this, but `pytz` (to be installed manually) can also be used for this. .. code:: python import pendulum - + local_tz = pendulum.timezone("Europe/Amsterdam") - + default_args=dict( start_date=datetime(2016, 1, 1, tzinfo=local_tz), owner='Airflow' @@ -110,18 +110,21 @@ recommended to use `pendulum` for this, but `pytz` (to be installed manually) ca op = DummyOperator(task_id='dummy', dag=dag) print(dag.timezone) # - +Please note that while it is possible to set a `start_date` and `end_date` for Tasks always the DAG timezone +or global timezone (in that order) will be used to calculate the next execution date. Upon first encounter +the start date or end date will be converted to UTC using the timezone associated with start_date or end_date, +then for calculations this timezone information will be disregarded. Templates ''''''''' -Airflow returns time zone aware datetimes in templates, but does not convert them to local time so they remain in UTC. +Airflow returns time zone aware datetimes in templates, but does not convert them to local time so they remain in UTC. It is left up to the DAG to handle this. .. code:: python import pendulum - + local_tz = pendulum.timezone("Europe/Amsterdam") local_tz.convert(execution_date) @@ -129,10 +132,10 @@ It is left up to the DAG to handle this. Cron schedules '''''''''''''' -In case you set a cron schedule, Airflow assumes you will always want to run at the exact same time. It will -then ignore day light savings time. Thus, if you have a schedule that says -run at end of interval every day at 08:00 GMT+1 it will always run end of interval 08:00 GMT+1, -regardless if day light savings time is in place. +In case you set a cron schedule, Airflow assumes you will always want to run at the exact same time. It will +then ignore day light savings time. Thus, if you have a schedule that says +run at end of interval every day at 08:00 GMT+1 it will always run end of interval 08:00 GMT+1, +regardless if day light savings time is in place. Time deltas diff --git a/scripts/ci/docker-compose.yml b/scripts/ci/docker-compose.yml index 861cf9e8b89ce..32e0a536c1597 100644 --- a/scripts/ci/docker-compose.yml +++ b/scripts/ci/docker-compose.yml @@ -70,11 +70,17 @@ services: - SLUGIFY_USES_TEXT_UNIDECODE=yes - TOX_ENV - PYTHON_VERSION + - CI - TRAVIS - TRAVIS_BRANCH - TRAVIS_BUILD_DIR - TRAVIS_JOB_ID + - TRAVIS_JOB_NUMBER - TRAVIS_PULL_REQUEST + - TRAVIS_COMMIT + - TRAVIS_REPO_SLUG + - TRAVIS_OS_NAME + - TRAVIS_TAG depends_on: - postgres - mysql diff --git a/setup.cfg b/setup.cfg index 622cc1303a173..881fe0107d9b2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + [metadata] name = Airflow summary = Airflow is a system to programmatically author, schedule and monitor data pipelines. @@ -34,4 +35,3 @@ all_files = 1 upload-dir = docs/_build/html [easy_install] - diff --git a/setup.py b/setup.py index b92d267aaa6a3..aecc21817025c 100644 --- a/setup.py +++ b/setup.py @@ -181,7 +181,7 @@ def write_version(filename=os.path.join(*['airflow', 'elasticsearch>=5.0.0,<6.0.0', 'elasticsearch-dsl>=5.0.0,<6.0.0' ] -emr = ['boto3>=1.0.0'] +emr = ['boto3>=1.0.0, <1.8.0'] gcp_api = [ 'httplib2>=0.9.2', 'google-api-python-client>=1.6.0, <2.0.0dev', @@ -219,7 +219,7 @@ def write_version(filename=os.path.join(*['airflow', qds = ['qds-sdk>=1.9.6'] rabbitmq = ['librabbitmq>=1.6.1'] redis = ['redis>=2.10.5'] -s3 = ['boto3>=1.7.0'] +s3 = ['boto3>=1.7.0, <1.8.0'] salesforce = ['simple-salesforce>=0.72'] samba = ['pysmbclient>=0.1.3'] segment = ['analytics-python>=1.2.9'] @@ -322,7 +322,7 @@ def do_setup(): 'requests>=2.5.1, <3', 'setproctitle>=1.1.8, <2', 'sqlalchemy>=1.1.15, <1.2.0', - 'tabulate>=0.7.5, <0.8.0', + 'tabulate>=0.7.5, <=0.8.2', 'tenacity==4.8.0', 'thrift>=0.9.2', 'tzlocal>=1.4', diff --git a/tests/contrib/hooks/test_bigquery_hook.py b/tests/contrib/hooks/test_bigquery_hook.py index d7e9491b8b8d4..69a103bf0689f 100644 --- a/tests/contrib/hooks/test_bigquery_hook.py +++ b/tests/contrib/hooks/test_bigquery_hook.py @@ -25,7 +25,8 @@ import mock from airflow.contrib.hooks import bigquery_hook as hook -from airflow.contrib.hooks.bigquery_hook import _cleanse_time_partitioning +from airflow.contrib.hooks.bigquery_hook import _cleanse_time_partitioning, \ + _validate_value, _api_resource_configs_duplication_check bq_available = True @@ -206,6 +207,16 @@ def mock_job_cancel(projectId, jobId): class TestBigQueryBaseCursor(unittest.TestCase): + def test_bql_deprecation_warning(self): + with warnings.catch_warnings(record=True) as w: + hook.BigQueryBaseCursor("test", "test").run_query( + bql='select * from test_table' + ) + yield + self.assertIn( + 'Deprecated parameter `bql`', + w[0].message.args[0]) + def test_invalid_schema_update_options(self): with self.assertRaises(Exception) as context: hook.BigQueryBaseCursor("test", "test").run_load( @@ -216,16 +227,6 @@ def test_invalid_schema_update_options(self): ) self.assertIn("THIS IS NOT VALID", str(context.exception)) - @mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration') - def test_bql_deprecation_warning(self, mock_rwc): - with warnings.catch_warnings(record=True) as w: - hook.BigQueryBaseCursor("test", "test").run_query( - bql='select * from test_table' - ) - self.assertIn( - 'Deprecated parameter `bql`', - w[0].message.args[0]) - def test_nobql_nosql_param_error(self): with self.assertRaises(TypeError) as context: hook.BigQueryBaseCursor("test", "test").run_query( @@ -281,6 +282,39 @@ def test_run_query_sql_dialect_override(self, run_with_config): args, kwargs = run_with_config.call_args self.assertIs(args[0]['query']['useLegacySql'], bool_val) + @mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration') + def test_api_resource_configs(self, run_with_config): + for bool_val in [True, False]: + cursor = hook.BigQueryBaseCursor(mock.Mock(), "project_id") + cursor.run_query('query', + api_resource_configs={ + 'query': {'useQueryCache': bool_val}}) + args, kwargs = run_with_config.call_args + self.assertIs(args[0]['query']['useQueryCache'], bool_val) + self.assertIs(args[0]['query']['useLegacySql'], True) + + @mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration') + def test_api_resource_configs_duplication_warning(self, run_with_config): + with self.assertRaises(ValueError): + cursor = hook.BigQueryBaseCursor(mock.Mock(), "project_id") + cursor.run_query('query', + use_legacy_sql=True, + api_resource_configs={ + 'query': {'useLegacySql': False}}) + + def test_validate_value(self): + with self.assertRaises(TypeError): + _validate_value("case_1", "a", dict) + self.assertIsNone(_validate_value("case_2", 0, int)) + + def test_duplication_check(self): + with self.assertRaises(ValueError): + key_one = True + _api_resource_configs_duplication_check( + "key_one", key_one, {"key_one": False}) + self.assertIsNone(_api_resource_configs_duplication_check( + "key_one", key_one, {"key_one": True})) + class TestLabelsInRunJob(unittest.TestCase): @mock.patch.object(hook.BigQueryBaseCursor, 'run_with_configuration') diff --git a/tests/contrib/hooks/test_databricks_hook.py b/tests/contrib/hooks/test_databricks_hook.py index aca8dd96004b4..04a7c8dc3c879 100644 --- a/tests/contrib/hooks/test_databricks_hook.py +++ b/tests/contrib/hooks/test_databricks_hook.py @@ -18,15 +18,21 @@ # under the License. # +import itertools import json import unittest +from requests import exceptions as requests_exceptions + from airflow import __version__ -from airflow.contrib.hooks.databricks_hook import DatabricksHook, RunState, SUBMIT_RUN_ENDPOINT, _TokenAuth +from airflow.contrib.hooks.databricks_hook import ( + DatabricksHook, + RunState, + SUBMIT_RUN_ENDPOINT +) from airflow.exceptions import AirflowException from airflow.models import Connection from airflow.utils import db -from requests import exceptions as requests_exceptions try: from unittest import mock @@ -46,6 +52,7 @@ 'node_type_id': 'r3.xlarge', 'num_workers': 1 } +CLUSTER_ID = 'cluster_id' RUN_ID = 1 HOST = 'xx.cloud.databricks.com' HOST_WITH_SCHEME = 'https://xx.cloud.databricks.com' @@ -79,12 +86,68 @@ def get_run_endpoint(host): """ return 'https://{}/api/2.0/jobs/runs/get'.format(host) + def cancel_run_endpoint(host): """ Utility function to generate the get run endpoint given the host. """ return 'https://{}/api/2.0/jobs/runs/cancel'.format(host) + +def start_cluster_endpoint(host): + """ + Utility function to generate the get run endpoint given the host. + """ + return 'https://{}/api/2.0/clusters/start'.format(host) + + +def restart_cluster_endpoint(host): + """ + Utility function to generate the get run endpoint given the host. + """ + return 'https://{}/api/2.0/clusters/restart'.format(host) + + +def terminate_cluster_endpoint(host): + """ + Utility function to generate the get run endpoint given the host. + """ + return 'https://{}/api/2.0/clusters/delete'.format(host) + +def create_valid_response_mock(content): + response = mock.MagicMock() + response.json.return_value = content + return response + + +def create_post_side_effect(exception, status_code=500): + if exception != requests_exceptions.HTTPError: + return exception() + else: + response = mock.MagicMock() + response.status_code = status_code + response.raise_for_status.side_effect = exception(response=response) + return response + + +def setup_mock_requests( + mock_requests, + exception, + status_code=500, + error_count=None, + response_content=None): + + side_effect = create_post_side_effect(exception, status_code) + + if error_count is None: + # POST requests will fail indefinitely + mock_requests.post.side_effect = itertools.repeat(side_effect) + else: + # POST requests will fail 'error_count' times, and then they will succeed (once) + mock_requests.post.side_effect = \ + [side_effect] * error_count + [create_valid_response_mock(response_content)] + + class DatabricksHookTest(unittest.TestCase): """ Tests for DatabricksHook. @@ -99,7 +162,7 @@ def setUp(self, session=None): conn.password = PASSWORD session.commit() - self.hook = DatabricksHook() + self.hook = DatabricksHook(retry_delay=0) def test_parse_host_with_proper_host(self): host = self.hook._parse_host(HOST) @@ -111,34 +174,85 @@ def test_parse_host_with_scheme(self): def test_init_bad_retry_limit(self): with self.assertRaises(ValueError): - DatabricksHook(retry_limit = 0) - - @mock.patch('airflow.contrib.hooks.databricks_hook.requests') - def test_do_api_call_with_error_retry(self, mock_requests): - for exception in [requests_exceptions.ConnectionError, requests_exceptions.Timeout]: - with mock.patch.object(self.hook.log, 'error') as mock_errors: - mock_requests.reset_mock() - mock_requests.post.side_effect = exception() + DatabricksHook(retry_limit=0) + + def test_do_api_call_retries_with_retryable_error(self): + for exception in [ + requests_exceptions.ConnectionError, + requests_exceptions.SSLError, + requests_exceptions.Timeout, + requests_exceptions.ConnectTimeout, + requests_exceptions.HTTPError]: + with mock.patch( + 'airflow.contrib.hooks.databricks_hook.requests') as mock_requests, \ + mock.patch.object(self.hook.log, 'error') as mock_errors: + setup_mock_requests(mock_requests, exception) with self.assertRaises(AirflowException): self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {}) - self.assertEquals(len(mock_errors.mock_calls), self.hook.retry_limit) + self.assertEquals(mock_errors.call_count, self.hook.retry_limit) @mock.patch('airflow.contrib.hooks.databricks_hook.requests') - def test_do_api_call_with_bad_status_code(self, mock_requests): - mock_requests.codes.ok = 200 - status_code_mock = mock.PropertyMock(return_value=500) - type(mock_requests.post.return_value).status_code = status_code_mock - with self.assertRaises(AirflowException): - self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {}) + def test_do_api_call_does_not_retry_with_non_retryable_error(self, mock_requests): + setup_mock_requests( + mock_requests, requests_exceptions.HTTPError, status_code=400 + ) + + with mock.patch.object(self.hook.log, 'error') as mock_errors: + with self.assertRaises(AirflowException): + self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {}) + + mock_errors.assert_not_called() + + def test_do_api_call_succeeds_after_retrying(self): + for exception in [ + requests_exceptions.ConnectionError, + requests_exceptions.SSLError, + requests_exceptions.Timeout, + requests_exceptions.ConnectTimeout, + requests_exceptions.HTTPError]: + with mock.patch( + 'airflow.contrib.hooks.databricks_hook.requests') as mock_requests, \ + mock.patch.object(self.hook.log, 'error') as mock_errors: + setup_mock_requests( + mock_requests, + exception, + error_count=2, + response_content={'run_id': '1'} + ) + + response = self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {}) + + self.assertEquals(mock_errors.call_count, 2) + self.assertEquals(response, {'run_id': '1'}) + + @mock.patch('airflow.contrib.hooks.databricks_hook.sleep') + def test_do_api_call_waits_between_retries(self, mock_sleep): + retry_delay = 5 + self.hook = DatabricksHook(retry_delay=retry_delay) + + for exception in [ + requests_exceptions.ConnectionError, + requests_exceptions.SSLError, + requests_exceptions.Timeout, + requests_exceptions.ConnectTimeout, + requests_exceptions.HTTPError]: + with mock.patch( + 'airflow.contrib.hooks.databricks_hook.requests') as mock_requests, \ + mock.patch.object(self.hook.log, 'error'): + mock_sleep.reset_mock() + setup_mock_requests(mock_requests, exception) + + with self.assertRaises(AirflowException): + self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {}) + + self.assertEquals(len(mock_sleep.mock_calls), self.hook.retry_limit - 1) + mock_sleep.assert_called_with(retry_delay) @mock.patch('airflow.contrib.hooks.databricks_hook.requests') def test_submit_run(self, mock_requests): - mock_requests.codes.ok = 200 mock_requests.post.return_value.json.return_value = {'run_id': '1'} - status_code_mock = mock.PropertyMock(return_value=200) - type(mock_requests.post.return_value).status_code = status_code_mock json = { 'notebook_task': NOTEBOOK_TASK, 'new_cluster': NEW_CLUSTER @@ -158,10 +272,7 @@ def test_submit_run(self, mock_requests): @mock.patch('airflow.contrib.hooks.databricks_hook.requests') def test_get_run_page_url(self, mock_requests): - mock_requests.codes.ok = 200 mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE - status_code_mock = mock.PropertyMock(return_value=200) - type(mock_requests.get.return_value).status_code = status_code_mock run_page_url = self.hook.get_run_page_url(RUN_ID) @@ -175,10 +286,7 @@ def test_get_run_page_url(self, mock_requests): @mock.patch('airflow.contrib.hooks.databricks_hook.requests') def test_get_run_state(self, mock_requests): - mock_requests.codes.ok = 200 mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE - status_code_mock = mock.PropertyMock(return_value=200) - type(mock_requests.get.return_value).status_code = status_code_mock run_state = self.hook.get_run_state(RUN_ID) @@ -195,10 +303,7 @@ def test_get_run_state(self, mock_requests): @mock.patch('airflow.contrib.hooks.databricks_hook.requests') def test_cancel_run(self, mock_requests): - mock_requests.codes.ok = 200 mock_requests.post.return_value.json.return_value = GET_RUN_RESPONSE - status_code_mock = mock.PropertyMock(return_value=200) - type(mock_requests.post.return_value).status_code = status_code_mock self.hook.cancel_run(RUN_ID) @@ -209,6 +314,54 @@ def test_cancel_run(self, mock_requests): headers=USER_AGENT_HEADER, timeout=self.hook.timeout_seconds) + @mock.patch('airflow.contrib.hooks.databricks_hook.requests') + def test_start_cluster(self, mock_requests): + mock_requests.codes.ok = 200 + mock_requests.post.return_value.json.return_value = {} + status_code_mock = mock.PropertyMock(return_value=200) + type(mock_requests.post.return_value).status_code = status_code_mock + + self.hook.start_cluster({"cluster_id": CLUSTER_ID}) + + mock_requests.post.assert_called_once_with( + start_cluster_endpoint(HOST), + json={'cluster_id': CLUSTER_ID}, + auth=(LOGIN, PASSWORD), + headers=USER_AGENT_HEADER, + timeout=self.hook.timeout_seconds) + + @mock.patch('airflow.contrib.hooks.databricks_hook.requests') + def test_restart_cluster(self, mock_requests): + mock_requests.codes.ok = 200 + mock_requests.post.return_value.json.return_value = {} + status_code_mock = mock.PropertyMock(return_value=200) + type(mock_requests.post.return_value).status_code = status_code_mock + + self.hook.restart_cluster({"cluster_id": CLUSTER_ID}) + + mock_requests.post.assert_called_once_with( + restart_cluster_endpoint(HOST), + json={'cluster_id': CLUSTER_ID}, + auth=(LOGIN, PASSWORD), + headers=USER_AGENT_HEADER, + timeout=self.hook.timeout_seconds) + + @mock.patch('airflow.contrib.hooks.databricks_hook.requests') + def test_terminate_cluster(self, mock_requests): + mock_requests.codes.ok = 200 + mock_requests.post.return_value.json.return_value = {} + status_code_mock = mock.PropertyMock(return_value=200) + type(mock_requests.post.return_value).status_code = status_code_mock + + self.hook.terminate_cluster({"cluster_id": CLUSTER_ID}) + + mock_requests.post.assert_called_once_with( + terminate_cluster_endpoint(HOST), + json={'cluster_id': CLUSTER_ID}, + auth=(LOGIN, PASSWORD), + headers=USER_AGENT_HEADER, + timeout=self.hook.timeout_seconds) + class DatabricksHookTokenTest(unittest.TestCase): """ diff --git a/tests/contrib/hooks/test_gcp_dataflow_hook.py b/tests/contrib/hooks/test_gcp_dataflow_hook.py index 7811e4aabd59f..a86db5b12bfe1 100644 --- a/tests/contrib/hooks/test_gcp_dataflow_hook.py +++ b/tests/contrib/hooks/test_gcp_dataflow_hook.py @@ -79,7 +79,7 @@ def setUp(self): new=mock_init): self.dataflow_hook = DataFlowHook(gcp_conn_id='test') - @mock.patch(DATAFLOW_STRING.format('uuid.uuid1')) + @mock.patch(DATAFLOW_STRING.format('uuid.uuid4')) @mock.patch(DATAFLOW_STRING.format('_DataflowJob')) @mock.patch(DATAFLOW_STRING.format('_Dataflow')) @mock.patch(DATAFLOW_STRING.format('DataFlowHook.get_conn')) @@ -103,7 +103,7 @@ def test_start_python_dataflow(self, mock_conn, self.assertListEqual(sorted(mock_dataflow.call_args[0][0]), sorted(EXPECTED_CMD)) - @mock.patch(DATAFLOW_STRING.format('uuid.uuid1')) + @mock.patch(DATAFLOW_STRING.format('uuid.uuid4')) @mock.patch(DATAFLOW_STRING.format('_DataflowJob')) @mock.patch(DATAFLOW_STRING.format('_Dataflow')) @mock.patch(DATAFLOW_STRING.format('DataFlowHook.get_conn')) @@ -127,7 +127,7 @@ def test_start_java_dataflow(self, mock_conn, self.assertListEqual(sorted(mock_dataflow.call_args[0][0]), sorted(EXPECTED_CMD)) - @mock.patch(DATAFLOW_STRING.format('uuid.uuid1')) + @mock.patch(DATAFLOW_STRING.format('uuid.uuid4')) @mock.patch(DATAFLOW_STRING.format('_DataflowJob')) @mock.patch(DATAFLOW_STRING.format('_Dataflow')) @mock.patch(DATAFLOW_STRING.format('DataFlowHook.get_conn')) diff --git a/tests/contrib/minikube/test_kubernetes_pod_operator.py b/tests/contrib/minikube/test_kubernetes_pod_operator.py index 5cb02d1ff1ba6..595d7aa8b43e5 100644 --- a/tests/contrib/minikube/test_kubernetes_pod_operator.py +++ b/tests/contrib/minikube/test_kubernetes_pod_operator.py @@ -20,6 +20,7 @@ import shutil from airflow.contrib.operators.kubernetes_pod_operator import KubernetesPodOperator from airflow import AirflowException +from kubernetes.client.rest import ApiException from subprocess import check_call import mock import json @@ -93,6 +94,34 @@ def test_working_pod(): ) k.execute(None) + @staticmethod + def test_delete_operator_pod(): + k = KubernetesPodOperator( + namespace='default', + image="ubuntu:16.04", + cmds=["bash", "-cx"], + arguments=["echo 10"], + labels={"foo": "bar"}, + name="test", + task_id="task", + is_delete_operator_pod=True + ) + k.execute(None) + + @staticmethod + def test_pod_hostnetwork(): + k = KubernetesPodOperator( + namespace='default', + image="ubuntu:16.04", + cmds=["bash", "-cx"], + arguments=["echo 10"], + labels={"foo": "bar"}, + name="test", + task_id="task", + hostnetwork=True + ) + k.execute(None) + @staticmethod def test_pod_node_selectors(): node_selectors = { @@ -200,10 +229,24 @@ def test_faulty_image(self): task_id="task", startup_timeout_seconds=5 ) - with self.assertRaises(AirflowException) as cm: - k.execute(None), + with self.assertRaises(AirflowException): + k.execute(None) - print("exception: {}".format(cm)) + def test_faulty_service_account(self): + bad_service_account_name = "foobar" + k = KubernetesPodOperator( + namespace='default', + image="ubuntu:16.04", + cmds=["bash", "-cx"], + arguments=["echo 10"], + labels={"foo": "bar"}, + name="test", + task_id="task", + startup_timeout_seconds=5, + service_account_name=bad_service_account_name + ) + with self.assertRaises(ApiException): + k.execute(None) def test_pod_failure(self): """ diff --git a/tests/contrib/operators/test_databricks_operator.py b/tests/contrib/operators/test_databricks_operator.py index f77da2ec18eda..afe1a92f28d9e 100644 --- a/tests/contrib/operators/test_databricks_operator.py +++ b/tests/contrib/operators/test_databricks_operator.py @@ -190,8 +190,9 @@ def test_exec_success(self, db_mock_class): 'run_name': TASK_ID }) db_mock_class.assert_called_once_with( - DEFAULT_CONN_ID, - retry_limit=op.databricks_retry_limit) + DEFAULT_CONN_ID, + retry_limit=op.databricks_retry_limit, + retry_delay=op.databricks_retry_delay) db_mock.submit_run.assert_called_once_with(expected) db_mock.get_run_page_url.assert_called_once_with(RUN_ID) db_mock.get_run_state.assert_called_once_with(RUN_ID) @@ -220,8 +221,9 @@ def test_exec_failure(self, db_mock_class): 'run_name': TASK_ID, }) db_mock_class.assert_called_once_with( - DEFAULT_CONN_ID, - retry_limit=op.databricks_retry_limit) + DEFAULT_CONN_ID, + retry_limit=op.databricks_retry_limit, + retry_delay=op.databricks_retry_delay) db_mock.submit_run.assert_called_once_with(expected) db_mock.get_run_page_url.assert_called_once_with(RUN_ID) db_mock.get_run_state.assert_called_once_with(RUN_ID) diff --git a/tests/contrib/operators/test_dataflow_operator.py b/tests/contrib/operators/test_dataflow_operator.py index 4ea5f65698f03..a373126b24e4b 100644 --- a/tests/contrib/operators/test_dataflow_operator.py +++ b/tests/contrib/operators/test_dataflow_operator.py @@ -20,9 +20,10 @@ import unittest -from airflow.contrib.operators.dataflow_operator import DataFlowPythonOperator, \ - DataFlowJavaOperator, DataflowTemplateOperator -from airflow.contrib.operators.dataflow_operator import DataFlowPythonOperator +from airflow.contrib.operators.dataflow_operator import \ + DataFlowPythonOperator, DataFlowJavaOperator, \ + DataflowTemplateOperator, GoogleCloudBucketHelper + from airflow.version import version try: @@ -186,3 +187,25 @@ def test_exec(self, dataflow_mock): } start_template_hook.assert_called_once_with(TASK_ID, expected_options, PARAMETERS, TEMPLATE) + + +class GoogleCloudBucketHelperTest(unittest.TestCase): + + @mock.patch( + 'airflow.contrib.operators.dataflow_operator.GoogleCloudBucketHelper.__init__' + ) + def test_invalid_object_path(self, mock_parent_init): + + # This is just the path of a bucket hence invalid filename + file_name = 'gs://test-bucket' + mock_parent_init.return_value = None + + gcs_bucket_helper = GoogleCloudBucketHelper() + gcs_bucket_helper._gcs_hook = mock.Mock() + + with self.assertRaises(Exception) as context: + gcs_bucket_helper.google_cloud_to_local(file_name) + + self.assertEquals( + 'Invalid Google Cloud Storage (GCS) object path: {}.'.format(file_name), + str(context.exception)) diff --git a/tests/contrib/operators/test_dataproc_operator.py b/tests/contrib/operators/test_dataproc_operator.py index e5cc770321b52..5b403a86bad81 100644 --- a/tests/contrib/operators/test_dataproc_operator.py +++ b/tests/contrib/operators/test_dataproc_operator.py @@ -61,8 +61,10 @@ IMAGE_VERSION = '1.1' MASTER_MACHINE_TYPE = 'n1-standard-2' MASTER_DISK_SIZE = 100 +MASTER_DISK_TYPE = 'pd-standard' WORKER_MACHINE_TYPE = 'n1-standard-2' WORKER_DISK_SIZE = 100 +WORKER_DISK_TYPE = 'pd-standard' NUM_PREEMPTIBLE_WORKERS = 2 GET_INIT_ACTION_TIMEOUT = "600s" # 10m LABEL1 = {} @@ -125,8 +127,10 @@ def setUp(self): storage_bucket=STORAGE_BUCKET, image_version=IMAGE_VERSION, master_machine_type=MASTER_MACHINE_TYPE, + master_disk_type=MASTER_DISK_TYPE, master_disk_size=MASTER_DISK_SIZE, worker_machine_type=WORKER_MACHINE_TYPE, + worker_disk_type=WORKER_DISK_TYPE, worker_disk_size=WORKER_DISK_SIZE, num_preemptible_workers=NUM_PREEMPTIBLE_WORKERS, labels=deepcopy(labels), @@ -159,8 +163,10 @@ def test_init(self): self.assertEqual(dataproc_operator.image_version, IMAGE_VERSION) self.assertEqual(dataproc_operator.master_machine_type, MASTER_MACHINE_TYPE) self.assertEqual(dataproc_operator.master_disk_size, MASTER_DISK_SIZE) + self.assertEqual(dataproc_operator.master_disk_type, MASTER_DISK_TYPE) self.assertEqual(dataproc_operator.worker_machine_type, WORKER_MACHINE_TYPE) self.assertEqual(dataproc_operator.worker_disk_size, WORKER_DISK_SIZE) + self.assertEqual(dataproc_operator.worker_disk_type, WORKER_DISK_TYPE) self.assertEqual(dataproc_operator.num_preemptible_workers, NUM_PREEMPTIBLE_WORKERS) self.assertEqual(dataproc_operator.labels, self.labels[suffix]) diff --git a/tests/contrib/operators/test_qubole_check_operator.py b/tests/contrib/operators/test_qubole_check_operator.py index 29044827ee6e5..972038005bb79 100644 --- a/tests/contrib/operators/test_qubole_check_operator.py +++ b/tests/contrib/operators/test_qubole_check_operator.py @@ -24,6 +24,7 @@ from airflow.contrib.operators.qubole_check_operator import QuboleValueCheckOperator from airflow.contrib.hooks.qubole_check_hook import QuboleCheckHook from airflow.contrib.hooks.qubole_hook import QuboleHook +from qds_sdk.commands import HiveCommand try: from unittest import mock @@ -80,11 +81,13 @@ def test_execute_pass(self, mock_get_hook): mock_hook.get_first.assert_called_with(query) @mock.patch.object(QuboleValueCheckOperator, 'get_hook') - def test_execute_fail(self, mock_get_hook): + def test_execute_assertion_fail(self, mock_get_hook): mock_cmd = mock.Mock() mock_cmd.status = 'done' mock_cmd.id = 123 + mock_cmd.is_success = mock.Mock( + return_value=HiveCommand.is_success(mock_cmd.status)) mock_hook = mock.Mock() mock_hook.get_first.return_value = [11] @@ -97,6 +100,30 @@ def test_execute_fail(self, mock_get_hook): 'Qubole Command Id: ' + str(mock_cmd.id)): operator.execute() + mock_cmd.is_success.assert_called_with(mock_cmd.status) + + @mock.patch.object(QuboleValueCheckOperator, 'get_hook') + def test_execute_assert_query_fail(self, mock_get_hook): + + mock_cmd = mock.Mock() + mock_cmd.status = 'error' + mock_cmd.id = 123 + mock_cmd.is_success = mock.Mock( + return_value=HiveCommand.is_success(mock_cmd.status)) + + mock_hook = mock.Mock() + mock_hook.get_first.return_value = [11] + mock_hook.cmd = mock_cmd + mock_get_hook.return_value = mock_hook + + operator = self.__construct_operator('select value from tab1 limit 1;', 5, 1) + + with self.assertRaises(AirflowException) as cm: + operator.execute() + + self.assertNotIn('Qubole Command Id: ', str(cm.exception)) + mock_cmd.is_success.assert_called_with(mock_cmd.status) + @mock.patch.object(QuboleCheckHook, 'get_query_results') @mock.patch.object(QuboleHook, 'execute') def test_results_parser_callable(self, mock_execute, mock_get_query_results): diff --git a/tests/contrib/operators/test_s3_to_gcs_operator.py b/tests/contrib/operators/test_s3_to_gcs_operator.py index 807882c324936..97d6eae916d28 100644 --- a/tests/contrib/operators/test_s3_to_gcs_operator.py +++ b/tests/contrib/operators/test_s3_to_gcs_operator.py @@ -88,8 +88,8 @@ def _assert_upload(bucket, object, tmp_filename): uploaded_files = operator.execute(None) - s3_one_mock_hook.assert_called_once_with(aws_conn_id=AWS_CONN_ID) - s3_two_mock_hook.assert_called_once_with(aws_conn_id=AWS_CONN_ID) + s3_one_mock_hook.assert_called_once_with(aws_conn_id=AWS_CONN_ID, verify=None) + s3_two_mock_hook.assert_called_once_with(aws_conn_id=AWS_CONN_ID, verify=None) gcs_mock_hook.assert_called_once_with( google_cloud_storage_conn_id=GCS_CONN_ID, delegate_to=None) diff --git a/tests/contrib/operators/test_sftp_operator.py b/tests/contrib/operators/test_sftp_operator.py index 01446a6fddd49..5770c1b940eb5 100644 --- a/tests/contrib/operators/test_sftp_operator.py +++ b/tests/contrib/operators/test_sftp_operator.py @@ -20,6 +20,7 @@ import os import unittest from base64 import b64encode +import six from airflow import configuration from airflow import models @@ -219,6 +220,71 @@ def test_json_file_transfer_get(self): self.assertEqual(content_received.strip(), test_remote_file_content.encode('utf-8').decode('utf-8')) + def test_arg_checking(self): + from airflow.exceptions import AirflowException + conn_id = "conn_id_for_testing" + os.environ['AIRFLOW_CONN_' + conn_id.upper()] = "ssh://test_id@localhost" + + # Exception should be raised if neither ssh_hook nor ssh_conn_id is provided + if six.PY2: + self.assertRaisesRegex = self.assertRaisesRegexp + with self.assertRaisesRegex(AirflowException, + "Cannot operate without ssh_hook or ssh_conn_id."): + task_0 = SFTPOperator( + task_id="test_sftp", + local_filepath=self.test_local_filepath, + remote_filepath=self.test_remote_filepath, + operation=SFTPOperation.PUT, + dag=self.dag + ) + task_0.execute(None) + + # if ssh_hook is invalid/not provided, use ssh_conn_id to create SSHHook + task_1 = SFTPOperator( + task_id="test_sftp", + ssh_hook="string_rather_than_SSHHook", # invalid ssh_hook + ssh_conn_id=conn_id, + local_filepath=self.test_local_filepath, + remote_filepath=self.test_remote_filepath, + operation=SFTPOperation.PUT, + dag=self.dag + ) + try: + task_1.execute(None) + except Exception: + pass + self.assertEqual(task_1.ssh_hook.ssh_conn_id, conn_id) + + task_2 = SFTPOperator( + task_id="test_sftp", + ssh_conn_id=conn_id, # no ssh_hook provided + local_filepath=self.test_local_filepath, + remote_filepath=self.test_remote_filepath, + operation=SFTPOperation.PUT, + dag=self.dag + ) + try: + task_2.execute(None) + except Exception: + pass + self.assertEqual(task_2.ssh_hook.ssh_conn_id, conn_id) + + # if both valid ssh_hook and ssh_conn_id are provided, ignore ssh_conn_id + task_3 = SFTPOperator( + task_id="test_sftp", + ssh_hook=self.hook, + ssh_conn_id=conn_id, + local_filepath=self.test_local_filepath, + remote_filepath=self.test_remote_filepath, + operation=SFTPOperation.PUT, + dag=self.dag + ) + try: + task_3.execute(None) + except Exception: + pass + self.assertEqual(task_3.ssh_hook.ssh_conn_id, self.hook.ssh_conn_id) + def delete_local_resource(self): if os.path.exists(self.test_local_filepath): os.remove(self.test_local_filepath) @@ -226,11 +292,11 @@ def delete_local_resource(self): def delete_remote_resource(self): # check the remote file content remove_file_task = SSHOperator( - task_id="test_check_file", - ssh_hook=self.hook, - command="rm {0}".format(self.test_remote_filepath), - do_xcom_push=True, - dag=self.dag + task_id="test_check_file", + ssh_hook=self.hook, + command="rm {0}".format(self.test_remote_filepath), + do_xcom_push=True, + dag=self.dag ) self.assertIsNotNone(remove_file_task) ti3 = TaskInstance(task=remove_file_task, execution_date=timezone.utcnow()) diff --git a/tests/contrib/operators/test_ssh_operator.py b/tests/contrib/operators/test_ssh_operator.py index 7ddd24b2ac2ca..1a2c788596671 100644 --- a/tests/contrib/operators/test_ssh_operator.py +++ b/tests/contrib/operators/test_ssh_operator.py @@ -19,6 +19,7 @@ import unittest from base64 import b64encode +import six from airflow import configuration from airflow import models @@ -148,6 +149,65 @@ def test_no_output_command(self): self.assertIsNotNone(ti.duration) self.assertEqual(ti.xcom_pull(task_ids='test', key='return_value'), b'') + def test_arg_checking(self): + import os + from airflow.exceptions import AirflowException + conn_id = "conn_id_for_testing" + TIMEOUT = 5 + os.environ['AIRFLOW_CONN_' + conn_id.upper()] = "ssh://test_id@localhost" + + # Exception should be raised if neither ssh_hook nor ssh_conn_id is provided + if six.PY2: + self.assertRaisesRegex = self.assertRaisesRegexp + with self.assertRaisesRegex(AirflowException, + "Cannot operate without ssh_hook or ssh_conn_id."): + task_0 = SSHOperator(task_id="test", command="echo -n airflow", + timeout=TIMEOUT, dag=self.dag) + task_0.execute(None) + + # if ssh_hook is invalid/not provided, use ssh_conn_id to create SSHHook + task_1 = SSHOperator( + task_id="test_1", + ssh_hook="string_rather_than_SSHHook", # invalid ssh_hook + ssh_conn_id=conn_id, + command="echo -n airflow", + timeout=TIMEOUT, + dag=self.dag + ) + try: + task_1.execute(None) + except Exception: + pass + self.assertEqual(task_1.ssh_hook.ssh_conn_id, conn_id) + + task_2 = SSHOperator( + task_id="test_2", + ssh_conn_id=conn_id, # no ssh_hook provided + command="echo -n airflow", + timeout=TIMEOUT, + dag=self.dag + ) + try: + task_2.execute(None) + except Exception: + pass + self.assertEqual(task_2.ssh_hook.ssh_conn_id, conn_id) + + # if both valid ssh_hook and ssh_conn_id are provided, ignore ssh_conn_id + task_3 = SSHOperator( + task_id="test_3", + ssh_hook=self.hook, + ssh_conn_id=conn_id, + command="echo -n airflow", + timeout=TIMEOUT, + dag=self.dag + ) + try: + task_3.execute(None) + except Exception: + pass + self.assertEqual(task_3.ssh_hook.ssh_conn_id, self.hook.ssh_conn_id) + if __name__ == '__main__': unittest.main() diff --git a/tests/core.py b/tests/core.py index 8df6312eeb2e6..f8b86919128cc 100644 --- a/tests/core.py +++ b/tests/core.py @@ -626,6 +626,35 @@ def __bool__(self): dag=self.dag) t.resolve_template_files() + def test_task_get_template(self): + TI = models.TaskInstance + ti = TI( + task=self.runme_0, execution_date=DEFAULT_DATE) + ti.dag = self.dag_bash + ti.run(ignore_ti_state=True) + context = ti.get_template_context() + + # DEFAULT DATE is 2015-01-01 + self.assertEquals(context['ds'], '2015-01-01') + self.assertEquals(context['ds_nodash'], '20150101') + + # next_ds is 2015-01-02 as the dag interval is daily + self.assertEquals(context['next_ds'], '2015-01-02') + self.assertEquals(context['next_ds_nodash'], '20150102') + + # prev_ds is 2014-12-31 as the dag interval is daily + self.assertEquals(context['prev_ds'], '2014-12-31') + self.assertEquals(context['prev_ds_nodash'], '20141231') + + self.assertEquals(context['ts'], '2015-01-01T00:00:00+00:00') + self.assertEquals(context['ts_nodash'], '20150101T000000+0000') + + self.assertEquals(context['yesterday_ds'], '2014-12-31') + self.assertEquals(context['yesterday_ds_nodash'], '20141231') + + self.assertEquals(context['tomorrow_ds'], '2015-01-02') + self.assertEquals(context['tomorrow_ds_nodash'], '20150102') + def test_import_examples(self): self.assertEqual(len(self.dagbag.dags), NUM_EXAMPLE_DAGS) diff --git a/tests/executors/dask_executor.py b/tests/executors/dask_executor.py index 9bf051f5805d2..4f0009e1ccb21 100644 --- a/tests/executors/dask_executor.py +++ b/tests/executors/dask_executor.py @@ -55,8 +55,8 @@ def assert_tasks_on_executor(self, executor): # start the executor executor.start() - success_command = ['true', ] - fail_command = ['false', ] + success_command = ['true', 'some_parameter'] + fail_command = ['false', 'some_parameter'] executor.execute_async(key='success', command=success_command) executor.execute_async(key='fail', command=fail_command) diff --git a/tests/executors/test_celery_executor.py b/tests/executors/test_celery_executor.py index 69f9fbfe9fea8..95ad58f6a2d96 100644 --- a/tests/executors/test_celery_executor.py +++ b/tests/executors/test_celery_executor.py @@ -34,8 +34,8 @@ def test_celery_integration(self): executor.start() with start_worker(app=app, logfile=sys.stdout, loglevel='debug'): - success_command = ['true', ] - fail_command = ['false', ] + success_command = ['true', 'some_parameter'] + fail_command = ['false', 'some_parameter'] executor.execute_async(key='success', command=success_command) # errors are propagated for some reason diff --git a/tests/executors/test_local_executor.py b/tests/executors/test_local_executor.py index 846e1325618ac..59cb09c74e6b0 100644 --- a/tests/executors/test_local_executor.py +++ b/tests/executors/test_local_executor.py @@ -33,8 +33,8 @@ def execution_parallelism(self, parallelism=0): executor.start() success_key = 'success {}' - success_command = ['true', ] - fail_command = ['false', ] + success_command = ['true', 'some_parameter'] + fail_command = ['false', 'some_parameter'] for i in range(self.TEST_SUCCESS_COMMANDS): key, command = success_key.format(i), success_command diff --git a/tests/jobs.py b/tests/jobs.py index fd3a96a4d87ad..f9c07b96c9f51 100644 --- a/tests/jobs.py +++ b/tests/jobs.py @@ -61,6 +61,7 @@ from airflow import configuration configuration.load_test_config() +logger = logging.getLogger(__name__) try: from unittest import mock @@ -194,23 +195,21 @@ def test_backfill_multi_dates(self): def test_backfill_examples(self): """ Test backfilling example dags - """ - # some DAGs really are just examples... but try to make them work! - skip_dags = [ - 'example_http_operator', - 'example_twitter_dag', - 'example_trigger_target_dag', - 'example_trigger_controller_dag', # tested above - 'test_utils', # sleeps forever - 'example_kubernetes_executor', # requires kubernetes cluster - 'example_kubernetes_operator' # requires kubernetes cluster - ] + Try to backfill some of the example dags. Be carefull, not all dags are suitable + for doing this. For example, a dag that sleeps forever, or does not have a + schedule won't work here since you simply can't backfill them. + """ + include_dags = { + 'example_branch_operator', + 'example_bash_operator', + 'example_skip_dag', + 'latest_only' + } - logger = logging.getLogger('BackfillJobTest.test_backfill_examples') dags = [ dag for dag in self.dagbag.dags.values() - if 'example_dags' in dag.full_filepath and dag.dag_id not in skip_dags + if 'example_dags' in dag.full_filepath and dag.dag_id in include_dags ] for dag in dags: @@ -218,6 +217,11 @@ def test_backfill_examples(self): start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + # Make sure that we have the dags that we want to test available + # in the example_dags folder, if this assertion fails, one of the + # dags in the include_dags array isn't available anymore + self.assertEqual(len(include_dags), len(dags)) + for i, dag in enumerate(sorted(dags, key=lambda d: d.dag_id)): logger.info('*** Running example DAG #{}: {}'.format(i, dag.dag_id)) job = BackfillJob( diff --git a/tests/models.py b/tests/models.py index a1fd1e991221a..a9c43dfd8e42a 100644 --- a/tests/models.py +++ b/tests/models.py @@ -801,7 +801,26 @@ def test_dagrun_deadlock(self): dr.update_state() self.assertEqual(dr.state, State.FAILED) - def test_dagrun_no_deadlock(self): + def test_dagrun_no_deadlock_with_shutdown(self): + session = settings.Session() + dag = DAG('test_dagrun_no_deadlock_with_shutdown', + start_date=DEFAULT_DATE) + with dag: + op1 = DummyOperator(task_id='upstream_task') + op2 = DummyOperator(task_id='downstream_task') + op2.set_upstream(op1) + + dr = dag.create_dagrun(run_id='test_dagrun_no_deadlock_with_shutdown', + state=State.RUNNING, + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE) + upstream_ti = dr.get_task_instance(task_id='upstream_task') + upstream_ti.set_state(State.SHUTDOWN, session=session) + + dr.update_state() + self.assertEqual(dr.state, State.RUNNING) + + def test_dagrun_no_deadlock_with_depends_on_past(self): session = settings.Session() dag = DAG('test_dagrun_no_deadlock', start_date=DEFAULT_DATE) @@ -896,6 +915,124 @@ def on_failure_callable(context): updated_dag_state = dag_run.update_state() self.assertEqual(State.FAILED, updated_dag_state) + def test_dagrun_set_state_end_date(self): + session = settings.Session() + + dag = DAG( + 'test_dagrun_set_state_end_date', + start_date=DEFAULT_DATE, + default_args={'owner': 'owner1'}) + + dag.clear() + + now = timezone.utcnow() + dr = dag.create_dagrun(run_id='test_dagrun_set_state_end_date', + state=State.RUNNING, + execution_date=now, + start_date=now) + + # Initial end_date should be NULL + # State.SUCCESS and State.FAILED are all ending state and should set end_date + # State.RUNNING set end_date back to NULL + session.add(dr) + session.commit() + self.assertIsNone(dr.end_date) + + dr.set_state(State.SUCCESS) + session.merge(dr) + session.commit() + + dr_database = session.query(DagRun).filter( + DagRun.run_id == 'test_dagrun_set_state_end_date' + ).one() + self.assertIsNotNone(dr_database.end_date) + self.assertEqual(dr.end_date, dr_database.end_date) + + dr.set_state(State.RUNNING) + session.merge(dr) + session.commit() + + dr_database = session.query(DagRun).filter( + DagRun.run_id == 'test_dagrun_set_state_end_date' + ).one() + + self.assertIsNone(dr_database.end_date) + + dr.set_state(State.FAILED) + session.merge(dr) + session.commit() + dr_database = session.query(DagRun).filter( + DagRun.run_id == 'test_dagrun_set_state_end_date' + ).one() + + self.assertIsNotNone(dr_database.end_date) + self.assertEqual(dr.end_date, dr_database.end_date) + + def test_dagrun_update_state_end_date(self): + session = settings.Session() + + dag = DAG( + 'test_dagrun_update_state_end_date', + start_date=DEFAULT_DATE, + default_args={'owner': 'owner1'}) + + # A -> B + with dag: + op1 = DummyOperator(task_id='A') + op2 = DummyOperator(task_id='B') + op1.set_upstream(op2) + + dag.clear() + + now = timezone.utcnow() + dr = dag.create_dagrun(run_id='test_dagrun_update_state_end_date', + state=State.RUNNING, + execution_date=now, + start_date=now) + + # Initial end_date should be NULL + # State.SUCCESS and State.FAILED are all ending state and should set end_date + # State.RUNNING set end_date back to NULL + session.merge(dr) + session.commit() + self.assertIsNone(dr.end_date) + + ti_op1 = dr.get_task_instance(task_id=op1.task_id) + ti_op1.set_state(state=State.SUCCESS, session=session) + ti_op2 = dr.get_task_instance(task_id=op2.task_id) + ti_op2.set_state(state=State.SUCCESS, session=session) + + dr.update_state() + + dr_database = session.query(DagRun).filter( + DagRun.run_id == 'test_dagrun_update_state_end_date' + ).one() + self.assertIsNotNone(dr_database.end_date) + self.assertEqual(dr.end_date, dr_database.end_date) + + ti_op1.set_state(state=State.RUNNING, session=session) + ti_op2.set_state(state=State.RUNNING, session=session) + dr.update_state() + + dr_database = session.query(DagRun).filter( + DagRun.run_id == 'test_dagrun_update_state_end_date' + ).one() + + self.assertEqual(dr._state, State.RUNNING) + self.assertIsNone(dr.end_date) + self.assertIsNone(dr_database.end_date) + + ti_op1.set_state(state=State.FAILED, session=session) + ti_op2.set_state(state=State.FAILED, session=session) + dr.update_state() + + dr_database = session.query(DagRun).filter( + DagRun.run_id == 'test_dagrun_update_state_end_date' + ).one() + + self.assertIsNotNone(dr_database.end_date) + self.assertEqual(dr.end_date, dr_database.end_date) + def test_get_task_instance_on_empty_dagrun(self): """ Make sure that a proper value is returned when a dagrun has no task instances @@ -1521,6 +1658,16 @@ def test_timezone_awareness(self): ti = TI(task=op1, execution_date=execution_date) self.assertEquals(ti.execution_date, utc_date) + def test_task_naive_datetime(self): + NAIVE_DATETIME = DEFAULT_DATE.replace(tzinfo=None) + + op_no_dag = DummyOperator(task_id='test_task_naive_datetime', + start_date=NAIVE_DATETIME, + end_date=NAIVE_DATETIME) + + self.assertTrue(op_no_dag.start_date.tzinfo) + self.assertTrue(op_no_dag.end_date.tzinfo) + def test_set_dag(self): """ Test assigning Operators to Dags, including deferred assignment diff --git a/tests/www/test_utils.py b/tests/www/test_utils.py index 9034b8b5fd28e..a06d6b066a113 100644 --- a/tests/www/test_utils.py +++ b/tests/www/test_utils.py @@ -195,6 +195,40 @@ def some_func(): self.assertEqual(anonymous_username, kwargs['owner']) mocked_session_instance.add.assert_called_once() + def test_open_maybe_zipped_normal_file(self): + with mock.patch( + 'io.open', mock.mock_open(read_data="data")) as mock_file: + utils.open_maybe_zipped('/path/to/some/file.txt') + mock_file.assert_called_with('/path/to/some/file.txt', mode='r') + + def test_open_maybe_zipped_normal_file_with_zip_in_name(self): + path = '/path/to/fakearchive.zip.other/file.txt' + with mock.patch( + 'io.open', mock.mock_open(read_data="data")) as mock_file: + utils.open_maybe_zipped(path) + mock_file.assert_called_with(path, mode='r') + + @mock.patch("zipfile.is_zipfile") + @mock.patch("zipfile.ZipFile") + def test_open_maybe_zipped_archive(self, mocked_ZipFile, mocked_is_zipfile): + mocked_is_zipfile.return_value = True + instance = mocked_ZipFile.return_value + instance.open.return_value = mock.mock_open(read_data="data") + + utils.open_maybe_zipped('/path/to/archive.zip/deep/path/to/file.txt') + + mocked_is_zipfile.assert_called_once() + (args, kwargs) = mocked_is_zipfile.call_args_list[0] + self.assertEqual('/path/to/archive.zip', args[0]) + + mocked_ZipFile.assert_called_once() + (args, kwargs) = mocked_ZipFile.call_args_list[0] + self.assertEqual('/path/to/archive.zip', args[0]) + + instance.open.assert_called_once() + (args, kwargs) = instance.open.call_args_list[0] + self.assertEqual('deep/path/to/file.txt', args[0]) + def test_get_python_source_from_method(self): class AMockClass(object): def a_method(self): diff --git a/tests/www_rbac/test_utils.py b/tests/www_rbac/test_utils.py index 1879ba082693d..68d1744ab8fa8 100644 --- a/tests/www_rbac/test_utils.py +++ b/tests/www_rbac/test_utils.py @@ -18,6 +18,7 @@ # under the License. import unittest +import mock from xml.dom import minidom from airflow.www_rbac import utils @@ -113,6 +114,40 @@ def test_params_all(self): self.assertEqual('page=3&search=bash_&showPaused=False', utils.get_params(showPaused=False, page=3, search='bash_')) + def test_open_maybe_zipped_normal_file(self): + with mock.patch( + 'io.open', mock.mock_open(read_data="data")) as mock_file: + utils.open_maybe_zipped('/path/to/some/file.txt') + mock_file.assert_called_with('/path/to/some/file.txt', mode='r') + + def test_open_maybe_zipped_normal_file_with_zip_in_name(self): + path = '/path/to/fakearchive.zip.other/file.txt' + with mock.patch( + 'io.open', mock.mock_open(read_data="data")) as mock_file: + utils.open_maybe_zipped(path) + mock_file.assert_called_with(path, mode='r') + + @mock.patch("zipfile.is_zipfile") + @mock.patch("zipfile.ZipFile") + def test_open_maybe_zipped_archive(self, mocked_ZipFile, mocked_is_zipfile): + mocked_is_zipfile.return_value = True + instance = mocked_ZipFile.return_value + instance.open.return_value = mock.mock_open(read_data="data") + + utils.open_maybe_zipped('/path/to/archive.zip/deep/path/to/file.txt') + + mocked_is_zipfile.assert_called_once() + (args, kwargs) = mocked_is_zipfile.call_args_list[0] + self.assertEqual('/path/to/archive.zip', args[0]) + + mocked_ZipFile.assert_called_once() + (args, kwargs) = mocked_ZipFile.call_args_list[0] + self.assertEqual('/path/to/archive.zip', args[0]) + + instance.open.assert_called_once() + (args, kwargs) = instance.open.call_args_list[0] + self.assertEqual('deep/path/to/file.txt', args[0]) + if __name__ == '__main__': unittest.main() diff --git a/tox.ini b/tox.ini index 73e3170ec8e69..c4b74a1e55345 100644 --- a/tox.ini +++ b/tox.ini @@ -33,7 +33,7 @@ ignore = E731,W503 [testenv] deps = wheel - coveralls + codecov basepython = py27: python2.7 @@ -52,18 +52,7 @@ setenv = backend_sqlite: AIRFLOW__CORE__SQL_ALCHEMY_CONN=sqlite:///{homedir}/airflow.db backend_sqlite: AIRFLOW__CORE__EXECUTOR=SequentialExecutor -passenv = - HOME - JAVA_HOME - USER - PATH - BOTO_CONFIG - TRAVIS - TRAVIS_BRANCH - TRAVIS_BUILD_DIR - TRAVIS_JOB_ID - TRAVIS_PULL_REQUEST - SLUGIFY_USES_TEXT_UNIDECODE +passenv = * commands = pip wheel --progress-bar off -w {homedir}/.wheelhouse -f {homedir}/.wheelhouse -e .[devel_ci] @@ -74,6 +63,7 @@ commands = {toxinidir}/scripts/ci/4-load-data.sh {toxinidir}/scripts/ci/5-run-tests.sh [] {toxinidir}/scripts/ci/6-check-license.sh + codecov -e TOXENV [testenv:flake8] basepython = python3