diff --git a/slots/slots.py b/slots/slots.py index d5786d4..9d9930d 100755 --- a/slots/slots.py +++ b/slots/slots.py @@ -39,9 +39,9 @@ def __init__(self, num_bandits=3, probs=None, payouts=None, live=True, stop_criterion : dict Stopping criterion (str) and threshold value (float). ''' - + self.num_bandits = num_bandits self.choices = [] - + self.payout_values = [] if not probs: if not payouts: if live: @@ -135,6 +135,7 @@ def _run(self, strategy, parameters=None): choice = self.run_strategy(strategy, parameters) self.choices.append(choice) payout = self.bandits.pull(choice) + self.payout_values.append(payout) if payout is None: print('Trials exhausted. No more values for bandit', choice) return None @@ -327,20 +328,47 @@ def best(self): else: return np.argmax(self.wins/(self.pulls+0.1)) - def est_payouts(self): + def current(self): ''' - Calculate current estimate of average payout for each bandit. + Return last choice of bandit. Returns ------- - array of floats or None + int + Index of bandit ''' if len(self.choices) < 1: print('slots: No trials run so far.') return None else: - return self.wins/(self.pulls+0.1) + return self.choices[-1] + + def est_payouts(self, bandit=None): + ''' + Calculate current estimate of average payout for each bandit. + + Parameters + ---------- + bandit : None + If a bandit is selected, return the payout for that bandit, otherwise return all payouts. + + Returns + ------- + array of floats or None + ''' + if not bandit: + if len(self.choices) < 1: + print('slots: No trials run so far.') + return None + else: + return self.wins/(self.pulls+0.1) + else: + if len(self.choices) < 1: + print('slots: No trials run so far.') + return None + else: + return (self.wins/(self.pulls+0.1))[bandit] def regret(self): ''' @@ -430,6 +458,67 @@ def online_trial(self, bandit=None, payout=None, strategy='eps_greedy', return {'new_trial': True, 'choice': self.run_strategy(strategy, parameters), 'best': self.best()} + + + def multiple_trials(self, bandits=None, payouts=None, method = 'hard', strategy='eps_greedy', + parameters=None): + ''' + Feeds two arrays in and based on those results returns the next trial. + This really isn't optimized, there's a much better way of doing this if we don't + care about maintaining the workflow. + + Parameters + ---------- + bandit : array of ints + Bandit index + payout : array of floats + Payout value + method : string + Name of summing strategy + If 'hard' then it manually iterates over each row + If 'lazy' it attempts to sum it as an array and only add final product + strategy : string + Name of update strategy + parameters : dict + Parameters for update strategy function + + Returns + ------- + dict + Format: {'new_trial': boolean, 'choice': int, 'best': int} + ''' + bandits = bandits.values + payouts = payouts.values + if len(payouts) != len(bandits): + raise Exception('slots.online_trials: number of bandits is different from number of payouts') + else: + if method == 'hard': + for x in range(0,len(payouts)): + if bandits[x] is not None and payouts[x] is not None: + self.update(bandit=bandits[x], payout=payouts[x]) + else: + raise Exception('slots.online_trial: bandit and/or payout value' + ' missing.') + + else if method = 'lazy': + banditos = np.array(bandits) + self.choices.extend(list(bandits)) + self.payout_values.extend(list(payouts)) + for y in list(set(bandits)): + indices = np.where(banditos == y)[0] + payola = [payouts[i] for i in indices] + self.pulls[y] += len(payola) + self.wins[y] += sum(payola) + self.bandits.payouts[y] += sum(payola) + + if self.crit_met(): + return {'new_trial': False, 'choice': self.best(), + 'best': self.best()} + else: + return {'new_trial': True, + 'choice': self.run_strategy(strategy, parameters), + 'best': self.best()} + def update(self, bandit, payout): ''' @@ -445,12 +534,17 @@ def update(self, bandit, payout): ------- None ''' - + self.payout_values.append(payout) self.choices.append(bandit) self.pulls[bandit] += 1 self.wins[bandit] += payout self.bandits.payouts[bandit] += payout + def info(self): + ''' + Default: display number of bandits, wins, and estimated probabilities + ''' + return('number of bandits:',self.num_bandits, 'number of wins:', self.wins, 'estimated payouts:', self.est_payouts()) class Bandits(): ''' @@ -511,5 +605,3 @@ def pull(self, i): else: return 0.0 - def info(self): - pass