@@ -256,11 +256,13 @@ Here
256256* primes indicate next period states (as well as derivatives), and
257257* $\sigma$ is the unknown function.
258258
259+ (We emphasize that {eq}` eqeul1 ` only holds when we have an interior solution, meaning $\sigma(a, z) < a$.)
260+
259261We aim to find a fixed point $\sigma$ of {eq}` eqeul1 ` .
260262
261263To do so we use the EGM.
262264
263- We begin with an exogenous grid $G = \{ s_0, \ldots, s_ {m-1}\} $ with $s_0 > 0$, where each $s_i$ represents savings.
265+ We begin with a strictly increasing exogenous grid $G = \{ s_0, \ldots, s_ {m-1}\} $ with $s_0 > 0$, where each $s_i$ represents savings.
264266
265267The relationship between current assets $a$, consumption $c$, and savings $s$ is
266268
292294 a^e_{ij} = c_{ij} + s_i.
293295$$
294296
295- Our next guess policy function, which we write as $K\sigma$, is the linear interpolation of
296- $(a^e_ {ij}, c_ {ij})$ over $i$, for each $j$.
297+ To anchor the interpolation at the origin, we add the point $(0, 0)$ to the start of the endogenous grid for each $j$.
297298
298- (The number of one dimensional linear interpolations is equal to ` len(z_grid) ` .)
299+ Our next guess policy function, which we write as $K\sigma$, is then the linear interpolation of
300+ the points $\{ (0, 0), (a^e_ {0j}, c_ {0j}), \ldots, (a^e_ {(m-1)j}, c_ {(m-1)j})\} $ for each $j$.
299301
300- For $a < a^e _ {i0}$ (i.e., below the minimum endogenous grid point), the household consumes everything, so we set $(K \sigma)(a, z_j) = a$.
302+ (The number of one dimensional linear interpolations is equal to ` len(z_grid) ` .)
301303
302304
303305
@@ -324,23 +326,23 @@ class IFP(NamedTuple):
324326 γ: float # Preference parameter
325327 Π: jnp.ndarray # Markov matrix for exogenous shock
326328 z_grid: jnp.ndarray # Markov state values for Z_t
327- asset_grid : jnp.ndarray # Exogenous asset grid
329+ savings_grid : jnp.ndarray # Exogenous savings grid
328330
329331
330332def create_ifp(r=0.01,
331- β=0.98 ,
333+ β=0.96 ,
332334 γ=1.5,
333335 Π=((0.6, 0.4),
334336 (0.05, 0.95)),
335- z_grid=(0 .0, 0.2 ),
336- asset_grid_max=40 ,
337- asset_grid_size =50):
337+ z_grid=(-10 .0, jnp.log(2.0) ),
338+ savings_grid_max=16 ,
339+ savings_grid_size =50):
338340
339- asset_grid = jnp.linspace(0, asset_grid_max, asset_grid_size )
341+ savings_grid = jnp.linspace(0, savings_grid_max, savings_grid_size )
340342 Π, z_grid = jnp.array(Π), jnp.array(z_grid)
341343 R = 1 + r
342344 assert R * β < 1, "Stability condition violated."
343- return IFP(R=R, β=β, γ=γ, Π=Π, z_grid=z_grid, asset_grid=asset_grid )
345+ return IFP(R=R, β=β, γ=γ, Π=Π, z_grid=z_grid, savings_grid=savings_grid )
344346
345347# Set y(z) = exp(z)
346348y = jnp.exp
@@ -384,17 +386,13 @@ def K(σ: jnp.ndarray, ifp: IFP) -> jnp.ndarray:
384386 3. Given σ(a', z'), compute current consumption c that
385387 satisfies Euler equation
386388 4. Compute the endogenous current asset level a^e = c + s
387- 5. Interpolate back to asset grid to get σ_new(a, z)
389+ 5. Interpolate back to savings grid to get σ_new(a, z)
388390
389391 """
390- R, β, γ, Π, z_grid, asset_grid = ifp
391- n_a = len(asset_grid )
392+ R, β, γ, Π, z_grid, savings_grid = ifp
393+ n_a = len(savings_grid )
392394 n_z = len(z_grid)
393395
394- # Create savings grid (exogenous grid for EGM)
395- # We use the asset grid as the savings grid
396- savings_grid = asset_grid
397-
398396 def compute_c_for_fixed_income_state(j):
399397 """
400398 Compute updated consumption policy for income state z_j.
@@ -413,7 +411,7 @@ def K(σ: jnp.ndarray, ifp: IFP) -> jnp.ndarray:
413411 # Interpolate to get consumption at each (a', z')
414412 # For each z', interpolate over the a' values
415413 def interp_for_z(z_idx):
416- return jnp.interp(a_next_grid[:, z_idx], asset_grid , σ[:, z_idx])
414+ return jnp.interp(a_next_grid[:, z_idx], savings_grid , σ[:, z_idx])
417415
418416 c_next_grid = jax.vmap(interp_for_z)(jnp.arange(n_z)) # Shape: (n_z, n_a)
419417 c_next_grid = c_next_grid.T # Shape: (n_a, n_z)
@@ -430,17 +428,14 @@ def K(σ: jnp.ndarray, ifp: IFP) -> jnp.ndarray:
430428 # Compute endogenous grid of current assets: a = c + s
431429 a_endogenous = c_vals + savings_grid
432430
433- # Interpolate back to exogenous asset grid
434- σ_new = jnp.interp(asset_grid, a_endogenous, c_vals)
435-
436- # For asset levels below the minimum endogenous grid point,
437- # the household is constrained and consumes everything: c = a
431+ # Add (0, 0) to anchor the interpolation at the origin
432+ a_endogenous = jnp.concatenate([jnp.array([0.0]), a_endogenous])
433+ c_vals = jnp.concatenate([jnp.array([0.0]), c_vals])
438434
439- σ_new = jnp.where(asset_grid < a_endogenous[0],
440- asset_grid,
441- σ_new)
435+ # Interpolate back to exogenous savings grid
436+ σ_new = jnp.interp(savings_grid, a_endogenous, c_vals)
442437
443- return σ_new # Consumption over the asset grid given z[j]
438+ return σ_new # Consumption over the savings grid given z[j]
444439
445440 # Compute consumption over all income states using vmap
446441 c_vmap = jax.vmap(compute_c_for_fixed_income_state)
@@ -477,23 +472,54 @@ def solve_model(ifp: IFP,
477472 return σ
478473```
479474
475+ Here's a helper function to compute the endogenous grid from the policy:
476+
477+ ``` {code-cell} ipython3
478+ def get_endogenous_grid(σ: jnp.ndarray, ifp: IFP):
479+ """
480+ Compute endogenous asset grid from consumption policy.
481+
482+ For each state j, the endogenous grid is a[i,j] = c[i,j] + s[i].
483+ We also add the point (0, 0) at the start for each state.
484+
485+ Returns:
486+ a_endogenous: array of shape (n_a+1, n_z) with asset values
487+ c_endogenous: array of shape (n_a+1, n_z) with consumption values
488+ """
489+ savings_grid = ifp.savings_grid
490+ n_z = σ.shape[1]
491+
492+ # Compute a = c + s for each state
493+ a_vals = σ + savings_grid[:, None]
494+
495+ # Add (0, 0) at the start for each state
496+ a_endogenous = jnp.vstack([jnp.zeros(n_z), a_vals])
497+ c_endogenous = jnp.vstack([jnp.zeros(n_z), σ])
498+
499+ return a_endogenous, c_endogenous
500+ ```
501+
480502### Test run
481503
482504Let's road test the EGM code.
483505
484506``` {code-cell} ipython3
485507ifp = create_ifp()
486- R, β, γ, Π, z_grid, asset_grid = ifp
487- σ_init = asset_grid [:, None] * jnp.ones(len(z_grid))
508+ R, β, γ, Π, z_grid, savings_grid = ifp
509+ σ_init = savings_grid [:, None] * jnp.ones(len(z_grid))
488510σ_star = solve_model(ifp, σ_init)
489511```
490512
491513Here's a plot of the optimal policy for each $z$ state
492514
493515``` {code-cell} ipython3
494516fig, ax = plt.subplots()
495- ax.plot(asset_grid, σ_star[:, 0], label='bad state')
496- ax.plot(asset_grid, σ_star[:, 1], label='good state')
517+
518+ # Get endogenous grid points
519+ a_endogenous, c_endogenous = get_endogenous_grid(σ_star, ifp)
520+
521+ ax.plot(a_endogenous[:, 0], c_endogenous[:, 0], label='bad state')
522+ ax.plot(a_endogenous[:, 1], c_endogenous[:, 1], label='good state')
497523ax.set(xlabel='assets', ylabel='consumption')
498524ax.legend()
499525plt.show()
@@ -504,10 +530,10 @@ To begin to understand the long run asset levels held by households under the de
504530
505531``` {code-cell} ipython3
506532ifp = create_ifp()
507- R, β, γ, Π, z_grid, asset_grid = ifp
508- σ_init = asset_grid [:, None] * jnp.ones(len(z_grid))
533+ R, β, γ, Π, z_grid, savings_grid = ifp
534+ σ_init = savings_grid [:, None] * jnp.ones(len(z_grid))
509535σ_star = solve_model(ifp, σ_init)
510- a = asset_grid
536+ a = savings_grid
511537
512538fig, ax = plt.subplots()
513539for z, lb in zip((0, 1), ('low income', 'high income')):
@@ -564,14 +590,17 @@ Let's see if we match up:
564590
565591``` {code-cell} ipython3
566592ifp_cake_eating = create_ifp(r=0.0, z_grid=(-jnp.inf, -jnp.inf))
567- R, β, γ, Π, z_grid, asset_grid = ifp_cake_eating
568- σ_init = asset_grid [:, None] * jnp.ones(len(z_grid))
593+ R, β, γ, Π, z_grid, savings_grid = ifp_cake_eating
594+ σ_init = savings_grid [:, None] * jnp.ones(len(z_grid))
569595σ_star = solve_model(ifp_cake_eating, σ_init)
570596
597+ # Get endogenous grid
598+ a_endogenous, c_endogenous = get_endogenous_grid(σ_star, ifp_cake_eating)
599+
571600fig, ax = plt.subplots()
572- ax.plot(asset_grid, σ_star [:, 0], label='numerical')
573- ax.plot(asset_grid ,
574- c_star(asset_grid , ifp_cake_eating.β, ifp_cake_eating.γ),
601+ ax.plot(a_endogenous[:, 0], c_endogenous [:, 0], label='numerical')
602+ ax.plot(a_endogenous[:, 0] ,
603+ c_star(a_endogenous[:, 0] , ifp_cake_eating.β, ifp_cake_eating.γ),
575604 '--', label='analytical')
576605ax.set(xlabel='assets', ylabel='consumption')
577606ax.legend()
@@ -606,16 +635,18 @@ suppress consumption (because they encourage more savings).
606635Here's one solution:
607636
608637``` {code-cell} ipython3
609- # With β=0.98 , we need R*β < 1, so r < 0.0204
610- r_vals = np.linspace(0, 0.016 , 4)
638+ # With β=0.96 , we need R*β < 1, so r < 0.0416
639+ r_vals = np.linspace(0, 0.04 , 4)
611640
612641fig, ax = plt.subplots()
613642for r_val in r_vals:
614643 ifp = create_ifp(r=r_val)
615- R, β, γ, Π, z_grid, asset_grid = ifp
616- σ_init = asset_grid [:, None] * jnp.ones(len(z_grid))
644+ R, β, γ, Π, z_grid, savings_grid = ifp
645+ σ_init = savings_grid [:, None] * jnp.ones(len(z_grid))
617646 σ_star = solve_model(ifp, σ_init)
618- ax.plot(asset_grid, σ_star[:, 0], label=f'$r = {r_val:.3f}$')
647+ # Get endogenous grid
648+ a_endogenous, c_endogenous = get_endogenous_grid(σ_star, ifp)
649+ ax.plot(a_endogenous[:, 0], c_endogenous[:, 0], label=f'$r = {r_val:.3f}$')
619650
620651ax.set(xlabel='asset level', ylabel='consumption (low income)')
621652ax.legend()
@@ -658,19 +689,19 @@ def compute_asset_stationary(ifp, σ_star, num_households=50_000, T=500, seed=12
658689 ifp is an instance of IFP
659690 σ_star is the optimal consumption policy
660691 """
661- R, β, γ, Π, z_grid, asset_grid = ifp
692+ R, β, γ, Π, z_grid, savings_grid = ifp
662693 n_z = len(z_grid)
663694
664695 # Create interpolation function for consumption policy
665- σ_interp = lambda a, z_idx: jnp.interp(a, asset_grid , σ_star[:, z_idx])
696+ σ_interp = lambda a, z_idx: jnp.interp(a, savings_grid , σ_star[:, z_idx])
666697
667698 # Simulate one household forward
668699 def simulate_one_household(key):
669700 # Random initial state (both z and a)
670701 key1, key2, key3 = jax.random.split(key, 3)
671702 z_idx = jax.random.choice(key1, n_z)
672- # Start with random assets drawn uniformly from [0, asset_grid_max /2]
673- a = jax.random.uniform(key3, minval=0.0, maxval=asset_grid [-1]/2)
703+ # Start with random assets drawn uniformly from [0, savings_grid_max /2]
704+ a = jax.random.uniform(key3, minval=0.0, maxval=savings_grid [-1]/2)
674705
675706 # Simulate forward T periods
676707 def step(state, key_t):
@@ -700,8 +731,8 @@ Now we call the function, generate the asset distribution and histogram it:
700731
701732``` {code-cell} ipython3
702733ifp = create_ifp()
703- R, β, γ, Π, z_grid, asset_grid = ifp
704- σ_init = asset_grid [:, None] * jnp.ones(len(z_grid))
734+ R, β, γ, Π, z_grid, savings_grid = ifp
735+ σ_init = savings_grid [:, None] * jnp.ones(len(z_grid))
705736σ_star = solve_model(ifp, σ_init)
706737assets = compute_asset_stationary(ifp, σ_star)
707738
@@ -765,8 +796,8 @@ asset_mean = []
765796for r in r_vals:
766797 print(f'Solving model at r = {r}')
767798 ifp = create_ifp(r=r)
768- R, β, γ, Π, z_grid, asset_grid = ifp
769- σ_init = asset_grid [:, None] * jnp.ones(len(z_grid))
799+ R, β, γ, Π, z_grid, savings_grid = ifp
800+ σ_init = savings_grid [:, None] * jnp.ones(len(z_grid))
770801 σ_star = solve_model(ifp, σ_init)
771802 assets = compute_asset_stationary(ifp, σ_star, num_households=10_000, T=500)
772803 mean = np.mean(assets)
0 commit comments