diff --git a/pycode/memilio-plot/memilio/plot/plotAbmInfectionStates.py b/pycode/memilio-plot/memilio/plot/plotAbmInfectionStates.py new file mode 100644 index 0000000000..626c954ffa --- /dev/null +++ b/pycode/memilio-plot/memilio/plot/plotAbmInfectionStates.py @@ -0,0 +1,413 @@ +############################################################################# +# Copyright (C) 2020-2025 MEmilio +# +# Authors: Sascha Korf +# +# Contact: Martin J. Kuehn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################# + +import sys +import argparse +import os +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +import matplotlib +import h5py +from datetime import datetime +from scipy.ndimage import gaussian_filter1d + + +# Module for plotting number of agents per infection state and number of infected agents per location type from ABM results. +# This module provides functions to load and visualize infection states and +# location types from simulation results of the agent-based model (ABM) stored in HDF5 format. + +# The used Loggers are: +# struct LogInfectionStatePerAgeGroup : mio::LogAlways { +# using Type = std::pair; +# /** +# * @brief Log the TimeSeries of the number of Person%s in an #InfectionState for every age group. +# * @param[in] sim The simulation of the abm. +# * @return A pair of the TimePoint and the TimeSeries of the number of Person%s in an #InfectionState for every age group. +# */ +# static Type log(const mio::abm::Simulation& sim) +# { +# +# Eigen::VectorXd sum = Eigen::VectorXd::Zero( +# Eigen::Index((size_t)mio::abm::InfectionState::Count * sim.get_world().parameters.get_num_groups())); +# const auto curr_time = sim.get_time(); +# const auto persons = sim.get_world().get_persons(); +# +# // PRAGMA_OMP(parallel for) +# for (auto i = size_t(0); i < persons.size(); ++i) { +# auto& p = persons[i]; +# auto index = (((size_t)(mio::abm::InfectionState::Count)) * ((uint32_t)p.get_age().get())) + +# ((uint32_t)p.get_infection_state(curr_time)); +# // PRAGMA_OMP(atomic) +# sum[index] += 1; +# } +# return std::make_pair(curr_time, sum); +# } +# }; +# +# struct LogInfectionPerLocationTypePerAgeGroup : mio::LogAlways { +# using Type = std::pair; +# /** +# * @brief Log the TimeSeries of the number of newly infected Person%s for each Location Type and each age. +# * @param[in] sim The simulation of the abm. +# * @return A pair of the TimePoint and the TimeSeries of newly infected Person%s for each Location Type and each age. +# */ +# static Type log(const mio::abm::Simulation& sim) +# { +# +# Eigen::VectorXd sum = Eigen::VectorXd::Zero( +# Eigen::Index((size_t)mio::abm::LocationType::Count * sim.get_world().parameters.get_num_groups())); +# auto curr_time = sim.get_time(); +# auto prev_time = sim.get_prev_time(); +# const auto persons = sim.get_world().get_persons(); +# +# // PRAGMA_OMP(parallel for) +# for (auto i = size_t(0); i < persons.size(); ++i) { +# auto& p = persons[i]; +# // PRAGMA_OMP(atomic) +# if ((p.get_infection_state(prev_time) != mio::abm::InfectionState::Exposed) && +# (p.get_infection_state(curr_time) == mio::abm::InfectionState::Exposed)) { +# auto index = (((size_t)(mio::abm::LocationType::Count)) * ((uint32_t)p.get_age().get())) + +# ((uint32_t)p.get_location().get_type()); +# sum[index] += 1; +# } +# } +# return std::make_pair(curr_time, sum); +# } +# }; +# +# The output of the loggers of several runs is stored in HDF5 files using mio::save_results in mio/io/result_io.h. + +# Adjust these as needed. +state_labels = { + 1: 'Exposed', + 2: 'I_Asymp', + 3: 'I_Symp', + 4: 'I_Severe', + 5: 'I_Critical', + 7: 'Dead' +} + +age_groups = ['Group1', 'Group2', 'Group3', 'Group4', + 'Group5', 'Group6', 'Total'] + +age_groups_dict = { + 'Group1': 'Ages 0-4', + 'Group2': 'Ages 5-14', + 'Group3': 'Ages 15-34', + 'Group4': 'Ages 35-59', + 'Group5': 'Ages 60-79', + 'Group6': 'Ages 80+', + 'Total': 'All Ages' +} + +location_type_labels = { + 0: 'Home', + 1: 'School', + 2: 'Work', + 3: 'SocialEvent', + 4: 'BasicsShop', + 5: 'Hospital', + 6: 'ICU' +} + + +def load_h5_results(base_path, percentile): + """ Reads HDF5 results for a given group and percentile. + + @param[in] base_path Path to results directory. + @param[in] percentile Subdirectory for percentile (e.g. 'p50'). + @return Dictionary with data arrays. Keys are dataset names from the HDF5 file + (e.g., 'Time', 'Total', age group names like 'Group1', 'Group2', etc.). + Values are numpy arrays containing the corresponding time series data. + """ + file_path = os.path.join(base_path, percentile, "Results.h5") + with h5py.File(file_path, 'r') as f: + data = {k: v[()] for k, v in f['0'].items()} + return data + + +def plot_infections_loc_types_average( + path_to_loc_types, + start_date='2021-03-01', + colormap='Set1', + smooth_sigma=1, + rolling_window=24, + xtick_step=150): + """ Plots rolling sum of new infections per 24 hours location type for the median run. + + @param[in] base_path Path to results directory. + @param[in] start_date Start date as string. + @param[in] colormap Matplotlib colormap. + @param[in] smooth_sigma Sigma for Gaussian smoothing. + @param[in] rolling_window Window size for rolling sum. + @param[in] xtick_step Step size for x-axis ticks. + """ + # Load data + p50 = load_h5_results(path_to_loc_types, "p50") + time = p50['Time'] + total_50 = p50['Total'] + + plt.figure('Infection_location_types') + plt.title( + 'Number of new infections per location type for the median run, rolling sum over 24 hours') + color_plot = matplotlib.colormaps.get_cmap(colormap).colors + + for idx, i in enumerate(location_type_labels.keys()): + color = color_plot[i % len(color_plot)] if i < len( + color_plot) else "black" + # Sum up every 24 hours, then smooth + indexer = pd.api.indexers.FixedForwardWindowIndexer( + window_size=rolling_window) + y = pd.DataFrame(total_50[:, i]).rolling( + window=indexer, min_periods=1).sum().to_numpy() + y = y[0::rolling_window].flatten() + y = gaussian_filter1d(y, sigma=smooth_sigma, mode='nearest') + plt.plot(time[0::rolling_window], y, color=color, linewidth=2.5) + + plt.legend(list(location_type_labels.values())) + _format_x_axis(time, start_date, xtick_step) + plt.xlabel('Date') + plt.ylabel('Number of individuals') + plt.show() + + +def plot_infection_states_results( + path_to_infection_states, + start_date='2021-03-01', + colormap='Set1', + xtick_step=150, + show90=False +): + """ Loads and plots infection state results. + + @param[in] path_to_infection_states Path to results directory containing infection state data. + @param[in] start_date Start date as string (YYYY-MM-DD format). + @param[in] colormap Matplotlib colormap name. + @param[in] xtick_step Step size for x-axis ticks. + @param[in] show90 If True, plot 90% percentile (5% and 95%) in addition to 50% percentile. + """ + + # Load data + p50 = load_h5_results(path_to_infection_states, "p50") + p25 = load_h5_results(path_to_infection_states, "p25") + p75 = load_h5_results(path_to_infection_states, "p75") + time = p50['Time'] + total_50 = p50['Total'] + total_25 = p25['Total'] + total_75 = p75['Total'] + p05 = p95 = None + total_05 = total_95 = None + if show90: + total_95 = load_h5_results(path_to_infection_states, "p95") + total_05 = load_h5_results(path_to_infection_states, "p05") + p95 = total_95['Total'] + p05 = total_05['Total'] + + plot_infection_states_by_age_group( + time, p50, p25, p75, colormap, + p05_bs=total_05 if show90 else None, + p95_bs=total_95 if show90 else None, + show90=show90 + ) + plot_infection_states(time, total_50, total_25, + total_75, start_date, colormap, xtick_step, + y05=p05, y95=p95, show_90=show90) + + +def plot_infection_states( + x, y50, y25, y75, + start_date='2021-03-01', + colormap='Set1', + xtick_step=150, + y05=None, y95=None, show_90=False): + """ Plots infection states with percentile bands. + + @param[in] x Time array for x-axis. + @param[in] y50 50th percentile data array. + @param[in] y25 25th percentile data array. + @param[in] y75 75th percentile data array. + @param[in] start_date Start date as string (YYYY-MM-DD format). + @param[in] colormap Matplotlib colormap name. + @param[in] xtick_step Step size for x-axis ticks. + @param[in] y05 5th percentile data array (optional). + @param[in] y95 95th percentile data array (optional). + @param[in] show_90 If True, plot 90% percentile bands in addition to 50% percentile. + """ + + plt.figure('Infection_states') + + plt.title('Infection states with 50% percentile') + if show_90: + plt.title('Infection states with 50% and 90% percentiles') + + color_plot = matplotlib.colormaps.get_cmap(colormap).colors + + states_plot = list(state_labels.keys()) + + for i in states_plot: + plt.plot(x, y50[:, i], color=color_plot[i], + linewidth=2.5, label=state_labels[i]) + # needs to be after the plot calls + plt.legend([state_labels[i] for i in states_plot]) + for i in states_plot: + plt.plot(x, y25[:, i], color=color_plot[i], + linestyle='dashdot', linewidth=1.2, alpha=0.7) + plt.plot(x, y75[:, i], color=color_plot[i], + linestyle='dashdot', linewidth=1.2, alpha=0.7) + plt.fill_between(x, y25[:, i], y75[:, i], + alpha=0.2, color=color_plot[i]) + # Optional: 90% percentile + if show_90 and y05 is not None and y95 is not None: + plt.plot(x, y05[:, i], color=color_plot[i], + linestyle='dashdot', linewidth=1.0, alpha=0.4) + plt.plot(x, y95[:, i], color=color_plot[i], + linestyle='dashdot', linewidth=1.0, alpha=0.4) + plt.fill_between(x, y05[:, i], y95[:, i], + # More transparent + alpha=0.25, color=color_plot[i]) + + _format_x_axis(x, start_date, xtick_step) + plt.xlabel('Date') + plt.ylabel('Number of individuals') + plt.show() + + +def plot_infection_states_by_age_group( + x, p50_bs, p25_bs, p75_bs, colormap='Set1', + p05_bs=None, p95_bs=None, show90=False +): + """ Plots infection states for each age group, with optional 90% percentile. + + @param[in] x Time array for x-axis. + @param[in] p50_bs Dictionary containing 50th percentile data for all age groups. + @param[in] p25_bs Dictionary containing 25th percentile data for all age groups. + @param[in] p75_bs Dictionary containing 75th percentile data for all age groups. + @param[in] colormap Matplotlib colormap name. + @param[in] p05_bs Dictionary containing 5th percentile data for all age groups (optional). + @param[in] p95_bs Dictionary containing 95th percentile data for all age groups (optional). + @param[in] show90 If True, plot 90% percentile bands in addition to 50% percentile. + """ + + color_plot = matplotlib.colormaps.get_cmap(colormap).colors + n_states = len(state_labels) + fig, ax = plt.subplots( + n_states, len(age_groups), constrained_layout=True, figsize=(20, 3 * n_states)) + + for col_idx, group in enumerate(age_groups): + y50 = p50_bs[group] + y25 = p25_bs[group] + y75 = p75_bs[group] + y05 = p05_bs[group] if (show90 and p05_bs is not None) else None + y95 = p95_bs[group] if (show90 and p95_bs is not None) else None + for row_idx, (state_idx, label) in enumerate(state_labels.items()): + _plot_state( + ax[row_idx, col_idx], x, y50[:, state_idx], y25[:, + state_idx], y75[:, state_idx], + color_plot[col_idx], f'#{label}, {age_groups_dict[group]}', + y05=y05[:, state_idx] if y05 is not None else None, + y95=y95[:, state_idx] if y95 is not None else None, + show90=show90 + ) + # The legend should say: solid line = median, dashed line = 25% and 75% perc. and if show90 is True, dotted line = 5%, 25%, 75%, 95% perc. + perc_string = '25/75%' if not show90 else '5/25/75/95%' + ax[row_idx, col_idx].legend( + ['Median', f'{perc_string} perc.'], + loc='upper left', fontsize=8) + + # Add y label for leftmost column + if col_idx == 0: + ax[row_idx, col_idx].set_ylabel('Number of individuals') + + # Add x label for bottom row + if row_idx == n_states - 1: + ax[row_idx, col_idx].set_xlabel('Time (days)') + + string_short = ' and 90%' if show90 else '' + fig.suptitle( + 'Infection states per age group with 50%' + string_short + ' percentile', + fontsize=16) + + plt.show() + + +def _plot_state(ax, x, y50, y25, y75, color, title, y05=None, y95=None, show90=False): + """ Helper to plot a single state with fill_between and optional 90% percentile. """ + ax.plot(x, y50, color=color, label='Median') + ax.fill_between(x, y25, y75, alpha=0.5, color=color) + if show90 and y05 is not None and y95 is not None: + ax.plot(x, y05, color=color, linestyle='dotted', + linewidth=1.0, alpha=0.4) + ax.plot(x, y95, color=color, linestyle='dotted', + linewidth=1.0, alpha=0.4) + ax.fill_between(x, y05, y95, alpha=0.15, color=color) + ax.tick_params(axis='y') + ax.set_title(title) + + +def _format_x_axis(x, start_date, xtick_step): + """ Helper to format x-axis as dates. """ + start = datetime.strptime(start_date, '%Y-%m-%d') + xx = [start + pd.Timedelta(days=int(i)) for i in x] + xx_str = [dt.strftime('%Y-%m-%d') for dt in xx] + plt.gca().set_xticks(x[::xtick_step]) + plt.gca().set_xticklabels(xx_str[::xtick_step]) + plt.gcf().autofmt_xdate() + + +def main(): + """ Main function for CLI usage. """ + parser = argparse.ArgumentParser( + description="Plot infection state and location type results.") + parser.add_argument("--path-to-infection-states", + help="Path to infection states results") + parser.add_argument("--path-to-loc-types", + help="Path to location types results") + parser.add_argument("--start-date", type=str, default='2021-03-01', + help="Simulation start date (YYYY-MM-DD)") + parser.add_argument("--colormap", type=str, + default='Set1', help="Matplotlib colormap") + parser.add_argument("--xtick-step", type=int, + default=150, help="Step for x-axis ticks (usually hours)") + parser.add_argument("--90percentile", action="store_true", + help="If set, plot 90% percentile as well") + args = parser.parse_args() + + plot_infection_states_results( + path_to_infection_states=args.path_to_infection_states, + start_date=args.start_date, + colormap=args.colormap, + xtick_step=args.xtick_step, + show90=True + ) + plot_infections_loc_types_average( + path_to_loc_types=args.path_to_loc_types, + start_date=args.start_date, + colormap=args.colormap, + xtick_step=args.xtick_step) + + if not args.path_to_infection_states and not args.path_to_loc_types: + print("Please provide a path to infection states or location types results.") + + plt.show() + + +if __name__ == "__main__": + main() diff --git a/pycode/memilio-plot/memilio/plot_test/test_plot_plotAbmInfectionStates.py b/pycode/memilio-plot/memilio/plot_test/test_plot_plotAbmInfectionStates.py new file mode 100644 index 0000000000..4bc5dde5e1 --- /dev/null +++ b/pycode/memilio-plot/memilio/plot_test/test_plot_plotAbmInfectionStates.py @@ -0,0 +1,304 @@ +############################################################################# +# Copyright (C) 2020-2025 MEmilio +# +# Authors: Sascha Korf +# +# Contact: Martin J. Kuehn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################# + +import unittest +from unittest.mock import patch, MagicMock +import numpy as np +import pandas as pd + +import memilio.plot.plotAbmInfectionStates as abm + + +class TestPlotAbmInfectionStates(unittest.TestCase): + + @patch('memilio.plot.plotAbmInfectionStates.h5py.File') + def test_load_h5_results(self, mock_h5file): + mock_group = {'Time': np.arange(10), 'Total': np.ones((10, 8))} + mock_h5file().__enter__().get.return_value = {'0': mock_group} + mock_h5file().__enter__().items.return_value = [('0', mock_group)] + mock_h5file().__enter__().__getitem__.return_value = mock_group + with patch('memilio.plot.plotAbmInfectionStates.h5py.File', mock_h5file): + result = abm.load_h5_results('dummy_path', 'p50') + assert 'Time' in result + assert 'Total' in result + np.testing.assert_array_equal(result['Time'], np.arange(10)) + np.testing.assert_array_equal(result['Total'], np.ones((10, 8))) + + @patch('memilio.plot.plotAbmInfectionStates.load_h5_results') + @patch('memilio.plot.plotAbmInfectionStates.matplotlib') + @patch('memilio.plot.plotAbmInfectionStates.gaussian_filter1d', side_effect=lambda x, sigma, mode: x) + @patch('memilio.plot.plotAbmInfectionStates.pd.DataFrame') + def test_plot_infections_loc_types_average(self, mock_df, mock_gauss, mock_matplotlib, mock_load): + mock_load.return_value = { + 'Time': np.arange(48), 'Total': np.ones((48, 7))} + mock_df.return_value.rolling.return_value.sum.return_value.to_numpy.return_value = np.ones( + (48, 1)) + mock_matplotlib.colormaps.get_cmap.return_value.colors = [(1, 0, 0)]*7 + + # Patch plt methods + with patch.object(abm.plt, 'gca') as mock_gca, \ + patch.object(abm.plt, 'figure') as mock_figure, \ + patch.object(abm.plt, 'title') as mock_title, \ + patch.object(abm.plt, 'legend') as mock_legend, \ + patch.object(abm.plt, 'xlabel') as mock_xlabel, \ + patch.object(abm.plt, 'ylabel') as mock_ylabel, \ + patch.object(abm.plt, 'show') as mock_show: + + mock_ax = MagicMock() + mock_gca.return_value = mock_ax + + abm.plot_infections_loc_types_average('dummy_path') + + # Test basic plotting functionality + assert mock_ax.plot.called + assert mock_ax.set_xticks.called + assert mock_ax.set_xticklabels.called + + # Test figure settings + self.assertEqual(mock_figure.call_count, 2, + "figure should be called twice") + # Verify first call is with the figure name + mock_figure.assert_any_call('Infection_location_types') + # Verify second call is without arguments (for autofmt_xdate) + mock_figure.assert_any_call() + mock_title.assert_called_once_with( + 'Number of new infections per location type for the median run, rolling sum over 24 hours') + mock_legend.assert_called_once() + mock_xlabel.assert_called_once_with('Date') + mock_ylabel.assert_called_once_with('Number of individuals') + mock_show.assert_called_once() + + # Verify legend was called with location type labels + legend_call_args = mock_legend.call_args + if legend_call_args and legend_call_args[0]: + legend_labels = legend_call_args[0][0] + # Should contain the location type labels from the function + expected_labels = list(abm.location_type_labels.values()) + self.assertEqual(legend_labels, expected_labels) + + # Verify that plot was called for each location type + plot_calls = mock_ax.plot.call_args_list + # Should plot 7 location types + self.assertEqual(len(plot_calls), 7, + "Should plot all 7 location types") + + @patch('memilio.plot.plotAbmInfectionStates.load_h5_results') + @patch('memilio.plot.plotAbmInfectionStates.plot_infection_states') + @patch('memilio.plot.plotAbmInfectionStates.plot_infection_states_by_age_group') + def test_plot_infection_states_results(self, mock_indiv, mock_states, mock_load): + test_data = { + 'Time': np.arange(10), + 'Total': np.ones((10, 8)), + 'Group1': np.ones((10, 8)), + 'Group2': np.ones((10, 8)), + 'Group3': np.ones((10, 8)), + 'Group4': np.ones((10, 8)), + 'Group5': np.ones((10, 8)), + 'Group6': np.ones((10, 8)) + } + mock_load.side_effect = [test_data, test_data, test_data] + + abm.plot_infection_states_results('dummy_path') + + # Verify functions are called with correct arguments + self.assertEqual(mock_load.call_count, 3, + "load_h5_results should be called 3 times (p25, p50, p75)") + + # Check that load_h5_results was called with correct percentiles + expected_calls = [ + unittest.mock.call('dummy_path', 'p25'), + unittest.mock.call('dummy_path', 'p50'), + unittest.mock.call('dummy_path', 'p75') + ] + mock_load.assert_has_calls(expected_calls, any_order=True) + + # Verify plotting functions were called with the loaded data + mock_states.assert_called_once() + mock_indiv.assert_called_once() + + # Check that plot_infection_states was called with correct data structure + states_call_args = mock_states.call_args + self.assertIsNotNone(states_call_args) + # x, y50, y25, y75, y05, y95, show_90 + self.assertEqual(len(states_call_args[0]), 7) + + # Check that plot_infection_states_by_age_group was called with correct data structure + indiv_call_args = mock_indiv.call_args + self.assertIsNotNone(indiv_call_args) + # x, p50_bs, p25_bs, p75_bs, p05_bs + self.assertEqual(len(indiv_call_args[0]), 5) + + @patch('memilio.plot.plotAbmInfectionStates.matplotlib') + def test_plot_infection_states(self, mock_matplotlib): + x = np.arange(10) + y50 = np.ones((10, 8)) + y25 = np.zeros((10, 8)) + y75 = np.ones((10, 8))*2 + y05 = np.ones((10, 8))*-1 + y95 = np.ones((10, 8))*3 + mock_matplotlib.colormaps.get_cmap.return_value.colors = [(1, 0, 0)]*8 + + # Patch plt.gca().plot and fill_between + with patch.object(abm.plt, 'gca') as mock_gca: + mock_ax = MagicMock() + mock_gca.return_value = mock_ax + + abm.plot_infection_states( + x, y50, y25, y75, + start_date='2021-03-01', + colormap='Set1', + xtick_step=2, + y05=y05, + y95=y95, + show_90=True + ) + + # Verify plot was called with correct data + self.assertTrue(mock_ax.plot.called, "Plot should be called") + plot_calls = mock_ax.plot.call_args_list + # From the actual function: 6 infection states * 5 plot calls each (median + 4 percentile lines) = 30 + self.assertEqual(len(plot_calls), 30, + "Should plot 30 lines total (6 states * 5 lines each)") + + # Verify fill_between was called for confidence intervals + self.assertTrue(mock_ax.fill_between.called, + "fill_between should be called for confidence intervals") + fill_calls = mock_ax.fill_between.call_args_list + # Should have calls for both 50% and 90% confidence intervals if show_90=True + self.assertEqual(len(fill_calls), 12, + "Should have 12 fill_between calls (6 states * 2 confidence intervals each)") + + # Verify axis formatting with correct parameters + mock_ax.set_xticks.assert_called_once() + xticks_call = mock_ax.set_xticks.call_args[0][0] + expected_ticks = np.arange(0, len(x), 2) # xtick_step=2 + np.testing.assert_array_equal(xticks_call, expected_ticks) + + # Verify xticklabels are set correctly + mock_ax.set_xticklabels.assert_called_once() + xticklabels_call = mock_ax.set_xticklabels.call_args[0][0] + self.assertEqual(len(xticklabels_call), len(expected_ticks)) + + # Verify that the xticklabels contain the correct date formatting + expected_dates = ['2021-03-01', '2021-03-03', + '2021-03-05', '2021-03-07', '2021-03-09'] + for i, label in enumerate(xticklabels_call): + self.assertEqual(str(label), expected_dates[i], + f"Label at position {i} should be {expected_dates[i]}") + + # Verify that start_date is used in label formatting + if len(xticklabels_call) > 0: + # All labels should contain date strings when start_date is provided + self.assertTrue(all('2021' in str(label) + for label in xticklabels_call)) + # First label should match the start_date + self.assertEqual(str(xticklabels_call[0]), '2021-03-01') + + @patch('memilio.plot.plotAbmInfectionStates.matplotlib') + def test_plot_infection_states_by_age_group(self, mock_matplotlib): + x = np.arange(10) + group_data = np.ones((10, 8)) + groups = ['Group1', 'Group2', 'Group3', + 'Group4', 'Group5', 'Group6', 'Total'] + p50_bs = {g: group_data for g in groups} + p25_bs = {g: group_data for g in groups} + p75_bs = {g: group_data for g in groups} + p05_bs = {g: group_data*-1 for g in groups} + p95_bs = {g: group_data*3 for g in groups} + mock_matplotlib.colormaps.get_cmap.return_value.colors = [(1, 0, 0)]*8 + + # Patch plt.subplots to return a grid of MagicMock axes + with patch.object(abm.plt, 'subplots') as mock_subplots: + fig_mock = MagicMock() + # From the actual function: n_states (6) rows, len(age_groups) (7) columns + ax_mock = np.empty((6, 7), dtype=object) + for i in range(6): + for j in range(7): + ax_mock[i, j] = MagicMock() + mock_subplots.return_value = (fig_mock, ax_mock) + + abm.plot_infection_states_by_age_group( + x, p50_bs, p25_bs, p75_bs, + colormap='Set1', + p05_bs=p05_bs, + p95_bs=p95_bs, + show90=True + ) + + # Verify that subplots was called to create a grid of axes + mock_subplots.assert_called_once() + subplot_call = mock_subplots.call_args + + # Verify that subplots was called with reasonable dimensions + if subplot_call and len(subplot_call[0]) >= 2: + rows, cols = subplot_call[0][:2] + self.assertEqual( + rows, 6, "Should have 6 rows (number of infection states)") + self.assertEqual( + cols, 7, "Should have 7 columns (number of age groups)") + + # Verify figure title is set + fig_mock.suptitle.assert_called_once() + + def test__format_x_axis(self): + test_x = np.arange(10) + test_start_date = '2021-03-01' + test_xtick_step = 2 + + with patch('memilio.plot.plotAbmInfectionStates.plt') as mock_plt: + mock_ax = MagicMock() + mock_plt.gca.return_value = mock_ax + + abm._format_x_axis(test_x, test_start_date, test_xtick_step) + + # Verify that gca was called to get current axis (it's called twice in the function) + self.assertEqual(mock_plt.gca.call_count, 2, + "gca should be called twice") + + # Verify that gcf was called to get current figure + mock_plt.gcf.assert_called_once() + + # Verify axis formatting methods were called + self.assertTrue(mock_ax.set_xticks.called, + "set_xticks should be called") + self.assertTrue(mock_ax.set_xticklabels.called, + "set_xticklabels should be called") + + # Verify correct tick positions + xticks_call = mock_ax.set_xticks.call_args + if xticks_call and xticks_call[0]: + tick_positions = xticks_call[0][0] + expected_positions = np.arange(0, len(test_x), test_xtick_step) + np.testing.assert_array_equal( + tick_positions, expected_positions) + + # Verify tick labels are date strings + xticklabels_call = mock_ax.set_xticklabels.call_args + if xticklabels_call and xticklabels_call[0]: + tick_labels = xticklabels_call[0][0] + self.assertIsInstance(tick_labels, (list, np.ndarray)) + if len(tick_labels) > 0: + # Should contain date information when start_date is provided + self.assertTrue(any('2021' in str(label) + for label in tick_labels)) + + +if __name__ == '__main__': + unittest.main() diff --git a/pycode/memilio-plot/setup.py b/pycode/memilio-plot/setup.py index 9c6a85908a..130996ac6a 100644 --- a/pycode/memilio-plot/setup.py +++ b/pycode/memilio-plot/setup.py @@ -8,12 +8,13 @@ class PylintCommand(Command): - """Custom command to run pylint and get a report as html.""" + """ + Custom command to run pylint and get a report as html. + """ description = "Runs pylint and outputs the report as html." user_options = [] def initialize_options(self): - """ """ from pylint.reporters.json_reporter import JSONReporter from pylint.reporters.text import ParseableTextReporter, TextReporter from pylint_json2html import JsonExtendedReporter @@ -29,12 +30,10 @@ def initialize_options(self): } def finalize_options(self): - """ """ self.reporter, self.out_file = self.REPORTERS.get( self.out_format) # , self.REPORTERS.get("parseable")) def run(self): - """ """ os.makedirs("build_pylint", exist_ok=True) # Run pylint @@ -74,6 +73,7 @@ def run(self): 'pyxlsb', 'wget', 'folium', + 'scipy.ndimage', 'matplotlib', 'mapclassify', 'geopandas',