diff --git a/glidertest/plots.py b/glidertest/plots.py index cf90ac3..3626f28 100644 --- a/glidertest/plots.py +++ b/glidertest/plots.py @@ -6,6 +6,8 @@ import xarray as xr from matplotlib.dates import DateFormatter from scipy import stats +from scipy.interpolate import interp1d +import cmocean.cm as cmo import matplotlib.colors as mcolors import gsw import cartopy.crs as ccrs @@ -1570,88 +1572,94 @@ def plot_max_depth_per_profile(ds: xr.Dataset, bins= 20, ax = None, **kw: dict) plt.show() return fig, ax - -def plot_profile(ds: xr.Dataset, profile_num: int, vars: list = ['TEMP','PSAL','DENSITY'], use_bins: bool = False, binning: float = 2) -> tuple: +def plot_profile(ds: xr.Dataset, profile_num: int = None, vars: list = ['TEMP','PSAL','DENSITY'], use_bins: bool = False, binning: float = 2) -> tuple: """ - Plots binned temperature, salinity, and density against depth on a single plot with three x-axes. + Plots the specified variables against depth for a given profile number from an OG1-format dataset. + If no profile number is provided, it plots the mean profile for the specified variables. Parameters ---------- ds: xarray.Dataset - Xarray dataset in OG1 format with at least PROFILE_NUMBER, DEPTH, TEMPERATURE, SALINITY, and DENSITY. - profile_num: int - The profile number to plot. + Xarray dataset in OG1 format with at least PROFILE_NUMBER, DEPTH and the specified variables. + profile_num: int or None + Profile number to plot. If None, plots mean profile. vars: list - The variables to plot. Default is ['TEMP','PSAL','DENSITY']. - binning: int - The depth resolution for binning. + Variables to plot. Default is ['TEMP','PSAL','DENSITY']. use_bins: bool - If True, use binned data instead of raw data. + Whether to bin the data vertically. + binning: float + Bin size in meters if use_bins is True. Returns ------- fig: matplotlib.figure.Figure The figure object containing the plot. - ax1: matplotlib.axes.Axes - The axis object containing the primary plot. + ax1: list of matplotlib.axes.Axes + The axis objects containing the subplots. Notes ----- Original Author: Till Moritz """ - # Remove empty strings from vars - vars = [v for v in vars if v] - # If vars is empty, show an empty plot + vars = [v for v in vars if v] if not vars: fig, ax1 = plt.subplots(figsize=(12, 9)) - ax1.set_title(f'Profile {profile_num} (No Variables Selected)') + ax1.set_title('No Variables Selected') ax1.set_ylabel('Depth (m)') ax1.invert_yaxis() ax1.grid(True) return fig, ax1 - + if len(vars) > 3: - raise ValueError("Only three variables can be plotted at once, chose less variables") - + raise ValueError("Only three variables can be plotted at once, chose fewer variables") + with plt.style.context(glidertest_style_file): fig, ax1 = plt.subplots(figsize=(12, 9)) + axs = [ax1, ax1.twiny(), ax1.twiny()] + colors = ['red', 'blue', 'grey'] + s = 10 + binning - profile = ds.where(ds.PROFILE_NUMBER == profile_num, drop=True) - if use_bins: - profile = utilities.bin_profile(profile, vars, binning) + if profile_num is not None: + profile = ds.where(ds.PROFILE_NUMBER == profile_num, drop=True) + if use_bins: + profile = utilities.bin_profile(profile, vars, binning) + title_str = f'Profile {profile_num}' + else: + # Mean profile logic + data = {} + for var in vars: + df = tools.mean_profile(ds, var=var, v_res=binning) + data[var] = df['mean'].values + depth = df['depth'].values + profile = xr.Dataset({var: (['depth'], data[var]) for var in vars}) + profile['DEPTH'] = (['depth'], depth) + title_str = 'Mean Profile' - # Plot binned data mission = ds.id.split('_')[1][0:8] glider = ds.id.split('_')[0] - s=10+binning - - axs = [ax1, ax1.twiny(), ax1.twiny()] - colors = ['red', 'blue', 'grey'] for i, var in enumerate(vars): ax = axs[i] - ### add the long_name to the label, if it exists - long_name = getattr(profile[var], 'long_name', '') - ax.plot(profile[var], profile['DEPTH'], color=colors[i], label=f'{var} - {long_name}', ls='-') - ax.scatter(profile[var], profile['DEPTH'], color=colors[i], marker='o',s=s) - unit = getattr(profile[var], 'units', '') - ax.set_xlabel(f'{var} [{unit}]', color=colors[i]) - ax.tick_params(axis='x', colors=colors[i], bottom=True, top=False, labelbottom=True, labeltop=False) + unit = utilities.plotting_units(ds, var) + label = utilities.plotting_labels(var) + ax.plot(profile[var], profile['DEPTH'], color=colors[i], label=label) + ax.scatter(profile[var], profile['DEPTH'], color=colors[i], marker='o', s=s) + ax.set_xlabel(f'{label} [{unit}]', color=colors[i]) + ax.tick_params(axis='x', colors=colors[i]) if i > 0: ax.xaxis.set_ticks_position('bottom') ax.spines['top'].set_visible(False) ax.spines['bottom'].set_position(('axes', -0.09*i)) ax.xaxis.set_label_coords(0.5, -0.05-0.105*i) - fig.legend(loc='upper right',fontsize=10) - # Set pressure as y-axis (Increasing Downward) ax1.grid(True) ax1.set_ylabel('Depth (m)') - ax1.invert_yaxis() # Pressure increases downward - ax1.set_title(f'Profile {profile_num} ({glider} on mission: {mission})') + ax1.invert_yaxis() + ax1.set_title(f'{title_str} ({glider} on mission: {mission})') return fig, ax1 + def plot_CR(ds: xr.Dataset, profile_num: int, use_bins: bool = False, binning: float = 2): """ Plots the convective resistance (CR) of a profile against depth based on calculate_CR_for_all_depth function. @@ -1705,3 +1713,140 @@ def plot_CR(ds: xr.Dataset, profile_num: int, use_bins: bool = False, binning: f plt.show() return fig, ax + +def plot_section(ds, vars=['PSAL', 'TEMP', 'DENSITY'], v_res=1, start=None, end=None, mld_df = None): + """ + Plots a section of the dataset with PROFILE_NUMBER on the x-axis, DEPTH on the y-axis, + and mean TIME per profile as secondary x-axis (automatically spaced). + + Parameters + ---------- + ds : xarray.Dataset + Dataset with at least PROFILE_NUMBER, TIME, DEPTH, and target variables. + vars : str or list of str + Variables to visualize. If a single variable is provided, it will be converted to a list. + v_res : float + Vertical resolution (DEPTH binning). + start : int or None + Start PROFILE_NUMBER (inclusive). + end : int or None + End PROFILE_NUMBER (inclusive). + mld_df : pd.DataFrame + MLD as a pandas Dataframe, which is the result of the MLD calculation compute_mld(). The dataframe should contain the profile number, MLD and the mean time profile. + + Returns + ------- + fig : matplotlib.figure.Figure + ax : list of matplotlib.axes.Axes + + Notes + ----- + Original Author: Till Moritz + """ + if not isinstance(vars, list): + vars = [vars] + utilities._check_necessary_variables(ds, vars + ['PROFILE_NUMBER', 'TIME', 'DEPTH']) + + if start is not None or end is not None: + if start is None: + start = ds.PROFILE_NUMBER.min().values + if end is None: + end = ds.PROFILE_NUMBER.max().values + mask = (ds.PROFILE_NUMBER >= start) & (ds.PROFILE_NUMBER <= end) + ds = ds.sel(N_MEASUREMENTS=mask) + if mld_df is not None: + mld_df = mld_df[(mld_df['PROFILE_NUMBER'] >= start) & (mld_df['PROFILE_NUMBER'] <= end)] + + num_vars = len(vars) + fig, ax = plt.subplots(num_vars, 1, figsize=(20, 5 * num_vars), sharex=True, gridspec_kw={'height_ratios': [8] * num_vars}) + if num_vars == 1: + ax = [ax] + + x_plot = ds['PROFILE_NUMBER'].values + + has_density_plot = any(utilities.plotting_colormap(var) == cmo.dense for var in vars) + + # Compute mean time per profile + df_time = ds[['TIME', 'PROFILE_NUMBER']].to_dataframe().dropna() + mean_times = df_time.groupby('PROFILE_NUMBER')['TIME'].mean() + + for i, var in enumerate(vars): + if var not in ds: + raise ValueError(f'Variable "{var}" not found in dataset.') + + values = ds[var].values + depth = ds['DEPTH'].values + + p = 1 + z = v_res + + varG, profG, depthG = utilities.construct_2dgrid(x_plot, depth, values, p, z, x_bin_center=False) + + cmap = utilities.plotting_colormap(var) + if cmap == cmo.delta and np.any(values < 0) and np.any(values > 0): + norm = mcolors.TwoSlopeNorm( + vmin=np.nanpercentile(values, 0.5), + vcenter=0, + vmax=np.nanpercentile(values, 99.5) + ) + im = ax[i].pcolormesh(profG, depthG, varG, cmap=cmap, norm=norm) + else: + im = ax[i].pcolormesh(profG, depthG, varG, cmap=cmap, + vmin=np.nanpercentile(values, 0.5), + vmax=np.nanpercentile(values, 99.5)) + if mld_df is not None: + if (has_density_plot and cmap == cmo.dense) or (not has_density_plot and i == 0): + ax[i].plot(mld_df['PROFILE_NUMBER'], mld_df['MLD'], color='black', marker='o', linewidth=1, + label='Mixed Layer Depth', markersize=2) + ax[i].legend(loc='upper left', fontsize=8) + + unit = utilities.plotting_units(ds,var) + label = utilities.plotting_labels(var) + + total_profiles = x_plot[-1] - x_plot[0] + ax[i].invert_yaxis() + ax[i].set_ylabel('Depth (m)') + ax[i].grid(True) + ax[i].set_title(f'Section plot of {label}') + ax[i].set_xlim(np.min(x_plot)-total_profiles/50, np.max(x_plot)+total_profiles/50) + + cbar = plt.colorbar(im, ax=ax[i], pad=0.03) + cbar.set_label(f'{label} [{unit}]', labelpad=20, rotation=270) + + # Main x-axis: profile numbers + ax[-1].set_xlabel('Profile Number') + + # Get mean time per profile (datetime) and profile numbers + times = pd.to_datetime(mean_times) + profiles = mean_times.index.values + time_nums = mdates.date2num(times) # matplotlib float format for dates + + # Build interpolators + to_time = interp1d(profiles, time_nums, bounds_error=False, fill_value="extrapolate") + to_profile = interp1d(time_nums, profiles, bounds_error=False, fill_value="extrapolate") + + # Create a transform that maps profile numbers → time for the secondary x-axis + def forward(x): + return to_time(x) + + def inverse(x): + return to_profile(x) + + # Create the secondary axis (top), linked to the bottom profile axis + time_ax = ax[-1].secondary_xaxis("bottom", functions=(forward, inverse)) + time_ax.set_xlabel("Mean Time per Profile") + time_ax.xaxis.set_major_locator(mdates.AutoDateLocator()) + time_delta = time_nums[-1] - time_nums[0] + if time_delta < 5: + time_ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d %H:%M')) + else: + time_ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y-%m-%d")) + + time_ax.spines['bottom'].set_position(('outward', 40)) + time_ax.tick_params(rotation=35) + + plt.tight_layout() + plt.show() + + return fig, ax + diff --git a/glidertest/tools.py b/glidertest/tools.py index 4463fa2..d432efd 100644 --- a/glidertest/tools.py +++ b/glidertest/tools.py @@ -52,6 +52,52 @@ def quant_updown_bias(ds, var='PSAL', v_res=1): df = pd.DataFrame() return df +def mean_profile(ds, var='TEMP', v_res=1): + """ + This function computes the mean vertical profile for a specific variable. + + Parameters + ---------- + ds: xarray.Dataset + Dataset in **OG1 format**, containing at least **PROFILE_NUMBER, DEPTH**, and the selected variable. + Data should **not** be gridded. + var: str, optional, default='PSAL' + Selected variable to average. + v_res: float + Vertical resolution in meters for binning the profile. + + Returns + ------- + df: pandas.DataFrame + DataFrame containing the mean profile and corresponding depth bins. + + Notes + ----- + Original Author: Till Moritz + """ + # Ensure required variables are in the dataset + utilities._check_necessary_variables(ds, ['PROFILE_NUMBER', 'DEPTH', var]) + + p = 1 # Horizontal resolution (not used here) + z = v_res # Vertical resolution + + if var in ds.variables: + # 2D gridding by profile and depth + varG, profG, depthG = utilities.construct_2dgrid( + ds.PROFILE_NUMBER, ds.DEPTH, ds[var], p, z, x_bin_center=False) + + # Compute mean across profiles for each depth level + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=RuntimeWarning) + mean_var = np.nanmean(varG, axis=0) + + df = pd.DataFrame(data={'mean': mean_var, 'depth': depthG[0, :]}) + else: + print(f'{var} is not in the dataset') + df = pd.DataFrame() + + return df + def compute_daynight_avg(ds, sel_var='CHLA', start_time=None, end_time=None, start_prof=None, end_prof=None): """ Computes day and night averages for a selected variable over a specified time period or range of dives. @@ -605,7 +651,7 @@ def add_sigma_1(ds: xr.Dataset, var_sigma_1: str = "SIGMA_1") -> xr.Dataset: return ds def compute_mld(ds: xr.Dataset, variable, method: str = 'threshold', threshold = 0.03, ref_depth = 10, - use_bins: bool = False, binning: float = 10): + use_bins: bool = True, binning: float = 10): """ Computes the mixed layer depth (MLD) for each profile in the dataset. Two methods are available: 1. **Threshold Method**: Computes MLD based on a density threshold (default is 0.03 kg/m³). diff --git a/glidertest/utilities.py b/glidertest/utilities.py index 292f112..e6cbc8b 100644 --- a/glidertest/utilities.py +++ b/glidertest/utilities.py @@ -7,6 +7,8 @@ import gsw from matplotlib.dates import DateFormatter import matplotlib.dates as mdates +import cmocean.cm as cmo + def _check_necessary_variables(ds: xr.Dataset, vars: list): @@ -413,45 +415,102 @@ def calc_DEPTH_Z(ds): } return ds -label_dict={ +label_dict = { "PSAL": { "label": "Practical salinity", - "units": "PSU"}, + "units": "PSU", + "colormap": cmo.haline + }, + "SA": { + "label": "Absolute salinity", + "units": "g kg⁻¹", + "colormap": cmo.haline # Added best-guess + }, "TEMP": { "label": "Temperature", - "units": "°C"}, - "DENSITY":{ + "units": "°C", + "colormap": cmo.thermal + }, + "THETA": { + "label": "Potential temperature", + "units": "°C", + "colormap": cmo.thermal + }, + "DENSITY": { "label": "In situ density", - "units": "kg m⁻³" + "units": "kg m⁻³", + "colormap": cmo.dense }, "SIGMA": { "label": "Sigma-t", - "units": "kg m⁻³" + "units": "kg m⁻³", + "colormap": cmo.dense + }, + "SIGMA_T": { + "label": "Sigma-t", + "units": "kg m⁻³", + "colormap": cmo.dense + }, + "POTDENS0": { + "label": "Potential density (σ₀)", + "units": "kg m⁻³", + "colormap": cmo.dense + }, + "SIGTHETA": { + "label": "Potential density (σθ)", + "units": "kg m⁻³", + "colormap": cmo.dense + }, + "PRES": { + "label": "Pressure", + "units": "dbar", + "colormap": cmo.deep + }, + "DEPTH_Z": { + "label": "Depth", + "units": "m", + "colormap": cmo.deep + }, + "DEPTH": { + "label": "Depth", + "units": "m", + "colormap": cmo.deep }, "DOXY": { "label": "Dissolved oxygen", - "units": "mmol m⁻³" - }, - "SA":{ - "label": "Absolute salinity", - "units": "g kg⁻¹" + "units": "mmol m⁻³", + "colormap": cmo.oxy }, - "CHLA":{ + "CHLA": { "label": "Chlorophyll", - "units": "mg m⁻³" + "units": "mg m⁻³", + "colormap": cmo.algae }, - "CNDC":{ + "CNDC": { "label": "Conductivity", - "units": "mS cm⁻¹" + "units": "mS cm⁻¹", + "colormap": cmo.haline }, - "DPAR":{ + "DPAR": { "label": "Irradiance PAR", - "units": "μE cm⁻² s⁻¹" + "units": "μE cm⁻² s⁻¹", + "colormap": cmo.solar # Best guess }, - "BBP700":{ + "BBP700": { "label": "Red backscatter, b${bp}$(700)", - "units": "m⁻¹" - } + "units": "m⁻¹", + "colormap": cmo.matter # Best guess + }, + "GLIDER_VERT_VELO_MODEL": { + "label": "Vertical glider velocity", + "units": "cm s⁻¹", + "colormap": cmo.delta + }, + "GLIDER_HORZ_VELO_MODEL": { + "label": "Horizontal glider velocity", + "units": "cm s⁻¹", + "colormap": cmo.delta + }, } def plotting_labels(var: str): @@ -481,6 +540,7 @@ def plotting_labels(var: str): else: label= f'{var}' return label + def plotting_units(ds: xr.Dataset,var: str): """ Retrieves the units associated with a variable from a dataset or a predefined dictionary. @@ -512,6 +572,35 @@ def plotting_units(ds: xr.Dataset,var: str): return f'{ds[var].units}' else: return "" + +def plotting_colormap(var: str): + """ + Retrieves the colormap associated with a variable from a predefined dictionary. + + This function checks if the given variable `var` exists as a key in the `label_dict` dictionary. + If found, it returns the associated colormap from `label_dict`. If not, it returns cmocean.cm.delta as a default colormap. + + Parameters + ---------- + var: str + The variable (key) whose colormap is to be retrieved. + + Returns + ------- + colormap: matplotlib colormap or None + The colormap corresponding to the variable `var`. If the variable is not found in `label_dict`, + the function returns None. + + Notes + ----- + Original Author: Till Moritz + """ + if var in label_dict: + colormap = label_dict[var]["colormap"] + else: + colormap = cmo.delta + return colormap + def group_by_profiles(ds, variables=None): """ Group glider dataset by the dive profile number. diff --git a/requirements-dev.txt b/requirements-dev.txt index e15b6e8..f11631c 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -3,6 +3,7 @@ pooch sgp4>=2.14 xarray numpy +cmocean matplotlib pandas seaborn diff --git a/requirements.txt b/requirements.txt index 01cdd87..f09987e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ netcdf4 pooch xarray numpy +cmocean matplotlib pandas seaborn diff --git a/tests/test_plots.py b/tests/test_plots.py index fe9e341..ef712b9 100644 --- a/tests/test_plots.py +++ b/tests/test_plots.py @@ -2,6 +2,7 @@ from glidertest import fetchers, tools, plots, utilities import matplotlib.pyplot as plt import numpy as np +import pandas as pd import matplotlib from ioos_qc import qartod @@ -137,4 +138,9 @@ def test_plot_CR(): ds = fetchers.load_sample_dataset() ds = tools.add_sigma_1(ds) prof_num = ds.PROFILE_NUMBER[0].values - plots.plot_CR(ds,profile_num=prof_num) \ No newline at end of file + plots.plot_CR(ds,profile_num=prof_num) + +#def test_plot_section(): +# ds = fetchers.load_sample_dataset() +# mld = tools.compute_mld(ds,variable = 'DENSITY', use_bins = True) +# plots.plot_section(ds,mld_df=mld) \ No newline at end of file diff --git a/tests/test_tools.py b/tests/test_tools.py index c02494f..4440ce0 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -13,6 +13,10 @@ def test_updown_bias(v_res=1): ncell = math.ceil(len(bins)/v_res) assert len(df) == ncell +def test_mean_profile(): + ds = fetchers.load_sample_dataset() + tools.mean_profile(ds, var='TEMP', v_res=1) + def test_daynight(): ds = fetchers.load_sample_dataset() if not "TIME" in ds.indexes.keys(): diff --git a/tests/test_utilities.py b/tests/test_utilities.py index ec28405..cc12887 100644 --- a/tests/test_utilities.py +++ b/tests/test_utilities.py @@ -2,6 +2,7 @@ from glidertest import fetchers, utilities import matplotlib matplotlib.use('agg') # use agg backend to prevent creating plot windows during tests +import cmocean.cm as cmo def test_utilitiesmix(): ds = fetchers.load_sample_dataset() @@ -30,11 +31,15 @@ def test_labels(): var = 'PITCH' label = utilities.plotting_labels(var) assert label == 'PITCH' + colormap = utilities.plotting_colormap(var) + assert colormap == cmo.delta var = 'TEMP' label = utilities.plotting_labels(var) assert label == 'Temperature' unit=utilities.plotting_units(ds, var) assert unit == "°C" + colormap = utilities.plotting_colormap(var) + assert colormap == cmo.thermal def test_bin_profile(): ds = fetchers.load_sample_dataset()