Skip to content

Commit

Permalink
Add loss factor to CostControlBandit and update logic for dynamic sub…
Browse files Browse the repository at this point in the history
…sidy factor

### Changes:
* Added update of subsidy factor on MAB update step.
* Added option to allow Beta regime for subsidy factor and added convex combination loss function for determining best action over feasible set.
* Edited MultiObjectiveCostControlBandit on strategy.py to generalize CostControlBandit.
* Edited Base MAB, sMAB, and cMAB classes in mab.py, smab.py, and cmab.py, respectively to reduce duplicated code.
* Added test suite to support the new dynamic CC functionality.
  • Loading branch information
Shahar-Bar committed Oct 14, 2024
1 parent cddda5c commit ce6749a
Show file tree
Hide file tree
Showing 10 changed files with 937 additions and 395 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/continuous_delivery.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:
poetry run pre-commit run --all-files
- name: Run tests
run: |
poetry run pytest -vv -k 'not time and not update_parallel'
poetry run pytest -vv -k 'not time and not update_parallel' --cov=pybandits
- name: Extract version from pyproject.toml
id: extract_version
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/continuous_integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,4 @@ jobs:
poetry run pre-commit run --all-files
- name: Run tests
run: |
poetry run pytest -vv -k 'not time and not update_parallel'
poetry run pytest -vv -k 'not time and not update_parallel' --cov=pybandits
17 changes: 7 additions & 10 deletions pybandits/cmab.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from typing import Dict, List, Optional, Set, Union
from typing import Dict, List, Set, Union

from numpy import array
from numpy.random import choice
Expand Down Expand Up @@ -76,22 +76,20 @@ def check_bayesian_logistic_regression_models(cls, v):
return v

@validate_call(config=dict(arbitrary_types_allowed=True))
def predict(
def _predict(
self,
context: ArrayLike,
forbidden_actions: Optional[Set[ActionId]] = None,
valid_actions: Set[ActionId],
) -> CmabPredictions:
"""
Predict actions.
Parameters
----------
context: ArrayLike of shape (n_samples, n_features)
context : ArrayLike of shape (n_samples, n_features)
Matrix of contextual features.
forbidden_actions : Optional[Set[ActionId]], default=None
Set of forbidden actions. If specified, the model will discard the forbidden_actions and it will only
consider the remaining allowed_actions. By default, the model considers all actions as allowed_actions.
Note that: actions = allowed_actions U forbidden_actions.
valid_actions : Set[ActionId]
The set of valid actions to consider.
Returns
-------
Expand All @@ -102,7 +100,6 @@ def predict(
ws : List[Dict[ActionId, float]]
The weighted sum of logistic regression logits.
"""
valid_actions = self._get_valid_actions(forbidden_actions)

# cast inputs to numpy arrays to facilitate their manipulation
context = array(context)
Expand Down Expand Up @@ -149,7 +146,7 @@ def predict(
return selected_actions, probs, weighted_sums

@validate_call(config=dict(arbitrary_types_allowed=True))
def update(
def _update(
self, context: ArrayLike, actions: List[ActionId], rewards: List[Union[BinaryReward, List[BinaryReward]]]
):
"""
Expand Down
47 changes: 44 additions & 3 deletions pybandits/mab.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ def _validate_update_params(

####################################################################################################################

@abstractmethod
@validate_call
def update(
self, actions: List[ActionId], rewards: Union[List[BinaryReward], List[List[BinaryReward]]], *args, **kwargs
Expand All @@ -182,10 +181,27 @@ def update(
rewards: List[Union[BinaryReward, List[BinaryReward]]]
The reward for each sample.
"""
self._validate_update_params(actions=actions, rewards=rewards)
self._update(actions=actions, rewards=rewards, *args, **kwargs)
if hasattr(self.strategy, "update"):
self.strategy.update(rewards=rewards)

@abstractmethod
def _update(
self, actions: List[ActionId], rewards: Union[List[BinaryReward], List[List[BinaryReward]]], *args, **kwargs
):
"""
Update the multi-armed bandit model.
actions : List[ActionId]
The selected action for each sample.
rewards : List[Union[BinaryReward, List[BinaryReward]]]
The reward for each sample.
"""
pass

@validate_call
def predict(self, forbidden_actions: Optional[Set[ActionId]] = None) -> Predictions:
def predict(self, forbidden_actions: Optional[Set[ActionId]] = None, **kwargs) -> Predictions:
"""
Predict actions.
Expand All @@ -196,15 +212,40 @@ def predict(self, forbidden_actions: Optional[Set[ActionId]] = None) -> Predicti
consider the remaining allowed_actions. By default, the model considers all actions as allowed_actions.
Note that: actions = allowed_actions U forbidden_actions.
Returns
-------
actions : List[ActionId] of shape (n_samples,)
The actions selected by the multi-armed bandit model.
probs : List[Dict[ActionId, Probability]] of shape (n_samples,)
The probabilities of getting a positive reward for each action
ws : List[Dict[ActionId, float]], only relevant for some of the MABs
The weighted sum of logistic regression logits.
"""
if hasattr(self.strategy, "reset"):
self.strategy.reset()
valid_actions = self._get_valid_actions(forbidden_actions)
return self._predict(valid_actions=valid_actions, **kwargs)

@abstractmethod
def _predict(self, valid_actions: Set[ActionId], **kwargs) -> Predictions:
"""
Predict actions.
Parameters
----------
valid_actions : Set[ActionId]
The set of valid actions.
Returns
-------
actions: List[ActionId] of shape (n_samples,)
The actions selected by the multi-armed bandit model.
probs: List[Dict[ActionId, Probability]] of shape (n_samples,)
The probabilities of getting a positive reward for each action
ws : List[Dict[ActionId, float]], only relevant for some of the MABs
The weighted sum of logistic regression logits..
The weighted sum of logistic regression logits.
"""
pass

def get_state(self) -> (str, dict):
"""
Expand Down
21 changes: 6 additions & 15 deletions pybandits/smab.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@


from collections import defaultdict
from typing import Dict, List, Optional, Set, Union
from typing import Dict, List, Set, Union

from pydantic import PositiveInt, field_validator, validate_call

Expand Down Expand Up @@ -59,22 +59,16 @@ class BaseSmabBernoulli(BaseMab):
actions: Dict[ActionId, BaseBeta]

@validate_call
def predict(
self,
n_samples: PositiveInt = 1,
forbidden_actions: Optional[Set[ActionId]] = None,
) -> SmabPredictions:
def _predict(self, n_samples: PositiveInt, valid_actions: Set[ActionId]) -> SmabPredictions:
"""
Predict actions.
Parameters
----------
n_samples : int > 0, default=1
n_samples : PositiveInt
Number of samples to predict.
forbidden_actions : Optional[Set[ActionId]], default=None
Set of forbidden actions. If specified, the model will discard the forbidden_actions and it will only
consider the remaining allowed_actions. By default, the model considers all actions as allowed_actions.
Note that: actions = allowed_actions U forbidden_actions.
valid_actions : Set[ActionId]
The set of valid actions.
Returns
-------
Expand All @@ -83,7 +77,6 @@ def predict(
probs: List[Dict[ActionId, Probability]] of shape (n_samples,)
The probabilities of getting a positive reward for each action.
"""
valid_actions = self._get_valid_actions(forbidden_actions)

selected_actions: List[ActionId] = []
probs: List[Dict[ActionId, Probability]] = []
Expand All @@ -96,7 +89,7 @@ def predict(
return selected_actions, probs

@validate_call
def update(self, actions: List[ActionId], rewards: Union[List[BinaryReward], List[List[BinaryReward]]]):
def _update(self, actions: List[ActionId], rewards: Union[List[BinaryReward], List[List[BinaryReward]]]):
"""
Update the stochastic Bernoulli bandit given the list of selected actions and their corresponding binary
rewards.
Expand All @@ -113,8 +106,6 @@ def update(self, actions: List[ActionId], rewards: Union[List[BinaryReward], Lis
rewards = [[1, 1], [1, 0], [1, 1], [1, 0], [1, 1], ...]
"""

self._validate_update_params(actions=actions, rewards=rewards)

rewards_dict = defaultdict(list)

for a, r in zip(actions, rewards):
Expand Down
Loading

0 comments on commit ce6749a

Please sign in to comment.