Add FedHT baseline: Federated Nonconvex Sparse Learning#7076
Add FedHT baseline: Federated Nonconvex Sparse Learning#7076hrshl4codes wants to merge 1 commit intoflwrlabs:mainfrom
Conversation
Implements Fed-HT and FedIter-HT from Tong et al. (2021), arxiv 2101.00052, as a Flower baseline. Covers Simulation I (sparse linear regression), Simulation II (sparse logistic regression), and MNIST (sparse softmax regression). Distributed-IHT (K=1) is included as the comparison baseline. Simulation I grid search results confirm the paper's core claim: Fed-HT (K=3, lr=0.003) matches Distributed-IHT's final objective in ~28 communication rounds vs 100, a 3.5x improvement.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 737229a747
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| return | ||
|
|
||
| abs_flat = flat.abs() | ||
| cutoff = torch.kthvalue(abs_flat, flat.numel() - tau).values |
There was a problem hiding this comment.
Use correct kth index for local hard threshold
torch.kthvalue is 1-indexed, so using flat.numel() - tau selects the (tau+1)-th largest magnitude instead of the tau-th. In FedIter-HT this means each local thresholding step can keep the wrong support, and the follow-up pruning can discard truly larger coefficients, so the client update no longer matches the intended hard-threshold operator.
Useful? React with 👍 / 👎.
| nonzero_indices = np.where(mask)[0] | ||
| if nonzero_indices.size > tau: | ||
| excess = nonzero_indices.size - tau | ||
| mask[nonzero_indices[:excess]] = False |
There was a problem hiding this comment.
Remove only cutoff ties when enforcing exact tau sparsity
After mask = abs_flat >= cutoff, the overflow handling removes the first masked indices regardless of their magnitude. When there are ties at the cutoff, this can drop entries strictly larger than the cutoff (for example magnitudes [10, 9, 9] with tau=2 can remove 10), so the function does not reliably keep the tau largest-magnitude coordinates as documented.
Useful? React with 👍 / 👎.
| X_all = np.concatenate([X for X, _ in all_data], axis=0) | ||
| y_all = np.concatenate([y for _, y in all_data], axis=0) | ||
| split = max(1, int(len(X_all) * 0.1)) | ||
| _, val_loader = make_loaders(X_all[-split:], y_all[-split:], batch_size=batch_size) |
There was a problem hiding this comment.
Avoid double-splitting the synthetic validation set
This code first takes a 10% holdout (X_all[-split:]) and then passes that holdout into make_loaders, which applies another 10% split and returns only the second split’s validation loader. As a result, server-side evaluation runs on ~1% of the synthetic data instead of the intended 10%, making objective tracking much noisier and inconsistent with the described experimental setup.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Pull request overview
This PR introduces a new Flower baseline (baselines/fedht) implementing the Fed-HT and FedIter-HT strategies from Federated Nonconvex Sparse Learning (Tong et al., 2021), along with synthetic data generators, MNIST support (via flwr-datasets), and helper scripts/docs to reproduce and visualize results.
Changes:
- Added Fed-HT/FedIter-HT strategies including a global hard-thresholding operator.
- Added client/server Flower Apps supporting Simulation I (linear), Simulation II (logistic), and MNIST (softmax), plus synthetic dataset generation.
- Added reproducibility tooling and documentation (grid search, plotting, smoke test, extended results).
Reviewed changes
Copilot reviewed 13 out of 17 changed files in this pull request and generated 14 comments.
Show a summary per file
| File | Description |
|---|---|
| baselines/fedht/fedht/strategy.py | Implements FedHT/FedIterHT strategies and global hard-threshold operator |
| baselines/fedht/fedht/model.py | Defines sparse linear/logistic/softmax models, local training loops, local hard-thresholding, eval helpers |
| baselines/fedht/fedht/dataset.py | Generates synthetic Simulation I/II data and loads MNIST partitions via flwr-datasets |
| baselines/fedht/fedht/client_app.py | ClientApp/NumPyClient implementing local training for all tasks |
| baselines/fedht/fedht/server_app.py | ServerApp wiring strategies, initialization, centralized evaluation |
| baselines/fedht/pyproject.toml | Baseline package/app config and dependency specification |
| baselines/fedht/grid_search.py | Runs flwr configs across lr/K sweeps and saves per-round objective CSVs |
| baselines/fedht/plot_results.py | Generates comparison + lr-sweep plots from saved CSVs |
| baselines/fedht/smoke_test.py | Scripted smoke tests for core components (no federation required) |
| baselines/fedht/README.md | Baseline README with experimental setup and usage instructions |
| baselines/fedht/docs/EXTENDED_README.md | Detailed per-experiment results and reproduction notes |
| baselines/fedht/fedht/init.py | Package marker/docstring |
| baselines/fedht/.gitignore | Ignores build artifacts/results/cache files |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| cutoff = np.partition(abs_flat, flat.size - tau)[flat.size - tau] | ||
| mask = abs_flat >= cutoff | ||
|
|
||
| # Resolve ties: keep exactly tau entries when multiple values equal the cutoff. | ||
| nonzero_indices = np.where(mask)[0] | ||
| if nonzero_indices.size > tau: | ||
| excess = nonzero_indices.size - tau | ||
| mask[nonzero_indices[:excess]] = False | ||
|
|
| if tau >= flat.numel(): | ||
| return | ||
|
|
||
| abs_flat = flat.abs() | ||
| cutoff = torch.kthvalue(abs_flat, flat.numel() - tau).values | ||
| mask = abs_flat >= cutoff | ||
|
|
||
| nonzero = mask.sum().item() | ||
| if nonzero > tau: | ||
| excess = int(nonzero - tau) | ||
| indices = mask.nonzero(as_tuple=True)[0] | ||
| mask[indices[:excess]] = False | ||
|
|
||
| flat = flat * mask | ||
|
|
| indices = mask.nonzero(as_tuple=True)[0] | ||
| mask[indices[:excess]] = False |
| _, val_loader = make_loaders(X_all[-split:], y_all[-split:], batch_size=batch_size) | ||
| return val_loader |
| if task == TASK_LINEAR: | ||
| all_data = generate_simulation_I(num_clients=num_clients, seed=42) | ||
| elif task == TASK_LOGISTIC: | ||
| all_data = generate_simulation_II(num_clients=num_clients, seed=42) |
| title: Federated Nonconvex Sparse Learning | ||
| url: https://arxiv.org/abs/2101.00052 | ||
| labels: [sparse learning, hard thresholding, non-IID, linear regression, logistic regression] | ||
| dataset: [Simulation I, Simulation II, MNIST, E2006-tfidf, RCV1] |
| 4. Weight serialization round-trip | ||
|
|
||
| Run with: | ||
| cd ~/Projects/fedht |
| num-server-rounds = 100 | ||
| num-clients = 100 # total number of simulated clients | ||
| fraction-fit = 1.0 # use all clients each round | ||
| min-available-clients = 2 |
| return rounds, objectives | ||
|
|
||
|
|
||
| def find_best_run(experiment: str, algorithm: str) -> tuple[Path | None, float, int]: |
| def _train( | ||
| net: nn.Module, | ||
| criterion: nn.Module, | ||
| loader: DataLoader, | ||
| device: torch.device, | ||
| lr: float, | ||
| local_steps: int, | ||
| use_local_ht: bool, | ||
| tau: int, | ||
| y_dtype: torch.dtype | None = None, | ||
| ) -> None: |
Description
This PR adds a Flower baseline for Federated Nonconvex Sparse Learning (Tong et al., 2021), implementing the Fed-HT and FedIter-HT algorithms proposed in arxiv 2101.00052.
Closes #3987.
Algorithms
Both algorithms solve a sparse parameter estimation problem under a cardinality constraint (hard thresholding operator H_tau). They differ in where thresholding is applied:
Experiments
Three tasks are supported out of the box:
Key Results (Simulation I)
Grid search over K ∈ {1, 3, 5} and lr ∈ {0.001–0.03} on Simulation I confirms the paper's main claim:
Fed-HT (K=3) reaches Distributed-IHT's final objective in approximately 28 communication rounds — a 3.5x reduction in communication cost.
What's Included
fedht/strategy.py—FedHTandFedIterHTstrategy classes,hard_thresholdoperatorfedht/client_app.py—FedHTClienthandling all three task typesfedht/server_app.py— server setup with centralized evaluationfedht/model.py— sparse linear, logistic, and softmax regression modelsfedht/dataset.py— synthetic data generators and MNIST loader (via flwr-datasets)grid_search.py— hyperparameter search scriptplot_results.py— result plotting scriptsmoke_test.py— 9 unit tests covering core components_static/— Simulation I comparison and lr-sweep plotsdocs/EXTENDED_README.md— detailed per-experiment resultsTesting
All 9 smoke tests pass. Linting:
ruff,black, andmypyall clean.