diff --git a/autoarray/plot/wrap/two_d/grid_plot.py b/autoarray/plot/wrap/two_d/grid_plot.py index c63a99216..e415a7476 100644 --- a/autoarray/plot/wrap/two_d/grid_plot.py +++ b/autoarray/plot/wrap/two_d/grid_plot.py @@ -98,7 +98,11 @@ def plot_grid_list(self, grid_list: Union[List[Grid2D], List[Grid2DIrregular]]): config_dict.pop("c") try: - for grid in grid_list: - plt.plot(grid[:, 1], grid[:, 0], c=next(color), **config_dict) + # for critical curves/caustics, grid_list[0] corresponds to tangential curves + # and grid_list[1] corresponds to radial curves + for i in range(len(grid_list)): + color_i = next(color) # one color for all tangential curves and another color for all radial curves + for grid in grid_list[i]: + plt.plot(grid[:, 1], grid[:, 0], c=color_i, **config_dict) except IndexError: pass