Active Learning for Probability Simplex with Neural Network Ensembles
This package implements an active learning approach using BALD (Bayesian Active Learning by Disagreement) with neural network ensembles for learning functions on the probability simplex.
We want to learn an unknown function
The challenge is that we can only observe
-
Probabilistic: Returns binary outcomes based on
$f(x)$ - Expensive: We want to minimize the number of queries
- Independent: Each call is an independent sample
-
Ensemble of Neural Networks: We approximate
$f$ using multiple neural networks to estimate uncertainty - BALD Acquisition Function: We use Bayesian Active Learning by Disagreement to select informative query points
-
Active Learning Loop: We iteratively:
- Select the most informative point to query
- Query the oracle (possibly multiple times for better estimates)
- Retrain the ensemble with the new data
git clone https://github.com/pik-gane/aspai_active.git
cd aspai_active
pip install -e .- Python >= 3.8
- PyTorch >= 2.0.0
- NumPy >= 1.21.0
- SciPy >= 1.7.0
- Matplotlib >= 3.4.0
import torch
import numpy as np
from aspai_active import ActiveLearner
# Define your oracle function (returns 0 or 1)
def my_oracle(x):
# x is a point on the simplex (sums to 1)
# Return 1 with probability f(x), 0 otherwise
prob = some_function(x) # Your unknown function
return int(np.random.random() < prob)
# Create active learner
learner = ActiveLearner(
d=10, # Dimension of simplex
oracle=my_oracle,
n_models=5, # Number of models in ensemble
device="cpu"
)
# Run active learning
results = learner.run(
n_iterations=50,
n_candidates=1000,
n_initial=20,
n_oracle_queries=3, # Query oracle 3 times per point
verbose=True
)
# For high dimensions, enable gradient descent optimization
# This improves candidate selection by optimizing points toward high acquisition values
results = learner.run(
n_iterations=50,
n_candidates=1000,
n_initial=20,
n_oracle_queries=3,
optimize_candidates=True, # Enable gradient descent optimization
gd_steps=20, # Number of optimization steps
gd_lr=0.05, # Learning rate
gd_top_k_fraction=0.2, # Optimize top 20% of candidates
verbose=True
)
# Make predictions
from aspai_active import sample_simplex
test_points = sample_simplex(100, d=10)
predictions = learner.predict(test_points)The package includes two complete examples:
Example with d=3 for visualization where the true function is a sum of 5 smooth step functions along random hyperplanes.
cd examples
python example_3d.pyThis will:
- Create a synthetic function as a sum of 5 smooth step functions
- Run active learning with BALD acquisition
- Generate an MP4 video (
example_progress.mp4) showing the learning progress with one frame per query - Generate a final visualization image (
example_results.png) showing:- True function values
- Estimated function values
- Classification accuracy for
$A = {x : f(x) > 0.5}$ - Query points selected by the algorithm
The example produces:
-
Video (
example_progress.mp4): An animated visualization showing learning progress- One frame per query iteration
- Three panels showing true function, estimated function, and classification
- Fitness metrics displayed on each frame: TP (True Positives), TN (True Negatives), FP (False Positives), FN (False Negatives), and Accuracy
-
Image (
example_results.png): A final visualization with three panels:-
Left: True function
$f(x)$ on the simplex - Middle: Learned function estimate
- Right: Classification correctness (TP/TN/FP/FN)
-
Left: True function
All query points are shown as black dots.
Example with d=20 demonstrating gradient descent optimization for candidate points.
cd examples
python example_highdim.pyThis will:
- Run multiple trials comparing with and without gradient optimization
- Show accuracy improvements from gradient-based candidate optimization
- Generate comparison visualizations
The gradient descent optimization is particularly beneficial in high dimensions where random sampling becomes less effective.
Main class for active learning.
learner = ActiveLearner(
d, # Dimension of simplex
oracle, # Oracle function
n_models=5, # Number of ensemble models
hidden_dims=[64, 64], # Hidden layer sizes
device="cpu", # Torch device
seed=None # Random seed
)Methods:
run(n_iterations, ...): Run the active learning loop- New parameters for gradient optimization:
optimize_candidates(bool): Enable gradient descent optimization (default: False)gd_steps(int): Number of gradient descent steps (default: 10)gd_lr(float): Learning rate for optimization (default: 0.1)gd_top_k_fraction(float): Fraction of top candidates to optimize (default: 0.1)
- New parameters for gradient optimization:
predict(X): Get predictions for new pointsquery_oracle(x, n_queries): Query oracle at a specific point
Ensemble of neural networks for uncertainty estimation.
ensemble = EnsembleModel(
n_models, # Number of models
input_dim, # Input dimension
hidden_dims, # Hidden layer sizes
device="cpu"
)Methods:
train_step(X, y, n_epochs): Train on datapredict_proba(X): Get probability predictionspredict_proba_with_grad(X): Get predictions with gradient support (for optimization)predict_mean(X): Get mean prediction
Compute BALD (Bayesian Active Learning by Disagreement) scores.
- Input: Tensor of shape
(n_models, n_points)with predictions - Output: Tensor of shape
(n_points,)with acquisition scores - Higher scores = more informative points
Optimize candidate points using gradient descent on the acquisition function.
- Input:
candidates: Tensor of initial candidate pointsensemble: Trained EnsembleModelacquisition_fn: Acquisition function to maximizen_steps: Number of gradient descent steps (default: 10)learning_rate: Learning rate (default: 0.1)top_k_fraction: Fraction of candidates to optimize (default: 0.1)
- Output: Tensor of optimized candidates
- Use case: Improves performance in high dimensions
Simple uncertainty sampling (distance from 0.5).
Variance-based acquisition function.
Sample uniformly from the probability simplex.
Create a grid of points on the simplex (for d=3, creates triangular grid).
Project points onto the probability simplex.
Convert 3D simplex points to 2D for visualization.
BALD measures the mutual information between predictions and model parameters:
Where:
- First term: Entropy of the mean prediction (predictive uncertainty)
- Second term: Expected entropy over models (aleatoric uncertainty)
- Difference: Epistemic uncertainty (what we can reduce by querying)
High BALD scores indicate points where:
- The ensemble is uncertain (epistemic uncertainty)
- But individual models are confident (low aleatoric uncertainty)
- These are the most informative points to query
For high-dimensional problems (e.g., d > 10), random sampling of candidate points becomes less effective. The gradient descent optimization improves candidate selection by:
- Starting with random candidates: Sample points uniformly from the simplex
- Selecting top candidates: Choose the top-k candidates with highest initial acquisition scores
- Gradient optimization: Use gradient descent to optimize these candidates to maximize the acquisition function
- Simplex projection: After each gradient step, project points back onto the simplex to maintain constraints
- High dimensions: Most effective when
d > 10where random sampling struggles - Better exploration: Finds regions with higher uncertainty more efficiently
- Configurable: Can adjust optimization steps, learning rate, and fraction of candidates to optimize
learner.run(
n_iterations=50,
optimize_candidates=True, # Enable optimization
gd_steps=20, # More steps for higher dimensions
gd_lr=0.05, # Lower learning rate for stability
gd_top_k_fraction=0.2 # Optimize top 20% of candidates
)- Enable for d ≥ 10: Particularly beneficial in high dimensions
- Disable for d < 5: Little benefit in low dimensions, adds computation time
- Moderate d (5-10): Optional, test to see if it helps your specific problem
Each model in the ensemble:
- Input layer: dimension
d(simplex dimension) - Hidden layers: configurable (default:
[64, 64]) - Dropout: 0.1 (for uncertainty estimation)
- Output: single logit for binary classification
- Activation: ReLU for hidden layers
- Loss: Binary cross-entropy with logits
- Initial phase: Train with random points
- Active learning: Iteratively select points with BALD
- Multiple queries: Query oracle multiple times per point for better estimates
- Incremental training: Add new points and retrain ensemble
The probability simplex
We sample uniformly using the Dirichlet distribution with all parameters equal to 1.
If you use this package in your research, please cite:
@software{aspai_active,
title = {aspai_active: Active Learning for Probability Simplex},
author = {aspai_active contributors},
year = {2024},
url = {https://github.com/pik-gane/aspai_active}
}MIT License - see LICENSE file for details.
Contributions are welcome! Please feel free to submit a Pull Request.