diff --git a/ternary/plotting.py b/ternary/plotting.py index 10550ea..12e539a 100644 --- a/ternary/plotting.py +++ b/ternary/plotting.py @@ -68,8 +68,7 @@ def plot(points, ax=None, permutation=None, **kwargs): if not ax: fig, ax = plt.subplots() xs, ys = project_sequence(points, permutation=permutation) - ax.plot(xs, ys, **kwargs) - return ax + return ax.plot(xs, ys, **kwargs) def plot_colored_trajectory(points, cmap=None, ax=None, permutation=None, @@ -111,9 +110,7 @@ def plot_colored_trajectory(points, cmap=None, ax=None, permutation=None, line_segments = matplotlib.collections.LineCollection(segments, cmap=cmap, **kwargs) line_segments.set_array(np.arange(len(segments))) - ax.add_collection(line_segments) - - return ax + return ax.add_collection(line_segments) def scatter(points, ax=None, permutation=None, colorbar=False, colormap=None, @@ -145,7 +142,7 @@ def scatter(points, ax=None, permutation=None, colorbar=False, colormap=None, if not ax: fig, ax = plt.subplots() xs, ys = project_sequence(points, permutation=permutation) - ax.scatter(xs, ys, vmin=vmin, vmax=vmax, **kwargs) + ax_points = ax.scatter(xs, ys, vmin=vmin, vmax=vmax, **kwargs) if colorbar and (colormap != None): if cb_kwargs != None: @@ -155,4 +152,4 @@ def scatter(points, ax=None, permutation=None, colorbar=False, colormap=None, colorbar_hack(ax, vmin, vmax, colormap, scientific=scientific, cbarlabel=cbarlabel) - return ax + return ax_points diff --git a/ternary/ternary_axes_subplot.py b/ternary/ternary_axes_subplot.py index dd56f9b..2572a97 100644 --- a/ternary/ternary_axes_subplot.py +++ b/ternary/ternary_axes_subplot.py @@ -320,7 +320,7 @@ def close(self): def legend(self, *args, **kwargs): ax = self.get_axes() - ax.legend(*args, **kwargs) + return ax.legend(*args, **kwargs) def savefig(self, filename, **kwargs): self._redraw_labels() @@ -425,14 +425,14 @@ def scatter(self, points, **kwargs): def plot(self, points, **kwargs): ax = self.get_axes() permutation = self._permutation - plotting.plot(points, ax=ax, permutation=permutation, - **kwargs) + return plotting.plot(points, ax=ax, permutation=permutation, + **kwargs) def plot_colored_trajectory(self, points, cmap=None, **kwargs): ax = self.get_axes() permutation = self._permutation - plotting.plot_colored_trajectory(points, cmap=cmap, ax=ax, - permutation=permutation, **kwargs) + return plotting.plot_colored_trajectory(points, cmap=cmap, ax=ax, + permutation=permutation, **kwargs) def heatmap(self, data, scale=None, cmap=None, scientific=False, style='triangular', colorbar=True, use_rgba=False,