Skip to content

Commit 326b343

Browse files
jstacclaude
andcommitted
Simplify lifetime value simulation to fixed 100 periods
- Replace while_loop with fixed-period simulation (100 periods) - Draw all wage offers upfront using vectorized operations - Use cumsum to track employment status from first acceptance - Simpler logic that's easier to parallelize (same path length) - Cleaner code without nested loop functions - Update description to reflect simplified approach 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent 905151e commit 326b343

File tree

1 file changed

+37
-41
lines changed

1 file changed

+37
-41
lines changed

lectures/mccall_model.md

Lines changed: 37 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -938,25 +938,28 @@ As expected, the reservation wage is increasing in $\sigma$.
938938

939939
### Lifetime Value and Volatility
940940

941-
We've seen that the reservation wage increases with volatility. Now let's verify that
942-
the lifetime value at the optimal policy also increases with volatility.
941+
We've seen that the reservation wage increases with volatility.
943942

944-
The intuition is that higher volatility provides more upside potential while the
945-
worker can protect against downside risk by rejecting low offers. This option value
946-
translates into higher expected lifetime utility.
943+
It's also the case that maximal lifetime value increases with volatility.
944+
945+
Higher volatility provides more upside potential, while at the same time
946+
workers can protect themselves against downside risk by rejecting low offers.
947+
948+
This option value translates into higher expected lifetime utility.
947949

948950
To demonstrate this, we'll:
949951
1. Compute the reservation wage for each volatility level
950952
2. Simulate the worker's job search process following the optimal policy
951953
3. Calculate the expected discounted lifetime income
952954

953-
The simulation works as follows: starting unemployed, the worker draws wage offers
954-
and accepts the first offer that exceeds their reservation wage. We then compute
955-
the present value of this income stream.
955+
The simulation works as follows: we draw 100 wage offers and track the worker's
956+
earnings at each date. The worker accepts the first offer that exceeds their
957+
reservation wage and earns that wage in all subsequent periods. We then compute
958+
the discounted sum of earnings over these 100 periods.
956959

957960
```{code-cell} ipython3
958961
@jax.jit
959-
def simulate_lifetime_value(key, model, w_bar, max_search_periods=1000):
962+
def simulate_lifetime_value(key, model, w_bar, n_periods=100):
960963
"""
961964
Simulate one realization of the job search and compute lifetime value.
962965
@@ -968,45 +971,38 @@ def simulate_lifetime_value(key, model, w_bar, max_search_periods=1000):
968971
The model containing parameters
969972
w_bar : float
970973
The reservation wage
971-
max_search_periods : int
972-
Maximum number of search periods before forcing acceptance
974+
n_periods : int
975+
Number of periods to simulate
973976
974977
Returns:
975978
--------
976979
lifetime_value : float
977-
Discounted sum of lifetime income
980+
Discounted sum of income over n_periods
978981
"""
979982
c, β, σ, μ, w_draws = model
980983
981-
def search_step(state):
982-
t, key, accepted, wage = state
983-
key, subkey = jax.random.split(key)
984-
# Draw wage offer
985-
s = jax.random.normal(subkey)
986-
w = jnp.exp(μ + σ * s)
987-
# Check if we accept
988-
accept_now = w >= w_bar
989-
# Update state: if we accept now, store the wage
990-
wage = jnp.where(accept_now, w, wage)
991-
accepted = jnp.logical_or(accepted, accept_now)
992-
t = t + 1
993-
return t, key, accepted, wage
994-
995-
def search_cond(state):
996-
t, _, accepted, _ = state
997-
# Continue searching if not accepted and haven't hit max periods
998-
return jnp.logical_and(jnp.logical_not(accepted), t < max_search_periods)
999-
1000-
# Initial state: period 0, not accepted, wage 0
1001-
initial_state = (0, key, False, 0.0)
1002-
t_final, _, _, final_wage = jax.lax.while_loop(search_cond, search_step, initial_state)
1003-
1004-
# Compute lifetime value
1005-
# During unemployment (periods 0 to t_final-1): receive c each period
1006-
# After employment (period t_final onwards): receive final_wage forever
1007-
unemployment_value = c * (1 - β**t_final) / (1 - β)
1008-
employment_value = (β**t_final) * final_wage / (1 - β)
1009-
lifetime_value = unemployment_value + employment_value
984+
# Draw all wage offers upfront
985+
key, subkey = jax.random.split(key)
986+
s_vals = jax.random.normal(subkey, (n_periods,))
987+
wage_offers = jnp.exp(μ + σ * s_vals)
988+
989+
# Determine which offers are acceptable
990+
accept = wage_offers >= w_bar
991+
992+
# Track employment status: employed from first acceptance onward
993+
employed = jnp.cumsum(accept) > 0
994+
995+
# Get the accepted wage (first wage where accept is True)
996+
first_accept_idx = jnp.argmax(accept)
997+
accepted_wage = wage_offers[first_accept_idx]
998+
999+
# Earnings at each period: accepted_wage if employed, c if unemployed
1000+
earnings = jnp.where(employed, accepted_wage, c)
1001+
1002+
# Compute discounted sum
1003+
periods = jnp.arange(n_periods)
1004+
discount_factors = β ** periods
1005+
lifetime_value = jnp.sum(discount_factors * earnings)
10101006
10111007
return lifetime_value
10121008

0 commit comments

Comments
 (0)