Skip to content
14 changes: 0 additions & 14 deletions src/core/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,20 +374,6 @@ class ServiceNotFoundError(ProblemDetailError):
_default_status_code = HTTPStatus.NOT_FOUND


# =============================================================================
# Quality Errors
# =============================================================================


class NoQualitiesError(ProblemDetailError):
"""Raised when a dataset has no stored quality values."""

uri = "https://openml.org/problems/quality-no-qualities"
title = "No Qualities Found"
_default_status_code = HTTPStatus.PRECONDITION_FAILED
_default_code = 362


# =============================================================================
# Internal Errors
# =============================================================================
Expand Down
133 changes: 64 additions & 69 deletions src/routers/openml/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Annotated, Any, Literal, NamedTuple

from fastapi import APIRouter, Body, Depends
from sqlalchemy import bindparam, text
from sqlalchemy import text
from sqlalchemy.engine import Row
from sqlalchemy.ext.asyncio import AsyncConnection

Expand Down Expand Up @@ -73,26 +73,9 @@ class DatasetStatusFilter(StrEnum):
ALL = "all"


def _quality_clause(quality: str, range_: str | None) -> str:
if not range_:
return ""
if not (match := re.match(integer_range_regex, range_)):
msg = f"`range_` not a valid range: {range_}"
raise ValueError(msg)
start, end = match.groups()
value = f"`value` BETWEEN {start} AND {end[2:]}" if end else f"`value`={start}"
return f""" AND
d.`did` IN (
SELECT `data`
FROM data_quality
WHERE `quality`='{quality}' AND {value}
)
""" # noqa: S608 - `quality` is not user provided, value is filtered with regex


@router.post(path="/list", description="Provided for convenience, same as `GET` endpoint.")
@router.get(path="/list")
async def list_datasets( # noqa: PLR0913, C901
async def list_datasets( # noqa: PLR0913
pagination: Annotated[Pagination, Body(default_factory=Pagination)],
data_name: Annotated[str | None, CasualString128] = None,
tag: Annotated[str | None, SystemString64] = None,
Expand Down Expand Up @@ -120,7 +103,7 @@ async def list_datasets( # noqa: PLR0913, C901
expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)] = None,
) -> list[dict[str, Any]]:
assert expdb_db is not None # noqa: S101
status_subquery = text(
current_status = text(
"""
SELECT ds1.`did`, ds1.`status`
FROM dataset_status as ds1
Expand All @@ -132,78 +115,90 @@ async def list_datasets( # noqa: PLR0913, C901
""",
)

clauses = []
parameters: dict[str, Any] = {
"offset": pagination.offset,
"limit": pagination.limit,
}
if status != DatasetStatusFilter.ALL:
clauses.append("AND IFNULL(cs.`status`, 'in_preparation') = :status")
parameters["status"] = status
if status == DatasetStatusFilter.ALL:
statuses = [
DatasetStatusFilter.ACTIVE,
DatasetStatusFilter.DEACTIVATED,
DatasetStatusFilter.IN_PREPARATION,
]
else:
statuses = [status]

where_status = ",".join(f"'{status}'" for status in statuses)
if user is None:
clauses.append("AND `visibility`='public'")
elif UserGroup.ADMIN not in await user.get_groups():
clauses.append("AND (`visibility`='public' OR `uploader`=:user_id)")
parameters["user_id"] = user.user_id

if uploader:
clauses.append("AND `uploader`=:uploader")
parameters["uploader"] = uploader

if data_name:
clauses.append("AND `name`=:data_name")
parameters["data_name"] = data_name

if data_version:
clauses.append("AND `version`=:data_version")
parameters["data_version"] = data_version
visible_to_user = "`visibility`='public'"
elif UserGroup.ADMIN in await user.get_groups():
visible_to_user = "TRUE"
else:
visible_to_user = f"(`visibility`='public' OR `uploader`={user.user_id})"

if data_id:
clauses.append("AND d.`did` IN :data_ids")
parameters["data_ids"] = data_id
where_name = "" if data_name is None else "AND `name`=:data_name"
where_version = "" if data_version is None else "AND `version`=:data_version"
where_uploader = "" if uploader is None else "AND `uploader`=:uploader"
data_id_str = ",".join(str(did) for did in data_id) if data_id else ""
where_data_id = "" if not data_id else f"AND d.`did` IN ({data_id_str})"

# requires some benchmarking on whether e.g., IN () is more efficient.
if tag:
clauses.append(
matching_tag = (
text(
"""
AND d.`did` IN (
SELECT `id`
FROM dataset_tag as dt
WHERE dt.`tag`=:tag
)
""",
AND d.`did` IN (
SELECT `id`
FROM dataset_tag as dt
WHERE dt.`tag`=:tag
)
""",
)
parameters["tag"] = tag
if tag
else ""
)

number_instances_filter = _quality_clause("NumberOfInstances", number_instances)
number_classes_filter = _quality_clause("NumberOfClasses", number_classes)
number_features_filter = _quality_clause("NumberOfFeatures", number_features)
number_missing_values_filter = _quality_clause("NumberOfMissingValues", number_missing_values)
def quality_clause(quality: str, range_: str | None) -> str:
if not range_:
return ""
if not (match := re.match(integer_range_regex, range_)):
msg = f"`range_` not a valid range: {range_}"
raise ValueError(msg)
start, end = match.groups()
value = f"`value` BETWEEN {start} AND {end[2:]}" if end else f"`value`={start}"
return f""" AND
d.`did` IN (
SELECT `data`
FROM data_quality
WHERE `quality`='{quality}' AND {value}
)
""" # noqa: S608 - `quality` is not user provided, value is filtered with regex

columns = ["did", "name", "version", "format", "file_id", "status"]
number_instances_filter = quality_clause("NumberOfInstances", number_instances)
number_classes_filter = quality_clause("NumberOfClasses", number_classes)
number_features_filter = quality_clause("NumberOfFeatures", number_features)
number_missing_values_filter = quality_clause("NumberOfMissingValues", number_missing_values)
matching_filter = text(
f"""
SELECT d.`did`,d.`name`,d.`version`,d.`format`,d.`file_id`,
IFNULL(cs.`status`, 'in_preparation')
FROM dataset AS d
LEFT JOIN ({status_subquery}) AS cs ON d.`did`=cs.`did`
WHERE 1=1 {number_instances_filter} {number_features_filter}
LEFT JOIN ({current_status}) AS cs ON d.`did`=cs.`did`
WHERE {visible_to_user} {where_name} {where_version} {where_uploader}
{where_data_id} {matching_tag} {number_instances_filter} {number_features_filter}
{number_classes_filter} {number_missing_values_filter}
{" ".join(clauses)}
LIMIT :limit OFFSET :offset
AND IFNULL(cs.`status`, 'in_preparation') IN ({where_status})
LIMIT {pagination.limit} OFFSET {pagination.offset}
""", # noqa: S608
# I am not sure how to do this correctly without an error from Bandit here.
# However, the `status` input is already checked by FastAPI to be from a set
# of given options, so no injection is possible (I think). The `current_status`
# subquery also has no user input. So I think this should be safe.
)

if data_id:
matching_filter.bindparams(bindparam("data_ids", expanding=True))
columns = ["did", "name", "version", "format", "file_id", "status"]
result = await expdb_db.execute(
matching_filter,
parameters=parameters,
parameters={
"tag": tag,
"data_name": data_name,
"data_version": data_version,
"uploader": uploader,
},
)
rows = result.all()
datasets: dict[int, dict[str, Any]] = {
Expand Down
39 changes: 15 additions & 24 deletions src/routers/openml/qualities.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from http import HTTPStatus
from typing import Annotated, Literal

from fastapi import APIRouter, Depends
Expand All @@ -6,12 +7,7 @@
import database.datasets
import database.qualities
from core.access import _user_has_access
from core.errors import (
DatasetNotFoundError,
DatasetNotProcessedError,
DatasetProcessingError,
NoQualitiesError,
)
from core.errors import DatasetNotFoundError
from database.users import User
from routers.dependencies import expdb_connection, fetch_user
from schemas.datasets.openml import Quality
Expand Down Expand Up @@ -39,24 +35,19 @@ async def get_qualities(
) -> list[Quality]:
dataset = await database.datasets.get(dataset_id, expdb)
if not dataset or not await _user_has_access(dataset, user):
# Backwards compatibility: PHP API returns 412 with code 113
msg = f"Dataset with id {dataset_id} not found."
no_data_file = 113
raise DatasetNotFoundError(
msg,
code=361,
) from None

processing = await database.datasets.get_latest_processing_update(dataset_id, expdb)
if processing is None:
msg = f"Dataset not processed yet for dataset {dataset_id}."
raise DatasetNotProcessedError(msg, code=363)

if processing.error:
msg = processing.error.strip() or "Error occurred during processing."
raise DatasetProcessingError(msg, code=364)

qualities = await database.qualities.get_for_dataset(dataset_id, expdb)
if not qualities:
msg = f"No qualities found for dataset {dataset_id}."
raise NoQualitiesError(msg)

return qualities
code=no_data_file,
status_code=HTTPStatus.PRECONDITION_FAILED,
)
return await database.qualities.get_for_dataset(dataset_id, expdb)
# The PHP API provided (sometime) helpful error messages
# if not qualities:
# check if dataset exists: error 360
# check if user has access: error 361
# check if there is a data processed entry and forward the error: 364
# if nothing in process table: 363
# otherwise: error 362
46 changes: 32 additions & 14 deletions src/routers/openml/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Annotated, cast

import xmltodict
from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import RowMapping, text
from sqlalchemy.ext.asyncio import AsyncConnection

Expand All @@ -18,6 +18,12 @@

type JSON = dict[str, "JSON"] | list["JSON"] | str | int | float | bool | None

ALLOWED_LOOKUP_TABLES = {"estimation_procedure", "evaluation_measure", "task_type", "dataset"}
PK_MAPPING = {
"task_type": "ttid",
"dataset": "did",
}


def convert_template_xml_to_json(xml_template: str) -> dict[str, JSON]:
json_template = xmltodict.parse(xml_template.replace("oml:", ""))
Expand Down Expand Up @@ -95,7 +101,7 @@ async def fill_template(
)


async def _fill_json_template( # noqa: C901
async def _fill_json_template( # noqa: C901, PLR0912
template: JSON,
task: RowMapping,
task_inputs: dict[str, str | int],
Expand Down Expand Up @@ -123,33 +129,45 @@ async def _fill_json_template( # noqa: C901
if match.string == template:
# How do we know the default value? probably ttype_io table?
return task_inputs.get(field, [])
template = template.replace(match.group(), str(task_inputs[field]))
template = template.replace(match.group(), str(task_inputs.get(field, "")))
if match := re.search(r"\[LOOKUP:(.*)]", template):
(field,) = match.groups()
if field not in fetched_data:
table, _ = field.split(".")
if table not in ALLOWED_LOOKUP_TABLES:
msg = f"Table {table} is not allowed for lookup."
raise HTTPException(status_code=400, detail=msg)
if table not in task_inputs or not task_inputs[table]:
msg = f"Missing or empty input for lookup table: {table}"
raise HTTPException(status_code=400, detail=msg)

try:
id_val = int(task_inputs[table])
except ValueError:
msg = f"Invalid integer id for table {table}: {task_inputs[table]}"
raise HTTPException(status_code=400, detail=msg) from None

pk = PK_MAPPING.get(table, "id")
result = await connection.execute(
text(
f"""
SELECT *
FROM {table}
WHERE `id` = :id_
WHERE `{pk}` = :id_
""", # noqa: S608
),
# Not sure how parametrize table names, as the parametrization adds
# quotes which is not legal.
parameters={"id_": int(task_inputs[table])},
parameters={"id_": id_val},
)
rows = result.mappings()
row_data = next(rows, None)
row_data = result.mappings().one_or_none()
if row_data is None:
msg = f"No data found for table {table} with id {task_inputs[table]}"
raise ValueError(msg)
msg = f"No data found for table {table} with id {id_val}"
raise HTTPException(status_code=400, detail=msg)
for column, value in row_data.items():
fetched_data[f"{table}.{column}"] = value
fetched_data[f"{table}.{column}"] = str(value)

if match.string == template:
return fetched_data[field]
template = template.replace(match.group(), fetched_data[field])
return fetched_data.get(field, "")
template = template.replace(match.group(), fetched_data.get(field, ""))
# I believe that the operations below are always part of string output, so
# we don't need to be careful to avoid losing typedness
template = template.replace("[TASK:id]", str(task.task_id))
Expand Down
Loading