Skip to content

Add test files for Find Nearest Neighbors filter#565

Open
ReubenHill wants to merge 6 commits intodevelopfrom
feature/use_nearest_neighbors
Open

Add test files for Find Nearest Neighbors filter#565
ReubenHill wants to merge 6 commits intodevelopfrom
feature/use_nearest_neighbors

Conversation

@ReubenHill
Copy link
Copy Markdown
Contributor

@ReubenHill ReubenHill commented Apr 20, 2026

Description

This adds 3 test files for the find nearest neigobors filter introduced in https://github.com/JCSDA-internal/ufo/pull/4114. Scripts used to generate the test files are in the comments below this description.

Issue(s) addressed

Partially https://github.com/JCSDA-internal/ufo/issues/4063

Dependencies

List the other PRs that this PR is dependent on:

build-group=https://github.com/JCSDA-internal/ufo/pull/4114

Impact

Expected impact on downstream repositories: None

Checklist

  • I have performed a self-review of my own code
  • I have made corresponding changes to the documentation
  • I have run the unit tests before creating the PR

@ReubenHill
Copy link
Copy Markdown
Contributor Author

ReubenHill commented Apr 20, 2026

Script for generating use_nearest_neighbors_obs_reference_point_variable_mean.nc4 and use_nearest_neighbors_obs_gather_and_match.nc4

import netCDF4 as nc
import numpy as np

MISSING_FLOAT = -3.3687952621450501176e38  # jedi's missing float value
MISSING_STR = "MISSING*"  # jedi's missing string value
MISSING_DATETIME = 253281254337  # 23:58:57 on 29 February 9996
MISSING_INT32 = -(2**31) + 5  # jedi's missing int32 value
MISSING_INT64 = -(2**63) + 7  # jedi's missing int64 value


# Generate files for testing the Use Nearest Neighbors UFO filter
def gather_and_match_timestamp_test_files_gen(
    name,
    t,
    t_query,
    t_ref,
    first_nearest_ref_station_id,
    second_nearest_ref_station_id,
    third_nearest_ref_station_id,
    first_nearest_distance,
    second_nearest_distance,
    third_nearest_distance,
    ob_type_num,
    station_id,
    lat,
    lon,
    datetime,
    binned_datetime,
    matched_first_nearest_temperatures,
    matched_second_nearest_temperatures,
    matched_third_nearest_temperatures,
    query_station_id,
    ref_station_id,
):
    assert (
        len(t)
        == len(t_query)
        == len(t_ref)
        == len(first_nearest_ref_station_id)
        == len(second_nearest_ref_station_id)
        == len(third_nearest_ref_station_id)
        == len(first_nearest_distance)
        == len(second_nearest_distance)
        == len(third_nearest_distance)
        == len(ob_type_num)
        == len(station_id)
        == len(lat)
        == len(lon)
        == len(datetime)
        == len(binned_datetime)
        == len(matched_first_nearest_temperatures)
        == len(matched_second_nearest_temperatures)
        == len(matched_third_nearest_temperatures)
        == len(query_station_id)
        == len(ref_station_id)
    )
    nlocs = len(station_id)
    err = np.ones(nlocs)  # arbitrary

    # Write observation file
    file = nc.Dataset(name, "w")
    file._ioda_layout = "ObsGroup"
    file._ioda_layout_version = 0
    file.date_time = "20200101T0000Z"
    file.createDimension("Location", nlocs)

    loc_var = file.createVariable(
        "Location", "f4", ("Location"), fill_value=MISSING_FLOAT
    )
    loc_var[:] = 0
    datetime_var = file.createVariable(
        "MetaData/dateTime", "i8", ("Location"), fill_value=MISSING_DATETIME
    )
    datetime_var.units = "seconds since 1970-01-01T00:00:00Z"
    datetime_var[:] = datetime

    # Variables
    lat_var = file.createVariable(
        "MetaData/latitude", "f4", ("Location"), fill_value=MISSING_FLOAT
    )
    lat_var[:] = lat
    lon_var = file.createVariable(
        "MetaData/longitude", "f4", ("Location"), fill_value=MISSING_FLOAT
    )
    lon_var[:] = lon
    station_id_var = file.createVariable(
        "MetaData/stationIdentification", str, ("Location"), fill_value=MISSING_STR
    )
    query_station_id_var = file.createVariable(
        "MetaData/queryStationIdentification", str, ("Location"), fill_value=MISSING_STR
    )
    query_station_id_var[:] = query_station_id
    ref_station_id_var = file.createVariable(
        "MetaData/referenceStationIdentification",
        str,
        ("Location"),
        fill_value=MISSING_STR,
    )
    ref_station_id_var[:] = ref_station_id
    station_id_var[:] = station_id
    ob_type_num_var = file.createVariable(
        "MetaData/observationTypeNum",
        "i4",
        ("Location"),
        fill_value=MISSING_INT32,
    )
    ob_type_num_var[:] = ob_type_num
    t_var = file.createVariable(
        "ObsValue/airTemperature", "f4", ("Location"), fill_value=MISSING_FLOAT
    )
    t_var[:] = t
    t_err_var = file.createVariable(
        "ObsError/airTemperature", "f4", ("Location"), fill_value=MISSING_FLOAT
    )
    t_err_var[:] = err
    t_query_var = file.createVariable(
        "DerivedObsValue/airTemperatureQuery",
        "f4",
        ("Location"),
        fill_value=MISSING_FLOAT,
    )
    t_query_var[:] = t_query
    t_ref_var = file.createVariable(
        "DerivedObsValue/airTemperatureReference",
        "f4",
        ("Location"),
        fill_value=MISSING_FLOAT,
    )
    t_ref_var[:] = t_ref
    first_nearest_ref_station_id_var = file.createVariable(
        "DerivedMetaData/firstNearestReferenceStationID",
        str,
        ("Location"),
        fill_value=MISSING_STR,
    )
    first_nearest_ref_station_id_var[:] = first_nearest_ref_station_id
    second_nearest_ref_station_id_var = file.createVariable(
        "DerivedMetaData/secondNearestReferenceStationID",
        str,
        ("Location"),
        fill_value=MISSING_STR,
    )
    second_nearest_ref_station_id_var[:] = second_nearest_ref_station_id
    third_nearest_ref_station_id_var = file.createVariable(
        "DerivedMetaData/thirdNearestReferenceStationID",
        str,
        ("Location"),
        fill_value=MISSING_STR,
    )
    third_nearest_ref_station_id_var[:] = third_nearest_ref_station_id
    first_nearest_distance_var = file.createVariable(
        "TestReference/firstNearestDistance",
        "f4",
        ("Location"),
        fill_value=MISSING_FLOAT,
    )
    first_nearest_distance_var[:] = first_nearest_distance
    second_nearest_distance_var = file.createVariable(
        "TestReference/secondNearestDistance",
        "f4",
        ("Location"),
        fill_value=MISSING_FLOAT,
    )
    second_nearest_distance_var[:] = second_nearest_distance
    third_nearest_distance_var = file.createVariable(
        "TestReference/thirdNearestDistance",
        "f4",
        ("Location"),
        fill_value=MISSING_FLOAT,
    )
    third_nearest_distance_var[:] = third_nearest_distance
    binned_datetime_var = file.createVariable(
        "DerivedMetaData/binnedDateTime",
        "i8",
        ("Location"),
        fill_value=MISSING_DATETIME,
    )
    binned_datetime_var.units = "seconds since 1970-01-01T00:00:00Z"
    # don't set the units - when superobbing datetime they come out like this.
    # Also panoply will use whichever datetime last had units set as the datetime
    # to display, so if the units are set here it will try to display the
    # binned datetime as a date which is not helpful.
    # binned_datetime_var.units = "seconds since 1970-01-01T00:00:00Z"
    binned_datetime_var[:] = binned_datetime
    matched_first_nearest_temperatures_var = file.createVariable(
        "TestReference/matchedFirstNearestTemperature",
        "f4",
        ("Location"),
        fill_value=MISSING_FLOAT,
    )
    matched_first_nearest_temperatures_var[:] = matched_first_nearest_temperatures
    matched_second_nearest_temperatures_var = file.createVariable(
        "TestReference/matchedSecondNearestTemperature",
        "f4",
        ("Location"),
        fill_value=MISSING_FLOAT,
    )
    matched_second_nearest_temperatures_var[:] = matched_second_nearest_temperatures
    matched_third_nearest_temperatures_var = file.createVariable(
        "TestReference/matchedThirdNearestTemperature",
        "f4",
        ("Location"),
        fill_value=MISSING_FLOAT,
    )
    matched_third_nearest_temperatures_var[:] = matched_third_nearest_temperatures
    missing_floats_var = file.createVariable(
        "TestReference/missingFloats",
        "f4",
        ("Location"),
        fill_value=MISSING_FLOAT,
    )
    missing_floats_var[:] = np.asarray([MISSING_FLOAT] * nlocs)
    file.close()


def reference_point_variable_mean_test_files_gen(
    name,
    t,
    t_query,
    t_ref,
    first_nearest_ref_station_id,
    second_nearest_ref_station_id,
    third_nearest_ref_station_id,
    ob_type_num,
    station_id,
    lat,
    lon,
    datetime_query,
    datetime_ref,
    datetime_averaging_bin,
    average_three_closest_ref_temps_most_recent,
    average_three_closest_ref_temps_previous,
    average_three_closest_ref_datetimes_most_recent,
    average_three_closest_ref_datetimes_previous,
    average_two_closest_ref_temps_most_recent,
    average_two_closest_ref_temps_previous,
    average_two_closest_ref_datetimes_most_recent,
    average_two_closest_ref_datetimes_previous,
    closest_ref_temps_most_recent,
    closest_ref_temps_previous,
    closest_ref_datetimes_most_recent,
    closest_ref_datetimes_previous,
    query_station_id,
    ref_station_id,
):
    assert (
        len(t)
        == len(t_query)
        == len(t_ref)
        == len(first_nearest_ref_station_id)
        == len(second_nearest_ref_station_id)
        == len(third_nearest_ref_station_id)
        == len(ob_type_num)
        == len(station_id)
        == len(lat)
        == len(lon)
        == len(datetime_query)
        == len(datetime_ref)
        == len(datetime_averaging_bin)
        == len(average_three_closest_ref_temps_most_recent)
        == len(average_three_closest_ref_temps_previous)
        == len(average_three_closest_ref_datetimes_most_recent)
        == len(average_three_closest_ref_datetimes_previous)
        == len(average_two_closest_ref_temps_most_recent)
        == len(average_two_closest_ref_temps_previous)
        == len(average_two_closest_ref_datetimes_most_recent)
        == len(average_two_closest_ref_datetimes_previous)
        == len(closest_ref_temps_most_recent)
        == len(closest_ref_temps_previous)
        == len(closest_ref_datetimes_most_recent)
        == len(closest_ref_datetimes_previous)
        == len(query_station_id)
        == len(ref_station_id)
    )
    nlocs = len(station_id)
    err = np.ones(nlocs)  # arbitrary

    # Write observation file
    file = nc.Dataset(name, "w")
    file._ioda_layout = "ObsGroup"
    file._ioda_layout_version = 0
    file.date_time = "20200101T0000Z"
    file.createDimension("Location", nlocs)

    loc_var = file.createVariable(
        "Location", "f4", ("Location"), fill_value=MISSING_FLOAT
    )
    loc_var[:] = 0
    datetime_var = file.createVariable(
        "MetaData/dateTime", "i8", ("Location"), fill_value=MISSING_DATETIME
    )
    datetime_var.units = "seconds since 1970-01-01T00:00:00Z"
    datetime_var[:] = datetime

    # Variables
    lat_var = file.createVariable(
        "MetaData/latitude", "f4", ("Location"), fill_value=MISSING_FLOAT
    )
    lat_var[:] = lat
    lon_var = file.createVariable(
        "MetaData/longitude", "f4", ("Location"), fill_value=MISSING_FLOAT
    )
    lon_var[:] = lon
    station_id_var = file.createVariable(
        "MetaData/stationIdentification", str, ("Location"), fill_value=MISSING_STR
    )
    station_id_var[:] = station_id
    query_station_id_var = file.createVariable(
        "MetaData/queryStationIdentification", str, ("Location"), fill_value=MISSING_STR
    )
    query_station_id_var[:] = query_station_id
    ref_station_id_var = file.createVariable(
        "MetaData/referenceStationIdentification",
        str,
        ("Location"),
        fill_value=MISSING_STR,
    )
    ref_station_id_var[:] = ref_station_id
    ob_type_num_var = file.createVariable(
        "MetaData/observationTypeNum",
        "i4",
        ("Location"),
        fill_value=MISSING_INT32,
    )
    ob_type_num_var[:] = ob_type_num
    t_var = file.createVariable(
        "ObsValue/airTemperature", "f4", ("Location"), fill_value=MISSING_FLOAT
    )
    t_var[:] = t
    t_err_var = file.createVariable(
        "ObsError/airTemperature", "f4", ("Location"), fill_value=MISSING_FLOAT
    )
    t_err_var[:] = err
    t_query_var = file.createVariable(
        "DerivedObsValue/airTemperatureQuery",
        "f4",
        ("Location"),
        fill_value=MISSING_FLOAT,
    )
    t_query_var[:] = t_query
    t_ref_var = file.createVariable(
        "DerivedObsValue/airTemperatureReference",
        "f4",
        ("Location"),
        fill_value=MISSING_FLOAT,
    )
    t_ref_var[:] = t_ref
    first_nearest_ref_station_id_var = file.createVariable(
        "DerivedMetaData/firstNearestReferenceStationID",
        str,
        ("Location"),
        fill_value=MISSING_STR,
    )
    first_nearest_ref_station_id_var[:] = first_nearest_ref_station_id
    second_nearest_ref_station_id_var = file.createVariable(
        "DerivedMetaData/secondNearestReferenceStationID",
        str,
        ("Location"),
        fill_value=MISSING_STR,
    )
    second_nearest_ref_station_id_var[:] = second_nearest_ref_station_id
    third_nearest_ref_station_id_var = file.createVariable(
        "DerivedMetaData/thirdNearestReferenceStationID",
        str,
        ("Location"),
        fill_value=MISSING_STR,
    )
    third_nearest_ref_station_id_var[:] = third_nearest_ref_station_id
    datetime_query_var = file.createVariable(
        "DerivedMetaData/dateTimeQuery",
        "i8",
        ("Location"),
        fill_value=MISSING_DATETIME,
    )
    datetime_query_var.units = "seconds since 1970-01-01T00:00:00Z"
    datetime_query_var[:] = datetime_query
    datetime_ref_var = file.createVariable(
        "DerivedMetaData/dateTimeReference",
        "i8",
        ("Location"),
        fill_value=MISSING_DATETIME,
    )
    datetime_ref_var.units = "seconds since 1970-01-01T00:00:00Z"
    datetime_ref_var[:] = datetime_ref
    datetime_averaging_bin_var = file.createVariable(
        "DerivedMetaData/dateTimeAveragingBin",
        "i8",
        ("Location"),
        fill_value=MISSING_INT64,
    )
    datetime_averaging_bin_var[:] = datetime_averaging_bin
    average_three_closest_ref_temps_most_recent_var = file.createVariable(
        "TestReference/averageThreeClosestReferenceAirTemperaturesMostRecent",
        "f4",
        ("Location"),
        fill_value=MISSING_FLOAT,
    )
    average_three_closest_ref_temps_most_recent_var[:] = (
        average_three_closest_ref_temps_most_recent
    )
    average_three_closest_ref_temps_previous_var = file.createVariable(
        "TestReference/averageThreeClosestReferenceAirTemperaturesPrevious",
        "f4",
        ("Location"),
        fill_value=MISSING_FLOAT,
    )
    average_three_closest_ref_temps_previous_var[:] = (
        average_three_closest_ref_temps_previous
    )
    average_three_closest_ref_datetimes_most_recent_var = file.createVariable(
        "TestReference/averageThreeClosestReferenceDateTimesMostRecent",
        "i8",
        ("Location"),
        fill_value=MISSING_DATETIME,
    )
    average_three_closest_ref_datetimes_most_recent_var.units = (
        "seconds since 1970-01-01T00:00:00Z"
    )
    average_three_closest_ref_datetimes_most_recent_var[:] = (
        average_three_closest_ref_datetimes_most_recent
    )
    average_three_closest_ref_datetimes_previous_var = file.createVariable(
        "TestReference/averageThreeClosestReferenceDateTimesPrevious",
        "i8",
        ("Location"),
        fill_value=MISSING_DATETIME,
    )
    average_three_closest_ref_datetimes_previous_var.units = (
        "seconds since 1970-01-01T00:00:00Z"
    )
    average_three_closest_ref_datetimes_previous_var[:] = (
        average_three_closest_ref_datetimes_previous
    )
    average_two_closest_ref_temps_most_recent_var = file.createVariable(
        "TestReference/averageTwoClosestReferenceAirTemperaturesMostRecent",
        "f4",
        ("Location"),
        fill_value=MISSING_FLOAT,
    )
    average_two_closest_ref_temps_most_recent_var[:] = (
        average_two_closest_ref_temps_most_recent
    )
    average_two_closest_ref_temps_previous_var = file.createVariable(
        "TestReference/averageTwoClosestReferenceAirTemperaturesPrevious",
        "f4",
        ("Location"),
        fill_value=MISSING_FLOAT,
    )
    average_two_closest_ref_temps_previous_var[:] = (
        average_two_closest_ref_temps_previous
    )
    average_two_closest_ref_datetimes_most_recent_var = file.createVariable(
        "TestReference/averageTwoClosestReferenceDateTimesMostRecent",
        "i8",
        ("Location"),
        fill_value=MISSING_DATETIME,
    )
    average_two_closest_ref_datetimes_most_recent_var.units = (
        "seconds since 1970-01-01T00:00:00Z"
    )
    average_two_closest_ref_datetimes_most_recent_var[:] = (
        average_two_closest_ref_datetimes_most_recent
    )
    average_two_closest_ref_datetimes_previous_var = file.createVariable(
        "TestReference/averageTwoClosestReferenceDateTimesPrevious",
        "i8",
        ("Location"),
        fill_value=MISSING_DATETIME,
    )
    average_two_closest_ref_datetimes_previous_var.units = (
        "seconds since 1970-01-01T00:00:00Z"
    )
    average_two_closest_ref_datetimes_previous_var[:] = (
        average_two_closest_ref_datetimes_previous
    )
    closest_ref_temps_most_recent_var = file.createVariable(
        "TestReference/closestReferenceAirTemperaturesMostRecent",
        "f4",
        ("Location"),
        fill_value=MISSING_FLOAT,
    )
    closest_ref_temps_most_recent_var[:] = closest_ref_temps_most_recent
    closest_ref_temps_previous_var = file.createVariable(
        "TestReference/closestReferenceAirTemperaturesPrevious",
        "f4",
        ("Location"),
        fill_value=MISSING_FLOAT,
    )
    closest_ref_temps_previous_var[:] = closest_ref_temps_previous
    closest_ref_datetimes_most_recent_var = file.createVariable(
        "TestReference/closestReferenceDateTimesMostRecent",
        "i8",
        ("Location"),
        fill_value=MISSING_DATETIME,
    )
    closest_ref_datetimes_most_recent_var.units = "seconds since 1970-01-01T00:00:00Z"
    closest_ref_datetimes_most_recent_var[:] = closest_ref_datetimes_most_recent
    closest_ref_datetimes_previous_var = file.createVariable(
        "TestReference/closestReferenceDateTimesPrevious",
        "i8",
        ("Location"),
        fill_value=MISSING_DATETIME,
    )
    closest_ref_datetimes_previous_var.units = "seconds since 1970-01-01T00:00:00Z"
    closest_ref_datetimes_previous_var[:] = closest_ref_datetimes_previous
    missing_floats_var = file.createVariable(
        "TestReference/missingFloats",
        "f4",
        ("Location"),
        fill_value=MISSING_FLOAT,
    )
    missing_floats_var[:] = np.asarray([MISSING_FLOAT] * nlocs)
    file.close()


def haversine(lat1, lon1, lat2, lon2):
    """
    Calculate the great circle distance between two points
    on the earth (specified in decimal degrees) - adapted from
    https://stackoverflow.com/questions/29545704/fast-haversine-approximation-python-pandas/29546836#29546836
    """
    lat1, lon1, lat2, lon2 = map(np.radians, [lat1, lon1, lat2, lon2])
    # haversine formula to get central angle between two points in radians
    dlat = lat2 - lat1
    dlon = lon2 - lon1
    a = np.sin(dlat / 2.0) ** 2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon / 2.0) ** 2
    c = 2.0 * np.arcsin(np.sqrt(a))
    # convert to distance in kilometers
    r = 6371.0  # Mean radius of earth
    return c * r


if __name__ == "__main__":

    # 3 temperature obs per station, 6 stations (18 obs in total), split over 2
    # obs type numbers. Station 4 is closest to station 1, station 5 to station 2,
    # station 6 to station 3.
    num_obs = 18
    station_1_loc = (40.0, -1.0)
    station_2_loc = (41.0, 0.0)
    station_3_loc = (42.0, 1.0)
    station_4_loc = (40.1, -0.9)  # closest to station 1
    station_5_loc = (41.1, 0.1)  # closest to station 2
    station_6_loc = (42.1, 1.1)  # closest to station 3
    station_1_id = "station1"
    station_2_id = "station2"
    station_3_id = "station3"
    station_4_id = "station4"
    station_5_id = "station5"
    station_6_id = "station6"
    station_1_station_4_dist = haversine(
        station_1_loc[0], station_1_loc[1], station_4_loc[0], station_4_loc[1]
    )
    station_1_station_5_dist = haversine(
        station_1_loc[0], station_1_loc[1], station_5_loc[0], station_5_loc[1]
    )
    station_1_station_6_dist = haversine(
        station_1_loc[0], station_1_loc[1], station_6_loc[0], station_6_loc[1]
    )
    station_2_station_4_dist = haversine(
        station_2_loc[0], station_2_loc[1], station_4_loc[0], station_4_loc[1]
    )
    station_2_station_5_dist = haversine(
        station_2_loc[0], station_2_loc[1], station_5_loc[0], station_5_loc[1]
    )
    station_2_station_6_dist = haversine(
        station_2_loc[0], station_2_loc[1], station_6_loc[0], station_6_loc[1]
    )
    station_3_station_4_dist = haversine(
        station_3_loc[0], station_3_loc[1], station_4_loc[0], station_4_loc[1]
    )
    station_3_station_5_dist = haversine(
        station_3_loc[0], station_3_loc[1], station_5_loc[0], station_5_loc[1]
    )
    station_3_station_6_dist = haversine(
        station_3_loc[0], station_3_loc[1], station_6_loc[0], station_6_loc[1]
    )
    ob_type_num = np.asarray([1] * 9 + [2] * 9)
    station_id = np.asarray(
        [station_1_id] * 3
        + [station_2_id] * 3
        + [station_3_id] * 3
        + [station_4_id] * 3
        + [station_5_id] * 3
        + [station_6_id] * 3
    )
    lat = np.asarray(
        [station_1_loc[0]] * 3
        + [station_2_loc[0]] * 3
        + [station_3_loc[0]] * 3
        + [station_4_loc[0]] * 3
        + [station_5_loc[0]] * 3
        + [station_6_loc[0]] * 3
    )
    # Lon crosses the meridian to make sure distance calculations are correct
    lon = np.asarray(
        [station_1_loc[1]] * 3
        + [station_2_loc[1]] * 3
        + [station_3_loc[1]] * 3
        + [station_4_loc[1]] * 3
        + [station_5_loc[1]] * 3
        + [station_6_loc[1]] * 3
    )
    # Stations with ob type number 1 get measurements every 5 minutes from
    # 1am to 1:10am (3 measurements), stations with ob type number 2 get
    # measurements on the hour from 1am to 3am (3 measurements). All are taken
    # on 2018-04-14 and are in units of seconds since 1970-01-01T00:00:00Z
    datetime_station_1 = np.asarray([1523667600, 1523667900, 1523668200])
    datetime_station_2 = np.asarray([1523667600, 1523667900, 1523668200])
    datetime_station_3 = np.asarray([1523667600, 1523667900, 1523668200])
    datetime_station_4 = np.asarray([1523667600, 1523671200, 1523674800])
    datetime_station_5 = np.asarray([1523667600, 1523671200, 1523674800])
    datetime_station_6 = np.asarray([1523667600, 1523671200, 1523674800])
    datetime = np.concatenate(
        (
            datetime_station_1,
            datetime_station_2,
            datetime_station_3,
            datetime_station_4,
            datetime_station_5,
            datetime_station_6,
        )
    )
    t_station_1 = np.asarray([280.0, 281.0, 282.0])  # ob type number 1
    t_station_2 = np.asarray([290.0, 291.0, 292.0])  # ob type number 1

    t_station_3 = np.asarray([300.0, 301.0, 302.0])  # ob type number 1
    t_station_4 = np.asarray([310.0, 311.0, 312.0])  # ob type number 2
    t_station_5 = np.asarray([320.0, 321.0, 322.0])  # ob type number 2
    t_station_6 = np.asarray(
        [MISSING_FLOAT, 331.0, 332.0]
    )  # ob type number 2, one missing value
    t = np.concatenate(
        (t_station_1, t_station_2, t_station_3, t_station_4, t_station_5, t_station_6)
    )
    # The stations are split into two observation type numbers, causing the
    # vector of t to be split into two vectors of 18 elements, where the values
    # are missing for the other observation type number. These become the query
    # and reference points for the nearest neighbor search.
    t_query = np.concatenate((t[0:9], np.asarray([MISSING_FLOAT] * 9)))
    t_ref = np.concatenate((np.asarray([MISSING_FLOAT] * 9), t[9:18]))
    query_station_id = np.concatenate(
        (station_id[0:9], np.asarray([MISSING_FLOAT] * 9))
    )
    ref_station_id = np.concatenate((np.asarray([MISSING_FLOAT] * 9), station_id[9:18]))
    # When the the nearest neighbors are found, the lats and lons for each non-
    # missing value in t_query is compared to all the lats and lons for each
    # non-missing value in t_ref.
    # first nearest neighbors should be:
    # station 1 -> station 4
    # station 2 -> station 5
    # station 3 -> station 6
    # second nearest neighbors should be:
    # station 1 -> station 5
    # station 2 -> station 4
    # station 3 -> station 5
    # third nearest neighbors should be:
    # station 1 -> station 6
    # station 2 -> station 6
    # station 3 -> station 4
    # If station ID is attached to the reference points, then the expected
    # output vectors are
    first_nearest_ref_station_id = np.asarray(
        [station_4_id] * 3 + [station_5_id] * 3 + [station_6_id] * 3 + [MISSING_STR] * 9
    )
    second_nearest_ref_station_id = np.asarray(
        [station_5_id] * 3 + [station_4_id] * 3 + [station_5_id] * 3 + [MISSING_STR] * 9
    )
    third_nearest_ref_station_id = np.asarray(
        [station_6_id] * 3 + [station_6_id] * 3 + [station_4_id] * 3 + [MISSING_STR] * 9
    )
    # The nearest neighbor distances in kilometers should be
    first_nearest_dists = np.asarray(
        [station_1_station_4_dist] * 3
        + [station_2_station_5_dist] * 3
        + [station_3_station_6_dist] * 3
        + [MISSING_FLOAT] * 9
    )
    second_nearest_dists = np.asarray(
        [station_1_station_5_dist] * 3
        + [station_2_station_4_dist] * 3
        + [station_3_station_5_dist] * 3
        + [MISSING_FLOAT] * 9
    )
    third_nearest_dists = np.asarray(
        [station_1_station_6_dist] * 3
        + [station_2_station_6_dist] * 3
        + [station_3_station_4_dist] * 3
        + [MISSING_FLOAT] * 9
    )
    # Create a timestamp which the "gather and match timestamp" algorithm
    # can use to match to - these exist for all stations and both obs types such
    # that the existing datetimes are binned with T-30M <= t <= T where T is the
    # hour that is assigned. If a datetime is inside these bounds it is set to
    # the corresponding hour. If it is outside of then it is set to missing.
    binned_datetime = np.asarray(
        [1523667600, MISSING_DATETIME, MISSING_DATETIME] * 3  # 1 am, others missing
        + [1523667600, 1523671200, 1523674800] * 3  # 1 am, 2 am, 3 am
    )  # todo: do this dynamically based on datetime
    # The algorithm matches the identified nearest neighbor station IDs to
    # this timestamp and returns the temperature at that time. The expected
    # output is therefore
    matched_first_nearest_temperatures = np.asarray(
        [
            t_station_4[0],
            MISSING_FLOAT,
            MISSING_FLOAT,
        ]  # Matches for station 1 - the measurement taken at 1 am
        + [
            t_station_5[0],
            MISSING_FLOAT,
            MISSING_FLOAT,
        ]  # Matches for station 2 - the measurement taken at 1 am
        + [
            t_station_6[0],
            MISSING_FLOAT,
            MISSING_FLOAT,
        ]  # Matches for station 3 - the measurement taken at 1 am
        + [MISSING_FLOAT]
        * 9  # No matches for stations 4, 5, 6 as they are reference points
    )
    matched_second_nearest_temperatures = np.asarray(
        [t_station_5[0], MISSING_FLOAT, MISSING_FLOAT]
        + [t_station_4[0], MISSING_FLOAT, MISSING_FLOAT]
        + [t_station_5[0], MISSING_FLOAT, MISSING_FLOAT]
        + [MISSING_FLOAT] * 9
    )
    matched_third_nearest_temperatures = np.asarray(
        [t_station_6[0], MISSING_FLOAT, MISSING_FLOAT]
        + [t_station_6[0], MISSING_FLOAT, MISSING_FLOAT]
        + [t_station_4[0], MISSING_FLOAT, MISSING_FLOAT]
        + [MISSING_FLOAT] * 9
    )

    # permute the obs types to make sure that the code does not rely on the
    # ordering of the observation type numbers
    permutation = np.random.permutation(num_obs)
    # for testing, can use no permutation with permutation = np.arange(num_obs)
    gather_and_match_timestamp_test_files_gen(
        "use_nearest_neighbors_obs_gather_and_match.nc4",
        t[permutation],
        t_query[permutation],
        t_ref[permutation],
        first_nearest_ref_station_id[permutation],
        second_nearest_ref_station_id[permutation],
        third_nearest_ref_station_id[permutation],
        first_nearest_dists[permutation],
        second_nearest_dists[permutation],
        third_nearest_dists[permutation],
        ob_type_num[permutation],
        station_id[permutation],
        lat[permutation],
        lon[permutation],
        datetime[permutation],
        binned_datetime[permutation],
        matched_first_nearest_temperatures[permutation],
        matched_second_nearest_temperatures[permutation],
        matched_third_nearest_temperatures[permutation],
        query_station_id[permutation],
        ref_station_id[permutation],
    )

    # for the "reference point variables mean" algorithm, we need to define
    # averaging bins and the expected output of the averaging. The averaging
    # bins are integers 0, 1, 2, 3 etc. These are assigned to the datetimes and
    # are assigned by binning from 30 minutes before the hour (inclusive) to
    # 29M59S after the hour (inclusive). The bins are in reverse chronological
    # order such that bin 0 is the most recent hour, bin 1 the hour before that
    # etc.
    datetime_averaging_bin = np.asarray(
        [2] * 9  # ob type number 1, all in bin 2 (00:30 to 01:29:59)
        + [2, 1, 0]
        * 3  # ob type number 2, in bins 2 (00:30 to 01:29:59), 1 (01:30 to 02:29:59), 0 (02:30 to 03:29:59)
    )  # todo: do this dynamically based on datetime

    # The "gather variable" is the

    # For each query point, the expected output is the mean of the three closest
    # reference point measurements at each averaging bin.
    average_three_closest_ref_temps_most_recent = np.asarray(
        [np.mean([t_station_4[2], t_station_5[2], t_station_6[2]])]
        * 3  # station 1, bin 0
        + [np.mean([t_station_5[2], t_station_4[2], t_station_6[2]])]
        * 3  # station 2, bin 0
        + [np.mean([t_station_6[2], t_station_5[2], t_station_4[2]])]
        * 3  # station 3, bin 0
        + [MISSING_FLOAT] * 9  # stations 4, 5, 6 are reference points
    )
    average_three_closest_ref_temps_previous = np.asarray(
        [np.mean([t_station_4[1], t_station_5[1], t_station_6[1]])]
        * 3  # station 1, bin 1
        + [np.mean([t_station_5[1], t_station_4[1], t_station_6[1]])]
        * 3  # station 2, bin 1
        + [np.mean([t_station_6[1], t_station_5[1], t_station_4[1]])]
        * 3  # station 3, bin 1
        + [MISSING_FLOAT] * 9  # stations 4, 5, 6 are reference points
    )
    # Separate the timestamps by observation type number
    datetime_query = np.asarray([MISSING_DATETIME] * num_obs)
    datetime_query[ob_type_num == 1] = datetime[ob_type_num == 1]
    datetime_ref = np.asarray([MISSING_DATETIME] * num_obs)
    datetime_ref[ob_type_num == 2] = datetime[ob_type_num == 2]
    # For each query point, we also get the average of the reference point
    # datetimes about the averaging bin
    average_three_closest_ref_datetimes_most_recent = np.asarray(
        [np.mean([datetime_station_4[2], datetime_station_5[2], datetime_station_6[2]])]
        * 3  # station 1, bin 0
        + [
            np.mean(
                [datetime_station_5[2], datetime_station_4[2], datetime_station_6[2]]
            )
        ]
        * 3  # station 2, bin 0
        + [
            np.mean(
                [datetime_station_6[2], datetime_station_5[2], datetime_station_4[2]]
            )
        ]
        * 3  # station 3, bin 0
        + [MISSING_DATETIME] * 9  # stations 4, 5, 6 are reference points
    )
    average_three_closest_ref_datetimes_previous = np.asarray(
        [np.mean([datetime_station_4[1], datetime_station_5[1], datetime_station_6[1]])]
        * 3  # station 1, bin 1
        + [
            np.mean(
                [datetime_station_5[1], datetime_station_4[1], datetime_station_6[1]]
            )
        ]
        * 3  # station 2, bin 1
        + [
            np.mean(
                [datetime_station_6[1], datetime_station_5[1], datetime_station_4[1]]
            )
        ]
        * 3  # station 3, bin 1
        + [MISSING_DATETIME] * 9  # stations 4, 5, 6 are reference points
    )
    # do the same for the 2 closest reference points
    average_two_closest_ref_temps_most_recent = np.asarray(
        [np.mean([t_station_4[2], t_station_5[2]])] * 3  # station 1, bin 0
        + [np.mean([t_station_5[2], t_station_4[2]])] * 3  # station 2, bin 0
        + [np.mean([t_station_6[2], t_station_5[2]])] * 3  # station 3, bin 0
        + [MISSING_FLOAT] * 9  # stations 4, 5, 6 are reference points
    )
    average_two_closest_ref_temps_previous = np.asarray(
        [np.mean([t_station_4[1], t_station_5[1]])] * 3  # station 1, bin 1
        + [np.mean([t_station_5[1], t_station_4[1]])] * 3  # station 2, bin 1
        + [np.mean([t_station_6[1], t_station_5[1]])] * 3  # station 3, bin 1
        + [MISSING_FLOAT] * 9  # stations 4, 5, 6 are reference points
    )
    average_two_closest_ref_datetimes_most_recent = np.asarray(
        [np.mean([datetime_station_4[2], datetime_station_5[2]])]
        * 3  # station 1, bin 0
        + [np.mean([datetime_station_5[2], datetime_station_4[2]])]
        * 3  # station 2, bin 0
        + [np.mean([datetime_station_6[2], datetime_station_5[2]])]
        * 3  # station 3, bin 0
        + [MISSING_DATETIME] * 9  # stations 4, 5, 6 are reference points
    )
    average_two_closest_ref_datetimes_previous = np.asarray(
        [np.mean([datetime_station_4[1], datetime_station_5[1]])]
        * 3  # station 1, bin 1
        + [np.mean([datetime_station_5[1], datetime_station_4[1]])]
        * 3  # station 2, bin 1
        + [np.mean([datetime_station_6[1], datetime_station_5[1]])]
        * 3  # station 3, bin 1
        + [MISSING_DATETIME] * 9  # stations 4, 5, 6 are reference points
    )
    # and for the closest reference point
    closest_ref_temps_most_recent = np.asarray(
        [t_station_4[2]] * 3  # station 1, bin 0
        + [t_station_5[2]] * 3  # station 2, bin 0
        + [t_station_6[2]] * 3  # station 3, bin 0
        + [MISSING_FLOAT] * 9  # stations 4, 5, 6 are reference points
    )
    closest_ref_temps_previous = np.asarray(
        [t_station_4[1]] * 3  # station 1, bin 1
        + [t_station_5[1]] * 3  # station 2, bin 1
        + [t_station_6[1]] * 3  # station 3, bin 1
        + [MISSING_FLOAT] * 9  # stations 4, 5, 6 are reference points
    )
    closest_ref_datetimes_most_recent = np.asarray(
        [datetime_station_4[2]] * 3  # station 1, bin 0
        + [datetime_station_5[2]] * 3  # station 2, bin 0
        + [datetime_station_6[2]] * 3  # station 3, bin 0
        + [MISSING_DATETIME] * 9  # stations 4, 5, 6 are reference points
    )
    closest_ref_datetimes_previous = np.asarray(
        [datetime_station_4[1]] * 3  # station 1, bin 1
        + [datetime_station_5[1]] * 3  # station 2, bin 1
        + [datetime_station_6[1]] * 3  # station 3, bin 1
        + [MISSING_DATETIME] * 9  # stations 4, 5, 6 are reference points
    )

    reference_point_variable_mean_test_files_gen(
        "use_nearest_neighbors_obs_reference_point_variable_mean.nc4",
        t[permutation],
        t_query[permutation],
        t_ref[permutation],
        first_nearest_ref_station_id[permutation],
        second_nearest_ref_station_id[permutation],
        third_nearest_ref_station_id[permutation],
        ob_type_num[permutation],
        station_id[permutation],
        lat[permutation],
        lon[permutation],
        datetime_query[permutation],
        datetime_ref[permutation],
        datetime_averaging_bin[permutation],
        average_three_closest_ref_temps_most_recent[permutation],
        average_three_closest_ref_temps_previous[permutation],
        average_three_closest_ref_datetimes_most_recent[permutation],
        average_three_closest_ref_datetimes_previous[permutation],
        average_two_closest_ref_temps_most_recent[permutation],
        average_two_closest_ref_temps_previous[permutation],
        average_two_closest_ref_datetimes_most_recent[permutation],
        average_two_closest_ref_datetimes_previous[permutation],
        closest_ref_temps_most_recent[permutation],
        closest_ref_temps_previous[permutation],
        closest_ref_datetimes_most_recent[permutation],
        closest_ref_datetimes_previous[permutation],
        query_station_id[permutation],
        ref_station_id[permutation],
    )

@ReubenHill
Copy link
Copy Markdown
Contributor Author

ReubenHill commented Apr 20, 2026

Script for generating use_nearest_neighbors_obs_interpolate_with_local_plane_fit.nc4

#!/usr/bin/env python3
"""
Generate test file for UseNearestNeighbors local plane fit algorithm.

This script generates test data for the "local plane fit" algorithm that:
- Uses up to 4 nearest neighbors
- Tests 5 scenarios: 4-neighbor plane fit, 3-neighbor exact fit, 4-neighbor forced IDW,
  2-neighbor IDW fallback, and 1-neighbor IDW
- Includes missing values and a colocated query/reference point pair
"""

import netCDF4 as nc
import numpy as np
from scipy.linalg import ldl
from datetime import datetime as dt
import pytz as tz

MISSING_FLOAT = -3.3687952621450501176e38  # JEDI's missing float value
MISSING_STR = "MISSING*"  # JEDI's missing string value
MISSING_DATETIME = 253281254337  # 23:58:57 on 29 February 9996
MISSING_INT32 = -(2**31) + 5  # JEDI's missing int32 value
MISSING_INT64 = -(2**63) + 7  # JEDI's missing int64 value


def haversine(lat1, lon1, lat2, lon2):
    """
    Calculate the great circle distance between two points
    on the earth (specified in decimal degrees).

    Returns distance in kilometers.
    """
    lat1, lon1, lat2, lon2 = map(np.radians, [lat1, lon1, lat2, lon2])
    # Haversine formula to get central angle between two points in radians
    dlat = lat2 - lat1
    dlon = lon2 - lon1
    a = np.sin(dlat / 2.0) ** 2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon / 2.0) ** 2
    c = 2.0 * np.arcsin(np.sqrt(a))
    # Convert to distance in kilometers
    r = 6371.0087714  # Earth radius matching C++ code
    return c * r


def griddata_knn(
    points,
    values,
    xi,
    k=6,
    fill_value=np.nan,
    power=2,
    relative_error_threshold=0.25,
):
    """
    Interpolate using k nearest neighbors with local plane fitting and inverse
    distance weighting (brute force haversine distances).

    This function performs linear interpolation using a weighted plane
    fitted to the k nearest neighbors of each query point using inverse
    distance weighting (least squares fit for k >= 3). By default,
    extrapolation is allowed.

    Parameters
    ----------
    points : tuple of ndarray
        (lon_coords, lat_coords) of data points in degrees
    values : ndarray
        Values at the data points
    xi : tuple of float or ndarray
        (lon, lat) coordinates where to interpolate in degrees
    k : int, optional
        Number of nearest neighbors to use (default: 6)
    fill_value : float, optional
        Value to use for points with missing data (default: nan)
    power : float, optional
        Power parameter for inverse distance weighting (default: 2)
        Weights are computed as 1/(distance^power)
    relative_error_threshold : float, optional
        Relative error threshold for least squares fit (default: 0.25)
        If the relative error of the fitted plane exceeds this value,
        interpolation falls back to inverse distance weighting

    Returns
    -------
    float or ndarray
        Interpolated values at xi, or fill_value for points with missing data
    """
    # Convert inputs to arrays
    lon_data, lat_data = points
    lon_data = np.asarray(lon_data)
    lat_data = np.asarray(lat_data)
    values = np.asarray(values)

    # Raise if any two input points share the same (lat, lon) coordinates,
    # consistent with scipy's griddata which does not handle duplicates.
    coords = list(zip(lat_data, lon_data))
    if len(coords) != len(set(coords)):
        raise ValueError(
            "Duplicate coordinates detected in input points. "
            "griddata_knn does not support duplicate coordinates, "
            "consistent with scipy.interpolate.griddata."
        )

    # Handle single point query
    lon_query, lat_query = xi
    is_single = np.isscalar(lon_query)
    if is_single:
        lon_query = np.array([lon_query])
        lat_query = np.array([lat_query])
    else:
        lon_query = np.asarray(lon_query)
        lat_query = np.asarray(lat_query)

    # Perform plane fitting for each query point
    result = np.full(len(lon_query), fill_value)

    for i in range(len(lon_query)):

        # Brute force: compute distances to all reference points
        distances = np.array(
            [
                haversine(lat_query[i], lon_query[i], lat_data[j], lon_data[j])
                for j in range(len(lat_data))
            ]
        )

        # Get k nearest neighbors
        idx = np.argsort(distances)[:k]
        dist = distances[idx]

        # Get the k nearest neighbor coordinates and values
        lon_neighbors = lon_data[idx]
        lat_neighbors = lat_data[idx]
        z_neighbors = values[idx]

        # Query point
        lon_q = lon_query[i]
        lat_q = lat_query[i]

        # Skip if data are missing
        if (
            np.any(z_neighbors == MISSING_FLOAT)
            or np.any(lon_neighbors == MISSING_FLOAT)
            or np.any(lat_neighbors == MISSING_FLOAT)
            or lon_q == MISSING_FLOAT
            or lat_q == MISSING_FLOAT
            or lat_q == MISSING_FLOAT
        ):
            continue

        # If query point coincides with a data point, use that value
        min_dist = np.min(dist)
        if min_dist < 1e-10:
            result[i] = z_neighbors[np.argmin(dist)]
            continue

        # Compute inverse distance weights using haversine distances
        weights = 1.0 / (dist**power + 1e-10)
        weights /= weights.sum()

        # For k >= 3, try plane fitting
        if k >= 3:
            # Center the local plane at the query point and use
            # equirectangular approximation for local distances
            # x_local = (lon - lon_q) * cos(lat_q) * R
            # y_local = (lat - lat_q) * R
            # where R is Earth radius in km
            R = 6371.0087714
            lat_q_rad = np.radians(lat_q)
            cos_lat_q = np.cos(lat_q_rad)

            # Compute local coordinates centered at query point
            dx = np.radians(lon_neighbors - lon_q) * cos_lat_q * R
            dy = np.radians(lat_neighbors - lat_q) * R

            # Fit plane z = a*dx + b*dy + c using weighted least squares
            # Build design matrix A = [dx, dy, 1]
            A = np.column_stack([dx, dy, np.ones(k)])

            # Create weight matrix (diagonal)
            W = np.diag(weights)

            # Solve using LDLT decomposition
            try:
                # (A^T W A) * coeffs = A^T W * z
                AtW = A.T @ W
                AtWA = AtW @ A
                AtWz = AtW @ z_neighbors

                # Use LDLT decomposition (works for symmetric matrices - e.g.
                # when points are colocated for k=3)
                L, D, perm = ldl(AtWA, lower=True)
                # Solve L * D * L^T * coeffs = AtWz
                # Forward: L * y = AtWz[perm]
                y = np.linalg.solve(L, AtWz[perm])
                # Scale: D * z = y
                z = y / np.diag(D)
                # Backward: L^T * coeffs_perm = z
                coeffs_perm = np.linalg.solve(L.T, z)
                # Reverse permutation
                coeffs = np.empty_like(coeffs_perm)
                coeffs[perm] = coeffs_perm

                a, b, c = coeffs
                print(f"Query {i}, k={k}: LDLT succeeded, coeffs = {coeffs}")

                # Check errors to ensure plane fit is reasonable
                # Compute predicted values at neighbor locations
                z_predicted = A @ coeffs
                errors = z_neighbors - z_predicted  # residual

                # Compute weighted RMS error
                weighted_rms_error = np.sqrt(np.sum(weights * errors**2))

                # Use weighted RMS of neighbor values as scale
                weighted_rms_neighbors = np.sqrt(np.sum(weights * z_neighbors**2))

                relative_error = weighted_rms_error / (weighted_rms_neighbors + 1e-10)
                print(
                    f"  relative_error = {relative_error:.6e}, threshold = {relative_error_threshold}"
                )

                # If relative error exceeds threshold, fall back to IDW
                if relative_error > relative_error_threshold:
                    print("  -> Falling back to IDW (relative error too high)")
                    result[i] = np.sum(weights * z_neighbors)
                    continue
                else:
                    print("  -> Using plane fit result")

                # Query point is at origin of local coordinate system
                # so intercept is result (dx_q = 0, dy_q = 0)
                result[i] = c

            except (np.linalg.LinAlgError, ValueError) as e:
                # LDLT failed, fall back to IDW
                print(f"Query {i}, k={k}: LDLT failed: {e}")
                print("  -> Falling back to IDW (LDLT error)")
                result[i] = np.sum(weights * z_neighbors)
        else:
            # For k < 3, use inverse distance weighting
            result[i] = np.sum(weights * z_neighbors)

    return result[0] if is_single else result


def local_plane_fit_test_files_gen(
    name,
    t,
    t_query,
    t_ref,
    first_nearest_ref_station_id,
    second_nearest_ref_station_id,
    third_nearest_ref_station_id,
    fourth_nearest_ref_station_id,
    first_nearest_distance,
    second_nearest_distance,
    third_nearest_distance,
    fourth_nearest_distance,
    ob_type_num,
    station_id,
    lat,
    lon,
    datetime,
    match_variable,
    match_timestamps_idxs,
    query_station_id,
    ref_station_id,
    interpolated_temp_4nn,
    interpolated_temp_3nn,
    interpolated_temp_4nn_idw2,
    interpolated_temp_2nn_idw2,
    interpolated_temp_1nn_idw1,
):
    """Write NetCDF4 test file for local plane fit algorithm."""
    assert (
        len(t)
        == len(t_query)
        == len(t_ref)
        == len(first_nearest_ref_station_id)
        == len(second_nearest_ref_station_id)
        == len(third_nearest_ref_station_id)
        == len(fourth_nearest_ref_station_id)
        == len(first_nearest_distance)
        == len(second_nearest_distance)
        == len(third_nearest_distance)
        == len(fourth_nearest_distance)
        == len(ob_type_num)
        == len(station_id)
        == len(lat)
        == len(lon)
        == len(datetime)
        == len(query_station_id)
        == len(ref_station_id)
        == len(interpolated_temp_4nn)
        == len(interpolated_temp_3nn)
        == len(interpolated_temp_4nn_idw2)
        == len(interpolated_temp_2nn_idw2)
        == len(interpolated_temp_1nn_idw1)
    )
    nlocs = len(station_id)
    err = np.ones(nlocs)  # arbitrary

    # Write observation file
    file = nc.Dataset(name, "w")
    file._ioda_layout = "ObsGroup"
    file._ioda_layout_version = 0
    file.date_time = "20200101T0000Z"
    file.createDimension("Location", nlocs)

    loc_var = file.createVariable(
        "Location", "f4", ("Location",), fill_value=MISSING_FLOAT
    )
    loc_var[:] = 0
    datetime_var = file.createVariable(
        "MetaData/dateTime", "i8", ("Location",), fill_value=MISSING_DATETIME
    )
    datetime_var.units = "seconds since 1970-01-01T00:00:00Z"
    datetime_var[:] = datetime

    # Metadata variables
    lat_var = file.createVariable(
        "MetaData/latitude", "f4", ("Location",), fill_value=MISSING_FLOAT
    )
    lat_var[:] = lat
    lon_var = file.createVariable(
        "MetaData/longitude", "f4", ("Location",), fill_value=MISSING_FLOAT
    )
    lon_var[:] = lon
    station_id_var = file.createVariable(
        "MetaData/stationIdentification", str, ("Location",), fill_value=MISSING_STR
    )
    station_id_var[:] = station_id
    query_station_id_var = file.createVariable(
        "MetaData/queryStationIdentification",
        str,
        ("Location",),
        fill_value=MISSING_STR,
    )
    query_station_id_var[:] = query_station_id
    ref_station_id_var = file.createVariable(
        "MetaData/referenceStationIdentification",
        str,
        ("Location",),
        fill_value=MISSING_STR,
    )
    ref_station_id_var[:] = ref_station_id
    ob_type_num_var = file.createVariable(
        "MetaData/observationTypeNum",
        "i4",
        ("Location",),
        fill_value=MISSING_INT32,
    )
    ob_type_num_var[:] = ob_type_num

    # ObsValue variables
    t_var = file.createVariable(
        "ObsValue/airTemperature",
        "f4",
        ("Location",),
        fill_value=MISSING_FLOAT,
    )
    t_var[:] = t
    t_err_var = file.createVariable(
        "ObsError/airTemperature",
        "f4",
        ("Location",),
        fill_value=MISSING_FLOAT,
    )
    t_err_var[:] = err
    t_query_var = file.createVariable(
        "ObsValue/airTemperatureObservation",
        "f4",
        ("Location",),
        fill_value=MISSING_FLOAT,
    )
    t_query_var[:] = t_query
    t_ref_var = file.createVariable(
        "ObsValue/airTemperatureReference",
        "f4",
        ("Location",),
        fill_value=MISSING_FLOAT,
    )
    t_ref_var[:] = t_ref

    # DerivedMetaData: Nearest neighbor station IDs
    first_nearest_ref_station_id_var = file.createVariable(
        "DerivedMetaData/firstNearestReferenceStationID",
        str,
        ("Location",),
        fill_value=MISSING_STR,
    )
    first_nearest_ref_station_id_var[:] = first_nearest_ref_station_id
    second_nearest_ref_station_id_var = file.createVariable(
        "DerivedMetaData/secondNearestReferenceStationID",
        str,
        ("Location",),
        fill_value=MISSING_STR,
    )
    second_nearest_ref_station_id_var[:] = second_nearest_ref_station_id
    third_nearest_ref_station_id_var = file.createVariable(
        "DerivedMetaData/thirdNearestReferenceStationID",
        str,
        ("Location",),
        fill_value=MISSING_STR,
    )
    third_nearest_ref_station_id_var[:] = third_nearest_ref_station_id
    fourth_nearest_ref_station_id_var = file.createVariable(
        "DerivedMetaData/fourthNearestReferenceStationID",
        str,
        ("Location",),
        fill_value=MISSING_STR,
    )
    fourth_nearest_ref_station_id_var[:] = fourth_nearest_ref_station_id

    # MetaData: Distance variables
    first_nearest_distance_var = file.createVariable(
        "MetaData/firstNearestSynopTempDistance",
        "f4",
        ("Location",),
        fill_value=MISSING_FLOAT,
    )
    first_nearest_distance_var[:] = first_nearest_distance
    second_nearest_distance_var = file.createVariable(
        "MetaData/secondNearestSynopTempDistance",
        "f4",
        ("Location",),
        fill_value=MISSING_FLOAT,
    )
    second_nearest_distance_var[:] = second_nearest_distance
    third_nearest_distance_var = file.createVariable(
        "MetaData/thirdNearestSynopTempDistance",
        "f4",
        ("Location",),
        fill_value=MISSING_FLOAT,
    )
    third_nearest_distance_var[:] = third_nearest_distance
    fourth_nearest_distance_var = file.createVariable(
        "MetaData/fourthNearestSynopTempDistance",
        "f4",
        ("Location",),
        fill_value=MISSING_FLOAT,
    )
    fourth_nearest_distance_var[:] = fourth_nearest_distance

    # TestReference: Expected outputs
    interpolated_temp_4nn_var = file.createVariable(
        "TestReference/interpolatedTemperatureFromFourNearestNeighbors",
        "f4",
        ("Location",),
        fill_value=MISSING_FLOAT,
    )
    interpolated_temp_4nn_var[:] = interpolated_temp_4nn
    interpolated_temp_3nn_var = file.createVariable(
        "TestReference/interpolatedTemperatureFromThreeNearestNeighbors",
        "f4",
        ("Location",),
        fill_value=MISSING_FLOAT,
    )
    interpolated_temp_3nn_var[:] = interpolated_temp_3nn
    interpolated_temp_4nn_idw2_var = file.createVariable(
        "TestReference/interpolatedTemperatureFromFourNearestNeighborsIDW2",
        "f4",
        ("Location",),
        fill_value=MISSING_FLOAT,
    )
    interpolated_temp_4nn_idw2_var[:] = interpolated_temp_4nn_idw2
    interpolated_temp_2nn_idw2_var = file.createVariable(
        "TestReference/interpolatedTemperatureFromTwoNearestNeighborsIDW2",
        "f4",
        ("Location",),
        fill_value=MISSING_FLOAT,
    )
    interpolated_temp_2nn_idw2_var[:] = interpolated_temp_2nn_idw2
    interpolated_temp_1nn_idw1_var = file.createVariable(
        "TestReference/interpolatedTemperatureFromOneNearestNeighborIDW1",
        "f4",
        ("Location",),
        fill_value=MISSING_FLOAT,
    )
    interpolated_temp_1nn_idw1_var[:] = interpolated_temp_1nn_idw1
    missing_floats_var = file.createVariable(
        "TestReference/missingFloats",
        "f4",
        ("Location",),
        fill_value=MISSING_FLOAT,
    )
    missing_floats_var[:] = np.asarray([MISSING_FLOAT] * nlocs)
    match_datetime_var = file.createVariable(
        "MetaData/matchDateTime",
        "i8",
        ("Location",),
        fill_value=MISSING_DATETIME,
    )
    match_datetime_var.units = "seconds since 1970-01-01T00:00:00Z"
    match_datetime_var[:] = match_variable
    match_timestamps_idxs_var = file.createVariable(
        "MetaData/matchTimestampsIdxs",
        "i4",
        ("Location",),
        fill_value=MISSING_INT32,
    )
    match_timestamps_idxs_var[:] = match_timestamps_idxs

    file.close()


if __name__ == "__main__":
    # Setup: 5 query stations (type 1) and 6 reference stations (type 2)
    # Query station 5 is colocated with reference station 8 (to test coincident points)
    # One reference station has a missing temperature value

    # Station locations (lat, lon)
    # Query stations (first 8)
    station_1_loc = (40.0, -1.0)
    station_2_loc = (41.0, 0.0)
    station_3_loc = (42.0, 1.0)
    station_4_loc = (43.0, 2.0)
    station_5_loc = (40.15, -0.85)  # Will be colocated with station 11
    station_6_loc = (40.3, -1.3)  # New query off diagonal
    station_7_loc = (50.0, 10.0)  # Far away query (missing ID test)
    station_8_loc = (30.0, -11.0)  # Far away query (missing dist test)

    # Reference stations (9-17)
    station_9_loc = (40.1, -0.9)  # reference
    station_10_loc = (41.1, 0.1)  # reference
    station_11_loc = (40.15, -0.85)  # reference, colocated with station 5
    station_12_loc = (42.1, 1.1)  # reference (will have missing temp)
    station_13_loc = (40.5, -0.5)  # reference (on diagonal)
    station_14_loc = (41.5, 0.5)  # reference
    station_15_loc = (40.6, -1.4)  # reference (off diagonal)
    station_16_loc = (50.1, 10.1)  # reference near station 7
    station_17_loc = (30.1, -11.1)  # reference near station 8

    # Station IDs
    station_ids = [f"station{i}" for i in range(1, 18)]

    # Temperature values at reference stations
    # Explicitly chosen to create realistic but non-planar temperature field
    t_station_9 = [282.0, 282.1]  # at (40.1, -0.9)
    t_station_10 = [296.5, 296.6]  # at (41.1, 0.1)
    t_station_11 = [283.2, 283.3]  # at (40.15, -0.85)
    t_station_12 = [MISSING_FLOAT, MISSING_FLOAT]  # Missing value for testing
    t_station_13 = [288.0, 288.1]  # at (40.5, -0.5)
    t_station_14 = [292.0, 292.1]  # at (41.5, 0.5)
    t_station_15 = [285.5, 285.6]  # at (40.6, -1.4)
    t_station_16 = [275.0, 275.1]  # at (50.1, 10.1)
    t_station_17 = [276.0, 276.1]  # at (50.2, 9.9)

    # Create observation arrays (2 obs per station = 34 total)
    n_obs_query = 16
    n_obs_ref = 18
    num_obs = 34
    ob_type_num = np.array([1] * n_obs_query + [2] * n_obs_ref, dtype=np.int64)
    assert len(ob_type_num) == num_obs
    station_id = np.array([station_ids[i // 2] for i in range(num_obs)], dtype=object)
    lat = np.array(
        [
            station_1_loc[0],
            station_1_loc[0],
            station_2_loc[0],
            station_2_loc[0],
            station_3_loc[0],
            station_3_loc[0],
            station_4_loc[0],
            station_4_loc[0],
            station_5_loc[0],
            station_5_loc[0],
            station_6_loc[0],
            station_6_loc[0],
            station_7_loc[0],
            station_7_loc[0],
            station_8_loc[0],
            station_8_loc[0],
            station_9_loc[0],
            station_9_loc[0],
            station_10_loc[0],
            station_10_loc[0],
            station_11_loc[0],
            station_11_loc[0],
            station_12_loc[0],
            station_12_loc[0],
            station_13_loc[0],
            station_13_loc[0],
            station_14_loc[0],
            station_14_loc[0],
            station_15_loc[0],
            station_15_loc[0],
            station_16_loc[0],
            station_16_loc[0],
            station_17_loc[0],
            station_17_loc[0],
        ]
    )
    lon = np.array(
        [
            station_1_loc[1],
            station_1_loc[1],
            station_2_loc[1],
            station_2_loc[1],
            station_3_loc[1],
            station_3_loc[1],
            station_4_loc[1],
            station_4_loc[1],
            station_5_loc[1],
            station_5_loc[1],
            station_6_loc[1],
            station_6_loc[1],
            station_7_loc[1],
            station_7_loc[1],
            station_8_loc[1],
            station_8_loc[1],
            station_9_loc[1],
            station_9_loc[1],
            station_10_loc[1],
            station_10_loc[1],
            station_11_loc[1],
            station_11_loc[1],
            station_12_loc[1],
            station_12_loc[1],
            station_13_loc[1],
            station_13_loc[1],
            station_14_loc[1],
            station_14_loc[1],
            station_15_loc[1],
            station_15_loc[1],
            station_16_loc[1],
            station_16_loc[1],
            station_17_loc[1],
            station_17_loc[1],
        ]
    )

    # Datetime (seconds since 1970-01-01T00:00:00Z) - arbitrary
    datetime = np.array([1523667600] * num_obs)

    # Temperature values
    t = np.array(
        [
            280.0,
            290.0,
            300.0,
            310.0,
            320.0,
            330.0,
            340.0,
            350.0,
            280.0,
            290.0,
            300.0,
            310.0,
            320.0,
            330.0,
            340.0,
            350.0,  # query station obs (arbitrary)
            t_station_9[0],
            t_station_9[1],
            t_station_10[0],
            t_station_10[1],
            t_station_11[0],
            t_station_11[1],
            t_station_12[0],
            t_station_12[1],
            t_station_13[0],
            t_station_13[1],
            t_station_14[0],
            t_station_14[1],
            t_station_15[0],
            t_station_15[1],
            t_station_16[0],
            t_station_16[1],
            t_station_17[0],
            t_station_17[1],  # reference stations
        ]
    )
    assert len(t) == num_obs

    # Need a variable to indicate where query and reference obs are treated as
    # "matched" (usually having the same timestamp)
    start_time = int(dt(2018, 4, 14, 1, 0, 0, tzinfo=tz.utc).timestamp())
    match_timestamps = [start_time, start_time + 3600, start_time + 7200]
    match_timestamps_idxs = np.array(
        [
            0,
            1,
            0,
            1,
            0,
            1,
            0,
            1,
            0,
            1,
            0,
            1,
            0,
            1,
            0,
            1,  # query station obs (arbitrary)
            0,
            2,  # This station's second observation is not a match!
            0,
            1,
            0,
            1,
            0,
            1,
            0,
            1,
            0,
            1,
            0,
            1,
            0,
            1,
            0,
            1,  # reference stations
        ]
    )
    match_variable = np.array([match_timestamps[i] for i in match_timestamps_idxs])

    # Split into query and reference
    t_query = np.concatenate((t[0:n_obs_query], np.array([MISSING_FLOAT] * n_obs_ref)))
    t_ref = np.concatenate((np.array([MISSING_FLOAT] * n_obs_query), t[n_obs_query:]))
    query_station_id = np.concatenate(
        (station_id[0:n_obs_query], np.array([MISSING_STR] * n_obs_ref, dtype=object))
    )
    ref_station_id = np.concatenate(
        (np.array([MISSING_STR] * n_obs_query, dtype=object), station_id[n_obs_query:])
    )
    assert len(t_query) == num_obs
    assert len(t_ref) == num_obs
    assert len(query_station_id) == num_obs
    assert len(ref_station_id) == num_obs

    # Shortened arrays for easier indexing:
    query_lats = lat[0:n_obs_query]
    query_lons = lon[0:n_obs_query]
    query_temps = t_query[0:n_obs_query]
    query_match = match_variable[0:n_obs_query]  # Timestamp group labels for query obs
    ref_lats = lat[n_obs_query:]
    ref_lons = lon[n_obs_query:]
    ref_temps = t_ref[n_obs_query:]
    ref_ids = station_id[n_obs_query:]
    ref_match = match_variable[n_obs_query:]  # Timestamp group labels for reference obs

    # Compute nearest neighbors for each query station
    unique_ordered_ref_ids = ref_ids[::2]  # Unique reference station IDs (every 2 obs)

    # Arrays to store nearest neighbor info
    first_nearest_ref_station_id = np.array([MISSING_STR] * num_obs, dtype=object)
    second_nearest_ref_station_id = np.array([MISSING_STR] * num_obs, dtype=object)
    third_nearest_ref_station_id = np.array([MISSING_STR] * num_obs, dtype=object)
    fourth_nearest_ref_station_id = np.array([MISSING_STR] * num_obs, dtype=object)
    first_nearest_distance = np.array([MISSING_FLOAT] * num_obs)
    second_nearest_distance = np.array([MISSING_FLOAT] * num_obs)
    third_nearest_distance = np.array([MISSING_FLOAT] * num_obs)
    fourth_nearest_distance = np.array([MISSING_FLOAT] * num_obs)

    # Compute nearest neighbors for query stations (first in list) - jump by 2
    # since each station has 2 obs with same location
    for i in range(0, n_obs_query, 2):
        distances = np.array(
            [
                haversine(query_lats[i], query_lons[i], ref_lats[j], ref_lons[j])
                for j in range(0, len(ref_lats), 2)
            ]
        )
        sorted_indices = np.argsort(distances)

        first_nearest_ref_station_id[i] = unique_ordered_ref_ids[sorted_indices[0]]
        first_nearest_ref_station_id[i + 1] = unique_ordered_ref_ids[sorted_indices[0]]
        second_nearest_ref_station_id[i] = unique_ordered_ref_ids[sorted_indices[1]]
        second_nearest_ref_station_id[i + 1] = unique_ordered_ref_ids[sorted_indices[1]]
        third_nearest_ref_station_id[i] = unique_ordered_ref_ids[sorted_indices[2]]
        third_nearest_ref_station_id[i + 1] = unique_ordered_ref_ids[sorted_indices[2]]
        fourth_nearest_ref_station_id[i] = unique_ordered_ref_ids[sorted_indices[3]]
        fourth_nearest_ref_station_id[i + 1] = unique_ordered_ref_ids[sorted_indices[3]]

        first_nearest_distance[i] = distances[sorted_indices[0]]
        first_nearest_distance[i + 1] = distances[sorted_indices[0]]
        second_nearest_distance[i] = distances[sorted_indices[1]]
        second_nearest_distance[i + 1] = distances[sorted_indices[1]]
        third_nearest_distance[i] = distances[sorted_indices[2]]
        third_nearest_distance[i + 1] = distances[sorted_indices[2]]
        fourth_nearest_distance[i] = distances[sorted_indices[3]]
        fourth_nearest_distance[i + 1] = distances[sorted_indices[3]]

    # Manually set missing values to test edge cases
    # Station 7 (indices 12 and 13): first nearest neighbor has missing station ID
    first_nearest_ref_station_id[12] = MISSING_STR
    first_nearest_ref_station_id[13] = MISSING_STR
    # Station 8 (indices 14 and 15): first nearest neighbor has missing distance
    first_nearest_distance[14] = MISSING_FLOAT
    first_nearest_distance[15] = MISSING_FLOAT

    # Compute reference values for the 5 test scenarios
    # Only compute for valid query stations, rest are MISSING_FLOAT
    interpolated_temp_4nn = np.array([MISSING_FLOAT] * num_obs)
    interpolated_temp_3nn = np.array([MISSING_FLOAT] * num_obs)
    interpolated_temp_4nn_idw2 = np.array([MISSING_FLOAT] * num_obs)
    interpolated_temp_2nn_idw2 = np.array([MISSING_FLOAT] * num_obs)
    interpolated_temp_1nn_idw1 = np.array([MISSING_FLOAT] * num_obs)

    # Compute interpolated values for each query station using the griddata_knn
    # function but ignore the results for stations 7 and 8 since they have
    # missing neighbor metadata.
    for i in range(n_obs_query - 4):
        query_lon = lon[i]
        query_lat = lat[i]

        # Each physical station appears twice in the reference arrays (once per
        # match group). Deduplicate by keeping only the first occurrence of each
        # unique (lat, lon) pair, using the obs whose match group label matches
        # the query obs where available, and MISSING_FLOAT otherwise.
        # This ensures griddata_knn sees one value per physical station.
        seen_locs = {}
        for j in range(len(ref_lats)):
            loc = (ref_lats[j], ref_lons[j])
            temp = ref_temps[j] if ref_match[j] == match_variable[i] else MISSING_FLOAT
            if loc not in seen_locs:
                seen_locs[loc] = temp
            elif seen_locs[loc] == MISSING_FLOAT:
                # Prefer a matched value over a previously stored missing one
                seen_locs[loc] = temp

        dedup_locs = np.array(list(seen_locs.keys()))
        dedup_lats = dedup_locs[:, 0]
        dedup_lons = dedup_locs[:, 1]
        dedup_temps = np.array(list(seen_locs.values()))

        # Scenario 1: 4-neighbor plane fit (threshold=0.25, power=2.0)
        interpolated_temp_4nn[i] = griddata_knn(
            (dedup_lons, dedup_lats),
            dedup_temps,
            (query_lon, query_lat),
            k=4,
            fill_value=MISSING_FLOAT,
            power=2.0,
            relative_error_threshold=0.25,
        )

        # Scenario 2: 3-neighbor plane fit
        interpolated_temp_3nn[i] = griddata_knn(
            (dedup_lons, dedup_lats),
            dedup_temps,
            (query_lon, query_lat),
            k=3,
            fill_value=MISSING_FLOAT,
            power=2.0,
            relative_error_threshold=0.25,
        )

        # Scenario 3: 4-neighbor forced IDW
        interpolated_temp_4nn_idw2[i] = griddata_knn(
            (dedup_lons, dedup_lats),
            dedup_temps,
            (query_lon, query_lat),
            k=4,
            fill_value=MISSING_FLOAT,
            power=2.0,
            relative_error_threshold=0.0,
        )

        # Scenario 4: 2-neighbor IDW
        interpolated_temp_2nn_idw2[i] = griddata_knn(
            (dedup_lons, dedup_lats),
            dedup_temps,
            (query_lon, query_lat),
            k=2,
            fill_value=MISSING_FLOAT,
            power=2.0,
            relative_error_threshold=0.25,
        )

        # Scenario 5: 1-neighbor IDW
        interpolated_temp_1nn_idw1[i] = griddata_knn(
            (dedup_lons, dedup_lats),
            dedup_temps,
            (query_lon, query_lat),
            k=1,
            fill_value=MISSING_FLOAT,
            power=1.0,
            relative_error_threshold=0.25,
        )

    # Permute observations for ordering independence
    permutation = np.random.permutation(num_obs)
    # permutation = np.arange(num_obs)  # No permutation for easier debugging

    # Write test file
    local_plane_fit_test_files_gen(
        "use_nearest_neighbors_obs_interpolate_with_local_plane_fit.nc4",
        t[permutation],
        t_query[permutation],
        t_ref[permutation],
        first_nearest_ref_station_id[permutation],
        second_nearest_ref_station_id[permutation],
        third_nearest_ref_station_id[permutation],
        fourth_nearest_ref_station_id[permutation],
        first_nearest_distance[permutation],
        second_nearest_distance[permutation],
        third_nearest_distance[permutation],
        fourth_nearest_distance[permutation],
        ob_type_num[permutation],
        station_id[permutation],
        lat[permutation],
        lon[permutation],
        datetime[permutation],
        match_variable[permutation],
        match_timestamps_idxs[permutation],
        query_station_id[permutation],
        ref_station_id[permutation],
        interpolated_temp_4nn[permutation],
        interpolated_temp_3nn[permutation],
        interpolated_temp_4nn_idw2[permutation],
        interpolated_temp_2nn_idw2[permutation],
        interpolated_temp_1nn_idw1[permutation],
    )

    print(
        "Test file generated: use_nearest_neighbors_obs_interpolate_with_local_plane_fit.nc4"
    )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant