Skip to content

Commit d3329b9

Browse files
committed
misc
1 parent e609076 commit d3329b9

File tree

1 file changed

+6
-36
lines changed

1 file changed

+6
-36
lines changed

lectures/mccall_model.md

Lines changed: 6 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -907,7 +907,7 @@ mean_wage = 20.0
907907
# Create a range of volatility values
908908
σ_vals = jnp.linspace(0.1, 1.0, 25)
909909
910-
# For each σ, compute μ to maintain constant mean
910+
# Given σ, compute μ to maintain constant mean
911911
def compute_μ_for_mean(σ, mean_wage):
912912
return jnp.log(mean_wage) - (σ**2) / 2
913913
@@ -927,7 +927,6 @@ Now let's plot the reservation wage as a function of volatility:
927927

928928
```{code-cell} ipython3
929929
fig, ax = plt.subplots()
930-
931930
ax.plot(σ_vals, res_wages_volatility, linewidth=2)
932931
ax.set_xlabel(r'volatility ($\sigma$)', fontsize=12)
933932
ax.set_ylabel('reservation wage', fontsize=12)
@@ -1015,66 +1014,37 @@ def compute_mean_lifetime_value(model, w_bar, num_reps=10000, seed=1234):
10151014
"""
10161015
Compute mean lifetime value across many simulations.
10171016
1018-
Parameters:
1019-
-----------
1020-
model : McCallModelContinuous
1021-
The model containing parameters
1022-
w_bar : float
1023-
The reservation wage
1024-
num_reps : int
1025-
Number of simulation replications
1026-
seed : int
1027-
Random seed
1028-
1029-
Returns:
1030-
--------
1031-
mean_value : float
1032-
Average lifetime value across all replications
10331017
"""
10341018
key = jax.random.PRNGKey(seed)
10351019
keys = jax.random.split(key, num_reps)
10361020
10371021
# Vectorize the simulation across all replications
10381022
simulate_fn = jax.vmap(simulate_lifetime_value, in_axes=(0, None, None))
10391023
lifetime_values = simulate_fn(keys, model, w_bar)
1040-
10411024
return jnp.mean(lifetime_values)
10421025
```
10431026

1044-
Now let's compute both the reservation wage and the expected lifetime value
1045-
for each volatility level:
1027+
Now let's compute the expected lifetime value for each volatility level:
10461028

10471029
```{code-cell} ipython3
10481030
# Use the same volatility range and mean wage
10491031
σ_vals = jnp.linspace(0.1, 1.0, 25)
10501032
mean_wage = 20.0
10511033
1052-
# Storage for results
1053-
res_wages_vol = []
1054-
lifetime_values_vol = []
1055-
1034+
lifetime_vals = []
10561035
for σ in σ_vals:
10571036
μ = compute_μ_for_mean(σ, mean_wage)
1058-
model = create_mccall_continuous(σ=float(σ), μ=float(μ))
1059-
1060-
# Compute reservation wage
1061-
w_bar = compute_reservation_wage_continuous(model)
1062-
res_wages_vol.append(w_bar)
1063-
1064-
# Compute expected lifetime value
1037+
model = create_mccall_continuous(σ=σ, μ=μ)
10651038
lv = compute_mean_lifetime_value(model, w_bar)
1066-
lifetime_values_vol.append(lv)
1039+
lifetime_vals.append(lv)
10671040
1068-
res_wages_vol = jnp.array(res_wages_vol)
1069-
lifetime_values_vol = jnp.array(lifetime_values_vol)
10701041
```
10711042

10721043
Let's visualize the expected lifetime value as a function of volatility:
10731044

10741045
```{code-cell} ipython3
10751046
fig, ax = plt.subplots()
1076-
1077-
ax.plot(σ_vals, lifetime_values_vol, linewidth=2, color='green')
1047+
ax.plot(σ_vals, lifetime_vals, linewidth=2, color='green')
10781048
ax.set_xlabel(r'volatility ($\sigma$)', fontsize=12)
10791049
ax.set_ylabel('expected lifetime value', fontsize=12)
10801050
plt.show()

0 commit comments

Comments
 (0)