Skip to content

Commit b4a0575

Browse files
committed
Updated the Matplotlib plots and added logic to handle negative points
1 parent a37b6d6 commit b4a0575

1 file changed

Lines changed: 69 additions & 26 deletions

File tree

RAT/plotting.py

Lines changed: 69 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
from RAT.rat_core import PlotEventData, makeSLDProfileXY
22

3+
import pyqtgraph as pg
4+
35
from plotly.subplots import make_subplots
46
import plotly.graph_objects as go
57

68
import matplotlib.pyplot as plt
79
from matplotlib.pyplot import draw, show
10+
import numpy as np
811

912

1013
def plot_ref_SLD_helper_plotly(data: PlotEventData):
@@ -17,16 +20,17 @@ def plot_ref_SLD_helper_plotly(data: PlotEventData):
1720
----------
1821
data : PlotEventData
1922
The plot event data that contains all the information
20-
to generate the ref and sld plots
23+
to generate the ref and sld plots.
2124
"""
2225

2326
# Create the figure with 2 sub plots
2427
rat_plot = go.FigureWidget(
2528
make_subplots(
2629
rows=1,
2730
cols=2,
28-
subplot_titles=("Reflectivity Plot",
29-
"Scattering Lenght Density Plot")))
31+
subplot_titles=(
32+
"Reflectivity Plot",
33+
"Scattering Lenght Density Plot")))
3034

3135
for i, (r, sd, sld, layer) in enumerate(zip(data.reflectivity,
3236
data.shiftedData,
@@ -47,12 +51,21 @@ def plot_ref_SLD_helper_plotly(data: PlotEventData):
4751

4852
# Plot the errors on plot (1,1)
4953
if data.dataPresent[i]:
50-
rat_plot.add_trace(go.Scatter(x=sd[0],
51-
y=sd[1]/div,
54+
55+
sd_x = sd[0]
56+
sd_y, sd_e = map(lambda x: x/div, (sd[1], sd[2]))
57+
58+
# Remove values where data - error will be negative
59+
indices_to_remove = np.flip(np.nonzero(0 > sd_y - sd_e)[0])
60+
sd_x, sd_y, sd_e = map(lambda x: np.delete(x, indices_to_remove),
61+
(sd_x, sd_y, sd_e))
62+
63+
rat_plot.add_trace(go.Scatter(x=sd_x,
64+
y=sd_y,
5265
mode='markers',
5366
error_y=dict(
5467
type='data',
55-
array = sd[2]/div),
68+
array = sd_e),
5669
showlegend=False))
5770

5871
# Plot the scattering lenght densities (slds) on plot (1,2)
@@ -65,10 +78,13 @@ def plot_ref_SLD_helper_plotly(data: PlotEventData):
6578
legendgroup = '2')
6679

6780
if data.resample[i] == 1 or data.modelType == 'custom xy':
81+
new_layer = [[a, b, c] for a, b, c in zip(layer[0],
82+
layer[1],
83+
layer[2])]
6884
new = makeSLDProfileXY(layer[1][1],
6985
layer[1][-1],
7086
data.ssubs[i],
71-
layer,
87+
new_layer,
7288
len(layer[0]),
7389
1.0)
7490

@@ -105,7 +121,7 @@ def plot_ref_SLD_helper_matplotlib(data: PlotEventData):
105121
"""
106122

107123
# Create the figure with 2 sub plots
108-
fig, (ref_plot, sld_plot) = plt.subplots(1, 2)
124+
_, (ref_plot, sld_plot) = plt.subplots(1, 2)
109125
draw()
110126

111127
for i, (r, sd, sld, layer) in enumerate(zip(data.reflectivity,
@@ -119,51 +135,78 @@ def plot_ref_SLD_helper_matplotlib(data: PlotEventData):
119135
# Plot the reflectivity on plot (1,1)
120136
ref_plot.plot(r[0],
121137
r[1]/div,
122-
label=f'ref {i+1}')
123-
138+
label=f'ref {i+1}',
139+
linewidth=2)
124140

125141
# Plot the errors on plot (1,1)
126142
if data.dataPresent[i]:
127-
ref_plot.errorbar(x=sd[0],
128-
y=sd[1]/div,
129-
yerr=sd[2]/div)
143+
144+
sd_x = sd[0]
145+
sd_y, sd_e = map(lambda x: x/div, (sd[1], sd[2]))
146+
147+
# Remove values where data - error will be negative
148+
indices_to_remove = np.flip(np.nonzero(0 > sd_y - sd_e)[0])
149+
sd_x, sd_y, sd_e = map(lambda x: np.delete(x, indices_to_remove),
150+
(sd_x, sd_y, sd_e))
151+
152+
ref_plot.errorbar(x=sd_x,
153+
y=sd_y,
154+
yerr=sd_e,
155+
fmt='none',
156+
color='red',
157+
ecolor='red',
158+
elinewidth=1,
159+
capsize=2)
130160

131161
# Plot the scattering lenght densities (slds) on plot (1,2)
132162
for j in range(1, sld.shape[0]):
133-
sld_plot.scatter(y=sld[j],
134-
x=sld[0],
135-
label=f'sld {i+1}')
163+
sld_plot.plot(sld[0],
164+
sld[j],
165+
label=f'sld {i+1}')
136166

137167
if data.resample[i] == 1 or data.modelType == 'custom xy':
168+
new_layer = [[a, b, c] for a, b, c in zip(layer[0],
169+
layer[1],
170+
layer[2])]
138171
new = makeSLDProfileXY(layer[1][1],
139172
layer[1][-1],
140173
data.ssubs[i],
141-
layer,
174+
new_layer,
142175
len(layer[0]),
143176
1.0)
144177

145-
sld_plot.scatter(y=[row[1] for row in new],
146-
x=[row[0]-49 for row in new])
178+
sld_plot.plot([row[0]-49 for row in new],
179+
[row[1] for row in new])
147180

148181
# Convert the axis to log
149182
ref_plot.set_yscale('log')
150183
ref_plot.set_xscale('log')
151184
ref_plot.set_xlabel('Qz')
152185
ref_plot.set_ylabel('Ref')
153186
ref_plot.legend()
187+
ref_plot.grid()
154188

155189
# Label the axis and disable legend
156190
sld_plot.set_xlabel('Z')
157191
sld_plot.set_ylabel('SLD')
158192
sld_plot.legend()
193+
sld_plot.grid()
159194

160195
# Show plot
161196
show()
162197

163198

164-
def plot_ref_SLD_helper_pyqtgraph(data: PlotEventData, noDelay: bool = True):
199+
def plot_ref_SLD_helper_pyqtgraph(data: PlotEventData):
165200
"""
166-
Helper function to make it eaier to plot from event
201+
Helper function to make it easier to plot from event.
202+
Uses the pyqt library to plot the reflectivity and the
203+
SLD profiles.
204+
205+
Parameters
206+
----------
207+
data : PlotEventData
208+
The plot event data that contains all the information
209+
to generate the ref and sld plots
167210
"""
168211

169212
# Plot the reflectivity
@@ -179,7 +222,7 @@ def plot_ref_SLD_helper_pyqtgraph(data: PlotEventData, noDelay: bool = True):
179222
# layout.addWidget(label, 1, 0)
180223

181224

182-
rat_plot = plotly.tools.make_subplots(rows=1, cols=2)
225+
rat_plot = pg.tools.make_subplots(rows=1, cols=2)
183226

184227
plotWidget = pg.plot(title="Reflectivity Algorithms Toolbox (RAT) - plots")
185228
plotWidget.setLogMode(True, True)
@@ -189,10 +232,10 @@ def plot_ref_SLD_helper_pyqtgraph(data: PlotEventData, noDelay: bool = True):
189232

190233
# refplot = plotWidget.AddPlot(r[0], list(np.divide(r[1], div)), symbol='o')
191234

192-
scatter = plotly.graph_objs.Scatter(
193-
y=r[1]/div,
194-
x=r[0])
195-
rat_plot.append_trace(scatter, 1, 1)
235+
# scatter = plotly.graph_objs.Scatter(
236+
# y=r[1]/div,
237+
# x=r[0])
238+
# rat_plot.append_trace(scatter, 1, 1)
196239

197240

198241
# if data.dataPresent[i]:

0 commit comments

Comments
 (0)