@@ -907,7 +907,7 @@ mean_wage = 20.0
907907# Create a range of volatility values
908908σ_vals = jnp.linspace(0.1, 1.0, 25)
909909
910- # For each σ, compute μ to maintain constant mean
910+ # Given σ, compute μ to maintain constant mean
911911def compute_μ_for_mean(σ, mean_wage):
912912 return jnp.log(mean_wage) - (σ**2) / 2
913913
@@ -927,7 +927,6 @@ Now let's plot the reservation wage as a function of volatility:
927927
928928``` {code-cell} ipython3
929929fig, ax = plt.subplots()
930-
931930ax.plot(σ_vals, res_wages_volatility, linewidth=2)
932931ax.set_xlabel(r'volatility ($\sigma$)', fontsize=12)
933932ax.set_ylabel('reservation wage', fontsize=12)
@@ -1015,66 +1014,37 @@ def compute_mean_lifetime_value(model, w_bar, num_reps=10000, seed=1234):
10151014 """
10161015 Compute mean lifetime value across many simulations.
10171016
1018- Parameters:
1019- -----------
1020- model : McCallModelContinuous
1021- The model containing parameters
1022- w_bar : float
1023- The reservation wage
1024- num_reps : int
1025- Number of simulation replications
1026- seed : int
1027- Random seed
1028-
1029- Returns:
1030- --------
1031- mean_value : float
1032- Average lifetime value across all replications
10331017 """
10341018 key = jax.random.PRNGKey(seed)
10351019 keys = jax.random.split(key, num_reps)
10361020
10371021 # Vectorize the simulation across all replications
10381022 simulate_fn = jax.vmap(simulate_lifetime_value, in_axes=(0, None, None))
10391023 lifetime_values = simulate_fn(keys, model, w_bar)
1040-
10411024 return jnp.mean(lifetime_values)
10421025```
10431026
1044- Now let's compute both the reservation wage and the expected lifetime value
1045- for each volatility level:
1027+ Now let's compute the expected lifetime value for each volatility level:
10461028
10471029``` {code-cell} ipython3
10481030# Use the same volatility range and mean wage
10491031σ_vals = jnp.linspace(0.1, 1.0, 25)
10501032mean_wage = 20.0
10511033
1052- # Storage for results
1053- res_wages_vol = []
1054- lifetime_values_vol = []
1055-
1034+ lifetime_vals = []
10561035for σ in σ_vals:
10571036 μ = compute_μ_for_mean(σ, mean_wage)
1058- model = create_mccall_continuous(σ=float(σ), μ=float(μ))
1059-
1060- # Compute reservation wage
1061- w_bar = compute_reservation_wage_continuous(model)
1062- res_wages_vol.append(w_bar)
1063-
1064- # Compute expected lifetime value
1037+ model = create_mccall_continuous(σ=σ, μ=μ)
10651038 lv = compute_mean_lifetime_value(model, w_bar)
1066- lifetime_values_vol .append(lv)
1039+ lifetime_vals .append(lv)
10671040
1068- res_wages_vol = jnp.array(res_wages_vol)
1069- lifetime_values_vol = jnp.array(lifetime_values_vol)
10701041```
10711042
10721043Let's visualize the expected lifetime value as a function of volatility:
10731044
10741045``` {code-cell} ipython3
10751046fig, ax = plt.subplots()
1076-
1077- ax.plot(σ_vals, lifetime_values_vol, linewidth=2, color='green')
1047+ ax.plot(σ_vals, lifetime_vals, linewidth=2, color='green')
10781048ax.set_xlabel(r'volatility ($\sigma$)', fontsize=12)
10791049ax.set_ylabel('expected lifetime value', fontsize=12)
10801050plt.show()
0 commit comments