Skip to content

Commit 0b8723d

Browse files
jstacclaude
andcommitted
Refine IFP lecture: Improve mathematical clarity and plotting
Key improvements: - Added clarification that Euler equation (eqeul1) holds only for interior solutions - Specified savings grid is strictly increasing - Fixed EGM boundary condition: anchor interpolation at (0,0) instead of c=a - Renamed asset_grid to savings_grid throughout for conceptual accuracy - Updated default parameters to match main branch (β=0.96, grid_max=16, y≈(0,2)) - Created get_endogenous_grid() helper function - Updated all plots to use endogenous grid (a = c + s) rather than exogenous savings grid - Removed intermediate ifp.py file (can be regenerated from ifp.md using jupytext) Mathematical corrections ensure the lecture accurately represents the EGM algorithm where the Euler equation is solved on the savings grid and policies are plotted on the resulting endogenous asset grid. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent 05a665a commit 0b8723d

File tree

2 files changed

+85
-847
lines changed

2 files changed

+85
-847
lines changed

lectures/ifp.md

Lines changed: 85 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -256,11 +256,13 @@ Here
256256
* primes indicate next period states (as well as derivatives), and
257257
* $\sigma$ is the unknown function.
258258

259+
(We emphasize that {eq}`eqeul1` only holds when we have an interior solution, meaning $\sigma(a, z) < a$.)
260+
259261
We aim to find a fixed point $\sigma$ of {eq}`eqeul1`.
260262

261263
To do so we use the EGM.
262264

263-
We begin with an exogenous grid $G = \{s_0, \ldots, s_{m-1}\}$ with $s_0 > 0$, where each $s_i$ represents savings.
265+
We begin with a strictly increasing exogenous grid $G = \{s_0, \ldots, s_{m-1}\}$ with $s_0 > 0$, where each $s_i$ represents savings.
264266

265267
The relationship between current assets $a$, consumption $c$, and savings $s$ is
266268

@@ -292,12 +294,12 @@ $$
292294
a^e_{ij} = c_{ij} + s_i.
293295
$$
294296

295-
Our next guess policy function, which we write as $K\sigma$, is the linear interpolation of
296-
$(a^e_{ij}, c_{ij})$ over $i$, for each $j$.
297+
To anchor the interpolation at the origin, we add the point $(0, 0)$ to the start of the endogenous grid for each $j$.
297298

298-
(The number of one dimensional linear interpolations is equal to `len(z_grid)`.)
299+
Our next guess policy function, which we write as $K\sigma$, is then the linear interpolation of
300+
the points $\{(0, 0), (a^e_{0j}, c_{0j}), \ldots, (a^e_{(m-1)j}, c_{(m-1)j})\}$ for each $j$.
299301

300-
For $a < a^e_{i0}$ (i.e., below the minimum endogenous grid point), the household consumes everything, so we set $(K \sigma)(a, z_j) = a$.
302+
(The number of one dimensional linear interpolations is equal to `len(z_grid)`.)
301303

302304

303305

@@ -324,23 +326,23 @@ class IFP(NamedTuple):
324326
γ: float # Preference parameter
325327
Π: jnp.ndarray # Markov matrix for exogenous shock
326328
z_grid: jnp.ndarray # Markov state values for Z_t
327-
asset_grid: jnp.ndarray # Exogenous asset grid
329+
savings_grid: jnp.ndarray # Exogenous savings grid
328330
329331
330332
def create_ifp(r=0.01,
331-
β=0.98,
333+
β=0.96,
332334
γ=1.5,
333335
Π=((0.6, 0.4),
334336
(0.05, 0.95)),
335-
z_grid=(0.0, 0.2),
336-
asset_grid_max=40,
337-
asset_grid_size=50):
337+
z_grid=(-10.0, jnp.log(2.0)),
338+
savings_grid_max=16,
339+
savings_grid_size=50):
338340
339-
asset_grid = jnp.linspace(0, asset_grid_max, asset_grid_size)
341+
savings_grid = jnp.linspace(0, savings_grid_max, savings_grid_size)
340342
Π, z_grid = jnp.array(Π), jnp.array(z_grid)
341343
R = 1 + r
342344
assert R * β < 1, "Stability condition violated."
343-
return IFP(R=R, β=β, γ=γ, Π=Π, z_grid=z_grid, asset_grid=asset_grid)
345+
return IFP(R=R, β=β, γ=γ, Π=Π, z_grid=z_grid, savings_grid=savings_grid)
344346
345347
# Set y(z) = exp(z)
346348
y = jnp.exp
@@ -384,17 +386,13 @@ def K(σ: jnp.ndarray, ifp: IFP) -> jnp.ndarray:
384386
3. Given σ(a', z'), compute current consumption c that
385387
satisfies Euler equation
386388
4. Compute the endogenous current asset level a^e = c + s
387-
5. Interpolate back to asset grid to get σ_new(a, z)
389+
5. Interpolate back to savings grid to get σ_new(a, z)
388390
389391
"""
390-
R, β, γ, Π, z_grid, asset_grid = ifp
391-
n_a = len(asset_grid)
392+
R, β, γ, Π, z_grid, savings_grid = ifp
393+
n_a = len(savings_grid)
392394
n_z = len(z_grid)
393395
394-
# Create savings grid (exogenous grid for EGM)
395-
# We use the asset grid as the savings grid
396-
savings_grid = asset_grid
397-
398396
def compute_c_for_fixed_income_state(j):
399397
"""
400398
Compute updated consumption policy for income state z_j.
@@ -413,7 +411,7 @@ def K(σ: jnp.ndarray, ifp: IFP) -> jnp.ndarray:
413411
# Interpolate to get consumption at each (a', z')
414412
# For each z', interpolate over the a' values
415413
def interp_for_z(z_idx):
416-
return jnp.interp(a_next_grid[:, z_idx], asset_grid, σ[:, z_idx])
414+
return jnp.interp(a_next_grid[:, z_idx], savings_grid, σ[:, z_idx])
417415
418416
c_next_grid = jax.vmap(interp_for_z)(jnp.arange(n_z)) # Shape: (n_z, n_a)
419417
c_next_grid = c_next_grid.T # Shape: (n_a, n_z)
@@ -430,17 +428,14 @@ def K(σ: jnp.ndarray, ifp: IFP) -> jnp.ndarray:
430428
# Compute endogenous grid of current assets: a = c + s
431429
a_endogenous = c_vals + savings_grid
432430
433-
# Interpolate back to exogenous asset grid
434-
σ_new = jnp.interp(asset_grid, a_endogenous, c_vals)
435-
436-
# For asset levels below the minimum endogenous grid point,
437-
# the household is constrained and consumes everything: c = a
431+
# Add (0, 0) to anchor the interpolation at the origin
432+
a_endogenous = jnp.concatenate([jnp.array([0.0]), a_endogenous])
433+
c_vals = jnp.concatenate([jnp.array([0.0]), c_vals])
438434
439-
σ_new = jnp.where(asset_grid < a_endogenous[0],
440-
asset_grid,
441-
σ_new)
435+
# Interpolate back to exogenous savings grid
436+
σ_new = jnp.interp(savings_grid, a_endogenous, c_vals)
442437
443-
return σ_new # Consumption over the asset grid given z[j]
438+
return σ_new # Consumption over the savings grid given z[j]
444439
445440
# Compute consumption over all income states using vmap
446441
c_vmap = jax.vmap(compute_c_for_fixed_income_state)
@@ -477,23 +472,54 @@ def solve_model(ifp: IFP,
477472
return σ
478473
```
479474

475+
Here's a helper function to compute the endogenous grid from the policy:
476+
477+
```{code-cell} ipython3
478+
def get_endogenous_grid(σ: jnp.ndarray, ifp: IFP):
479+
"""
480+
Compute endogenous asset grid from consumption policy.
481+
482+
For each state j, the endogenous grid is a[i,j] = c[i,j] + s[i].
483+
We also add the point (0, 0) at the start for each state.
484+
485+
Returns:
486+
a_endogenous: array of shape (n_a+1, n_z) with asset values
487+
c_endogenous: array of shape (n_a+1, n_z) with consumption values
488+
"""
489+
savings_grid = ifp.savings_grid
490+
n_z = σ.shape[1]
491+
492+
# Compute a = c + s for each state
493+
a_vals = σ + savings_grid[:, None]
494+
495+
# Add (0, 0) at the start for each state
496+
a_endogenous = jnp.vstack([jnp.zeros(n_z), a_vals])
497+
c_endogenous = jnp.vstack([jnp.zeros(n_z), σ])
498+
499+
return a_endogenous, c_endogenous
500+
```
501+
480502
### Test run
481503

482504
Let's road test the EGM code.
483505

484506
```{code-cell} ipython3
485507
ifp = create_ifp()
486-
R, β, γ, Π, z_grid, asset_grid = ifp
487-
σ_init = asset_grid[:, None] * jnp.ones(len(z_grid))
508+
R, β, γ, Π, z_grid, savings_grid = ifp
509+
σ_init = savings_grid[:, None] * jnp.ones(len(z_grid))
488510
σ_star = solve_model(ifp, σ_init)
489511
```
490512

491513
Here's a plot of the optimal policy for each $z$ state
492514

493515
```{code-cell} ipython3
494516
fig, ax = plt.subplots()
495-
ax.plot(asset_grid, σ_star[:, 0], label='bad state')
496-
ax.plot(asset_grid, σ_star[:, 1], label='good state')
517+
518+
# Get endogenous grid points
519+
a_endogenous, c_endogenous = get_endogenous_grid(σ_star, ifp)
520+
521+
ax.plot(a_endogenous[:, 0], c_endogenous[:, 0], label='bad state')
522+
ax.plot(a_endogenous[:, 1], c_endogenous[:, 1], label='good state')
497523
ax.set(xlabel='assets', ylabel='consumption')
498524
ax.legend()
499525
plt.show()
@@ -504,10 +530,10 @@ To begin to understand the long run asset levels held by households under the de
504530

505531
```{code-cell} ipython3
506532
ifp = create_ifp()
507-
R, β, γ, Π, z_grid, asset_grid = ifp
508-
σ_init = asset_grid[:, None] * jnp.ones(len(z_grid))
533+
R, β, γ, Π, z_grid, savings_grid = ifp
534+
σ_init = savings_grid[:, None] * jnp.ones(len(z_grid))
509535
σ_star = solve_model(ifp, σ_init)
510-
a = asset_grid
536+
a = savings_grid
511537
512538
fig, ax = plt.subplots()
513539
for z, lb in zip((0, 1), ('low income', 'high income')):
@@ -564,14 +590,17 @@ Let's see if we match up:
564590

565591
```{code-cell} ipython3
566592
ifp_cake_eating = create_ifp(r=0.0, z_grid=(-jnp.inf, -jnp.inf))
567-
R, β, γ, Π, z_grid, asset_grid = ifp_cake_eating
568-
σ_init = asset_grid[:, None] * jnp.ones(len(z_grid))
593+
R, β, γ, Π, z_grid, savings_grid = ifp_cake_eating
594+
σ_init = savings_grid[:, None] * jnp.ones(len(z_grid))
569595
σ_star = solve_model(ifp_cake_eating, σ_init)
570596
597+
# Get endogenous grid
598+
a_endogenous, c_endogenous = get_endogenous_grid(σ_star, ifp_cake_eating)
599+
571600
fig, ax = plt.subplots()
572-
ax.plot(asset_grid, σ_star[:, 0], label='numerical')
573-
ax.plot(asset_grid,
574-
c_star(asset_grid, ifp_cake_eating.β, ifp_cake_eating.γ),
601+
ax.plot(a_endogenous[:, 0], c_endogenous[:, 0], label='numerical')
602+
ax.plot(a_endogenous[:, 0],
603+
c_star(a_endogenous[:, 0], ifp_cake_eating.β, ifp_cake_eating.γ),
575604
'--', label='analytical')
576605
ax.set(xlabel='assets', ylabel='consumption')
577606
ax.legend()
@@ -606,16 +635,18 @@ suppress consumption (because they encourage more savings).
606635
Here's one solution:
607636

608637
```{code-cell} ipython3
609-
# With β=0.98, we need R*β < 1, so r < 0.0204
610-
r_vals = np.linspace(0, 0.016, 4)
638+
# With β=0.96, we need R*β < 1, so r < 0.0416
639+
r_vals = np.linspace(0, 0.04, 4)
611640
612641
fig, ax = plt.subplots()
613642
for r_val in r_vals:
614643
ifp = create_ifp(r=r_val)
615-
R, β, γ, Π, z_grid, asset_grid = ifp
616-
σ_init = asset_grid[:, None] * jnp.ones(len(z_grid))
644+
R, β, γ, Π, z_grid, savings_grid = ifp
645+
σ_init = savings_grid[:, None] * jnp.ones(len(z_grid))
617646
σ_star = solve_model(ifp, σ_init)
618-
ax.plot(asset_grid, σ_star[:, 0], label=f'$r = {r_val:.3f}$')
647+
# Get endogenous grid
648+
a_endogenous, c_endogenous = get_endogenous_grid(σ_star, ifp)
649+
ax.plot(a_endogenous[:, 0], c_endogenous[:, 0], label=f'$r = {r_val:.3f}$')
619650
620651
ax.set(xlabel='asset level', ylabel='consumption (low income)')
621652
ax.legend()
@@ -658,19 +689,19 @@ def compute_asset_stationary(ifp, σ_star, num_households=50_000, T=500, seed=12
658689
ifp is an instance of IFP
659690
σ_star is the optimal consumption policy
660691
"""
661-
R, β, γ, Π, z_grid, asset_grid = ifp
692+
R, β, γ, Π, z_grid, savings_grid = ifp
662693
n_z = len(z_grid)
663694
664695
# Create interpolation function for consumption policy
665-
σ_interp = lambda a, z_idx: jnp.interp(a, asset_grid, σ_star[:, z_idx])
696+
σ_interp = lambda a, z_idx: jnp.interp(a, savings_grid, σ_star[:, z_idx])
666697
667698
# Simulate one household forward
668699
def simulate_one_household(key):
669700
# Random initial state (both z and a)
670701
key1, key2, key3 = jax.random.split(key, 3)
671702
z_idx = jax.random.choice(key1, n_z)
672-
# Start with random assets drawn uniformly from [0, asset_grid_max/2]
673-
a = jax.random.uniform(key3, minval=0.0, maxval=asset_grid[-1]/2)
703+
# Start with random assets drawn uniformly from [0, savings_grid_max/2]
704+
a = jax.random.uniform(key3, minval=0.0, maxval=savings_grid[-1]/2)
674705
675706
# Simulate forward T periods
676707
def step(state, key_t):
@@ -700,8 +731,8 @@ Now we call the function, generate the asset distribution and histogram it:
700731

701732
```{code-cell} ipython3
702733
ifp = create_ifp()
703-
R, β, γ, Π, z_grid, asset_grid = ifp
704-
σ_init = asset_grid[:, None] * jnp.ones(len(z_grid))
734+
R, β, γ, Π, z_grid, savings_grid = ifp
735+
σ_init = savings_grid[:, None] * jnp.ones(len(z_grid))
705736
σ_star = solve_model(ifp, σ_init)
706737
assets = compute_asset_stationary(ifp, σ_star)
707738
@@ -765,8 +796,8 @@ asset_mean = []
765796
for r in r_vals:
766797
print(f'Solving model at r = {r}')
767798
ifp = create_ifp(r=r)
768-
R, β, γ, Π, z_grid, asset_grid = ifp
769-
σ_init = asset_grid[:, None] * jnp.ones(len(z_grid))
799+
R, β, γ, Π, z_grid, savings_grid = ifp
800+
σ_init = savings_grid[:, None] * jnp.ones(len(z_grid))
770801
σ_star = solve_model(ifp, σ_init)
771802
assets = compute_asset_stationary(ifp, σ_star, num_households=10_000, T=500)
772803
mean = np.mean(assets)

0 commit comments

Comments
 (0)