Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions airflow-core/tests/unit/api_fastapi/auth/test_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing import TYPE_CHECKING
from unittest.mock import patch

import anyio
import httpx
import jwt
import pytest
Expand Down Expand Up @@ -213,10 +214,10 @@ async def test_jwt_generate_validate_roundtrip_with_jwks(private_key, algorithm,
jwk_content = json.dumps({"keys": [key_to_jwk_dict(private_key, "custom-kid")]})

jwks = tmp_path.joinpath("jwks.json")
jwks.write_text(jwk_content)
await anyio.Path(jwks).write_text(jwk_content)

priv_key = tmp_path.joinpath("key.pem")
priv_key.write_bytes(key_to_pem(private_key))
await anyio.Path(priv_key).write_bytes(key_to_pem(private_key))

with conf_vars(
{
Expand Down
37 changes: 21 additions & 16 deletions airflow-core/tests/unit/cli/commands/test_task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,27 +241,32 @@ def test_cli_test_with_env_vars(self):
assert "foo=bar" in output
assert "AIRFLOW_TEST_MODE=True" in output

@mock.patch("airflow.providers.standard.triggers.file.os.path.getmtime", return_value=0)
@mock.patch(
"airflow.providers.standard.triggers.file.glob", return_value=["/tmp/temporary_file_for_testing"]
)
@mock.patch("airflow.providers.standard.triggers.file.os")
@mock.patch("airflow.providers.standard.sensors.filesystem.FileSensor.poke", return_value=False)
def test_cli_test_with_deferrable_operator(self, mock_pock, mock_os, mock_glob, mock_getmtime, caplog):
mock_os.path.isfile.return_value = True
with caplog.at_level(level=logging.INFO):
task_command.task_test(
self.parser.parse_args(
[
"tasks",
"test",
"example_sensors",
"wait_for_file_async",
DEFAULT_DATE.isoformat(),
]
def test_cli_test_with_deferrable_operator(self, mock_poke, mock_glob, caplog):
mock_stat = mock.MagicMock()
mock_stat.st_mtime = 0
mock_path_instance = mock.MagicMock()
mock_path_instance.is_file = mock.AsyncMock(return_value=True)
mock_path_instance.stat = mock.AsyncMock(return_value=mock_stat)
mock_anyio_path = mock.MagicMock(return_value=mock_path_instance)

with mock.patch("airflow.providers.standard.triggers.file.anyio.Path", mock_anyio_path):
with caplog.at_level(level=logging.INFO):
task_command.task_test(
self.parser.parse_args(
[
"tasks",
"test",
"example_sensors",
"wait_for_file_async",
DEFAULT_DATE.isoformat(),
]
)
)
)
output = caplog.text
output = caplog.text
assert "Found File /tmp/temporary_file_for_testing" in output

def test_task_render(self):
Expand Down
12 changes: 9 additions & 3 deletions devel-common/src/sphinx_exts/pagefind_search/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from pathlib import Path
from typing import TYPE_CHECKING

import anyio
from pagefind.index import IndexConfig, PagefindIndex
from sphinx.util.fileutil import copy_asset

Expand Down Expand Up @@ -118,8 +119,13 @@ async def build_pagefind_index(app: Sphinx) -> dict[str, int]:
skipped = 0

async with PagefindIndex(config=config) as index:
for html_file in output_dir.glob(app.config.pagefind_glob):
if not html_file.is_file():

def _glob_html_files():
return list(output_dir.glob(app.config.pagefind_glob))

html_files = await anyio.to_thread.run_sync(_glob_html_files)
for html_file in html_files:
if not await anyio.Path(html_file).is_file():
continue

relative_path = html_file.relative_to(output_dir)
Expand All @@ -131,7 +137,7 @@ async def build_pagefind_index(app: Sphinx) -> dict[str, int]:
continue

try:
content = html_file.read_text(encoding="utf-8")
content = await anyio.Path(html_file).read_text(encoding="utf-8")
await index.add_html_file(
content=content,
source_path=str(html_file),
Expand Down
4 changes: 3 additions & 1 deletion providers/edge3/src/airflow/providers/edge3/cli/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from pathlib import Path
from typing import TYPE_CHECKING

import anyio
from aiofiles import open as aio_open
from aiohttp import ClientResponseError
from lockfile.pidlockfile import remove_existing_pidfile
Expand Down Expand Up @@ -238,7 +239,8 @@ def _launch_job(self, workload: ExecuteTask) -> tuple[Process, Queue[Exception]]
return process, results_queue

async def _push_logs_in_chunks(self, job: Job):
if push_logs and job.logfile.exists() and job.logfile.stat().st_size > job.logsize:
aio_logfile = anyio.Path(job.logfile)
if push_logs and await aio_logfile.exists() and (await aio_logfile.stat()).st_size > job.logsize:
async with aio_open(job.logfile, mode="rb") as logf:
await logf.seek(job.logsize, os.SEEK_SET)
read_data = await logf.read()
Expand Down
10 changes: 6 additions & 4 deletions providers/edge3/tests/unit/edge3/cli/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from unittest import mock
from unittest.mock import call, patch

import anyio
import pytest
import time_machine
from aiohttp import ClientResponseError, RequestInfo
Expand Down Expand Up @@ -307,7 +308,7 @@ async def test_fetch_and_run_job_one_job_fail(
@pytest.mark.asyncio
async def test_push_logs_in_chunks(self, mock_logs_push, worker_with_job: EdgeWorker):
job = EdgeWorker.jobs[0]
job.logfile.write_text("some log content")
await anyio.Path(job.logfile).write_text("some log content")
with conf_vars({("edge", "api_url"): "https://invalid-api-test-endpoint"}):
await worker_with_job._push_logs_in_chunks(job)

Expand All @@ -321,9 +322,10 @@ async def test_push_logs_in_chunks(self, mock_logs_push, worker_with_job: EdgeWo
@pytest.mark.asyncio
async def test_check_running_jobs_log_push_increment(self, mock_logs_push, worker_with_job: EdgeWorker):
job = EdgeWorker.jobs[0]
job.logfile.write_text("hello ")
job.logsize = job.logfile.stat().st_size
job.logfile.write_text("hello world")
aio_logfile = anyio.Path(job.logfile)
await aio_logfile.write_text("hello ")
job.logsize = (await aio_logfile.stat()).st_size
await aio_logfile.write_text("hello world")
with conf_vars({("edge", "api_url"): "https://invalid-api-test-endpoint"}):
await worker_with_job._push_logs_in_chunks(job)
assert len(EdgeWorker.jobs) == 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from glob import glob
from typing import Any

import anyio

from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS

if AIRFLOW_V_3_0_PLUS:
Expand Down Expand Up @@ -73,13 +75,13 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
"""Loop until the relevant files are found."""
while True:
for path in glob(self.filepath, recursive=self.recursive):
if os.path.isfile(path):
mod_time_f = os.path.getmtime(path)
if await anyio.Path(path).is_file():
mod_time_f = (await anyio.Path(path).stat()).st_mtime
mod_time = datetime.datetime.fromtimestamp(mod_time_f).strftime("%Y%m%d%H%M%S")
self.log.info("Found File %s last modified: %s", path, mod_time)
yield TriggerEvent(True)
return
for _, _, files in os.walk(path):
for _, _, files in await anyio.to_thread.run_sync(lambda: list(os.walk(path))):
if files:
yield TriggerEvent(True)
return
Expand Down Expand Up @@ -120,11 +122,12 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
async def run(self) -> AsyncIterator[TriggerEvent]:
"""Loop until the relevant file is found."""
while True:
if os.path.isfile(self.filepath):
mod_time_f = os.path.getmtime(self.filepath)
filepath = anyio.Path(self.filepath)
if await filepath.is_file():
mod_time_f = (await filepath.stat()).st_mtime
mod_time = datetime.datetime.fromtimestamp(mod_time_f).strftime("%Y%m%d%H%M%S")
self.log.info("Found file %s last modified: %s", self.filepath, mod_time)
os.remove(self.filepath)
await filepath.unlink()
self.log.info("File %s has been deleted", self.filepath)
yield TriggerEvent(True)
return
Expand Down
7 changes: 4 additions & 3 deletions providers/standard/tests/unit/standard/triggers/test_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import asyncio

import anyio
import pytest

from airflow.providers.standard.triggers.file import FileDeleteTrigger, FileTrigger
Expand All @@ -43,7 +44,7 @@ def test_serialization(self):
async def test_task_file_trigger(self, tmp_path):
"""Asserts that the trigger only goes off on or after file is found"""
tmp_dir = tmp_path / "test_dir"
tmp_dir.mkdir()
await anyio.Path(tmp_dir).mkdir()
p = tmp_dir / "hello.txt"

trigger = FileTrigger(
Expand Down Expand Up @@ -84,7 +85,7 @@ def test_serialization(self):
async def test_file_delete_trigger(self, tmp_path):
"""Asserts that the trigger goes off on or after file is found and that the files gets deleted."""
tmp_dir = tmp_path / "test_dir"
tmp_dir.mkdir()
await anyio.Path(tmp_dir).mkdir()
p = tmp_dir / "hello.txt"

trigger = FileDeleteTrigger(
Expand All @@ -101,7 +102,7 @@ async def test_file_delete_trigger(self, tmp_path):
p.touch()

await asyncio.sleep(0.5)
assert p.exists() is False
assert await anyio.Path(p).exists() is False

# Prevents error when task is destroyed while in "pending" state
asyncio.get_event_loop().stop()
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,6 @@ ignore = [
"COM819",
"E501", # Formatted code may exceed the line length, leading to line-too-long (E501) errors.
"ASYNC110", # TODO: Use `anyio.Event` instead of awaiting `anyio.sleep` in a `while` loop
"ASYNC240", # TODO: Async functions should not use os.path methods, use trio.Path or anyio.path
"SIM105", # Use contextlib.suppress({exception}) instead of try-except-pass
]
unfixable = [
Expand Down