|
39 | 39 | import plotly.express as px |
40 | 40 | import plotly.graph_objects as go |
41 | 41 | import plotly.offline |
| 42 | +import xarray as xr |
42 | 43 | from plotly.exceptions import PlotlyError |
43 | 44 |
|
44 | 45 | if TYPE_CHECKING: |
@@ -326,143 +327,253 @@ def process_colors( |
326 | 327 |
|
327 | 328 |
|
328 | 329 | def with_plotly( |
329 | | - data: pd.DataFrame, |
| 330 | + data: pd.DataFrame | xr.DataArray | xr.Dataset, |
330 | 331 | mode: Literal['stacked_bar', 'line', 'area', 'grouped_bar'] = 'stacked_bar', |
331 | 332 | colors: ColorType = 'viridis', |
332 | 333 | title: str = '', |
333 | 334 | ylabel: str = '', |
334 | 335 | xlabel: str = 'Time in h', |
335 | 336 | fig: go.Figure | None = None, |
| 337 | + facet_by: str | list[str] | None = None, |
| 338 | + animate_by: str | None = None, |
| 339 | + facet_cols: int = 3, |
| 340 | + shared_yaxes: bool = True, |
| 341 | + shared_xaxes: bool = True, |
336 | 342 | ) -> go.Figure: |
337 | 343 | """ |
338 | | - Plot a DataFrame with Plotly, using either stacked bars or stepped lines. |
| 344 | + Plot data with Plotly using facets (subplots) and/or animation for multidimensional data. |
| 345 | +
|
| 346 | + Uses Plotly Express for convenient faceting and animation with automatic styling. |
| 347 | + For simple plots without faceting, can optionally add to an existing figure. |
339 | 348 |
|
340 | 349 | Args: |
341 | | - data: A DataFrame containing the data to plot, where the index represents time (e.g., hours), |
342 | | - and each column represents a separate data series. |
343 | | - mode: The plotting mode. Use 'stacked_bar' for stacked bar charts, 'line' for stepped lines, |
344 | | - or 'area' for stacked area charts. |
345 | | - colors: Color specification, can be: |
346 | | - - A string with a colorscale name (e.g., 'viridis', 'plasma') |
347 | | - - A list of color strings (e.g., ['#ff0000', '#00ff00']) |
348 | | - - A dictionary mapping column names to colors (e.g., {'Column1': '#ff0000'}) |
349 | | - title: The title of the plot. |
| 350 | + data: A DataFrame or xarray DataArray/Dataset to plot. |
| 351 | + mode: The plotting mode. Use 'stacked_bar' for stacked bar charts, 'line' for lines, |
| 352 | + 'area' for stacked area charts, or 'grouped_bar' for grouped bar charts. |
| 353 | + colors: Color specification (colormap, list, or dict mapping labels to colors). |
| 354 | + title: The main title of the plot. |
350 | 355 | ylabel: The label for the y-axis. |
351 | 356 | xlabel: The label for the x-axis. |
352 | | - fig: A Plotly figure object to plot on. If not provided, a new figure will be created. |
| 357 | + fig: A Plotly figure object to plot on (only for simple plots without faceting). |
| 358 | + If not provided, a new figure will be created. |
| 359 | + facet_by: Dimension(s) to create facets for. Creates a subplot grid. |
| 360 | + Can be a single dimension name or list of dimensions (max 2 for facet_row and facet_col). |
| 361 | + If the dimension doesn't exist in the data, it will be silently ignored. |
| 362 | + animate_by: Dimension to animate over. Creates animation frames. |
| 363 | + If the dimension doesn't exist in the data, it will be silently ignored. |
| 364 | + facet_cols: Number of columns in the facet grid (used when facet_by is single dimension). |
| 365 | + shared_yaxes: Whether subplots share y-axes. |
| 366 | + shared_xaxes: Whether subplots share x-axes. |
353 | 367 |
|
354 | 368 | Returns: |
355 | | - A Plotly figure object containing the generated plot. |
| 369 | + A Plotly figure object containing the faceted/animated plot. |
| 370 | +
|
| 371 | + Examples: |
| 372 | + Simple plot: |
| 373 | +
|
| 374 | + ```python |
| 375 | + fig = with_plotly(df, mode='area', title='Energy Mix') |
| 376 | + ``` |
| 377 | +
|
| 378 | + Facet by scenario: |
| 379 | +
|
| 380 | + ```python |
| 381 | + fig = with_plotly(ds, facet_by='scenario', facet_cols=2) |
| 382 | + ``` |
| 383 | +
|
| 384 | + Animate by period: |
| 385 | +
|
| 386 | + ```python |
| 387 | + fig = with_plotly(ds, animate_by='period') |
| 388 | + ``` |
| 389 | +
|
| 390 | + Facet and animate: |
| 391 | +
|
| 392 | + ```python |
| 393 | + fig = with_plotly(ds, facet_by='scenario', animate_by='period') |
| 394 | + ``` |
356 | 395 | """ |
357 | 396 | if mode not in ('stacked_bar', 'line', 'area', 'grouped_bar'): |
358 | 397 | raise ValueError(f"'mode' must be one of {{'stacked_bar','line','area', 'grouped_bar'}}, got {mode!r}") |
359 | | - if data.empty: |
360 | | - return go.Figure() |
361 | 398 |
|
362 | | - processed_colors = ColorProcessor(engine='plotly').process_colors(colors, list(data.columns)) |
363 | | - |
364 | | - fig = fig if fig is not None else go.Figure() |
| 399 | + # Handle empty data |
| 400 | + if isinstance(data, pd.DataFrame) and data.empty: |
| 401 | + return go.Figure() |
| 402 | + elif isinstance(data, xr.DataArray) and data.size == 0: |
| 403 | + return go.Figure() |
| 404 | + elif isinstance(data, xr.Dataset) and len(data.data_vars) == 0: |
| 405 | + return go.Figure() |
365 | 406 |
|
366 | | - if mode == 'stacked_bar': |
367 | | - for i, column in enumerate(data.columns): |
368 | | - fig.add_trace( |
369 | | - go.Bar( |
370 | | - x=data.index, |
371 | | - y=data[column], |
372 | | - name=column, |
373 | | - marker=dict( |
374 | | - color=processed_colors[i], line=dict(width=0, color='rgba(0,0,0,0)') |
375 | | - ), # Transparent line with 0 width |
| 407 | + # Warn if fig parameter is used with faceting |
| 408 | + if fig is not None and (facet_by is not None or animate_by is not None): |
| 409 | + logger.warning('The fig parameter is ignored when using faceting or animation. Creating a new figure.') |
| 410 | + fig = None |
| 411 | + |
| 412 | + # Convert xarray to long-form DataFrame for Plotly Express |
| 413 | + if isinstance(data, (xr.DataArray, xr.Dataset)): |
| 414 | + # Convert to long-form (tidy) DataFrame |
| 415 | + # Structure: time, variable, value, scenario, period, ... (all dims as columns) |
| 416 | + if isinstance(data, xr.Dataset): |
| 417 | + # Stack all data variables into long format |
| 418 | + df_long = data.to_dataframe().reset_index() |
| 419 | + # Melt to get: time, scenario, period, ..., variable, value |
| 420 | + id_vars = [dim for dim in data.dims] |
| 421 | + value_vars = list(data.data_vars) |
| 422 | + df_long = df_long.melt(id_vars=id_vars, value_vars=value_vars, var_name='variable', value_name='value') |
| 423 | + else: |
| 424 | + # DataArray |
| 425 | + df_long = data.to_dataframe().reset_index() |
| 426 | + if data.name: |
| 427 | + df_long = df_long.rename(columns={data.name: 'value'}) |
| 428 | + else: |
| 429 | + # Unnamed DataArray, find the value column |
| 430 | + value_col = [col for col in df_long.columns if col not in data.dims][0] |
| 431 | + df_long = df_long.rename(columns={value_col: 'value'}) |
| 432 | + df_long['variable'] = data.name or 'data' |
| 433 | + else: |
| 434 | + # Already a DataFrame - convert to long format for Plotly Express |
| 435 | + df_long = data.reset_index() |
| 436 | + if 'time' not in df_long.columns: |
| 437 | + # First column is probably time |
| 438 | + df_long = df_long.rename(columns={df_long.columns[0]: 'time'}) |
| 439 | + # Melt to long format |
| 440 | + id_vars = [ |
| 441 | + col |
| 442 | + for col in df_long.columns |
| 443 | + if col in ['time', 'scenario', 'period'] |
| 444 | + or col in (facet_by if isinstance(facet_by, list) else [facet_by] if facet_by else []) |
| 445 | + ] |
| 446 | + value_vars = [col for col in df_long.columns if col not in id_vars] |
| 447 | + df_long = df_long.melt(id_vars=id_vars, value_vars=value_vars, var_name='variable', value_name='value') |
| 448 | + |
| 449 | + # Validate facet_by and animate_by dimensions exist in the data |
| 450 | + available_dims = [col for col in df_long.columns if col not in ['variable', 'value']] |
| 451 | + |
| 452 | + # Check facet_by dimensions |
| 453 | + if facet_by is not None: |
| 454 | + if isinstance(facet_by, str): |
| 455 | + if facet_by not in available_dims: |
| 456 | + logger.debug( |
| 457 | + f"Dimension '{facet_by}' not found in data. Available dimensions: {available_dims}. " |
| 458 | + f'Ignoring facet_by parameter.' |
376 | 459 | ) |
377 | | - ) |
378 | | - |
379 | | - fig.update_layout( |
380 | | - barmode='relative', |
381 | | - bargap=0, # No space between bars |
382 | | - bargroupgap=0, # No space between grouped bars |
| 460 | + facet_by = None |
| 461 | + elif isinstance(facet_by, list): |
| 462 | + # Filter out dimensions that don't exist |
| 463 | + missing_dims = [dim for dim in facet_by if dim not in available_dims] |
| 464 | + facet_by = [dim for dim in facet_by if dim in available_dims] |
| 465 | + if missing_dims: |
| 466 | + logger.debug( |
| 467 | + f'Dimensions {missing_dims} not found in data. Available dimensions: {available_dims}. ' |
| 468 | + f'Using only existing dimensions: {facet_by if facet_by else "none"}.' |
| 469 | + ) |
| 470 | + if len(facet_by) == 0: |
| 471 | + facet_by = None |
| 472 | + |
| 473 | + # Check animate_by dimension |
| 474 | + if animate_by is not None and animate_by not in available_dims: |
| 475 | + logger.debug( |
| 476 | + f"Dimension '{animate_by}' not found in data. Available dimensions: {available_dims}. " |
| 477 | + f'Ignoring animate_by parameter.' |
383 | 478 | ) |
384 | | - if mode == 'grouped_bar': |
385 | | - for i, column in enumerate(data.columns): |
386 | | - fig.add_trace(go.Bar(x=data.index, y=data[column], name=column, marker=dict(color=processed_colors[i]))) |
| 479 | + animate_by = None |
| 480 | + |
| 481 | + # Setup faceting parameters for Plotly Express |
| 482 | + facet_row = None |
| 483 | + facet_col = None |
| 484 | + if facet_by: |
| 485 | + if isinstance(facet_by, str): |
| 486 | + # Single facet dimension - use facet_col with facet_col_wrap |
| 487 | + facet_col = facet_by |
| 488 | + elif len(facet_by) == 1: |
| 489 | + facet_col = facet_by[0] |
| 490 | + elif len(facet_by) == 2: |
| 491 | + # Two facet dimensions - use facet_row and facet_col |
| 492 | + facet_row = facet_by[0] |
| 493 | + facet_col = facet_by[1] |
| 494 | + else: |
| 495 | + raise ValueError(f'facet_by can have at most 2 dimensions, got {len(facet_by)}') |
| 496 | + |
| 497 | + # Process colors |
| 498 | + all_vars = df_long['variable'].unique().tolist() |
| 499 | + processed_colors = ColorProcessor(engine='plotly').process_colors(colors, all_vars) |
| 500 | + color_discrete_map = {var: color for var, color in zip(all_vars, processed_colors, strict=False)} |
| 501 | + |
| 502 | + # Create plot using Plotly Express based on mode |
| 503 | + common_args = { |
| 504 | + 'data_frame': df_long, |
| 505 | + 'x': 'time', |
| 506 | + 'y': 'value', |
| 507 | + 'color': 'variable', |
| 508 | + 'facet_row': facet_row, |
| 509 | + 'facet_col': facet_col, |
| 510 | + 'animation_frame': animate_by, |
| 511 | + 'color_discrete_map': color_discrete_map, |
| 512 | + 'title': title, |
| 513 | + 'labels': {'value': ylabel, 'time': xlabel, 'variable': ''}, |
| 514 | + } |
387 | 515 |
|
388 | | - fig.update_layout( |
389 | | - barmode='group', |
390 | | - bargap=0.2, # No space between bars |
391 | | - bargroupgap=0, # space between grouped bars |
392 | | - ) |
| 516 | + # Add facet_col_wrap for single facet dimension |
| 517 | + if facet_col and not facet_row: |
| 518 | + common_args['facet_col_wrap'] = facet_cols |
| 519 | + |
| 520 | + if mode == 'stacked_bar': |
| 521 | + fig = px.bar(**common_args) |
| 522 | + fig.update_traces(marker_line_width=0) |
| 523 | + fig.update_layout(barmode='relative', bargap=0, bargroupgap=0) |
| 524 | + elif mode == 'grouped_bar': |
| 525 | + fig = px.bar(**common_args) |
| 526 | + fig.update_layout(barmode='group', bargap=0.2, bargroupgap=0) |
393 | 527 | elif mode == 'line': |
394 | | - for i, column in enumerate(data.columns): |
395 | | - fig.add_trace( |
396 | | - go.Scatter( |
397 | | - x=data.index, |
398 | | - y=data[column], |
399 | | - mode='lines', |
400 | | - name=column, |
401 | | - line=dict(shape='hv', color=processed_colors[i]), |
402 | | - ) |
403 | | - ) |
| 528 | + fig = px.line(**common_args, line_shape='hv') # Stepped lines |
404 | 529 | elif mode == 'area': |
405 | | - data = data.copy() |
406 | | - data[(data > -1e-5) & (data < 1e-5)] = 0 # Preventing issues with plotting |
407 | | - # Split columns into positive, negative, and mixed categories |
408 | | - positive_columns = list(data.columns[(data >= 0).where(~np.isnan(data), True).all()]) |
409 | | - negative_columns = list(data.columns[(data <= 0).where(~np.isnan(data), True).all()]) |
410 | | - negative_columns = [column for column in negative_columns if column not in positive_columns] |
411 | | - mixed_columns = list(set(data.columns) - set(positive_columns + negative_columns)) |
412 | | - |
413 | | - if mixed_columns: |
414 | | - logger.error( |
415 | | - f'Data for plotting stacked lines contains columns with both positive and negative values:' |
416 | | - f' {mixed_columns}. These can not be stacked, and are printed as simple lines' |
417 | | - ) |
| 530 | + # Use Plotly Express to create the area plot (preserves animation, legends, faceting) |
| 531 | + fig = px.area(**common_args, line_shape='hv') |
418 | 532 |
|
419 | | - # Get color mapping for all columns |
420 | | - colors_stacked = {column: processed_colors[i] for i, column in enumerate(data.columns)} |
421 | | - |
422 | | - for column in positive_columns + negative_columns: |
423 | | - fig.add_trace( |
424 | | - go.Scatter( |
425 | | - x=data.index, |
426 | | - y=data[column], |
427 | | - mode='lines', |
428 | | - name=column, |
429 | | - line=dict(shape='hv', color=colors_stacked[column]), |
430 | | - fill='tonexty', |
431 | | - stackgroup='pos' if column in positive_columns else 'neg', |
432 | | - ) |
433 | | - ) |
| 533 | + # Classify each variable based on its values |
| 534 | + variable_classification = {} |
| 535 | + for var in all_vars: |
| 536 | + var_data = df_long[df_long['variable'] == var]['value'] |
| 537 | + var_data_clean = var_data[(var_data < -1e-5) | (var_data > 1e-5)] |
434 | 538 |
|
435 | | - for column in mixed_columns: |
436 | | - fig.add_trace( |
437 | | - go.Scatter( |
438 | | - x=data.index, |
439 | | - y=data[column], |
440 | | - mode='lines', |
441 | | - name=column, |
442 | | - line=dict(shape='hv', color=colors_stacked[column], dash='dash'), |
| 539 | + if len(var_data_clean) == 0: |
| 540 | + variable_classification[var] = 'zero' |
| 541 | + else: |
| 542 | + has_pos, has_neg = (var_data_clean > 0).any(), (var_data_clean < 0).any() |
| 543 | + variable_classification[var] = ( |
| 544 | + 'mixed' if has_pos and has_neg else ('negative' if has_neg else 'positive') |
443 | 545 | ) |
444 | | - ) |
445 | 546 |
|
446 | | - # Update layout for better aesthetics |
| 547 | + # Log warning for mixed variables |
| 548 | + mixed_vars = [v for v, c in variable_classification.items() if c == 'mixed'] |
| 549 | + if mixed_vars: |
| 550 | + logger.warning(f'Variables with both positive and negative values: {mixed_vars}. Plotted as dashed lines.') |
| 551 | + |
| 552 | + all_traces = list(fig.data) |
| 553 | + for frame in fig.frames: |
| 554 | + all_traces.extend(frame.data) |
| 555 | + |
| 556 | + for trace in all_traces: |
| 557 | + trace.stackgroup = variable_classification.get(trace.name, None) |
| 558 | + # No opacity and no line for stacked areas |
| 559 | + if trace.stackgroup is not None: |
| 560 | + if hasattr(trace, 'line') and trace.line.color: |
| 561 | + trace.fillcolor = trace.line.color # Will be solid by default |
| 562 | + trace.line.width = 0 |
| 563 | + |
| 564 | + # Update layout with basic styling (Plotly Express handles sizing automatically) |
447 | 565 | fig.update_layout( |
448 | | - title=title, |
449 | | - yaxis=dict( |
450 | | - title=ylabel, |
451 | | - showgrid=True, # Enable grid lines on the y-axis |
452 | | - gridcolor='lightgrey', # Customize grid line color |
453 | | - gridwidth=0.5, # Customize grid line width |
454 | | - ), |
455 | | - xaxis=dict( |
456 | | - title=xlabel, |
457 | | - showgrid=True, # Enable grid lines on the x-axis |
458 | | - gridcolor='lightgrey', # Customize grid line color |
459 | | - gridwidth=0.5, # Customize grid line width |
460 | | - ), |
461 | | - plot_bgcolor='rgba(0,0,0,0)', # Transparent background |
462 | | - paper_bgcolor='rgba(0,0,0,0)', # Transparent paper background |
463 | | - font=dict(size=14), # Increase font size for better readability |
| 566 | + plot_bgcolor='rgba(0,0,0,0)', |
| 567 | + paper_bgcolor='rgba(0,0,0,0)', |
| 568 | + font=dict(size=12), |
464 | 569 | ) |
465 | 570 |
|
| 571 | + # Update axes to share if requested (Plotly Express already handles this, but we can customize) |
| 572 | + if not shared_yaxes: |
| 573 | + fig.update_yaxes(matches=None) |
| 574 | + if not shared_xaxes: |
| 575 | + fig.update_xaxes(matches=None) |
| 576 | + |
466 | 577 | return fig |
467 | 578 |
|
468 | 579 |
|
|
0 commit comments