diff --git a/src/spac/templates/relational_heatmap_template.py b/src/spac/templates/relational_heatmap_template.py index fbf57218..bad826a0 100644 --- a/src/spac/templates/relational_heatmap_template.py +++ b/src/spac/templates/relational_heatmap_template.py @@ -1,19 +1,12 @@ """ -Relational Heatmap with Plotly-matplotlib color synchronization. -Extracts actual colors from Plotly and uses them in matplotlib. +Relational Heatmap template with Plotly figure export. +Generates both static PNG snapshots and interactive HTML outputs. """ import json import sys from pathlib import Path from typing import Any, Dict, Union, Tuple -import pandas as pd -import numpy as np -import matplotlib -matplotlib.use('Agg') -import matplotlib.pyplot as plt -import matplotlib.colors as mcolors import plotly.io as pio -import plotly.express as px sys.path.append(str(Path(__file__).parent.parent.parent)) @@ -26,112 +19,35 @@ ) -def get_plotly_colorscale_as_matplotlib(plotly_colormap: str) -> mcolors.LinearSegmentedColormap: - """ - Extract actual colors from Plotly colorscale and create matplotlib colormap. - This ensures exact color matching between Plotly and matplotlib. - """ - # Get Plotly's colorscale - try: - # Use plotly express to get the actual color sequence - colorscale = getattr(px.colors.sequential, plotly_colormap, None) - if colorscale is None: - colorscale = getattr(px.colors.diverging, plotly_colormap, None) - if colorscale is None: - colorscale = getattr(px.colors.cyclical, plotly_colormap, None) - - if colorscale is None: - # Fallback to a default - print(f"Warning: Could not find Plotly colorscale '{plotly_colormap}', using default") - colorscale = px.colors.sequential.Viridis - - # Convert to matplotlib colormap - if isinstance(colorscale, list): - # Create custom colormap from color list - cmap = mcolors.LinearSegmentedColormap.from_list( - f"plotly_{plotly_colormap}", - colorscale - ) - return cmap - except Exception as e: - print(f"Error extracting Plotly colors: {e}") - - # Fallback to matplotlib's viridis - return plt.cm.viridis - - -def create_matplotlib_heatmap_matching_plotly( - data: pd.DataFrame, - plotly_fig: Any, - source_annotation: str, - target_annotation: str, - colormap_name: str, - figsize: tuple, - dpi: int, - font_size: int -) -> plt.Figure: - """ - Create matplotlib heatmap that matches Plotly's appearance. - Extracts color information from the Plotly figure. - """ - fig, ax = plt.subplots(figsize=figsize, dpi=dpi) - - # Get the actual colormap from Plotly - cmap = get_plotly_colorscale_as_matplotlib(colormap_name) - - # Extract data range from Plotly figure if possible - try: - zmin = plotly_fig.data[0].zmin if hasattr(plotly_fig.data[0], 'zmin') else data.min().min() - zmax = plotly_fig.data[0].zmax if hasattr(plotly_fig.data[0], 'zmax') else data.max().max() - except: - zmin, zmax = data.min().min(), data.max().max() - - # Create heatmap matching Plotly's style - im = ax.imshow( - data.values, - aspect='auto', - cmap=cmap, - interpolation='nearest', - vmin=zmin, - vmax=zmax - ) - - # Match Plotly's tick placement - ax.set_xticks(np.arange(len(data.columns))) - ax.set_yticks(np.arange(len(data.index))) - ax.set_xticklabels(data.columns, rotation=45, ha='right', fontsize=font_size) - ax.set_yticklabels(data.index, fontsize=font_size) - - # Add colorbar - cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04) - cbar.set_label('Count', fontsize=font_size) - cbar.ax.tick_params(labelsize=font_size) - - # Title matching Plotly - ax.set_title( - f'Relational Heatmap: {source_annotation} vs {target_annotation}', - fontsize=font_size + 2, - pad=20 - ) - ax.set_xlabel(target_annotation, fontsize=font_size) - ax.set_ylabel(source_annotation, fontsize=font_size) - - # Add grid for clarity (like Plotly) - ax.set_xticks(np.arange(len(data.columns) + 1) - 0.5, minor=True) - ax.set_yticks(np.arange(len(data.index) + 1) - 0.5, minor=True) - ax.grid(which='minor', color='gray', linestyle='-', linewidth=0.3, alpha=0.3) - ax.tick_params(which='both', length=0) - - plt.tight_layout() - return fig - - def run_from_json( json_path: Union[str, Path, Dict[str, Any]], save_to_disk: bool = True, output_dir: str = None ) -> Union[Dict, Tuple]: - """Execute Relational Heatmap with color-matched outputs.""" + """ + Execute Relational Heatmap analysis. + + Generates a relational heatmap showing relationships between + two annotations, with outputs in three formats: + - PNG snapshot of the Plotly figure + - Interactive HTML version + - CSV data matrix + + Parameters + ---------- + json_path : str, Path, or dict + Path to JSON file, JSON string, or parameter dictionary + save_to_disk : bool, optional + Whether to save results to disk. Default is True. + output_dir : str, optional + Override output directory from params + + Returns + ------- + dict or tuple + If save_to_disk=True: Dictionary of saved file paths + If save_to_disk=False: Tuple of (plotly_fig, dataframe) + """ params = parse_params(json_path) @@ -150,8 +66,12 @@ def run_from_json( print(f"Data loaded: {adata.shape[0]} cells, {adata.shape[1]} genes") # Parameters - source_annotation = text_to_value(params.get("Source_Annotation_Name", "None")) - target_annotation = text_to_value(params.get("Target_Annotation_Name", "None")) + source_annotation = text_to_value( + params.get("Source_Annotation_Name", "None") + ) + target_annotation = text_to_value( + params.get("Target_Annotation_Name", "None") + ) dpi = float(params.get("Figure_DPI", 300)) width_in = float(params.get("Figure_Width_inch", 8)) @@ -173,7 +93,11 @@ def run_from_json( rhmap_data = result_dict['data'] plotly_fig = result_dict['figure'] - # Update Plotly figure + # Calculate scale factor for high-DPI export + # Plotly's default is 96 DPI, so scale relative to that + scale_factor = dpi / 96.0 + + # Update Plotly figure dimensions and styling for HTML display if plotly_fig: plotly_fig.update_layout( width=width_in * 96, @@ -181,28 +105,35 @@ def run_from_json( font=dict(size=font_size) ) - # Create matplotlib figure that matches Plotly's colors - print("Creating color-matched matplotlib figure...") - static_fig = create_matplotlib_heatmap_matching_plotly( - rhmap_data, - plotly_fig, - source_annotation, - target_annotation, - colormap, - (width_in, height_in), - int(dpi), - int(font_size) - ) - if save_to_disk: + # Generate PNG snapshot directly from Plotly figure + print("Generating PNG snapshot from Plotly figure...") + img_bytes = pio.to_image( + plotly_fig, + format='png', + width=int(width_in * 96), # Use base dimensions + height=int(height_in * 96), + scale=scale_factor # Scale up for higher DPI + ) + + # Prepare outputs results_dict = { - "figures": {"relational_heatmap": static_fig}, - "html": {"relational_heatmap": pio.to_html(plotly_fig, full_html=True, include_plotlyjs='cdn')}, + "figures": {"relational_heatmap": img_bytes}, + "html": { + "relational_heatmap": pio.to_html( + plotly_fig, + full_html=True, + include_plotlyjs='cdn' + ) + }, "dataframe": rhmap_data } - saved_files = save_results(results_dict, params, output_base_dir=output_dir) - plt.close(static_fig) + saved_files = save_results( + results_dict, + params, + output_base_dir=output_dir + ) print("✓ Relational Heatmap completed") return saved_files @@ -212,7 +143,10 @@ def run_from_json( if __name__ == "__main__": if len(sys.argv) < 2: - print("Usage: python relational_heatmap_template.py ", file=sys.stderr) + print( + "Usage: python relational_heatmap_template.py ", + file=sys.stderr + ) sys.exit(1) try: diff --git a/src/spac/templates/template_utils.py b/src/spac/templates/template_utils.py index 80c81025..afa066aa 100644 --- a/src/spac/templates/template_utils.py +++ b/src/spac/templates/template_utils.py @@ -272,7 +272,17 @@ def _save_single_object(obj: Any, name: str, output_dir: Path) -> Path: Path to saved file """ # Determine file format based on object type - if isinstance(obj, pd.DataFrame): + if isinstance(obj, bytes): + # Raw bytes (e.g., PNG image data from Plotly) + # Determine extension from name or default to .png + image_exts = ['.png', '.jpg', '.jpeg', '.pdf', '.svg'] + if not any(name.endswith(ext) for ext in image_exts): + name = f"{name}.png" + filepath = output_dir / name + with open(filepath, 'wb') as f: + f.write(obj) + + elif isinstance(obj, pd.DataFrame): # DataFrames -> CSV if not name.endswith('.csv'): name = f"{name}.csv" @@ -287,7 +297,9 @@ def _save_single_object(obj: Any, name: str, output_dir: Path) -> Path: obj.savefig(filepath, dpi=300, bbox_inches='tight') plt.close(obj) # Close figure to free memory - elif isinstance(obj, str) and (' ad.AnnData: - """Return a minimal synthetic AnnData for fast tests.""" - rng = np.random.default_rng(0) +def create_test_adata(n_cells: int = 100) -> ad.AnnData: + """Create a realistic test AnnData object.""" + rng = np.random.default_rng(42) + + # Create observations with two categorical annotations obs = pd.DataFrame({ - "phenograph_k60_r1": (["cluster1", "cluster2", "cluster3"] * - ((n_cells + 2) // 3))[:n_cells], - "renamed_phenotypes": (["phenotype_A", "phenotype_B"] * - ((n_cells + 1) // 2))[:n_cells] + "phenograph_k60_r1": rng.choice( + ["cluster1", "cluster2", "cluster3"], + n_cells + ), + "renamed_phenotypes": rng.choice( + ["phenotype_A", "phenotype_B", "phenotype_C"], + n_cells + ) }) - x_mat = rng.normal(size=(n_cells, 3)) + + # Create expression matrix + x_mat = rng.normal(size=(n_cells, 20)) + + # Create AnnData object adata = ad.AnnData(X=x_mat, obs=obs) - adata.var_names = ["Gene1", "Gene2", "Gene3"] + adata.var_names = [f"Gene{i}" for i in range(20)] + return adata -class TestRelationalHeatmapTemplate(unittest.TestCase): - """Unit tests for the Relational Heatmap template.""" +class TestRelationalHeatmapTemplateRefactored(unittest.TestCase): + """Test suite for the refactored relational heatmap template.""" def setUp(self) -> None: + """Set up test fixtures.""" self.tmp_dir = tempfile.TemporaryDirectory() - self.in_file = os.path.join( - self.tmp_dir.name, "input.pickle" - ) - self.out_file = "relational_heatmap" - - # Save minimal mock data - with open(self.in_file, 'wb') as f: - pickle.dump(mock_adata(), f) - - # Minimal parameters from NIDAP template + self.tmp_path = Path(self.tmp_dir.name) + + # Create test data file + self.input_file = self.tmp_path / "input.pickle" + test_adata = create_test_adata(n_cells=100) + with open(self.input_file, 'wb') as f: + pickle.dump(test_adata, f) + + # Define test parameters self.params = { - "Upstream_Analysis": self.in_file, + "Upstream_Analysis": str(self.input_file), "Source_Annotation_Name": "phenograph_k60_r1", "Target_Annotation_Name": "renamed_phenotypes", "Colormap": "darkmint", "Figure_Width_inch": 8, "Figure_Height_inch": 10, - "Figure_DPI": 300, + "Figure_DPI": 150, # Lower DPI for faster tests "Font_Size": 8, - "Output_File": self.out_file, + "Output_File": "relational_heatmap", + "Output_Directory": str(self.tmp_path), + "outputs": { + "figures": {"type": "directory", "name": "figures_dir"}, + "html": {"type": "directory", "name": "html_dir"}, + "dataframe": {"type": "file", "name": "dataframe.csv"} + } } def tearDown(self) -> None: + """Clean up test fixtures.""" self.tmp_dir.cleanup() - @patch('spac.templates.relational_heatmap_template.relational_heatmap') - @patch('plotly.io.write_image') - @patch('matplotlib.pyplot.show') # Mock plt.show() - def test_complete_io_workflow( - self, mock_plt_show, mock_write_image, mock_relational - ) -> None: - """Single I/O test covering input/output scenarios.""" - # Mock the relational_heatmap function - mock_fig = Mock() - mock_fig.show = Mock() # Mock the fig.show() method - - mock_df = pd.DataFrame({ - 'source': ['cluster1', 'cluster2'], - 'target': ['phenotype_A', 'phenotype_B'], - 'value': [5, 3] - }) - - mock_relational.return_value = { - 'file_name': 'relational_heatmap.csv', - 'data': mock_df, - 'figure': mock_fig - } + def test_complete_workflow_with_file_outputs(self): + """Test complete workflow with actual file outputs.""" + print("\n=== Testing Complete Workflow ===") - # Mock the plotly write_image to create a dummy image - def create_dummy_image(fig, path, **kwargs): - # Create a minimal PNG file - # Ensure file path exists (for NamedTemporaryFile) - if not os.path.exists(path): - # Create parent directory if needed - os.makedirs(os.path.dirname(path), exist_ok=True) - fig_dummy, ax = plt.subplots(figsize=(1, 1)) - ax.text(0.5, 0.5, 'test', ha='center', va='center') - plt.savefig(path, dpi=72) - plt.close(fig_dummy) - - mock_write_image.side_effect = create_dummy_image - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - - # Test 1: Run with default parameters - result = run_from_json(self.params) - self.assertIsInstance(result, dict) - # Should have both CSV and PNG files - self.assertEqual(len(result), 2) - csv_files = [f for f in result.keys() if f.endswith('.csv')] - png_files = [f for f in result.keys() if f.endswith('.png')] - self.assertEqual(len(csv_files), 1) - self.assertEqual(len(png_files), 1) - - # Test 2: Run without saving - result_no_save = run_from_json( - self.params, save_results=False + # Run the template + result = run_from_json(self.params, save_to_disk=True) + + # Check that result is a dictionary + self.assertIsInstance(result, dict) + print(f"✓ Result is a dictionary with keys: {list(result.keys())}") + + # Check that all expected outputs are present + self.assertIn("figures", result) + self.assertIn("html", result) + self.assertIn("dataframe", result) + print("✓ All expected output types present") + + # Check figures directory + figures_list = result["figures"] + self.assertIsInstance(figures_list, list) + self.assertGreater(len(figures_list), 0) + + # Verify PNG file exists and has content + png_file = Path(figures_list[0]) + self.assertTrue(png_file.exists()) + self.assertTrue(png_file.suffix == '.png') + png_size = png_file.stat().st_size + self.assertGreater(png_size, 1000) # Should be at least 1KB + print(f"✓ PNG file created: {png_file.name} ({png_size} bytes)") + + # Check HTML directory + html_list = result["html"] + self.assertIsInstance(html_list, list) + self.assertGreater(len(html_list), 0) + + # Verify HTML file exists and has content + html_file = Path(html_list[0]) + self.assertTrue(html_file.exists()) + self.assertTrue(html_file.suffix == '.html') + + # Check HTML content + with open(html_file, 'r') as f: + html_content = f.read() + self.assertIn('plotly', html_content.lower()) + self.assertGreater(len(html_content), 1000) + print(f"✓ HTML file created: {html_file.name} " + f"({len(html_content)} chars)") + + # Check dataframe file + df_file = Path(result["dataframe"]) + self.assertTrue(df_file.exists()) + self.assertTrue(df_file.suffix == '.csv') + + # Verify CSV content + df = pd.read_csv(df_file) + self.assertGreater(len(df), 0) + self.assertGreater(len(df.columns), 0) + print(f"✓ CSV file created: {df_file.name} " + f"({df.shape[0]} rows, {df.shape[1]} cols)") + + print("✓ Complete workflow test passed!") + + def test_no_save_returns_figure_and_data(self): + """Test that save_to_disk=False returns figure and dataframe.""" + print("\n=== Testing No-Save Mode ===") + + # Run without saving + result = run_from_json(self.params, save_to_disk=False) + + # Check result is a tuple + self.assertIsInstance(result, tuple) + self.assertEqual(len(result), 2) + print("✓ Result is a tuple of length 2") + + # Unpack result + fig, df = result + + # Check figure is a Plotly figure + self.assertTrue(hasattr(fig, 'data')) + self.assertTrue(hasattr(fig, 'layout')) + print(f"✓ Returned Plotly figure: {type(fig).__name__}") + + # Check dataframe + self.assertIsInstance(df, pd.DataFrame) + self.assertGreater(len(df), 0) + print(f"✓ Returned DataFrame: {df.shape[0]} rows, " + f"{df.shape[1]} cols") + + print("✓ No-save mode test passed!") + + def test_json_file_input(self): + """Test reading parameters from JSON file.""" + print("\n=== Testing JSON File Input ===") + + # Write params to JSON file + json_file = self.tmp_path / "params.json" + with open(json_file, 'w') as f: + json.dump(self.params, f) + print(f"✓ Created JSON file: {json_file}") + + # Run with JSON file path + result = run_from_json(str(json_file), save_to_disk=True) + + # Verify outputs + self.assertIsInstance(result, dict) + self.assertIn("figures", result) + print("✓ Successfully loaded params from JSON file") + + print("✓ JSON file input test passed!") + + def test_png_image_quality(self): + """Test that PNG image is properly generated with correct size.""" + print("\n=== Testing PNG Image Quality ===") + + # Run with specific dimensions + self.params["Figure_DPI"] = 300 + self.params["Figure_Width_inch"] = 10 + self.params["Figure_Height_inch"] = 12 + + result = run_from_json(self.params, save_to_disk=True) + + # Get PNG file + png_file = Path(result["figures"][0]) + png_size = png_file.stat().st_size + + # Higher DPI should produce larger file + self.assertGreater(png_size, 10000) # At least 10KB + print(f"✓ High-quality PNG: {png_size} bytes") + + # Try to verify it's a valid PNG by checking magic bytes + with open(png_file, 'rb') as f: + header = f.read(8) + self.assertEqual(header[:4], b'\x89PNG') + print("✓ Valid PNG file header") + + print("✓ PNG quality test passed!") + + def test_different_colormaps(self): + """Test with different colormap options.""" + print("\n=== Testing Different Colormaps ===") + + for colormap in ['viridis', 'plasma', 'darkmint']: + print(f" Testing colormap: {colormap}") + self.params["Colormap"] = colormap + self.params["Output_Directory"] = str( + self.tmp_path / f"output_{colormap}" ) - # Check appropriate return type - should be (figure, dataframe) - self.assertIsInstance(result_no_save, tuple) - self.assertEqual(len(result_no_save), 2) - fig, df = result_no_save - self.assertIsInstance(df, pd.DataFrame) - - # Test 3: JSON file input - json_path = os.path.join(self.tmp_dir.name, "params.json") - with open(json_path, "w") as f: - json.dump(self.params, f) - - result_json = run_from_json(json_path) - self.assertIsInstance(result_json, dict) - - @patch('matplotlib.pyplot.show') # Mock plt.show() - def test_error_validation(self, mock_plt_show) -> None: - """Test exact error message for invalid parameters.""" - # Test with None annotations (should be handled by text_to_value) - params_none = self.params.copy() - params_none["Source_Annotation_Name"] = "None" - params_none["Target_Annotation_Name"] = "None" - - with patch('spac.templates.relational_heatmap_template.' - 'relational_heatmap') as mock_rel: - # The template should pass None values to the function - mock_rel.return_value = { - 'file_name': 'test.csv', - 'data': pd.DataFrame(), - 'figure': Mock(show=Mock()) # Mock fig.show() - } - # Mock write_image to create a dummy file - def create_dummy_image(fig, path, **kwargs): - # Create a minimal PNG file - # Ensure file path exists (for NamedTemporaryFile) - if not os.path.exists(path): - # Create parent directory if needed - os.makedirs(os.path.dirname(path), exist_ok=True) - fig_dummy, ax = plt.subplots(figsize=(1, 1)) - ax.text(0.5, 0.5, 'test', ha='center', va='center') - plt.savefig(path, dpi=72) - plt.close(fig_dummy) + result = run_from_json(self.params, save_to_disk=True) - with patch('plotly.io.write_image', - side_effect=create_dummy_image): - run_from_json(params_none) - - # Verify None was passed - call_args = mock_rel.call_args - self.assertIsNone(call_args[1]['source_annotation']) - self.assertIsNone(call_args[1]['target_annotation']) - - @patch('spac.templates.relational_heatmap_template.relational_heatmap') - @patch('plotly.io.write_image') - @patch('matplotlib.pyplot.show') # Mock plt.show() - def test_function_calls( - self, mock_plt_show, mock_write_image, mock_relational - ) -> None: - """Test that main function is called with correct parameters.""" - # Mock the main function - mock_relational.return_value = { - 'file_name': 'test.csv', - 'data': pd.DataFrame({'a': [1, 2]}), - 'figure': Mock(show=Mock()) # Mock fig.show() - } + # Verify outputs created + self.assertIsInstance(result, dict) + png_file = Path(result["figures"][0]) + self.assertTrue(png_file.exists()) + print(f" ✓ {colormap}: {png_file.stat().st_size} bytes") - # Mock write_image to create a dummy file - def create_dummy_image(fig, path, **kwargs): - # Create a minimal PNG file - # Ensure file path exists (for NamedTemporaryFile) - if not os.path.exists(path): - # Create parent directory if needed - os.makedirs(os.path.dirname(path), exist_ok=True) - fig_dummy, ax = plt.subplots(figsize=(1, 1)) - ax.text(0.5, 0.5, 'test', ha='center', va='center') - plt.savefig(path, dpi=72) - plt.close(fig_dummy) - - mock_write_image.side_effect = create_dummy_image - - run_from_json(self.params, save_results=False) - - # Verify function was called correctly - mock_relational.assert_called_once() - call_args = mock_relational.call_args - - # Check specific parameters - self.assertEqual( - call_args[1]['source_annotation'], 'phenograph_k60_r1' + print("✓ Colormap test passed!") + + +def run_specific_test(test_name: str = None): + """Run a specific test or all tests.""" + loader = unittest.TestLoader() + suite = unittest.TestSuite() + + if test_name: + # Run specific test + suite.addTest( + TestRelationalHeatmapTemplateRefactored(test_name) ) - self.assertEqual( - call_args[1]['target_annotation'], 'renamed_phenotypes' + else: + # Run all tests + suite.addTests( + loader.loadTestsFromTestCase( + TestRelationalHeatmapTemplateRefactored + ) ) - self.assertEqual(call_args[1]['color_map'], 'darkmint') - self.assertEqual(call_args[1]['font_size'], 8) + + runner = unittest.TextTestRunner(verbosity=2) + result = runner.run(suite) + return result if __name__ == "__main__": - unittest.main() \ No newline at end of file + print("=" * 70) + print("Relational Heatmap Template - Refactored Tests") + print("Testing new implementation with pio.to_image()") + print("=" * 70) + + # Run all tests + result = run_specific_test() + + print("\n" + "=" * 70) + if result.wasSuccessful(): + print("✓ ALL TESTS PASSED!") + else: + print("✗ SOME TESTS FAILED") + print(f" Failures: {len(result.failures)}") + print(f" Errors: {len(result.errors)}") + print("=" * 70) + + sys.exit(0 if result.wasSuccessful() else 1)