Skip to content

Commit ae37759

Browse files
ENH 2nd-order coefficient visualization
* Add second-order coefficients visualisation * Format fixes
1 parent 476d77f commit ae37759

File tree

1 file changed

+138
-49
lines changed

1 file changed

+138
-49
lines changed

examples/2d/plot_scattering_disk.py

+138-49
Original file line numberDiff line numberDiff line change
@@ -1,102 +1,191 @@
11
"""
22
Scattering disk display
33
=======================
4-
This script reproduces concentric circles that encode Scattering coefficient's
5-
energy as described in "Invariant Scattering Convolution Networks" by Bruna and Mallat.
6-
Here, for the sake of simplicity, we only consider first order scattering.
4+
This script reproduces the display of scattering coefficients amplitude within a disk as described in
5+
"Invariant Scattering Convolution Networks" by J. Bruna and S. Mallat (2012) (https://arxiv.org/pdf/1203.1513.pdf).
76
87
Author: https://github.com/Jonas1312
9-
Edited by: Edouard Oyallon
10-
"""
118
9+
Edited by: Edouard Oyallon and anakin-datawalker
10+
"""
1211

1312

1413
import matplotlib as mpl
1514
import matplotlib.cm as cm
1615
import matplotlib.pyplot as plt
16+
from matplotlib import gridspec
1717
import numpy as np
1818
from kymatio import Scattering2D
1919
from PIL import Image
2020
import os
2121

22-
23-
img_name = os.path.join(os.getcwd(),"images/digit.png")
22+
img_name = os.path.join(os.getcwd(), "images/digit.png")
2423

2524
####################################################################
2625
# Scattering computations
2726
#-------------------------------------------------------------------
2827
# First, we read the input digit:
29-
src_img = Image.open(img_name).convert('L').resize((32,32))
28+
src_img = Image.open(img_name).convert('L').resize((32, 32))
3029
src_img = np.array(src_img)
3130
print("img shape: ", src_img.shape)
3231

3332
####################################################################
34-
# We compute a Scattering Transform with L=6 angles and J=3 scales.
35-
# Rotating a wavelet $\psi$ by $\pi$ is equivalent to consider its
36-
# conjugate in fourier: $\hat\psi_{\pi}(\omega)=\hat\psi(r_{-\pi}\omega)^*$.
33+
# We compute a Scattering Transform with $L=6$ angles and $J=3$ scales.
34+
#
35+
# Morlet wavelets $\psi_{\theta}$ are Hermitian, i.e: $\psi_{\theta}^*(u) = \psi_{\theta}(-u) = \psi_{\theta+\pi}(u)$.
36+
#
37+
# As a consequence, the modulus wavelet transform of a real signal $x$ computed with a Morlet wavelet $\psi_{\theta}$
38+
# is invariant by a rotation of $\pi$ of the wavelet. Indeed, since $(x*\psi_{\theta}(u))^* = x*(\psi_{\theta}^*)(u) =
39+
# x*\psi_{\theta+\pi}(u)$, we have $\lvert x*\psi_{\theta}(u)\rvert = \lvert x*\psi_{\theta+\pi}(u)\rvert$.
3740
#
38-
# Combining this and the fact that a real signal has a Hermitian symmetry
39-
# implies that it is usually sufficient to use the angles $\{\frac{\pi l}{L}\}_{l\leq L}$ at computation time.
40-
# For consistency, we will however display $\{\frac{2\pi l}{L}\}_{l\leq 2L}$,
41-
# which implies that our visualization will be redundant and have a symmetry by rotation of $\pi$.
41+
# Scattering coefficients of order $n$:
42+
# $\lvert \lvert \lvert x * \psi_{\theta_1, j_1} \rvert * \psi_{\theta_2, j_2} \rvert \cdots * \psi_{\theta_n, j_n}
43+
# \rvert * \phi_J$ are thus invariant to a rotation of $\pi$ of any wavelet $\psi_{\theta_i, j_i}$. As a consequence,
44+
# Kymatio computes scattering coefficients with $L$ wavelets whose orientation is uniformly sampled in
45+
# an interval of length $\pi$.
4246

4347
L = 6
4448
J = 3
45-
scattering = Scattering2D(J=J, shape=src_img.shape, L=L, max_order=1, frontend='numpy')
49+
scattering = Scattering2D(J=J, shape=src_img.shape, L=L, max_order=2, frontend='numpy')
4650

4751
####################################################################
4852
# We now compute the scattering coefficients:
4953
src_img_tensor = src_img.astype(np.float32) / 255.
5054

51-
scattering_coefficients = scattering(src_img_tensor)
52-
print("coeffs shape: ", scattering_coefficients.shape)
55+
scat_coeffs = scattering(src_img_tensor)
56+
print("coeffs shape: ", scat_coeffs.shape)
5357
# Invert colors
54-
scattering_coefficients = -scattering_coefficients
58+
scat_coeffs= -scat_coeffs
5559

5660
####################################################################
57-
# We skip the low pass filter...
58-
scattering_coefficients = scattering_coefficients[1:, :, :]
59-
norm = mpl.colors.Normalize(scattering_coefficients.min(), scattering_coefficients.max(), clip=True)
60-
mapper = cm.ScalarMappable(norm=norm, cmap="gray")
61-
nb_coeffs, window_rows, window_columns = scattering_coefficients.shape
61+
# There are 127 scattering coefficients, among which 1 is low-pass, $JL=18$ are of first-order and $L^2(J(J-1)/2)=108$
62+
# are of second-order. Due to the subsampling by $2^J=8$, the final spatial grid is of size $4\times4$.
63+
# We now retrieve first-order and second-order coefficients for the display.
64+
len_order_1 = J*L
65+
scat_coeffs_order_1 = scat_coeffs[1:1+len_order_1, :, :]
66+
norm_order_1 = mpl.colors.Normalize(scat_coeffs_order_1.min(), scat_coeffs_order_1.max(), clip=True)
67+
mapper_order_1 = cm.ScalarMappable(norm=norm_order_1, cmap="gray")
68+
# Mapper of coefficient amplitude to a grayscale color for visualisation.
69+
70+
len_order_2 = (J*(J-1)//2)*(L**2)
71+
scat_coeffs_order_2 = scat_coeffs[1+len_order_1:, :, :]
72+
norm_order_2 = mpl.colors.Normalize(scat_coeffs_order_2.min(), scat_coeffs_order_2.max(), clip=True)
73+
mapper_order_2 = cm.ScalarMappable(norm=norm_order_2, cmap="gray")
74+
# Mapper of coefficient amplitude to a grayscale color for visualisation.
75+
76+
# Retrieve spatial size
77+
window_rows, window_columns = scat_coeffs.shape[1:]
78+
print("nb of (order 1, order 2) coefficients: ", (len_order_1, len_order_2))
6279

6380
####################################################################
6481
# Figure reproduction
6582
#-------------------------------------------------------------------
6683

6784
####################################################################
68-
# Now we can reproduce a figure that displays the energy of the first
69-
# order Scattering coefficient, which are given by $\{\mid x\star\psi_{j,\theta}\mid\star\phi_J|\}_{j,\theta}$ .
70-
# Here, each scattering coefficient is represented on the polar plane. The polar radius and angle correspond
71-
# respectively to the scale $j$ and the rotation $\theta$ applied to the mother wavelet.
85+
# Now we can reproduce a figure that displays the amplitude of first-order and second-order scattering coefficients
86+
# within a disk like in Bruna and Mallat's paper.
87+
#
88+
# For visualisation purposes, we display first-order scattering coefficients
89+
# $\lvert x * \psi_{\theta_1, j_1} \rvert * \phi_J$ with $2L$ angles $\theta_1$ spanning $[0,2\pi]$ using the
90+
# central symmetry of those coefficients explained above. We similarly display second-order scattering coefficients
91+
# $\lvert \lvert x * \psi_{\theta_1, j_1} \rvert * \psi_{\theta_2, j_2} \rvert * \phi_J$ with $2L$ angles
92+
# $\theta_1$ spanning $[0,2\pi]$ but keep only $L$ orientations for $\theta_2$ (and thus an interval of $\pi$),
93+
# so as not to overload the display.
94+
#
95+
# Here, each scattering coefficient is represented on the polar plane within a quadrant indexed by a radius
96+
# and an angle.
97+
#
98+
# For first-order coefficients, the polar radius is inversely proportional to the scale $2^{j_1}$ of the wavelet
99+
# $\psi_{\theta_1, j_1}$ while the angle corresponds to the orientation $\theta_1$. The surface of each quadrant
100+
# is also inversely proportional to the scale $2^{j_1}$, which corresponds to the frequency bandwidth of the Fourier
101+
# transform $\hat{\psi}_{\theta_1, j_1}$. First-order scattering quadrants can thus be indexed by $(\theta_1,j_1)$.
102+
#
103+
# For second-order coefficients, each first-order quadrant is equally divided along the radius axis by the number
104+
# of increasing scales $j_1 < j_2 < J$ and by $L$ along the angle axis to produce a quadrant indexed by
105+
# $(\theta_1,\theta_2, j_1, j_2)$. It simply means in our case where $J=3$ that the first-order quadrant corresponding
106+
# to $j_1=0$ is subdivided along its radius in 2 equal quadrants corresponding to $j_2 \in \{1,2\}$, which are each
107+
# further divided by the $L$ possible $\theta_2$ angles, and that $j_1=1$ quadrants are only divided by $L$,
108+
# corresponding to $j_2=2$ and the $L$ possible $\theta_2$. Note that no second-order coefficients are thus associated
109+
# to $j_1=2$ in this case whose quadrants are just left blank.
72110
#
73-
# Observe that as predicted, the visualization exhibit a redundancy and a symmetry.
111+
# Observe how the amplitude of first-order coefficients is strongest along the direction of edges, and that they
112+
# exhibit by construction a central symmetry.
74113

75-
fig,ax = plt.subplots()
114+
# Define figure size and grid on which to plot input digit image, first-order and second-order scattering coefficients
115+
fig = plt.figure(figsize=(47, 15))
116+
spec = fig.add_gridspec(ncols=3, nrows=1)
117+
118+
gs = gridspec.GridSpec(1, 3, wspace=0.1)
119+
gs_order_1 = gridspec.GridSpecFromSubplotSpec(window_rows, window_columns, subplot_spec=gs[1])
120+
gs_order_2 = gridspec.GridSpecFromSubplotSpec(window_rows, window_columns, subplot_spec=gs[2])
121+
122+
# Start by plotting input digit image and invert colors
123+
ax = plt.subplot(gs[0])
124+
ax.set_xticks([])
125+
ax.set_yticks([])
126+
ax.imshow(255 - src_img, cmap='gray', interpolation='nearest', aspect='auto')
127+
128+
# Plot first-order scattering coefficients
129+
ax = plt.subplot(gs[1])
130+
ax.set_xticks([])
131+
ax.set_yticks([])
132+
133+
l_offset = int(L - L / 2 - 1) # follow same ordering as Kymatio for angles
76134

77-
plt.imshow(1-src_img,cmap='gray',interpolation='nearest', aspect='auto')
78-
ax.axis('off')
79-
offset = 0.1
80135
for row in range(window_rows):
81136
for column in range(window_columns):
82-
ax=fig.add_subplot(window_rows, window_columns, 1 + column + row * window_rows, projection='polar')
83-
ax.set_ylim(0, 1)
137+
ax = fig.add_subplot(gs_order_1[row, column], projection='polar')
84138
ax.axis('off')
85-
ax.set_yticklabels([]) # turn off radial tick labels (yticks)
86-
ax.set_xticklabels([]) # turn off degrees
87-
# ax.set_theta_zero_location('N') # 0° to North
88-
coefficients = scattering_coefficients[:, row, column]
139+
coefficients = scat_coeffs_order_1[:, row, column]
89140
for j in range(J):
90141
for l in range(L):
91-
coeff = coefficients[l + (J - 1 - j) * L]
92-
color = mpl.colors.to_hex(mapper.to_rgba(coeff))
93-
ax.bar(x=(4.5+l) * np.pi / L,
94-
height=2*(2**(j-1) / 2**J),
95-
width=2 * np.pi / L,
96-
bottom=offset + (2**j / 2**J) ,
142+
coeff = coefficients[l + j * L]
143+
color = mapper_order_1.to_rgba(coeff)
144+
angle = (l_offset - l) * np.pi / L
145+
radius = 2 ** (-j - 1)
146+
ax.bar(x=angle,
147+
height=radius,
148+
width=np.pi / L,
149+
bottom=radius,
97150
color=color)
98-
ax.bar(x=(4.5+l+L) * np.pi / L,
99-
height=2*(2**(j-1) / 2**J),
100-
width=2 * np.pi / L,
101-
bottom=offset + (2**j / 2**J) ,
151+
ax.bar(x=angle + np.pi,
152+
height=radius,
153+
width=np.pi / L,
154+
bottom=radius,
102155
color=color)
156+
157+
# Plot second-order scattering coefficients
158+
ax = plt.subplot(gs[2])
159+
ax.set_xticks([])
160+
ax.set_yticks([])
161+
162+
for row in range(window_rows):
163+
for column in range(window_columns):
164+
ax = fig.add_subplot(gs_order_2[row, column], projection='polar')
165+
ax.axis('off')
166+
coefficients = scat_coeffs_order_2[:, row, column]
167+
for j1 in range(J - 1):
168+
for j2 in range(j1 + 1, J):
169+
for l1 in range(L):
170+
for l2 in range(L):
171+
coeff_index = l1 * L * (J - j1 - 1) + l2 + L * (j2 - j1 - 1) + (L ** 2) * \
172+
(j1 * (J - 1) - j1 * (j1 - 1) // 2)
173+
# indexing a bit complex which follows the order used by Kymatio to compute
174+
# scattering coefficients
175+
coeff = coefficients[coeff_index]
176+
color = mapper_order_2.to_rgba(coeff)
177+
# split along angles first-order quadrants in L quadrants, using same ordering
178+
# as Kymatio (clockwise) and center (with the 0.5 offset)
179+
angle = (l_offset - l1) * np.pi / L + (L // 2 - l2 - 0.5) * np.pi / (L ** 2)
180+
radius = 2 ** (-j1 - 1)
181+
# equal split along radius is performed through height variable
182+
ax.bar(x=angle,
183+
height=radius / 2 ** (J - 2 - j1),
184+
width=np.pi / L ** 2,
185+
bottom=radius + (radius / 2 ** (J - 2 - j1)) * (J - j2 - 1),
186+
color=color)
187+
ax.bar(x=angle + np.pi,
188+
height=radius / 2 ** (J - 2 - j1),
189+
width=np.pi / L ** 2,
190+
bottom=radius + (radius / 2 ** (J - 2 - j1)) * (J - j2 - 1),
191+
color=color)

0 commit comments

Comments
 (0)