Skip to content

Commit 5819e32

Browse files
committed
Fix bug when calling infidelity before infidelity_derivative
1 parent cde8258 commit 5819e32

File tree

3 files changed

+96
-29
lines changed

3 files changed

+96
-29
lines changed

filter_functions/gradient.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,8 @@ def calculate_derivative_of_control_matrix_from_scratch(
487487
basis_transformed = numeric._transform_by_unitary(eigvecs[:, None], basis[None],
488488
out=np.empty((n_dt, d**2, d, d), complex))
489489
c_opers_transformed = numeric._transform_hamiltonian(eigvecs, c_opers[idx]).swapaxes(0, 1)
490-
if intermediates is None:
490+
if not intermediates:
491+
# None or empty
491492
n_opers_transformed = numeric._transform_hamiltonian(eigvecs, n_opers,
492493
n_coeffs).swapaxes(0, 1)
493494
exp_buf, integral = np.empty((2, n_omega, d, d), dtype=complex)
@@ -507,7 +508,7 @@ def calculate_derivative_of_control_matrix_from_scratch(
507508
n_opers.shape, (len(omega), d, d),
508509
optimize=[(0, 3), (0, 1), (0, 1)])
509510
for g in range(n_dt):
510-
if intermediates is None:
511+
if not intermediates:
511512
integral = numeric._first_order_integral(omega, eigvals[g], dt[g], exp_buf, integral)
512513
else:
513514
integral = intermediates['first_order_integral'][g]

filter_functions/numeric.py

Lines changed: 61 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -475,25 +475,26 @@ def calculate_noise_operators_from_scratch(
475475
t: Optional[Coefficients] = None,
476476
show_progressbar: bool = False,
477477
cache_intermediates: bool = False
478-
) -> ndarray:
478+
) -> Union[ndarray, Tuple[ndarray, Dict[str, ndarray]]]:
479479
r"""
480480
Calculate the noise operators in interaction picture from scratch.
481481
482482
Parameters
483483
----------
484484
eigvals: array_like, shape (n_dt, d)
485-
Eigenvalue vectors for each time pulse segment *g* with the first
486-
axis counting the pulse segment, i.e.
485+
Eigenvalue vectors for each time pulse segment *g* with the
486+
first axis counting the pulse segment, i.e.
487487
``eigvals == array([D_0, D_1, ...])``.
488488
eigvecs: array_like, shape (n_dt, d, d)
489-
Eigenvector matrices for each time pulse segment *g* with the first
490-
axis counting the pulse segment, i.e.
489+
Eigenvector matrices for each time pulse segment *g* with the
490+
first axis counting the pulse segment, i.e.
491491
``eigvecs == array([V_0, V_1, ...])``.
492492
propagators: array_like, shape (n_dt+1, d, d)
493-
The propagators :math:`Q_g = P_g P_{g-1}\cdots P_0` as a (d, d) array
494-
with *d* the dimension of the Hilbert space.
493+
The propagators :math:`Q_g = P_g P_{g-1}\cdots P_0` as a (d, d)
494+
array with *d* the dimension of the Hilbert space.
495495
omega: array_like, shape (n_omega,)
496-
Frequencies at which the pulse control matrix is to be evaluated.
496+
Frequencies at which the pulse control matrix is to be
497+
evaluated.
497498
n_opers: array_like, shape (n_nops, d, d)
498499
Noise operators :math:`B_\alpha`.
499500
n_coeffs: array_like, shape (n_nops, n_dt)
@@ -507,12 +508,19 @@ def calculate_noise_operators_from_scratch(
507508
computed from *dt*.
508509
show_progressbar: bool, optional
509510
Show a progress bar for the calculation.
511+
cache_intermediates: bool, optional
512+
Keep and return intermediate terms of the calculation that can
513+
be reused in other computations (second order or gradients).
514+
Otherwise the sum is performed in-place. The default is False.
510515
511516
Returns
512517
-------
513518
noise_operators: ndarray, shape (n_omega, n_nops, d, d)
514519
The interaction picture noise operators
515520
:math:`\tilde{B}_\alpha(\omega)`.
521+
intermediates: dict[str, ndarray]
522+
Intermediate results of the calculation. Only if
523+
cache_intermediates is True.
516524
517525
Notes
518526
-----
@@ -700,7 +708,7 @@ def calculate_control_matrix_from_scratch(
700708
show_progressbar: bool = False,
701709
cache_intermediates: bool = False,
702710
out: Optional[ndarray] = None
703-
) -> Union[ndarray, Dict[str, ndarray]]:
711+
) -> Union[ndarray, Tuple[ndarray, Dict[str, ndarray]]]:
704712
r"""
705713
Calculate the control matrix from scratch, i.e. without knowledge of
706714
the control matrices of more atomic pulse sequences.
@@ -738,10 +746,9 @@ def calculate_control_matrix_from_scratch(
738746
show_progressbar: bool, optional
739747
Show a progress bar for the calculation.
740748
cache_intermediates: bool, optional
741-
Keep and return intermediate terms
742-
:math:`\mathcal{G}^{(g)}(\omega)` of the sum so that
743-
:math:`\mathcal{B}(\omega)=\sum_g\mathcal{G}^{(g)}(\omega)`.
744-
Otherwise the sum is performed in-place.
749+
Keep and return intermediate terms of the calculation that can
750+
be reused in other computations (second order or gradients).
751+
Otherwise the sum is performed in-place. The default is False.
745752
out: ndarray, optional
746753
A location into which the result is stored. See
747754
:func:`numpy.ufunc`.
@@ -934,7 +941,8 @@ def calculate_cumulant_function(
934941
decay_amplitudes: Optional[ndarray] = None,
935942
frequency_shifts: Optional[ndarray] = None,
936943
show_progressbar: bool = False,
937-
memory_parsimonious: bool = False
944+
memory_parsimonious: bool = False,
945+
cache_intermediates: bool = False
938946
) -> ndarray:
939947
r"""Calculate the cumulant function :math:`\mathcal{K}(\tau)`.
940948
@@ -984,6 +992,11 @@ def calculate_cumulant_function(
984992
Trade memory footprint for performance. See
985993
:func:`~numeric.calculate_decay_amplitudes`. The default is
986994
``False``.
995+
cache_intermediates: bool, optional
996+
Keep and return intermediate terms of the calculation of the
997+
control matrix that can be reused in other computations (second
998+
order or gradients). Otherwise the sum is performed in-place.
999+
The default is False.
9871000
9881001
Returns
9891002
-------
@@ -1063,9 +1076,8 @@ def calculate_cumulant_function(
10631076

10641077
if decay_amplitudes is None:
10651078
decay_amplitudes = calculate_decay_amplitudes(pulse, spectrum, omega, n_oper_identifiers,
1066-
which, show_progressbar,
1067-
cache_intermediates=second_order,
1068-
memory_parsimonious=memory_parsimonious)
1079+
which, show_progressbar, cache_intermediates,
1080+
memory_parsimonious)
10691081

10701082
if second_order:
10711083
if frequency_shifts is None:
@@ -1694,7 +1706,8 @@ def error_transfer_matrix(
16941706
second_order: bool = False,
16951707
cumulant_function: Optional[ndarray] = None,
16961708
show_progressbar: bool = False,
1697-
memory_parsimonious: bool = False
1709+
memory_parsimonious: bool = False,
1710+
cache_intermediates: bool = False
16981711
) -> ndarray:
16991712
r"""Compute the error transfer matrix up to unitary rotations.
17001713
@@ -1735,6 +1748,11 @@ def error_transfer_matrix(
17351748
Trade memory footprint for performance. See
17361749
:func:`~numeric.calculate_decay_amplitudes`. The default is
17371750
``False``.
1751+
cache_intermediates: bool, optional
1752+
Keep and return intermediate terms of the calculation of the
1753+
control matrix (if it is not already cached) that can be reused
1754+
for second order or gradients. Can consume large amount of
1755+
memory, but speed up the calculation.
17381756
17391757
Returns
17401758
-------
@@ -1787,7 +1805,8 @@ def error_transfer_matrix(
17871805
cumulant_function = calculate_cumulant_function(pulse, spectrum, omega,
17881806
n_oper_identifiers, 'total', second_order,
17891807
show_progressbar=show_progressbar,
1790-
memory_parsimonious=memory_parsimonious)
1808+
memory_parsimonious=memory_parsimonious,
1809+
cache_intermediates=cache_intermediates)
17911810

17921811
try:
17931812
# agnostic of the specific shape of cumulant_function, just sum over everything except for
@@ -1804,11 +1823,17 @@ def error_transfer_matrix(
18041823

18051824

18061825
@util.parse_optional_parameters({'which': ('total', 'correlations')})
1807-
def infidelity(pulse: 'PulseSequence', spectrum: Union[Coefficients, Callable],
1808-
omega: Union[Coefficients, Dict[str, Union[int, str]]],
1809-
n_oper_identifiers: Optional[Sequence[str]] = None,
1810-
which: str = 'total', return_smallness: bool = False,
1811-
test_convergence: bool = False) -> Union[ndarray, Any]:
1826+
def infidelity(
1827+
pulse: 'PulseSequence',
1828+
spectrum: Union[Coefficients, Callable],
1829+
omega: Union[Coefficients, Dict[str, Union[int, str]]],
1830+
n_oper_identifiers: Optional[Sequence[str]] = None,
1831+
which: str = 'total',
1832+
show_progressbar: bool = False,
1833+
cache_intermediates: bool = False,
1834+
return_smallness: bool = False,
1835+
test_convergence: bool = False
1836+
) -> Union[ndarray, Any]:
18121837
r"""Calculate the leading order entanglement infidelity.
18131838
18141839
This function calculates the infidelity approximately from the
@@ -1861,6 +1886,12 @@ def infidelity(pulse: 'PulseSequence', spectrum: Union[Coefficients, Callable],
18611886
functions have been computed during concatenation (see
18621887
:func:`calculate_pulse_correlation_filter_function` and
18631888
:func:`~filter_functions.pulse_sequence.concatenate`).
1889+
show_progressbar: bool, optional
1890+
Show a progressbar for the calculation of the control matrix.
1891+
cache_intermediates: bool, optional
1892+
Keep and return intermediate terms of the calculation of the
1893+
control matrix (if it is not already cached) that can be reused
1894+
for second order or gradients. The default is False.
18641895
return_smallness: bool, optional
18651896
Return the smallness parameter :math:`\xi` for the given
18661897
spectrum.
@@ -1998,7 +2029,8 @@ def infidelity(pulse: 'PulseSequence', spectrum: Union[Coefficients, Callable],
19982029
freqs = xspace(omega_IR, omega_UV, n//2)
19992030
convergence_infids[i] = infidelity(pulse, spectrum(freqs), freqs,
20002031
n_oper_identifiers=n_oper_identifiers,
2001-
which='total', return_smallness=False,
2032+
which='total', show_progressbar=show_progressbar,
2033+
cache_intermediates=False, return_smallness=False,
20022034
test_convergence=False)
20032035

20042036
return n_samples, convergence_infids
@@ -2012,11 +2044,13 @@ def infidelity(pulse: 'PulseSequence', spectrum: Union[Coefficients, Callable],
20122044
traces_diag = (sparse.diagonal(traces, axis1=2, axis2=3).sum(-1) -
20132045
sparse.diagonal(traces, axis1=1, axis2=3).sum(-1)).todense()
20142046

2015-
control_matrix = pulse.get_control_matrix(omega)
2047+
control_matrix = pulse.get_control_matrix(omega, show_progressbar, cache_intermediates)
20162048
filter_function = np.einsum('ako,blo,kl->abo',
20172049
control_matrix.conj(), control_matrix, traces_diag)/pulse.d
20182050
else:
2019-
filter_function = pulse.get_filter_function(omega)
2051+
filter_function = pulse.get_filter_function(omega, which='fidelity',
2052+
show_progressbar=show_progressbar,
2053+
cache_intermediates=cache_intermediates)
20202054
else:
20212055
# which == 'correlations'
20222056
if not pulse.basis.istraceless:

tests/test_gradient.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,38 @@ def test_gradient_calculation_random_pulse(self):
7272
)
7373
self.assertArrayAlmostEqual(ana_grad, fin_diff_grad, rtol=1e-6, atol=1e-8)
7474

75+
def test_caching(self):
76+
"""Make sure calculation works with or without cached intermediates."""
77+
78+
for d, n_dt in zip(testutil.rng.integers(2, 5, 5), testutil.rng.integers(2, 8, 5)):
79+
pulse = testutil.rand_pulse_sequence(d, n_dt)
80+
omega = ff.util.get_sample_frequencies(pulse, n_samples=27)
81+
spect = 1/omega
82+
83+
# Cache control matrix but not intermediates
84+
pulse.cache_control_matrix(omega, cache_intermediates=False)
85+
infid_nocache = ff.infidelity(pulse, spect, omega, cache_intermediates=False)
86+
infid_cache = ff.infidelity(pulse, spect, omega, cache_intermediates=True)
87+
88+
self.assertArrayAlmostEqual(infid_nocache, infid_cache)
89+
90+
cm_nocache = ff.gradient.calculate_derivative_of_control_matrix_from_scratch(
91+
omega, pulse.propagators, pulse.eigvals, pulse.eigvecs, pulse.basis, pulse.t,
92+
pulse.dt, pulse.n_opers, pulse.n_coeffs, pulse.c_opers, pulse.c_oper_identifiers,
93+
intermediates=dict()
94+
)
95+
96+
pulse.cleanup('frequency dependent')
97+
pulse.cache_control_matrix(omega, cache_intermediates=True)
98+
cm_cache = ff.gradient.calculate_derivative_of_control_matrix_from_scratch(
99+
omega, pulse.propagators, pulse.eigvals, pulse.eigvecs, pulse.basis, pulse.t,
100+
pulse.dt, pulse.n_opers, pulse.n_coeffs, pulse.c_opers, pulse.c_oper_identifiers,
101+
intermediates=pulse._intermediates
102+
)
103+
104+
self.assertArrayAlmostEqual(cm_nocache, cm_cache)
105+
106+
75107
def test_raises(self):
76108
pulse = testutil.rand_pulse_sequence(2, 3)
77109
omega = ff.util.get_sample_frequencies(pulse, n_samples=13)

0 commit comments

Comments
 (0)