diff --git a/tests/hooks/test_hive_hook.py b/tests/hooks/test_hive_hook.py index 7d5d103d9b4ef..81270c635b042 100644 --- a/tests/hooks/test_hive_hook.py +++ b/tests/hooks/test_hive_hook.py @@ -375,8 +375,8 @@ def test_get_conn(self): def test_get_records(self): hook = HiveServer2Hook() query = "SELECT * FROM {}".format(self.table) - results = hook.get_pandas_df(query, schema=self.database) - self.assertEqual(len(results), 2) + results = hook.get_records(query, schema=self.database) + self.assertListEqual(results, [(1, 1), (2, 2)]) def test_get_pandas_df(self): hook = HiveServer2Hook() @@ -409,3 +409,13 @@ def test_to_csv(self): self.assertListEqual(df.columns.tolist(), self.columns) self.assertListEqual(df[self.columns[0]].values.tolist(), [1, 2]) self.assertEqual(len(df), 2) + + def test_multi_statements(self): + sqls = [ + "CREATE TABLE IF NOT EXISTS test_multi_statements (i INT)", + "SELECT * FROM {}".format(self.table), + "DROP TABLE test_multi_statements", + ] + hook = HiveServer2Hook() + results = hook.get_records(sqls, schema=self.database) + self.assertListEqual(results, [(1, 1), (2, 2)]) diff --git a/tests/operators/hive_operator.py b/tests/operators/hive_operator.py index 0914ba9539b60..1569e7f7e9538 100644 --- a/tests/operators/hive_operator.py +++ b/tests/operators/hive_operator.py @@ -114,119 +114,6 @@ def test_hiveconf(self): import airflow.hooks.hive_hooks import airflow.operators.presto_to_mysql - class HiveServer2Test(unittest.TestCase): - def setUp(self): - configuration.load_test_config() - self.nondefault_schema = "nondefault" - - def test_select_conn(self): - from airflow.hooks.hive_hooks import HiveServer2Hook - sql = "select 1" - hook = HiveServer2Hook() - hook.get_records(sql) - - def test_multi_statements(self): - from airflow.hooks.hive_hooks import HiveServer2Hook - sqls = [ - "CREATE TABLE IF NOT EXISTS test_multi_statements (i INT)", - "DROP TABLE test_multi_statements", - ] - hook = HiveServer2Hook() - hook.get_records(sqls) - - def test_get_metastore_databases(self): - if six.PY2: - from airflow.hooks.hive_hooks import HiveMetastoreHook - hook = HiveMetastoreHook() - hook.get_databases() - - def test_to_csv(self): - from airflow.hooks.hive_hooks import HiveServer2Hook - sql = "select 1" - hook = HiveServer2Hook() - hook.to_csv(hql=sql, csv_filepath="/tmp/test_to_csv") - - def connect_mock(self, host, port, - auth_mechanism, kerberos_service_name, - user, database): - self.assertEqual(database, self.nondefault_schema) - - @mock.patch('HiveServer2Hook.connect', return_value="foo") - def test_select_conn_with_schema(self, connect_mock): - from airflow.hooks.hive_hooks import HiveServer2Hook - - # Configure - hook = HiveServer2Hook() - - # Run - hook.get_conn(self.nondefault_schema) - - # Verify - self.assertTrue(connect_mock.called) - (args, kwargs) = connect_mock.call_args_list[0] - self.assertEqual(self.nondefault_schema, kwargs['database']) - - def test_get_results_with_schema(self): - from airflow.hooks.hive_hooks import HiveServer2Hook - from unittest.mock import MagicMock - - # Configure - sql = "select 1" - schema = "notdefault" - hook = HiveServer2Hook() - cursor_mock = MagicMock( - __enter__=cursor_mock, - __exit__=None, - execute=None, - fetchall=[], - ) - get_conn_mock = MagicMock( - __enter__=get_conn_mock, - __exit__=None, - cursor=cursor_mock, - ) - hook.get_conn = get_conn_mock - - # Run - hook.get_results(sql, schema) - - # Verify - get_conn_mock.assert_called_with(self.nondefault_schema) - - @mock.patch('HiveServer2Hook.get_results', return_value={'data': []}) - def test_get_records_with_schema(self, get_results_mock): - from airflow.hooks.hive_hooks import HiveServer2Hook - - # Configure - sql = "select 1" - hook = HiveServer2Hook() - - # Run - hook.get_records(sql, self.nondefault_schema) - - # Verify - self.assertTrue(self.connect_mock.called) - (args, kwargs) = self.connect_mock.call_args_list[0] - self.assertEqual(sql, args[0]) - self.assertEqual(self.nondefault_schema, kwargs['schema']) - - @mock.patch('HiveServer2Hook.get_results', return_value={'data': []}) - def test_get_pandas_df_with_schema(self, get_results_mock): - from airflow.hooks.hive_hooks import HiveServer2Hook - - # Configure - sql = "select 1" - hook = HiveServer2Hook() - - # Run - hook.get_pandas_df(sql, self.nondefault_schema) - - # Verify - self.assertTrue(self.connect_mock.called) - (args, kwargs) = self.connect_mock.call_args_list[0] - self.assertEqual(sql, args[0]) - self.assertEqual(self.nondefault_schema, kwargs['schema']) - class HivePrestoTest(HiveEnvironmentTest): def test_hive(self):