diff --git a/ternary/heatmapping.py b/ternary/heatmapping.py index 7487f69..68146a4 100644 --- a/ternary/heatmapping.py +++ b/ternary/heatmapping.py @@ -186,7 +186,7 @@ def polygon_generator(data, scale, style, permutation=None): yield map(project, vertices), value -def heatmap(data, scale, vmin=None, vmax=None, cmap=None, ax=None, +def heatmap(data, scale, vmin=None, vmax=None, adj_vlims=False, cmap=None, ax=None, scientific=False, style='triangular', colorbar=True, permutation=None, use_rgba=False, cbarlabel=None, cb_kwargs=None): """ @@ -203,6 +203,8 @@ def heatmap(data, scale, vmin=None, vmax=None, cmap=None, ax=None, The minimum color value, used to normalize colors. Computed if absent. vmax: float, None The maximum color value, used to normalize colors. Computed if absent. + adj_vlims: bool, False + Redefine min and max color values based on computed averages. cmap: String or matplotlib.colors.Colormap, None The name of the Matplotlib colormap to use. ax: Matplotlib AxesSubplot, None @@ -229,6 +231,12 @@ def heatmap(data, scale, vmin=None, vmax=None, cmap=None, ax=None, if not ax: fig, ax = plt.subplots() + + if vmax or vmin: + vlims_defined = True + else: + vlims_defined = False + # If use_rgba, make the RGBA values numpy arrays so that they can # be averaged. if use_rgba: @@ -247,6 +255,18 @@ def heatmap(data, scale, vmin=None, vmax=None, cmap=None, ax=None, vertices_values = polygon_generator(data, scale, style, permutation=permutation) + # adjust limits of the colormapper if requested, + # but only if user also didn't request specific vlims + if adj_vlims and not vlims_defined: + vals = [] + for _, val in vertices_values: + vals.append(val) + vmin = min(vals) + vmax = max(vals) + + vertices_values = polygon_generator(data, scale, style, + permutation=permutation) + # Draw the polygons and color them for vertices, value in vertices_values: if value is None: @@ -257,7 +277,7 @@ def heatmap(data, scale, vmin=None, vmax=None, cmap=None, ax=None, color = value # rgba tuple (r,g,b,a) all in [0,1] # Matplotlib wants a list of xs and a list of ys xs, ys = unzip(vertices) - ax.fill(xs, ys, facecolor=color, edgecolor=color) + ax.fill(xs, ys, facecolor=color, edgecolor=None) if not cb_kwargs: cb_kwargs = dict() @@ -272,8 +292,8 @@ def heatmap(data, scale, vmin=None, vmax=None, cmap=None, ax=None, def heatmapf(func, scale=10, boundary=True, cmap=None, ax=None, scientific=False, style='triangular', colorbar=True, - permutation=None, vmin=None, vmax=None, cbarlabel=None, - cb_kwargs=None): + permutation=None, vmin=None, vmax=None, adj_vlims=False, + cbarlabel=None, cb_kwargs=None): """ Computes func on heatmap partition coordinates and plots heatmap. In other words, computes the function on lattice points of the simplex (normalized @@ -303,6 +323,8 @@ def heatmapf(func, scale=10, boundary=True, cmap=None, ax=None, The minimum color value, used to normalize colors. vmax: float The maximum color value, used to normalize colors. + adj_vlims: bool, False + Redefine min and max color values based on computed averages. cb_kwargs: dict dict of kwargs to pass to colorbar @@ -318,7 +340,8 @@ def heatmapf(func, scale=10, boundary=True, cmap=None, ax=None, # Pass everything to the heatmapper ax = heatmap(data, scale, cmap=cmap, ax=ax, style=style, scientific=scientific, colorbar=colorbar, - permutation=permutation, vmin=vmin, vmax=vmax, + permutation=permutation, vmin=vmin, vmax=vmax, + adj_vlims=adj_vlims, cbarlabel=cbarlabel, cb_kwargs=cb_kwargs) return ax @@ -347,8 +370,8 @@ def svg_polygon(coordinates, color): return polygon -def svg_heatmap(data, scale, filename, vmax=None, vmin=None, style='h', - permutation=None, cmap=None): +def svg_heatmap(data, scale, filename, vmax=None, vmin=None, adj_vlims=False, + style='h', permutation=None, cmap=None): """ Create a heatmap in SVG format. Intended for use with very large datasets, which would require large amounts of RAM using matplotlib. You can convert @@ -370,6 +393,8 @@ def svg_heatmap(data, scale, filename, vmax=None, vmin=None, style='h', The minimum color value, used to normalize colors. vmax: float The maximum color value, used to normalize colors. + adj_vlims: bool, False + Redefine min and max color values based on computed averages. cmap: String or matplotlib.colors.Colormap, None The name of the Matplotlib colormap to use. style: String, "h" diff --git a/ternary/ternary_axes_subplot.py b/ternary/ternary_axes_subplot.py index dd56f9b..f28a814 100644 --- a/ternary/ternary_axes_subplot.py +++ b/ternary/ternary_axes_subplot.py @@ -436,7 +436,7 @@ def plot_colored_trajectory(self, points, cmap=None, **kwargs): def heatmap(self, data, scale=None, cmap=None, scientific=False, style='triangular', colorbar=True, use_rgba=False, - vmin=None, vmax=None, cbarlabel=None, cb_kwargs=None): + vmin=None, vmax=None, adj_vlims=False, cbarlabel=None, cb_kwargs=None): permutation = self._permutation if not scale: scale = self.get_scale() @@ -446,12 +446,12 @@ def heatmap(self, data, scale=None, cmap=None, scientific=False, heatmapping.heatmap(data, scale, cmap=cmap, style=style, ax=ax, scientific=scientific, colorbar=colorbar, permutation=permutation, use_rgba=use_rgba, - vmin=vmin, vmax=vmax, cbarlabel=cbarlabel, + vmin=vmin, vmax=vmax, adj_vlims=adj_vlims, cbarlabel=cbarlabel, cb_kwargs=cb_kwargs) def heatmapf(self, func, scale=None, cmap=None, boundary=True, style='triangular', colorbar=True, scientific=False, - vmin=None, vmax=None, cbarlabel=None, cb_kwargs=None): + vmin=None, vmax=None, adj_vlims=False, cbarlabel=None, cb_kwargs=None): if not scale: scale = self.get_scale() if style.lower()[0] == 'd':