@@ -83,6 +83,50 @@ def create_transforms(traced_points, mesh_weight_map=None, xp=np):
8383 t = xp .cumsum (t , axis = 0 )
8484
8585 if xp .__name__ .startswith ("jax" ):
86+ # --------------------------------------------------------------
87+ # Gradient stabilisation for `jnp.interp`.
88+ #
89+ # Ray-traced source grids commonly contain near-duplicate or
90+ # exactly-duplicate coordinates — e.g. an Isothermal lens over
91+ # a circular mask produces ~50% gaps that are exactly zero
92+ # after sorting. This breaks `jnp.interp` for autodiff in two
93+ # distinct ways, both of which have to be patched here:
94+ #
95+ # (1) Knot-gradient term. The vjp of `jnp.interp` w.r.t. its
96+ # knot array `xp` divides by `xp[i+1] - xp[i]`, which is
97+ # 0/0 at duplicate knots and emits O(1e24) cotangents.
98+ # We freeze this path with `stop_gradient`. This is
99+ # semantically correct: the only downstream consumer of
100+ # the transformed grid is `adaptive_rectangular_mappings_
101+ # weights_..._from`, which uses `floor`/`ceil` to select
102+ # the 4 corner pixels. That bin assignment already has
103+ # zero gradient, so the knot-gradient term has no
104+ # downstream consumer anyway.
105+ #
106+ # (2) Query-gradient term. The vjp w.r.t. the query `x` is
107+ # the local slope `(yp[i+1] - yp[i]) / (xp[i+1] - xp[i])`.
108+ # Even with frozen knots, this blows up when the knot gap
109+ # is near zero. We prevent that by adding a strictly-
110+ # monotonic offset `arange(N) * JITTER` to `sort_points`.
111+ # For the default `mesh_weight_map=None` path, `t` moves
112+ # in steps of `1/(N+1)`; with `JITTER = 1e-7` and
113+ # `N ~ 1.5e4`, the worst-case slope is bounded by
114+ # `(1/(N+1)) / JITTER ~ 650`, which is harmless, and the
115+ # forward interpolation value is perturbed by at most
116+ # `N * JITTER ~ 1.5e-3` in the source-plane scaled units
117+ # — well below the `(source_grid_size - 3)` downstream
118+ # multiplier's sensitivity to sub-pixel placement.
119+ #
120+ # Together these two patches make the rectangular interpolator
121+ # differentiable end-to-end and bring the mapping-matrix
122+ # gradient into agreement with finite differences.
123+ import jax
124+
125+ JITTER = 1e-7
126+ jitter = xp .arange (sort_points .shape [0 ], dtype = sort_points .dtype ) * JITTER
127+ jitter = xp .stack ([jitter , jitter ], axis = 1 )
128+ sort_points = jax .lax .stop_gradient (sort_points + jitter )
129+
86130 transform = partial (forward_interp , sort_points , t )
87131 inv_transform = partial (reverse_interp , t , sort_points )
88132 return transform , inv_transform
0 commit comments