Skip to content

Commit d258389

Browse files
authored
sparse state prep: allow user to pick target bitsize if needed (#1430)
* sparse state prep: allow user to pick target bitsize if needed * add test, add custom bitsize support in permute bloqs * notebooks * add more simulation tests * assert target bitsize is large enough
1 parent 542ffc5 commit d258389

File tree

5 files changed

+55
-15
lines changed

5 files changed

+55
-15
lines changed

qualtran/bloqs/arithmetic/permutation.ipynb

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@
5050
"\n",
5151
"#### Parameters\n",
5252
" - `N`: the total size the permutation acts on.\n",
53-
" - `cycles`: a sequence of permutation cycles that form the permutation. \n",
53+
" - `cycles`: a sequence of permutation cycles that form the permutation.\n",
54+
" - `bitsize`: number of bits to store the indices, defaults to $\\ceil(\\log_2(N))$. \n",
5455
"\n",
5556
"#### Registers\n",
5657
" - `x`: integer register storing a value in [0, ..., N - 1] \n",
@@ -235,7 +236,8 @@
235236
"\n",
236237
"#### Parameters\n",
237238
" - `N`: the total size the permutation acts on.\n",
238-
" - `cycle`: the permutation cycle to apply. \n",
239+
" - `cycle`: the permutation cycle to apply.\n",
240+
" - `bitsize`: number of bits to store the indices, defaults to $\\ceil(\\log_2(N))$. \n",
239241
"\n",
240242
"#### Registers\n",
241243
" - `x`: integer register storing a value in [0, ..., N - 1] \n",

qualtran/bloqs/arithmetic/permutation.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ class PermutationCycle(Bloq):
8484
Args:
8585
N: the total size the permutation acts on.
8686
cycle: the permutation cycle to apply.
87+
bitsize: number of bits to store the indices, defaults to $\ceil(\log_2(N))$.
8788
8889
Registers:
8990
x: integer register storing a value in [0, ..., N - 1]
@@ -95,13 +96,14 @@ class PermutationCycle(Bloq):
9596

9697
N: SymbolicInt
9798
cycle: Union[tuple[int, ...], Shaped] = field(converter=_convert_cycle)
99+
bitsize: SymbolicInt = field()
98100

99101
@cached_property
100102
def signature(self) -> Signature:
101103
return Signature.build_from_dtypes(x=BQUInt(self.bitsize, self.N))
102104

103-
@cached_property
104-
def bitsize(self):
105+
@bitsize.default
106+
def _default_bitsize(self):
105107
return bit_length(self.N - 1)
106108

107109
def build_composite_bloq(self, bb: 'BloqBuilder', x: 'SoquetT') -> dict[str, 'SoquetT']:
@@ -194,6 +196,7 @@ class Permutation(Bloq):
194196
Args:
195197
N: the total size the permutation acts on.
196198
cycles: a sequence of permutation cycles that form the permutation.
199+
bitsize: number of bits to store the indices, defaults to $\ceil(\log_2(N))$.
197200
198201
Registers:
199202
x: integer register storing a value in [0, ..., N - 1]
@@ -205,13 +208,14 @@ class Permutation(Bloq):
205208

206209
N: SymbolicInt
207210
cycles: Union[tuple[SymbolicCycleT, ...], Shaped] = field(converter=_convert_cycles)
211+
bitsize: SymbolicInt = field()
208212

209213
@cached_property
210214
def signature(self) -> Signature:
211215
return Signature.build_from_dtypes(x=BQUInt(self.bitsize, self.N))
212216

213-
@cached_property
214-
def bitsize(self):
217+
@bitsize.default
218+
def _default_bitsize(self):
215219
return bit_length(self.N - 1)
216220

217221
def is_symbolic(self):
@@ -265,7 +269,7 @@ def build_composite_bloq(self, bb: 'BloqBuilder', x: 'Soquet') -> dict[str, 'Soq
265269
raise DecomposeTypeError(f"cannot decompose symbolic {self}")
266270

267271
for cycle in self.cycles:
268-
x = bb.add(PermutationCycle(self.N, cycle), x=x)
272+
x = bb.add(PermutationCycle(self.N, cycle, self.bitsize), x=x)
269273

270274
return {'x': x}
271275

@@ -275,7 +279,7 @@ def build_call_graph(
275279
if is_symbolic(self.cycles):
276280
# worst case cost: single cycle of length N
277281
cycle = Shaped((self.N,))
278-
return {PermutationCycle(self.N, cycle): 1}
282+
return {PermutationCycle(self.N, cycle, self.bitsize): 1}
279283

280284
return super().build_call_graph(ssa)
281285

qualtran/bloqs/state_preparation/sparse_state_preparation_via_rotations.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
from typing import Sequence, TYPE_CHECKING, Union
1515

16+
import attrs
1617
import numpy as np
1718
import sympy
1819
from attrs import field, frozen
@@ -25,6 +26,7 @@
2526
_to_tuple_or_has_length,
2627
StatePreparationViaRotations,
2728
)
29+
from qualtran.resource_counting.generalizers import ignore_split_join
2830
from qualtran.symbolics import bit_length, HasLength, is_symbolic, slen, SymbolicInt
2931

3032
if TYPE_CHECKING:
@@ -58,21 +60,24 @@ class SparseStatePreparationViaRotations(Bloq):
5860
nonzero_coeffs: Union[tuple[complex, ...], HasLength] = field(converter=_to_tuple_or_has_length)
5961
N: SymbolicInt
6062
phase_bitsize: SymbolicInt
63+
target_bitsize: SymbolicInt = field()
6164

6265
def __attrs_post_init__(self):
6366
n_idx = slen(self.sparse_indices)
6467
n_coeff = slen(self.nonzero_coeffs)
6568
if not is_symbolic(n_idx, n_coeff) and n_idx != n_coeff:
6669
raise ValueError(f"Number of indices {n_idx} must equal number of coeffs {n_coeff}")
70+
if not is_symbolic(self.target_bitsize, self.N):
71+
assert 2**self.target_bitsize >= self.N
6772

6873
@property
6974
def signature(self) -> Signature:
7075
return Signature.build_from_dtypes(
7176
target_state=QUInt(self.target_bitsize), phase_gradient=QAny(self.phase_bitsize)
7277
)
7378

74-
@property
75-
def target_bitsize(self) -> SymbolicInt:
79+
@target_bitsize.default
80+
def _default_target_bitsize(self) -> SymbolicInt:
7681
return bit_length(self.N - 1)
7782

7883
@property
@@ -160,6 +165,7 @@ def _dense_stateprep_bloq(self) -> StatePreparationViaRotations:
160165
dense_coeffs_padded = np.pad(
161166
list(self.nonzero_coeffs), (0, 2**self.dense_bitsize - len(self.nonzero_coeffs))
162167
)
168+
dense_coeffs_padded = dense_coeffs_padded / np.linalg.norm(dense_coeffs_padded)
163169
return StatePreparationViaRotations(tuple(dense_coeffs_padded.tolist()), self.phase_bitsize)
164170

165171
@property
@@ -170,9 +176,10 @@ def _basis_permutation_bloq(self) -> Permutation:
170176

171177
assert isinstance(self.sparse_indices, tuple)
172178

173-
return Permutation.from_partial_permutation_map(
179+
permute_bloq = Permutation.from_partial_permutation_map(
174180
self.N, dict(enumerate(self.sparse_indices))
175181
)
182+
return attrs.evolve(permute_bloq, bitsize=self.target_bitsize)
176183

177184
def build_composite_bloq(
178185
self, bb: 'BloqBuilder', target_state: 'SoquetT', phase_gradient: 'SoquetT'
@@ -198,10 +205,24 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
198205
return {self._dense_stateprep_bloq: 1, self._basis_permutation_bloq: 1}
199206

200207

201-
@bloq_example
208+
@bloq_example(generalizer=ignore_split_join)
202209
def _sparse_state_prep_via_rotations() -> SparseStatePreparationViaRotations:
203210
sparse_state_prep_via_rotations = SparseStatePreparationViaRotations.from_sparse_array(
204211
[0.70914953, 0, 0, 0, 0.46943701, 0, 0.2297245, 0, 0, 0.32960471, 0, 0, 0.33959273, 0, 0],
205212
phase_bitsize=2,
206213
)
207214
return sparse_state_prep_via_rotations
215+
216+
217+
@bloq_example(generalizer=ignore_split_join)
218+
def _sparse_state_prep_via_rotations_with_large_target_bitsize() -> (
219+
SparseStatePreparationViaRotations
220+
):
221+
sparse_state_prep_via_rotations = SparseStatePreparationViaRotations.from_sparse_array(
222+
[0.70914953, 0, 0, 0, 0.46943701, 0, 0.2297245, 0, 0, 0.32960471, 0, 0, 0.33959273, 0, 0],
223+
phase_bitsize=2,
224+
)
225+
sparse_state_prep_via_rotations_with_large_target_bitsize = attrs.evolve(
226+
sparse_state_prep_via_rotations, target_bitsize=6
227+
)
228+
return sparse_state_prep_via_rotations_with_large_target_bitsize

qualtran/bloqs/state_preparation/sparse_state_preparation_via_rotations_test.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from typing import Optional
15+
16+
import attrs
1417
import numpy as np
1518
import pytest
1619
from numpy.typing import NDArray
@@ -20,12 +23,17 @@
2023
from qualtran.bloqs.rotations import PhaseGradientState
2124
from qualtran.bloqs.state_preparation.sparse_state_preparation_via_rotations import (
2225
_sparse_state_prep_via_rotations,
26+
_sparse_state_prep_via_rotations_with_large_target_bitsize,
2327
SparseStatePreparationViaRotations,
2428
)
2529

2630

27-
def test_examples(bloq_autotester):
28-
bloq_autotester(_sparse_state_prep_via_rotations)
31+
@pytest.mark.parametrize(
32+
"bloq_ex",
33+
[_sparse_state_prep_via_rotations, _sparse_state_prep_via_rotations_with_large_target_bitsize],
34+
)
35+
def test_examples(bloq_autotester, bloq_ex):
36+
bloq_autotester(bloq_ex)
2937

3038

3139
def get_prepared_state_vector(bloq: SparseStatePreparationViaRotations) -> NDArray[np.complex128]:
@@ -40,7 +48,8 @@ def get_prepared_state_vector(bloq: SparseStatePreparationViaRotations) -> NDArr
4048

4149

4250
@pytest.mark.slow
43-
def test_prepared_state():
51+
@pytest.mark.parametrize("target_bitsize", [None, 4, 6])
52+
def test_prepared_state(target_bitsize: Optional[int]):
4453
expected_state = np.array(
4554
[
4655
(-0.42677669529663675 - 0.1767766952966366j),
@@ -63,6 +72,9 @@ def test_prepared_state():
6372
N = len(expected_state)
6473

6574
bloq = SparseStatePreparationViaRotations.from_sparse_array(expected_state, phase_bitsize=3)
75+
if target_bitsize is not None:
76+
bloq = attrs.evolve(bloq, target_bitsize=target_bitsize)
77+
6678
actual_state = get_prepared_state_vector(bloq)
6779
np.testing.assert_allclose(np.linalg.norm(actual_state), 1)
6880
np.testing.assert_allclose(actual_state[:N], expected_state)

qualtran/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def assert_bloq_example_serializes_for_pytest(bloq_ex: BloqExample):
116116
'state_prep_via_rotation_symb', # cannot serialize HasLength
117117
'state_prep_via_rotation_symb_phasegrad', # cannot serialize Shaped
118118
'sparse_state_prep_via_rotations', # cannot serialize Permutation
119+
'sparse_state_prep_via_rotations_with_large_target_bitsize', # setting an array element with a sequence.
119120
'explicit_matrix_block_encoding', # cannot serialize AutoPartition
120121
'symmetric_banded_matrix_block_encoding', # cannot serialize AutoPartition
121122
'chebyshev_poly_even',

0 commit comments

Comments
 (0)