Skip to content

Commit cc7de38

Browse files
authored
Feature/398 feature facet plots in results (#419)
* Add animation and faceting options to plots * Adjust size of the frame * Utilize plotly express directly * Rmeocve old class * Use plotly express and modify stackgroup afterwards * Add modifications also to animations * Mkae more compact * Remove height stuff * Remove line and make set opacity =0 for area * Integrate faceting and animating into existing with_plotly method * Improve results.py * Improve results.py * Move check if dims are found to plotting.py * Fix usage of indexer * Change selection string with indexer * Change behaviout of parameter "indexing" * Update CHANGELOG.md * Add new selection parameter to plotting methods * deprectae old indexer parameter * deprectae old indexer parameter * Add test * Add test * Add test * Add test * Fix not supportet check for matplotlib * Typo in CHANGELOG.md
1 parent 8323b27 commit cc7de38

File tree

7 files changed

+758
-164
lines changed

7 files changed

+758
-164
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,12 @@ If upgrading from v2.x, see the [Migration Guide](https://flixopt.github.io/flix
5454
5555
5656
### ✨ Added
57+
- Added faceting and animation options to plotting methods
5758
5859
### 💥 Breaking Changes
5960
6061
### ♻️ Changed
62+
- Changed indexer behaviour. Defaults to not indexing instead of the first value except for time. Also changed naming when indexing.
6163
6264
### 🗑️ Deprecated
6365

flixopt/plotting.py

Lines changed: 217 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import plotly.express as px
4040
import plotly.graph_objects as go
4141
import plotly.offline
42+
import xarray as xr
4243
from plotly.exceptions import PlotlyError
4344

4445
if TYPE_CHECKING:
@@ -326,143 +327,253 @@ def process_colors(
326327

327328

328329
def with_plotly(
329-
data: pd.DataFrame,
330+
data: pd.DataFrame | xr.DataArray | xr.Dataset,
330331
mode: Literal['stacked_bar', 'line', 'area', 'grouped_bar'] = 'stacked_bar',
331332
colors: ColorType = 'viridis',
332333
title: str = '',
333334
ylabel: str = '',
334335
xlabel: str = 'Time in h',
335336
fig: go.Figure | None = None,
337+
facet_by: str | list[str] | None = None,
338+
animate_by: str | None = None,
339+
facet_cols: int = 3,
340+
shared_yaxes: bool = True,
341+
shared_xaxes: bool = True,
336342
) -> go.Figure:
337343
"""
338-
Plot a DataFrame with Plotly, using either stacked bars or stepped lines.
344+
Plot data with Plotly using facets (subplots) and/or animation for multidimensional data.
345+
346+
Uses Plotly Express for convenient faceting and animation with automatic styling.
347+
For simple plots without faceting, can optionally add to an existing figure.
339348
340349
Args:
341-
data: A DataFrame containing the data to plot, where the index represents time (e.g., hours),
342-
and each column represents a separate data series.
343-
mode: The plotting mode. Use 'stacked_bar' for stacked bar charts, 'line' for stepped lines,
344-
or 'area' for stacked area charts.
345-
colors: Color specification, can be:
346-
- A string with a colorscale name (e.g., 'viridis', 'plasma')
347-
- A list of color strings (e.g., ['#ff0000', '#00ff00'])
348-
- A dictionary mapping column names to colors (e.g., {'Column1': '#ff0000'})
349-
title: The title of the plot.
350+
data: A DataFrame or xarray DataArray/Dataset to plot.
351+
mode: The plotting mode. Use 'stacked_bar' for stacked bar charts, 'line' for lines,
352+
'area' for stacked area charts, or 'grouped_bar' for grouped bar charts.
353+
colors: Color specification (colormap, list, or dict mapping labels to colors).
354+
title: The main title of the plot.
350355
ylabel: The label for the y-axis.
351356
xlabel: The label for the x-axis.
352-
fig: A Plotly figure object to plot on. If not provided, a new figure will be created.
357+
fig: A Plotly figure object to plot on (only for simple plots without faceting).
358+
If not provided, a new figure will be created.
359+
facet_by: Dimension(s) to create facets for. Creates a subplot grid.
360+
Can be a single dimension name or list of dimensions (max 2 for facet_row and facet_col).
361+
If the dimension doesn't exist in the data, it will be silently ignored.
362+
animate_by: Dimension to animate over. Creates animation frames.
363+
If the dimension doesn't exist in the data, it will be silently ignored.
364+
facet_cols: Number of columns in the facet grid (used when facet_by is single dimension).
365+
shared_yaxes: Whether subplots share y-axes.
366+
shared_xaxes: Whether subplots share x-axes.
353367
354368
Returns:
355-
A Plotly figure object containing the generated plot.
369+
A Plotly figure object containing the faceted/animated plot.
370+
371+
Examples:
372+
Simple plot:
373+
374+
```python
375+
fig = with_plotly(df, mode='area', title='Energy Mix')
376+
```
377+
378+
Facet by scenario:
379+
380+
```python
381+
fig = with_plotly(ds, facet_by='scenario', facet_cols=2)
382+
```
383+
384+
Animate by period:
385+
386+
```python
387+
fig = with_plotly(ds, animate_by='period')
388+
```
389+
390+
Facet and animate:
391+
392+
```python
393+
fig = with_plotly(ds, facet_by='scenario', animate_by='period')
394+
```
356395
"""
357396
if mode not in ('stacked_bar', 'line', 'area', 'grouped_bar'):
358397
raise ValueError(f"'mode' must be one of {{'stacked_bar','line','area', 'grouped_bar'}}, got {mode!r}")
359-
if data.empty:
360-
return go.Figure()
361398

362-
processed_colors = ColorProcessor(engine='plotly').process_colors(colors, list(data.columns))
363-
364-
fig = fig if fig is not None else go.Figure()
399+
# Handle empty data
400+
if isinstance(data, pd.DataFrame) and data.empty:
401+
return go.Figure()
402+
elif isinstance(data, xr.DataArray) and data.size == 0:
403+
return go.Figure()
404+
elif isinstance(data, xr.Dataset) and len(data.data_vars) == 0:
405+
return go.Figure()
365406

366-
if mode == 'stacked_bar':
367-
for i, column in enumerate(data.columns):
368-
fig.add_trace(
369-
go.Bar(
370-
x=data.index,
371-
y=data[column],
372-
name=column,
373-
marker=dict(
374-
color=processed_colors[i], line=dict(width=0, color='rgba(0,0,0,0)')
375-
), # Transparent line with 0 width
407+
# Warn if fig parameter is used with faceting
408+
if fig is not None and (facet_by is not None or animate_by is not None):
409+
logger.warning('The fig parameter is ignored when using faceting or animation. Creating a new figure.')
410+
fig = None
411+
412+
# Convert xarray to long-form DataFrame for Plotly Express
413+
if isinstance(data, (xr.DataArray, xr.Dataset)):
414+
# Convert to long-form (tidy) DataFrame
415+
# Structure: time, variable, value, scenario, period, ... (all dims as columns)
416+
if isinstance(data, xr.Dataset):
417+
# Stack all data variables into long format
418+
df_long = data.to_dataframe().reset_index()
419+
# Melt to get: time, scenario, period, ..., variable, value
420+
id_vars = [dim for dim in data.dims]
421+
value_vars = list(data.data_vars)
422+
df_long = df_long.melt(id_vars=id_vars, value_vars=value_vars, var_name='variable', value_name='value')
423+
else:
424+
# DataArray
425+
df_long = data.to_dataframe().reset_index()
426+
if data.name:
427+
df_long = df_long.rename(columns={data.name: 'value'})
428+
else:
429+
# Unnamed DataArray, find the value column
430+
value_col = [col for col in df_long.columns if col not in data.dims][0]
431+
df_long = df_long.rename(columns={value_col: 'value'})
432+
df_long['variable'] = data.name or 'data'
433+
else:
434+
# Already a DataFrame - convert to long format for Plotly Express
435+
df_long = data.reset_index()
436+
if 'time' not in df_long.columns:
437+
# First column is probably time
438+
df_long = df_long.rename(columns={df_long.columns[0]: 'time'})
439+
# Melt to long format
440+
id_vars = [
441+
col
442+
for col in df_long.columns
443+
if col in ['time', 'scenario', 'period']
444+
or col in (facet_by if isinstance(facet_by, list) else [facet_by] if facet_by else [])
445+
]
446+
value_vars = [col for col in df_long.columns if col not in id_vars]
447+
df_long = df_long.melt(id_vars=id_vars, value_vars=value_vars, var_name='variable', value_name='value')
448+
449+
# Validate facet_by and animate_by dimensions exist in the data
450+
available_dims = [col for col in df_long.columns if col not in ['variable', 'value']]
451+
452+
# Check facet_by dimensions
453+
if facet_by is not None:
454+
if isinstance(facet_by, str):
455+
if facet_by not in available_dims:
456+
logger.debug(
457+
f"Dimension '{facet_by}' not found in data. Available dimensions: {available_dims}. "
458+
f'Ignoring facet_by parameter.'
376459
)
377-
)
378-
379-
fig.update_layout(
380-
barmode='relative',
381-
bargap=0, # No space between bars
382-
bargroupgap=0, # No space between grouped bars
460+
facet_by = None
461+
elif isinstance(facet_by, list):
462+
# Filter out dimensions that don't exist
463+
missing_dims = [dim for dim in facet_by if dim not in available_dims]
464+
facet_by = [dim for dim in facet_by if dim in available_dims]
465+
if missing_dims:
466+
logger.debug(
467+
f'Dimensions {missing_dims} not found in data. Available dimensions: {available_dims}. '
468+
f'Using only existing dimensions: {facet_by if facet_by else "none"}.'
469+
)
470+
if len(facet_by) == 0:
471+
facet_by = None
472+
473+
# Check animate_by dimension
474+
if animate_by is not None and animate_by not in available_dims:
475+
logger.debug(
476+
f"Dimension '{animate_by}' not found in data. Available dimensions: {available_dims}. "
477+
f'Ignoring animate_by parameter.'
383478
)
384-
if mode == 'grouped_bar':
385-
for i, column in enumerate(data.columns):
386-
fig.add_trace(go.Bar(x=data.index, y=data[column], name=column, marker=dict(color=processed_colors[i])))
479+
animate_by = None
480+
481+
# Setup faceting parameters for Plotly Express
482+
facet_row = None
483+
facet_col = None
484+
if facet_by:
485+
if isinstance(facet_by, str):
486+
# Single facet dimension - use facet_col with facet_col_wrap
487+
facet_col = facet_by
488+
elif len(facet_by) == 1:
489+
facet_col = facet_by[0]
490+
elif len(facet_by) == 2:
491+
# Two facet dimensions - use facet_row and facet_col
492+
facet_row = facet_by[0]
493+
facet_col = facet_by[1]
494+
else:
495+
raise ValueError(f'facet_by can have at most 2 dimensions, got {len(facet_by)}')
496+
497+
# Process colors
498+
all_vars = df_long['variable'].unique().tolist()
499+
processed_colors = ColorProcessor(engine='plotly').process_colors(colors, all_vars)
500+
color_discrete_map = {var: color for var, color in zip(all_vars, processed_colors, strict=False)}
501+
502+
# Create plot using Plotly Express based on mode
503+
common_args = {
504+
'data_frame': df_long,
505+
'x': 'time',
506+
'y': 'value',
507+
'color': 'variable',
508+
'facet_row': facet_row,
509+
'facet_col': facet_col,
510+
'animation_frame': animate_by,
511+
'color_discrete_map': color_discrete_map,
512+
'title': title,
513+
'labels': {'value': ylabel, 'time': xlabel, 'variable': ''},
514+
}
387515

388-
fig.update_layout(
389-
barmode='group',
390-
bargap=0.2, # No space between bars
391-
bargroupgap=0, # space between grouped bars
392-
)
516+
# Add facet_col_wrap for single facet dimension
517+
if facet_col and not facet_row:
518+
common_args['facet_col_wrap'] = facet_cols
519+
520+
if mode == 'stacked_bar':
521+
fig = px.bar(**common_args)
522+
fig.update_traces(marker_line_width=0)
523+
fig.update_layout(barmode='relative', bargap=0, bargroupgap=0)
524+
elif mode == 'grouped_bar':
525+
fig = px.bar(**common_args)
526+
fig.update_layout(barmode='group', bargap=0.2, bargroupgap=0)
393527
elif mode == 'line':
394-
for i, column in enumerate(data.columns):
395-
fig.add_trace(
396-
go.Scatter(
397-
x=data.index,
398-
y=data[column],
399-
mode='lines',
400-
name=column,
401-
line=dict(shape='hv', color=processed_colors[i]),
402-
)
403-
)
528+
fig = px.line(**common_args, line_shape='hv') # Stepped lines
404529
elif mode == 'area':
405-
data = data.copy()
406-
data[(data > -1e-5) & (data < 1e-5)] = 0 # Preventing issues with plotting
407-
# Split columns into positive, negative, and mixed categories
408-
positive_columns = list(data.columns[(data >= 0).where(~np.isnan(data), True).all()])
409-
negative_columns = list(data.columns[(data <= 0).where(~np.isnan(data), True).all()])
410-
negative_columns = [column for column in negative_columns if column not in positive_columns]
411-
mixed_columns = list(set(data.columns) - set(positive_columns + negative_columns))
412-
413-
if mixed_columns:
414-
logger.error(
415-
f'Data for plotting stacked lines contains columns with both positive and negative values:'
416-
f' {mixed_columns}. These can not be stacked, and are printed as simple lines'
417-
)
530+
# Use Plotly Express to create the area plot (preserves animation, legends, faceting)
531+
fig = px.area(**common_args, line_shape='hv')
418532

419-
# Get color mapping for all columns
420-
colors_stacked = {column: processed_colors[i] for i, column in enumerate(data.columns)}
421-
422-
for column in positive_columns + negative_columns:
423-
fig.add_trace(
424-
go.Scatter(
425-
x=data.index,
426-
y=data[column],
427-
mode='lines',
428-
name=column,
429-
line=dict(shape='hv', color=colors_stacked[column]),
430-
fill='tonexty',
431-
stackgroup='pos' if column in positive_columns else 'neg',
432-
)
433-
)
533+
# Classify each variable based on its values
534+
variable_classification = {}
535+
for var in all_vars:
536+
var_data = df_long[df_long['variable'] == var]['value']
537+
var_data_clean = var_data[(var_data < -1e-5) | (var_data > 1e-5)]
434538

435-
for column in mixed_columns:
436-
fig.add_trace(
437-
go.Scatter(
438-
x=data.index,
439-
y=data[column],
440-
mode='lines',
441-
name=column,
442-
line=dict(shape='hv', color=colors_stacked[column], dash='dash'),
539+
if len(var_data_clean) == 0:
540+
variable_classification[var] = 'zero'
541+
else:
542+
has_pos, has_neg = (var_data_clean > 0).any(), (var_data_clean < 0).any()
543+
variable_classification[var] = (
544+
'mixed' if has_pos and has_neg else ('negative' if has_neg else 'positive')
443545
)
444-
)
445546

446-
# Update layout for better aesthetics
547+
# Log warning for mixed variables
548+
mixed_vars = [v for v, c in variable_classification.items() if c == 'mixed']
549+
if mixed_vars:
550+
logger.warning(f'Variables with both positive and negative values: {mixed_vars}. Plotted as dashed lines.')
551+
552+
all_traces = list(fig.data)
553+
for frame in fig.frames:
554+
all_traces.extend(frame.data)
555+
556+
for trace in all_traces:
557+
trace.stackgroup = variable_classification.get(trace.name, None)
558+
# No opacity and no line for stacked areas
559+
if trace.stackgroup is not None:
560+
if hasattr(trace, 'line') and trace.line.color:
561+
trace.fillcolor = trace.line.color # Will be solid by default
562+
trace.line.width = 0
563+
564+
# Update layout with basic styling (Plotly Express handles sizing automatically)
447565
fig.update_layout(
448-
title=title,
449-
yaxis=dict(
450-
title=ylabel,
451-
showgrid=True, # Enable grid lines on the y-axis
452-
gridcolor='lightgrey', # Customize grid line color
453-
gridwidth=0.5, # Customize grid line width
454-
),
455-
xaxis=dict(
456-
title=xlabel,
457-
showgrid=True, # Enable grid lines on the x-axis
458-
gridcolor='lightgrey', # Customize grid line color
459-
gridwidth=0.5, # Customize grid line width
460-
),
461-
plot_bgcolor='rgba(0,0,0,0)', # Transparent background
462-
paper_bgcolor='rgba(0,0,0,0)', # Transparent paper background
463-
font=dict(size=14), # Increase font size for better readability
566+
plot_bgcolor='rgba(0,0,0,0)',
567+
paper_bgcolor='rgba(0,0,0,0)',
568+
font=dict(size=12),
464569
)
465570

571+
# Update axes to share if requested (Plotly Express already handles this, but we can customize)
572+
if not shared_yaxes:
573+
fig.update_yaxes(matches=None)
574+
if not shared_xaxes:
575+
fig.update_xaxes(matches=None)
576+
466577
return fig
467578

468579

0 commit comments

Comments
 (0)