28
28
from pytools .lex import ParseError
29
29
from pymbolic .mapper import IdentityMapper
30
30
31
+ import logging
32
+ logger = logging .getLogger (__name__ )
33
+
31
34
32
35
# {{{ utilities
33
36
34
37
def assert_parsed_same_as_python (expr_str ):
35
38
# makes sure that has only one line
36
39
expr_str , = expr_str .split ("\n " )
37
- from pymbolic . interop . ast import ASTToPymbolic
40
+
38
41
import ast
42
+ from pymbolic .interop .ast import ASTToPymbolic
39
43
ast2p = ASTToPymbolic ()
44
+
40
45
try :
41
46
expr_parsed_by_python = ast2p (ast .parse (expr_str ).body [0 ].value )
42
47
except SyntaxError :
@@ -48,9 +53,10 @@ def assert_parsed_same_as_python(expr_str):
48
53
49
54
50
55
def assert_parse_roundtrip (expr_str ):
51
- expr = parse (expr_str )
52
56
from pymbolic .mapper .stringifier import StringifyMapper
57
+ expr = parse (expr_str )
53
58
strified = StringifyMapper ()(expr )
59
+
54
60
assert strified == expr_str , (strified , expr_str )
55
61
56
62
# }}}
@@ -123,7 +129,6 @@ def expect_typeerror(f):
123
129
124
130
def test_structure_preservation ():
125
131
x = prim .Sum ((5 , 7 ))
126
- from pymbolic .mapper import IdentityMapper
127
132
x2 = IdentityMapper ()(x )
128
133
assert x == x2
129
134
@@ -200,9 +205,9 @@ def test_fft():
200
205
from pymbolic .algorithm import fft , sym_fft
201
206
202
207
vars = numpy .array ([var (chr (97 + i )) for i in range (16 )], dtype = object )
203
- print ( vars )
208
+ logger . info ( "vars: %s" , vars )
204
209
205
- print ( fft (vars ))
210
+ logger . info ( "fft: %s" , fft (vars ))
206
211
traced_fft = sym_fft (vars )
207
212
208
213
from pymbolic .mapper .stringifier import PREC_NONE
@@ -212,10 +217,10 @@ def test_fft():
212
217
code = [ccm (tfi , PREC_NONE ) for tfi in traced_fft ]
213
218
214
219
for cse_name , cse_str in enumerate (ccm .cse_name_list ):
215
- print ( f" { cse_name } = { cse_str } " )
220
+ logger . info ( "%s = %s" , cse_name , cse_str )
216
221
217
222
for i , line in enumerate (code ):
218
- print ("result[%d] = %s" % ( i , line ) )
223
+ logger . info ("result[%d] = %s" , i , line )
219
224
220
225
# }}}
221
226
@@ -250,25 +255,25 @@ def test_parser():
250
255
parse ("(2*a[1]*b[1]+2*a[0]*b[0])*(hankel_1(-1,sqrt(a[1]**2+a[0]**2)*k) "
251
256
"-hankel_1(1,sqrt(a[1]**2+a[0]**2)*k))*k /(4*sqrt(a[1]**2+a[0]**2)) "
252
257
"+hankel_1(0,sqrt(a[1]**2+a[0]**2)*k)" )
253
- print ( repr ( parse ("d4knl0" ) ))
254
- print ( repr ( parse ("0." ) ))
255
- print ( repr ( parse ("0.e1" ) ))
258
+ logger . info ( "%r" , parse ("d4knl0" ))
259
+ logger . info ( "%r" , parse ("0." ))
260
+ logger . info ( "%r" , parse ("0.e1" ))
256
261
assert parse ("0.e1" ) == 0
257
262
assert parse ("1e-12" ) == 1e-12
258
- print ( repr ( parse ("a >= 1" ) ))
259
- print ( repr ( parse ("a <= 1" ) ))
260
-
261
- print ( repr ( parse (":" ) ))
262
- print ( repr ( parse ("1:" ) ))
263
- print ( repr ( parse (":2" ) ))
264
- print ( repr ( parse ("1:2" ) ))
265
- print ( repr ( parse ("::" ) ))
266
- print ( repr ( parse ("1::" ) ))
267
- print ( repr ( parse (":1:" ) ))
268
- print ( repr ( parse ("::1" ) ))
269
- print ( repr ( parse ("3::1" ) ))
270
- print ( repr ( parse (":5:1" ) ))
271
- print ( repr ( parse ("3:5:1" ) ))
263
+ logger . info ( "%r" , parse ("a >= 1" ))
264
+ logger . info ( "%r" , parse ("a <= 1" ))
265
+
266
+ logger . info ( "%r" , parse (":" ))
267
+ logger . info ( "%r" , parse ("1:" ))
268
+ logger . info ( "%r" , parse (":2" ))
269
+ logger . info ( "%r" , parse ("1:2" ))
270
+ logger . info ( "%r" , parse ("::" ))
271
+ logger . info ( "%r" , parse ("1::" ))
272
+ logger . info ( "%r" , parse (":1:" ))
273
+ logger . info ( "%r" , parse ("::1" ))
274
+ logger . info ( "%r" , parse ("3::1" ))
275
+ logger . info ( "%r" , parse (":5:1" ))
276
+ logger . info ( "%r" , parse ("3:5:1" ))
272
277
273
278
assert_parse_roundtrip ("()" )
274
279
assert_parse_roundtrip ("(3,)" )
@@ -280,17 +285,17 @@ def test_parser():
280
285
assert_parse_roundtrip ("g[i, k] + 2.0*h[i, k]" )
281
286
parse ("g[i,k]+(+2.0)*h[i, k]" )
282
287
283
- print ( repr ( parse ("a - b - c" ) ))
284
- print ( repr ( parse ("-a - -b - -c" ) ))
285
- print ( repr ( parse ("- - - a - - - - b - - - - - c" ) ))
288
+ logger . info ( "%r" , parse ("a - b - c" ))
289
+ logger . info ( "%r" , parse ("-a - -b - -c" ))
290
+ logger . info ( "%r" , parse ("- - - a - - - - b - - - - - c" ))
286
291
287
- print ( repr ( parse ("~(a ^ b)" ) ))
288
- print ( repr ( parse ("(a | b) | ~(~a & ~b)" ) ))
292
+ logger . info ( "%r" , parse ("~(a ^ b)" ))
293
+ logger . info ( "%r" , parse ("(a | b) | ~(~a & ~b)" ))
289
294
290
- print ( repr ( parse ("3 << 1" ) ))
291
- print ( repr ( parse ("1 >> 3" ) ))
295
+ logger . info ( "%r" , parse ("3 << 1" ))
296
+ logger . info ( "%r" , parse ("1 >> 3" ))
292
297
293
- print (parse ("3::1" ))
298
+ logger . info (parse ("3::1" ))
294
299
295
300
assert parse ("e1" ) == prim .Variable ("e1" )
296
301
assert parse ("d1" ) == prim .Variable ("d1" )
@@ -374,7 +379,7 @@ def test_graphviz():
374
379
from pymbolic .mapper .graphviz import GraphvizMapper
375
380
gvm = GraphvizMapper ()
376
381
gvm (expr )
377
- print ( gvm .get_dot_code ())
382
+ logger . info ( "%s" , gvm .get_dot_code ())
378
383
379
384
# }}}
380
385
@@ -495,7 +500,7 @@ def f():
495
500
import ast
496
501
mod = ast .parse (src .replace ("\n " , "\n " ))
497
502
498
- print ( ast .dump (mod ))
503
+ logger . info ( "%s" , ast .dump (mod ))
499
504
500
505
from pymbolic .interop .ast import ASTToPymbolic
501
506
ast2p = ASTToPymbolic ()
@@ -512,7 +517,7 @@ def f():
512
517
lhs = ast2p (lhs )
513
518
rhs = ast2p (stmt .value )
514
519
515
- print ( lhs , rhs )
520
+ logger . info ( "lhs %s rhs %s" , lhs , rhs )
516
521
517
522
# }}}
518
523
0 commit comments