From c79e89bb2aee5771aa276943e852d735096ac079 Mon Sep 17 00:00:00 2001 From: George Zaki Date: Wed, 26 Nov 2025 14:01:00 -0500 Subject: [PATCH] refactor(templates): use pio.to_image() for relational heatmap PNG export MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Remove matplotlib workaround that recreated plots and use Plotly's native image export instead. This simplifies the code while maintaining output quality and fixing font size consistency between PNG and HTML. - Remove get_plotly_colorscale_as_matplotlib() and create_matplotlib_heatmap_matching_plotly() functions - Add direct PNG export using pio.to_image() with proper scale factor - Update template_utils._save_single_object() to handle raw image bytes - Replace test suite with comprehensive integration tests - Fix font size inconsistency between PNG and HTML outputs - Reduce code complexity by ~30% (200 lines → 140 lines) The refactored implementation uses kaleido for headless PNG generation, which works correctly in Galaxy's headless environment. --- .../templates/relational_heatmap_template.py | 196 +++------ src/spac/templates/template_utils.py | 16 +- .../test_relational_heatmap_template.py | 415 ++++++++++-------- 3 files changed, 318 insertions(+), 309 deletions(-) diff --git a/src/spac/templates/relational_heatmap_template.py b/src/spac/templates/relational_heatmap_template.py index fbf57218..bad826a0 100644 --- a/src/spac/templates/relational_heatmap_template.py +++ b/src/spac/templates/relational_heatmap_template.py @@ -1,19 +1,12 @@ """ -Relational Heatmap with Plotly-matplotlib color synchronization. -Extracts actual colors from Plotly and uses them in matplotlib. +Relational Heatmap template with Plotly figure export. +Generates both static PNG snapshots and interactive HTML outputs. """ import json import sys from pathlib import Path from typing import Any, Dict, Union, Tuple -import pandas as pd -import numpy as np -import matplotlib -matplotlib.use('Agg') -import matplotlib.pyplot as plt -import matplotlib.colors as mcolors import plotly.io as pio -import plotly.express as px sys.path.append(str(Path(__file__).parent.parent.parent)) @@ -26,112 +19,35 @@ ) -def get_plotly_colorscale_as_matplotlib(plotly_colormap: str) -> mcolors.LinearSegmentedColormap: - """ - Extract actual colors from Plotly colorscale and create matplotlib colormap. - This ensures exact color matching between Plotly and matplotlib. - """ - # Get Plotly's colorscale - try: - # Use plotly express to get the actual color sequence - colorscale = getattr(px.colors.sequential, plotly_colormap, None) - if colorscale is None: - colorscale = getattr(px.colors.diverging, plotly_colormap, None) - if colorscale is None: - colorscale = getattr(px.colors.cyclical, plotly_colormap, None) - - if colorscale is None: - # Fallback to a default - print(f"Warning: Could not find Plotly colorscale '{plotly_colormap}', using default") - colorscale = px.colors.sequential.Viridis - - # Convert to matplotlib colormap - if isinstance(colorscale, list): - # Create custom colormap from color list - cmap = mcolors.LinearSegmentedColormap.from_list( - f"plotly_{plotly_colormap}", - colorscale - ) - return cmap - except Exception as e: - print(f"Error extracting Plotly colors: {e}") - - # Fallback to matplotlib's viridis - return plt.cm.viridis - - -def create_matplotlib_heatmap_matching_plotly( - data: pd.DataFrame, - plotly_fig: Any, - source_annotation: str, - target_annotation: str, - colormap_name: str, - figsize: tuple, - dpi: int, - font_size: int -) -> plt.Figure: - """ - Create matplotlib heatmap that matches Plotly's appearance. - Extracts color information from the Plotly figure. - """ - fig, ax = plt.subplots(figsize=figsize, dpi=dpi) - - # Get the actual colormap from Plotly - cmap = get_plotly_colorscale_as_matplotlib(colormap_name) - - # Extract data range from Plotly figure if possible - try: - zmin = plotly_fig.data[0].zmin if hasattr(plotly_fig.data[0], 'zmin') else data.min().min() - zmax = plotly_fig.data[0].zmax if hasattr(plotly_fig.data[0], 'zmax') else data.max().max() - except: - zmin, zmax = data.min().min(), data.max().max() - - # Create heatmap matching Plotly's style - im = ax.imshow( - data.values, - aspect='auto', - cmap=cmap, - interpolation='nearest', - vmin=zmin, - vmax=zmax - ) - - # Match Plotly's tick placement - ax.set_xticks(np.arange(len(data.columns))) - ax.set_yticks(np.arange(len(data.index))) - ax.set_xticklabels(data.columns, rotation=45, ha='right', fontsize=font_size) - ax.set_yticklabels(data.index, fontsize=font_size) - - # Add colorbar - cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04) - cbar.set_label('Count', fontsize=font_size) - cbar.ax.tick_params(labelsize=font_size) - - # Title matching Plotly - ax.set_title( - f'Relational Heatmap: {source_annotation} vs {target_annotation}', - fontsize=font_size + 2, - pad=20 - ) - ax.set_xlabel(target_annotation, fontsize=font_size) - ax.set_ylabel(source_annotation, fontsize=font_size) - - # Add grid for clarity (like Plotly) - ax.set_xticks(np.arange(len(data.columns) + 1) - 0.5, minor=True) - ax.set_yticks(np.arange(len(data.index) + 1) - 0.5, minor=True) - ax.grid(which='minor', color='gray', linestyle='-', linewidth=0.3, alpha=0.3) - ax.tick_params(which='both', length=0) - - plt.tight_layout() - return fig - - def run_from_json( json_path: Union[str, Path, Dict[str, Any]], save_to_disk: bool = True, output_dir: str = None ) -> Union[Dict, Tuple]: - """Execute Relational Heatmap with color-matched outputs.""" + """ + Execute Relational Heatmap analysis. + + Generates a relational heatmap showing relationships between + two annotations, with outputs in three formats: + - PNG snapshot of the Plotly figure + - Interactive HTML version + - CSV data matrix + + Parameters + ---------- + json_path : str, Path, or dict + Path to JSON file, JSON string, or parameter dictionary + save_to_disk : bool, optional + Whether to save results to disk. Default is True. + output_dir : str, optional + Override output directory from params + + Returns + ------- + dict or tuple + If save_to_disk=True: Dictionary of saved file paths + If save_to_disk=False: Tuple of (plotly_fig, dataframe) + """ params = parse_params(json_path) @@ -150,8 +66,12 @@ def run_from_json( print(f"Data loaded: {adata.shape[0]} cells, {adata.shape[1]} genes") # Parameters - source_annotation = text_to_value(params.get("Source_Annotation_Name", "None")) - target_annotation = text_to_value(params.get("Target_Annotation_Name", "None")) + source_annotation = text_to_value( + params.get("Source_Annotation_Name", "None") + ) + target_annotation = text_to_value( + params.get("Target_Annotation_Name", "None") + ) dpi = float(params.get("Figure_DPI", 300)) width_in = float(params.get("Figure_Width_inch", 8)) @@ -173,7 +93,11 @@ def run_from_json( rhmap_data = result_dict['data'] plotly_fig = result_dict['figure'] - # Update Plotly figure + # Calculate scale factor for high-DPI export + # Plotly's default is 96 DPI, so scale relative to that + scale_factor = dpi / 96.0 + + # Update Plotly figure dimensions and styling for HTML display if plotly_fig: plotly_fig.update_layout( width=width_in * 96, @@ -181,28 +105,35 @@ def run_from_json( font=dict(size=font_size) ) - # Create matplotlib figure that matches Plotly's colors - print("Creating color-matched matplotlib figure...") - static_fig = create_matplotlib_heatmap_matching_plotly( - rhmap_data, - plotly_fig, - source_annotation, - target_annotation, - colormap, - (width_in, height_in), - int(dpi), - int(font_size) - ) - if save_to_disk: + # Generate PNG snapshot directly from Plotly figure + print("Generating PNG snapshot from Plotly figure...") + img_bytes = pio.to_image( + plotly_fig, + format='png', + width=int(width_in * 96), # Use base dimensions + height=int(height_in * 96), + scale=scale_factor # Scale up for higher DPI + ) + + # Prepare outputs results_dict = { - "figures": {"relational_heatmap": static_fig}, - "html": {"relational_heatmap": pio.to_html(plotly_fig, full_html=True, include_plotlyjs='cdn')}, + "figures": {"relational_heatmap": img_bytes}, + "html": { + "relational_heatmap": pio.to_html( + plotly_fig, + full_html=True, + include_plotlyjs='cdn' + ) + }, "dataframe": rhmap_data } - saved_files = save_results(results_dict, params, output_base_dir=output_dir) - plt.close(static_fig) + saved_files = save_results( + results_dict, + params, + output_base_dir=output_dir + ) print("✓ Relational Heatmap completed") return saved_files @@ -212,7 +143,10 @@ def run_from_json( if __name__ == "__main__": if len(sys.argv) < 2: - print("Usage: python relational_heatmap_template.py ", file=sys.stderr) + print( + "Usage: python relational_heatmap_template.py ", + file=sys.stderr + ) sys.exit(1) try: diff --git a/src/spac/templates/template_utils.py b/src/spac/templates/template_utils.py index 80c81025..afa066aa 100644 --- a/src/spac/templates/template_utils.py +++ b/src/spac/templates/template_utils.py @@ -272,7 +272,17 @@ def _save_single_object(obj: Any, name: str, output_dir: Path) -> Path: Path to saved file """ # Determine file format based on object type - if isinstance(obj, pd.DataFrame): + if isinstance(obj, bytes): + # Raw bytes (e.g., PNG image data from Plotly) + # Determine extension from name or default to .png + image_exts = ['.png', '.jpg', '.jpeg', '.pdf', '.svg'] + if not any(name.endswith(ext) for ext in image_exts): + name = f"{name}.png" + filepath = output_dir / name + with open(filepath, 'wb') as f: + f.write(obj) + + elif isinstance(obj, pd.DataFrame): # DataFrames -> CSV if not name.endswith('.csv'): name = f"{name}.csv" @@ -287,7 +297,9 @@ def _save_single_object(obj: Any, name: str, output_dir: Path) -> Path: obj.savefig(filepath, dpi=300, bbox_inches='tight') plt.close(obj) # Close figure to free memory - elif isinstance(obj, str) and (' ad.AnnData: - """Return a minimal synthetic AnnData for fast tests.""" - rng = np.random.default_rng(0) +def create_test_adata(n_cells: int = 100) -> ad.AnnData: + """Create a realistic test AnnData object.""" + rng = np.random.default_rng(42) + + # Create observations with two categorical annotations obs = pd.DataFrame({ - "phenograph_k60_r1": (["cluster1", "cluster2", "cluster3"] * - ((n_cells + 2) // 3))[:n_cells], - "renamed_phenotypes": (["phenotype_A", "phenotype_B"] * - ((n_cells + 1) // 2))[:n_cells] + "phenograph_k60_r1": rng.choice( + ["cluster1", "cluster2", "cluster3"], + n_cells + ), + "renamed_phenotypes": rng.choice( + ["phenotype_A", "phenotype_B", "phenotype_C"], + n_cells + ) }) - x_mat = rng.normal(size=(n_cells, 3)) + + # Create expression matrix + x_mat = rng.normal(size=(n_cells, 20)) + + # Create AnnData object adata = ad.AnnData(X=x_mat, obs=obs) - adata.var_names = ["Gene1", "Gene2", "Gene3"] + adata.var_names = [f"Gene{i}" for i in range(20)] + return adata -class TestRelationalHeatmapTemplate(unittest.TestCase): - """Unit tests for the Relational Heatmap template.""" +class TestRelationalHeatmapTemplateRefactored(unittest.TestCase): + """Test suite for the refactored relational heatmap template.""" def setUp(self) -> None: + """Set up test fixtures.""" self.tmp_dir = tempfile.TemporaryDirectory() - self.in_file = os.path.join( - self.tmp_dir.name, "input.pickle" - ) - self.out_file = "relational_heatmap" - - # Save minimal mock data - with open(self.in_file, 'wb') as f: - pickle.dump(mock_adata(), f) - - # Minimal parameters from NIDAP template + self.tmp_path = Path(self.tmp_dir.name) + + # Create test data file + self.input_file = self.tmp_path / "input.pickle" + test_adata = create_test_adata(n_cells=100) + with open(self.input_file, 'wb') as f: + pickle.dump(test_adata, f) + + # Define test parameters self.params = { - "Upstream_Analysis": self.in_file, + "Upstream_Analysis": str(self.input_file), "Source_Annotation_Name": "phenograph_k60_r1", "Target_Annotation_Name": "renamed_phenotypes", "Colormap": "darkmint", "Figure_Width_inch": 8, "Figure_Height_inch": 10, - "Figure_DPI": 300, + "Figure_DPI": 150, # Lower DPI for faster tests "Font_Size": 8, - "Output_File": self.out_file, + "Output_File": "relational_heatmap", + "Output_Directory": str(self.tmp_path), + "outputs": { + "figures": {"type": "directory", "name": "figures_dir"}, + "html": {"type": "directory", "name": "html_dir"}, + "dataframe": {"type": "file", "name": "dataframe.csv"} + } } def tearDown(self) -> None: + """Clean up test fixtures.""" self.tmp_dir.cleanup() - @patch('spac.templates.relational_heatmap_template.relational_heatmap') - @patch('plotly.io.write_image') - @patch('matplotlib.pyplot.show') # Mock plt.show() - def test_complete_io_workflow( - self, mock_plt_show, mock_write_image, mock_relational - ) -> None: - """Single I/O test covering input/output scenarios.""" - # Mock the relational_heatmap function - mock_fig = Mock() - mock_fig.show = Mock() # Mock the fig.show() method - - mock_df = pd.DataFrame({ - 'source': ['cluster1', 'cluster2'], - 'target': ['phenotype_A', 'phenotype_B'], - 'value': [5, 3] - }) - - mock_relational.return_value = { - 'file_name': 'relational_heatmap.csv', - 'data': mock_df, - 'figure': mock_fig - } + def test_complete_workflow_with_file_outputs(self): + """Test complete workflow with actual file outputs.""" + print("\n=== Testing Complete Workflow ===") - # Mock the plotly write_image to create a dummy image - def create_dummy_image(fig, path, **kwargs): - # Create a minimal PNG file - # Ensure file path exists (for NamedTemporaryFile) - if not os.path.exists(path): - # Create parent directory if needed - os.makedirs(os.path.dirname(path), exist_ok=True) - fig_dummy, ax = plt.subplots(figsize=(1, 1)) - ax.text(0.5, 0.5, 'test', ha='center', va='center') - plt.savefig(path, dpi=72) - plt.close(fig_dummy) - - mock_write_image.side_effect = create_dummy_image - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - - # Test 1: Run with default parameters - result = run_from_json(self.params) - self.assertIsInstance(result, dict) - # Should have both CSV and PNG files - self.assertEqual(len(result), 2) - csv_files = [f for f in result.keys() if f.endswith('.csv')] - png_files = [f for f in result.keys() if f.endswith('.png')] - self.assertEqual(len(csv_files), 1) - self.assertEqual(len(png_files), 1) - - # Test 2: Run without saving - result_no_save = run_from_json( - self.params, save_results=False + # Run the template + result = run_from_json(self.params, save_to_disk=True) + + # Check that result is a dictionary + self.assertIsInstance(result, dict) + print(f"✓ Result is a dictionary with keys: {list(result.keys())}") + + # Check that all expected outputs are present + self.assertIn("figures", result) + self.assertIn("html", result) + self.assertIn("dataframe", result) + print("✓ All expected output types present") + + # Check figures directory + figures_list = result["figures"] + self.assertIsInstance(figures_list, list) + self.assertGreater(len(figures_list), 0) + + # Verify PNG file exists and has content + png_file = Path(figures_list[0]) + self.assertTrue(png_file.exists()) + self.assertTrue(png_file.suffix == '.png') + png_size = png_file.stat().st_size + self.assertGreater(png_size, 1000) # Should be at least 1KB + print(f"✓ PNG file created: {png_file.name} ({png_size} bytes)") + + # Check HTML directory + html_list = result["html"] + self.assertIsInstance(html_list, list) + self.assertGreater(len(html_list), 0) + + # Verify HTML file exists and has content + html_file = Path(html_list[0]) + self.assertTrue(html_file.exists()) + self.assertTrue(html_file.suffix == '.html') + + # Check HTML content + with open(html_file, 'r') as f: + html_content = f.read() + self.assertIn('plotly', html_content.lower()) + self.assertGreater(len(html_content), 1000) + print(f"✓ HTML file created: {html_file.name} " + f"({len(html_content)} chars)") + + # Check dataframe file + df_file = Path(result["dataframe"]) + self.assertTrue(df_file.exists()) + self.assertTrue(df_file.suffix == '.csv') + + # Verify CSV content + df = pd.read_csv(df_file) + self.assertGreater(len(df), 0) + self.assertGreater(len(df.columns), 0) + print(f"✓ CSV file created: {df_file.name} " + f"({df.shape[0]} rows, {df.shape[1]} cols)") + + print("✓ Complete workflow test passed!") + + def test_no_save_returns_figure_and_data(self): + """Test that save_to_disk=False returns figure and dataframe.""" + print("\n=== Testing No-Save Mode ===") + + # Run without saving + result = run_from_json(self.params, save_to_disk=False) + + # Check result is a tuple + self.assertIsInstance(result, tuple) + self.assertEqual(len(result), 2) + print("✓ Result is a tuple of length 2") + + # Unpack result + fig, df = result + + # Check figure is a Plotly figure + self.assertTrue(hasattr(fig, 'data')) + self.assertTrue(hasattr(fig, 'layout')) + print(f"✓ Returned Plotly figure: {type(fig).__name__}") + + # Check dataframe + self.assertIsInstance(df, pd.DataFrame) + self.assertGreater(len(df), 0) + print(f"✓ Returned DataFrame: {df.shape[0]} rows, " + f"{df.shape[1]} cols") + + print("✓ No-save mode test passed!") + + def test_json_file_input(self): + """Test reading parameters from JSON file.""" + print("\n=== Testing JSON File Input ===") + + # Write params to JSON file + json_file = self.tmp_path / "params.json" + with open(json_file, 'w') as f: + json.dump(self.params, f) + print(f"✓ Created JSON file: {json_file}") + + # Run with JSON file path + result = run_from_json(str(json_file), save_to_disk=True) + + # Verify outputs + self.assertIsInstance(result, dict) + self.assertIn("figures", result) + print("✓ Successfully loaded params from JSON file") + + print("✓ JSON file input test passed!") + + def test_png_image_quality(self): + """Test that PNG image is properly generated with correct size.""" + print("\n=== Testing PNG Image Quality ===") + + # Run with specific dimensions + self.params["Figure_DPI"] = 300 + self.params["Figure_Width_inch"] = 10 + self.params["Figure_Height_inch"] = 12 + + result = run_from_json(self.params, save_to_disk=True) + + # Get PNG file + png_file = Path(result["figures"][0]) + png_size = png_file.stat().st_size + + # Higher DPI should produce larger file + self.assertGreater(png_size, 10000) # At least 10KB + print(f"✓ High-quality PNG: {png_size} bytes") + + # Try to verify it's a valid PNG by checking magic bytes + with open(png_file, 'rb') as f: + header = f.read(8) + self.assertEqual(header[:4], b'\x89PNG') + print("✓ Valid PNG file header") + + print("✓ PNG quality test passed!") + + def test_different_colormaps(self): + """Test with different colormap options.""" + print("\n=== Testing Different Colormaps ===") + + for colormap in ['viridis', 'plasma', 'darkmint']: + print(f" Testing colormap: {colormap}") + self.params["Colormap"] = colormap + self.params["Output_Directory"] = str( + self.tmp_path / f"output_{colormap}" ) - # Check appropriate return type - should be (figure, dataframe) - self.assertIsInstance(result_no_save, tuple) - self.assertEqual(len(result_no_save), 2) - fig, df = result_no_save - self.assertIsInstance(df, pd.DataFrame) - - # Test 3: JSON file input - json_path = os.path.join(self.tmp_dir.name, "params.json") - with open(json_path, "w") as f: - json.dump(self.params, f) - - result_json = run_from_json(json_path) - self.assertIsInstance(result_json, dict) - - @patch('matplotlib.pyplot.show') # Mock plt.show() - def test_error_validation(self, mock_plt_show) -> None: - """Test exact error message for invalid parameters.""" - # Test with None annotations (should be handled by text_to_value) - params_none = self.params.copy() - params_none["Source_Annotation_Name"] = "None" - params_none["Target_Annotation_Name"] = "None" - - with patch('spac.templates.relational_heatmap_template.' - 'relational_heatmap') as mock_rel: - # The template should pass None values to the function - mock_rel.return_value = { - 'file_name': 'test.csv', - 'data': pd.DataFrame(), - 'figure': Mock(show=Mock()) # Mock fig.show() - } - # Mock write_image to create a dummy file - def create_dummy_image(fig, path, **kwargs): - # Create a minimal PNG file - # Ensure file path exists (for NamedTemporaryFile) - if not os.path.exists(path): - # Create parent directory if needed - os.makedirs(os.path.dirname(path), exist_ok=True) - fig_dummy, ax = plt.subplots(figsize=(1, 1)) - ax.text(0.5, 0.5, 'test', ha='center', va='center') - plt.savefig(path, dpi=72) - plt.close(fig_dummy) + result = run_from_json(self.params, save_to_disk=True) - with patch('plotly.io.write_image', - side_effect=create_dummy_image): - run_from_json(params_none) - - # Verify None was passed - call_args = mock_rel.call_args - self.assertIsNone(call_args[1]['source_annotation']) - self.assertIsNone(call_args[1]['target_annotation']) - - @patch('spac.templates.relational_heatmap_template.relational_heatmap') - @patch('plotly.io.write_image') - @patch('matplotlib.pyplot.show') # Mock plt.show() - def test_function_calls( - self, mock_plt_show, mock_write_image, mock_relational - ) -> None: - """Test that main function is called with correct parameters.""" - # Mock the main function - mock_relational.return_value = { - 'file_name': 'test.csv', - 'data': pd.DataFrame({'a': [1, 2]}), - 'figure': Mock(show=Mock()) # Mock fig.show() - } + # Verify outputs created + self.assertIsInstance(result, dict) + png_file = Path(result["figures"][0]) + self.assertTrue(png_file.exists()) + print(f" ✓ {colormap}: {png_file.stat().st_size} bytes") - # Mock write_image to create a dummy file - def create_dummy_image(fig, path, **kwargs): - # Create a minimal PNG file - # Ensure file path exists (for NamedTemporaryFile) - if not os.path.exists(path): - # Create parent directory if needed - os.makedirs(os.path.dirname(path), exist_ok=True) - fig_dummy, ax = plt.subplots(figsize=(1, 1)) - ax.text(0.5, 0.5, 'test', ha='center', va='center') - plt.savefig(path, dpi=72) - plt.close(fig_dummy) - - mock_write_image.side_effect = create_dummy_image - - run_from_json(self.params, save_results=False) - - # Verify function was called correctly - mock_relational.assert_called_once() - call_args = mock_relational.call_args - - # Check specific parameters - self.assertEqual( - call_args[1]['source_annotation'], 'phenograph_k60_r1' + print("✓ Colormap test passed!") + + +def run_specific_test(test_name: str = None): + """Run a specific test or all tests.""" + loader = unittest.TestLoader() + suite = unittest.TestSuite() + + if test_name: + # Run specific test + suite.addTest( + TestRelationalHeatmapTemplateRefactored(test_name) ) - self.assertEqual( - call_args[1]['target_annotation'], 'renamed_phenotypes' + else: + # Run all tests + suite.addTests( + loader.loadTestsFromTestCase( + TestRelationalHeatmapTemplateRefactored + ) ) - self.assertEqual(call_args[1]['color_map'], 'darkmint') - self.assertEqual(call_args[1]['font_size'], 8) + + runner = unittest.TextTestRunner(verbosity=2) + result = runner.run(suite) + return result if __name__ == "__main__": - unittest.main() \ No newline at end of file + print("=" * 70) + print("Relational Heatmap Template - Refactored Tests") + print("Testing new implementation with pio.to_image()") + print("=" * 70) + + # Run all tests + result = run_specific_test() + + print("\n" + "=" * 70) + if result.wasSuccessful(): + print("✓ ALL TESTS PASSED!") + else: + print("✗ SOME TESTS FAILED") + print(f" Failures: {len(result.failures)}") + print(f" Errors: {len(result.errors)}") + print("=" * 70) + + sys.exit(0 if result.wasSuccessful() else 1)