2222
2323print (device )
2424
25+ enable_dlpack = True
26+ # enable_dlpack = False # for old version of ML libs
27+ tf_device = "/GPU:0"
28+ # tf_device "/device:CPU:0"
29+
2530# dataset preparation
2631
2732(x_train , y_train ), (x_test , y_test ) = tf .keras .datasets .mnist .load_data ()
@@ -52,17 +57,18 @@ def filter_pair(x, y, a, b):
5257
5358
5459def qpreds (x , weights ):
55- c = tc .Circuit (n )
56- for i in range (n ):
57- c .rx (i , theta = x [i ])
58- for j in range (nlayers ):
59- for i in range (n - 1 ):
60- c .cnot (i , i + 1 )
60+ with tf .device (tf_device ):
61+ c = tc .Circuit (n )
6162 for i in range (n ):
62- c .rx (i , theta = weights [2 * j , i ])
63- c .ry (i , theta = weights [2 * j + 1 , i ])
63+ c .rx (i , theta = x [i ])
64+ for j in range (nlayers ):
65+ for i in range (n - 1 ):
66+ c .cnot (i , i + 1 )
67+ for i in range (n ):
68+ c .rx (i , theta = weights [2 * j , i ])
69+ c .ry (i , theta = weights [2 * j + 1 , i ])
6470
65- return K .stack ([K .real (c .expectation_ps (z = [i ])) for i in range (n )])
71+ return K .stack ([K .real (c .expectation_ps (z = [i ])) for i in range (n )])
6672
6773
6874# qpreds_vmap = K.vmap(qpreds, vectorized_argnums=0)
@@ -74,9 +80,8 @@ def qpreds(x, weights):
7480 use_vmap = True ,
7581 use_interface = True ,
7682 use_jit = True ,
77- enable_dlpack = True ,
83+ enable_dlpack = enable_dlpack ,
7884)
79- # enable_dlpack = False for old version of ML libs
8085
8186
8287model = torch .nn .Sequential (quantumnet , torch .nn .Linear (9 , 1 ), torch .nn .Sigmoid ())
0 commit comments