Skip to content

Commit 1d4f09c

Browse files
committed
update center for DotClustermapPlotter
1 parent b597d77 commit 1d4f09c

File tree

5 files changed

+92
-89
lines changed

5 files changed

+92
-89
lines changed

PyComplexHeatmap/clustermap.py

Lines changed: 3 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
_index_to_label,
1919
_index_to_ticklabels,
2020
plot_legend_list,
21-
get_colormap,
21+
get_colormap,adjust_cmap,
2222
evaluate_bezier,getControlPoints
2323
)
2424

@@ -615,34 +615,8 @@ def plot_heatmap(
615615
else:
616616
vmax = np.nanmax(calc_data)
617617

618-
# Choose default colormaps if not provided
619-
if isinstance(cmap, str):
620-
try:
621-
cmap = get_colormap(cmap).copy()
622-
except:
623-
cmap = get_colormap(cmap)
624-
625-
cmap.set_bad(color=na_col) # set the color for NaN values
626-
# Recenter a divergent colormap
627-
if center is not None:
628-
# bad = cmap(np.ma.masked_invalid([np.nan]))[0] # set the first color as the na_color
629-
under = cmap(-np.inf)
630-
over = cmap(np.inf)
631-
under_set = under != cmap(0)
632-
over_set = over != cmap(cmap.N - 1)
633-
634-
vrange = max(vmax - center, center - vmin)
635-
normlize = matplotlib.colors.Normalize(center - vrange, center + vrange)
636-
cmin, cmax = normlize([vmin, vmax])
637-
cc = np.linspace(cmin, cmax, 256)
638-
cmap = matplotlib.colors.ListedColormap(cmap(cc))
639-
# self.cmap.set_bad(bad)
640-
if under_set:
641-
cmap.set_under(
642-
under
643-
) # set the color of -np.inf as the color for low out-of-range values.
644-
if over_set:
645-
cmap.set_over(over)
618+
cmap=adjust_cmap(cmap,vmin=vmin,vmax=vmax,center=center,
619+
na_col=na_col)
646620

647621
# Sort out the annotations
648622
if annot is None or annot is False:
@@ -2326,12 +2300,6 @@ def collect_legends(self):
23262300
if annotation.label_max_width > self.label_max_width:
23272301
self.label_max_width = annotation.label_max_width
23282302
if self.legend:
2329-
# vmax = self.kwargs.get(
2330-
# "vmax", np.nanmax(self.data2d[self.data2d != np.inf])
2331-
# )
2332-
# vmin = self.kwargs.get(
2333-
# "vmin", np.nanmin(self.data2d[self.data2d != -np.inf])
2334-
# )
23352303
self.legend_kws.setdefault("vmin", self.kwargs.get('vmin')) #round(vmin, 2))
23362304
self.legend_kws.setdefault("vmax", self.kwargs.get('vmax')) #round(vmax, 2))
23372305
self.legend_kws.setdefault("center", self.kwargs.get('center',None))

PyComplexHeatmap/dotHeatmap.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,8 @@
55
import numpy as np
66
import matplotlib
77
import matplotlib.pylab as plt
8-
from .utils import mm2inch, plot_legend_list, despine, get_colormap
8+
from .utils import mm2inch, despine, get_colormap,adjust_cmap
99
from .clustermap import ClusterMapPlotter
10-
from matplotlib.ticker import MultipleLocator, FormatStrFormatter
1110

1211

1312
# =============================================================================
@@ -21,13 +20,13 @@ def scale(values, vmin=None, vmax=None):
2120
delta = vmax - vmin
2221
return [(j - vmin) / delta for j in values]
2322

24-
2523
# =============================================================================
2624
def dotHeatmap2d(
2725
data,
2826
hue=None,
2927
vmin=None,
3028
vmax=None,
29+
center=None,
3130
ax=None,
3231
colors=None,
3332
cmap=None,
@@ -154,6 +153,7 @@ def dotHeatmap2d(
154153
kwargs.setdefault(
155154
"norm", matplotlib.colors.Normalize(vmin=vmin, vmax=vmax, clip=True)
156155
)
156+
157157
kwargs["cmap"] = cmap
158158
if hue is None:
159159
#plot using c
@@ -163,6 +163,8 @@ def dotHeatmap2d(
163163
if df1.shape[0] == 0:
164164
continue
165165
kwargs["marker"] = mk
166+
if isinstance(kwargs["cmap"],str) and not center is None:
167+
kwargs["cmap"]=adjust_cmap(kwargs["cmap"],vmin=vmin,vmax=vmax,center=center)
166168
ax.scatter(
167169
x=df1.X.values,
168170
y=df1.Y.values,
@@ -177,6 +179,8 @@ def dotHeatmap2d(
177179
df1 = df.loc[df.Hue == h].copy()
178180
if df1.shape[0] == 0:
179181
continue
182+
if isinstance(cmap[h], str) and not center is None:
183+
cmap[h] = adjust_cmap(cmap[h], vmin=vmin, vmax=vmax, center=center)
180184
kwargs["cmap"] = cmap[h]
181185
for mk in df1.Markers.unique():
182186
# df2 = df1.query("Markers==@mk").copy()

PyComplexHeatmap/utils.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,36 @@ def cluster_labels(labels=None, xticks=None, majority=True):
395395
x = [np.mean(clusters_x[i]) for i in clusters_x]
396396
return labels, x
397397

398+
def adjust_cmap(cmap,vmin,vmax,center=None,na_col='white'):
399+
# Choose default colormaps if not provided
400+
if isinstance(cmap, str):
401+
try:
402+
cmap = get_colormap(cmap).copy()
403+
except:
404+
cmap = get_colormap(cmap)
405+
406+
cmap.set_bad(color=na_col) # set the color for NaN values
407+
# Recenter a divergent colormap
408+
if center is not None:
409+
# bad = cmap(np.ma.masked_invalid([np.nan]))[0] # set the first color as the na_color
410+
under = cmap(-np.inf)
411+
over = cmap(np.inf)
412+
under_set = under != cmap(0)
413+
over_set = over != cmap(cmap.N - 1)
398414

415+
vrange = max(vmax - center, center - vmin)
416+
normlize = matplotlib.colors.Normalize(center - vrange, center + vrange)
417+
cmin, cmax = normlize([vmin, vmax])
418+
cc = np.linspace(cmin, cmax, 256)
419+
cmap = matplotlib.colors.ListedColormap(cmap(cc))
420+
# self.cmap.set_bad(bad)
421+
if under_set:
422+
cmap.set_under(
423+
under
424+
) # set the color of -np.inf as the color for low out-of-range values.
425+
if over_set:
426+
cmap.set_over(over)
427+
return cmap
399428
# =============================================================================
400429
def plot_color_dict_legend(
401430
D=None, ax=None, title=None, color_text=True, label_side="right", kws=None
@@ -518,17 +547,19 @@ def plot_cmap_legend(
518547
vcenter= (vmax + vmin) / 2
519548
center=cbar_kws.pop("center",None)
520549
if center is None:
521-
cbar_kws.setdefault("ticks", [vmin, vcenter, vmax])
550+
center=vcenter
522551
m = plt.cm.ScalarMappable(
523552
norm=matplotlib.colors.Normalize(vmin=vmin, vmax=vmax), cmap=cmap
524553
)
525554
else:
526555
m = plt.cm.ScalarMappable(
527-
norm=matplotlib.colors.CenteredNorm(vcenter=center), cmap=cmap
556+
norm=matplotlib.colors.TwoSlopeNorm(center,vmin=vmin, vmax=vmax), cmap=cmap
528557
)
558+
cbar_kws.setdefault("ticks", [vmin, center, vmax])
529559
cax.yaxis.set_label_position(label_side)
530560
cax.yaxis.set_ticks_position(label_side)
531561
cbar = ax.figure.colorbar(m, cax=cax, **cbar_kws) # use_gridspec=True
562+
# cbar.set_ticks([vmin,center,vmax])
532563
# cbar.outline.set_color('white')
533564
# cbar.outline.set_linewidth(2)
534565
# cbar.dividers.set_color('red')

notebooks/clustermap.ipynb

Lines changed: 6 additions & 6 deletions
Large diffs are not rendered by default.

notebooks/dotHeatmap.ipynb

Lines changed: 43 additions & 43 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)