@@ -938,25 +938,28 @@ As expected, the reservation wage is increasing in $\sigma$.
938938
939939### Lifetime Value and Volatility
940940
941- We've seen that the reservation wage increases with volatility. Now let's verify that
942- the lifetime value at the optimal policy also increases with volatility.
941+ We've seen that the reservation wage increases with volatility.
943942
944- The intuition is that higher volatility provides more upside potential while the
945- worker can protect against downside risk by rejecting low offers. This option value
946- translates into higher expected lifetime utility.
943+ It's also the case that maximal lifetime value increases with volatility.
944+
945+ Higher volatility provides more upside potential, while at the same time
946+ workers can protect themselves against downside risk by rejecting low offers.
947+
948+ This option value translates into higher expected lifetime utility.
947949
948950To demonstrate this, we'll:
9499511 . Compute the reservation wage for each volatility level
9509522 . Simulate the worker's job search process following the optimal policy
9519533 . Calculate the expected discounted lifetime income
952954
953- The simulation works as follows: starting unemployed, the worker draws wage offers
954- and accepts the first offer that exceeds their reservation wage. We then compute
955- the present value of this income stream.
955+ The simulation works as follows: we draw 100 wage offers and track the worker's
956+ earnings at each date. The worker accepts the first offer that exceeds their
957+ reservation wage and earns that wage in all subsequent periods. We then compute
958+ the discounted sum of earnings over these 100 periods.
956959
957960``` {code-cell} ipython3
958961@jax.jit
959- def simulate_lifetime_value(key, model, w_bar, max_search_periods=1000 ):
962+ def simulate_lifetime_value(key, model, w_bar, n_periods=100 ):
960963 """
961964 Simulate one realization of the job search and compute lifetime value.
962965
@@ -968,45 +971,38 @@ def simulate_lifetime_value(key, model, w_bar, max_search_periods=1000):
968971 The model containing parameters
969972 w_bar : float
970973 The reservation wage
971- max_search_periods : int
972- Maximum number of search periods before forcing acceptance
974+ n_periods : int
975+ Number of periods to simulate
973976
974977 Returns:
975978 --------
976979 lifetime_value : float
977- Discounted sum of lifetime income
980+ Discounted sum of income over n_periods
978981 """
979982 c, β, σ, μ, w_draws = model
980983
981- def search_step(state):
982- t, key, accepted, wage = state
983- key, subkey = jax.random.split(key)
984- # Draw wage offer
985- s = jax.random.normal(subkey)
986- w = jnp.exp(μ + σ * s)
987- # Check if we accept
988- accept_now = w >= w_bar
989- # Update state: if we accept now, store the wage
990- wage = jnp.where(accept_now, w, wage)
991- accepted = jnp.logical_or(accepted, accept_now)
992- t = t + 1
993- return t, key, accepted, wage
994-
995- def search_cond(state):
996- t, _, accepted, _ = state
997- # Continue searching if not accepted and haven't hit max periods
998- return jnp.logical_and(jnp.logical_not(accepted), t < max_search_periods)
999-
1000- # Initial state: period 0, not accepted, wage 0
1001- initial_state = (0, key, False, 0.0)
1002- t_final, _, _, final_wage = jax.lax.while_loop(search_cond, search_step, initial_state)
1003-
1004- # Compute lifetime value
1005- # During unemployment (periods 0 to t_final-1): receive c each period
1006- # After employment (period t_final onwards): receive final_wage forever
1007- unemployment_value = c * (1 - β**t_final) / (1 - β)
1008- employment_value = (β**t_final) * final_wage / (1 - β)
1009- lifetime_value = unemployment_value + employment_value
984+ # Draw all wage offers upfront
985+ key, subkey = jax.random.split(key)
986+ s_vals = jax.random.normal(subkey, (n_periods,))
987+ wage_offers = jnp.exp(μ + σ * s_vals)
988+
989+ # Determine which offers are acceptable
990+ accept = wage_offers >= w_bar
991+
992+ # Track employment status: employed from first acceptance onward
993+ employed = jnp.cumsum(accept) > 0
994+
995+ # Get the accepted wage (first wage where accept is True)
996+ first_accept_idx = jnp.argmax(accept)
997+ accepted_wage = wage_offers[first_accept_idx]
998+
999+ # Earnings at each period: accepted_wage if employed, c if unemployed
1000+ earnings = jnp.where(employed, accepted_wage, c)
1001+
1002+ # Compute discounted sum
1003+ periods = jnp.arange(n_periods)
1004+ discount_factors = β ** periods
1005+ lifetime_value = jnp.sum(discount_factors * earnings)
10101006
10111007 return lifetime_value
10121008
0 commit comments