Skip to content

Commit b3c8707

Browse files
restrict classical action of certain arithmetic bloqs (#1518)
1 parent 1bd4f70 commit b3c8707

File tree

4 files changed

+59
-7
lines changed

4 files changed

+59
-7
lines changed

qualtran/bloqs/arithmetic/subtraction.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import math
1615
from typing import Dict, Optional, Tuple, TYPE_CHECKING, Union
1716

1817
import numpy as np
@@ -40,6 +39,7 @@
4039
from qualtran.bloqs.bookkeeping import Allocate, Cast, Free
4140
from qualtran.bloqs.mcmt.multi_target_cnot import MultiTargetCNOT
4241
from qualtran.drawing import Text
42+
from qualtran.simulation.classical_sim import add_ints
4343

4444
if TYPE_CHECKING:
4545
from qualtran.drawing import WireSymbol
@@ -270,10 +270,15 @@ def signature(self):
270270
def on_classical_vals(
271271
self, a: 'ClassicalValT', b: 'ClassicalValT'
272272
) -> Dict[str, 'ClassicalValT']:
273-
unsigned = isinstance(self.dtype, (QUInt, QMontgomeryUInt))
274-
bitsize = self.dtype.bitsize
275-
N = 2**bitsize if unsigned else 2 ** (bitsize - 1)
276-
return {'a': a, 'b': int(math.fmod(b - a, N))}
273+
return {
274+
'a': a,
275+
'b': add_ints(
276+
int(b),
277+
-int(a),
278+
num_bits=int(self.dtype.bitsize),
279+
is_signed=isinstance(self.dtype, QInt),
280+
),
281+
}
277282

278283
def wire_symbol(
279284
self, reg: Optional['Register'], idx: Tuple[int, ...] = tuple()

qualtran/bloqs/arithmetic/subtraction_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,12 @@ def test_subtract_from_bloq_decomposition():
160160
want[(a << 4) | c][a_b] = 1
161161
got = gate.tensor_contract()
162162
np.testing.assert_allclose(got, want)
163+
164+
165+
@pytest.mark.parametrize('bitsize', range(2, 5))
166+
def test_subtractfrom_classical_action(bitsize):
167+
dtype = QInt(bitsize)
168+
blq = SubtractFrom(dtype)
169+
qlt_testing.assert_consistent_classical_action(
170+
blq, a=tuple(dtype.get_classical_domain()), b=tuple(dtype.get_classical_domain())
171+
)

qualtran/bloqs/mod_arithmetic/mod_addition.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,17 @@ def signature(self) -> 'Signature':
8787
def on_classical_vals(
8888
self, x: 'ClassicalValT', y: 'ClassicalValT'
8989
) -> Dict[str, 'ClassicalValT']:
90-
return {'x': x, 'y': (x + y) % self.mod}
90+
if not (0 <= x < self.mod):
91+
raise ValueError(
92+
f'{x=} is outside the valid interval for modular addition [0, {self.mod})'
93+
)
94+
if not (0 <= y < self.mod):
95+
raise ValueError(
96+
f'{y=} is outside the valid interval for modular addition [0, {self.mod})'
97+
)
98+
99+
y = (x + y) % self.mod
100+
return {'x': x, 'y': y}
91101

92102
def build_composite_bloq(self, bb: 'BloqBuilder', x: Soquet, y: Soquet) -> Dict[str, 'SoquetT']:
93103
if is_symbolic(self.bitsize):
@@ -307,6 +317,12 @@ def on_classical_vals(
307317
return {'ctrl': 0, 'x': x}
308318

309319
assert ctrl == 1, 'Bad ctrl value.'
320+
321+
if not (0 <= x < self.mod):
322+
raise ValueError(
323+
f'{x=} is outside the valid interval for modular addition [0, {self.mod})'
324+
)
325+
310326
x = (x + self.k) % self.mod
311327
return {'ctrl': ctrl, 'x': x}
312328

@@ -492,7 +508,17 @@ def on_classical_vals(
492508
if ctrl != self.cv:
493509
return {'ctrl': ctrl, 'x': x, 'y': y}
494510

495-
return {'ctrl': ctrl, 'x': x, 'y': (x + y) % self.mod}
511+
if not (0 <= x < self.mod):
512+
raise ValueError(
513+
f'{x=} is outside the valid interval for modular addition [0, {self.mod})'
514+
)
515+
if not (0 <= y < self.mod):
516+
raise ValueError(
517+
f'{y=} is outside the valid interval for modular addition [0, {self.mod})'
518+
)
519+
520+
y = (x + y) % self.mod
521+
return {'ctrl': ctrl, 'x': x, 'y': y}
496522

497523
def build_composite_bloq(
498524
self, bb: 'BloqBuilder', ctrl, x: Soquet, y: Soquet

qualtran/bloqs/mod_arithmetic/mod_addition_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,3 +208,15 @@ def test_cmod_add_complexity_vs_ref():
208208
def test_mod_add_classical_action(bitsize, prime):
209209
b = ModAdd(bitsize, prime)
210210
assert_consistent_classical_action(b, x=range(prime), y=range(prime))
211+
212+
213+
def test_cmodadd_tensor():
214+
blq = CModAddK(bitsize=4, mod=7, k=1)
215+
want = np.zeros((7, 7))
216+
for i in range(7):
217+
j = (i + 1) % 7
218+
want[j, i] = 1
219+
220+
tn = blq.tensor_contract()
221+
np.testing.assert_allclose(tn[:7, :7], np.eye(7)) # ctrl = 0
222+
np.testing.assert_allclose(tn[16 : 16 + 7, 16 : 16 + 7], want) # ctrl = 1

0 commit comments

Comments
 (0)