77import autofit as af
88import numpy as np
99from autolens import PointSolver , Tracer
10-
11- try :
12- from autoarray .structures .triangles .coordinate_array .jax_coordinate_array import (
13- CoordinateArrayTriangles ,
14- )
15-
16- except ImportError :
17- from autoarray .structures .triangles .coordinate_array import CoordinateArrayTriangles
18-
19- from autolens .mock import NullTracer
20-
21- pytest .importorskip ("jax" )
22-
23-
24- @pytest .fixture (autouse = True )
25- def register (tracer ):
26- af .Model .from_instance (tracer )
27-
28-
29- @pytest .fixture
30- def solver (grid ):
31- return PointSolver .for_grid (
32- grid = grid ,
33- pixel_scale_precision = 0.01 ,
34- array_triangles_cls = CoordinateArrayTriangles ,
35- )
36-
37-
38- def test_solver (solver ):
39- mass_profile = ag .mp .Isothermal (
40- centre = (0.0 , 0.0 ),
41- einstein_radius = 1.0 ,
42- )
43- tracer = Tracer (
44- galaxies = [ag .Galaxy (redshift = 0.5 , mass = mass_profile )],
45- )
46- result = solver .solve (
47- tracer ,
48- source_plane_coordinate = (0.0 , 0.0 ),
49- )
50- print (result )
51- assert result
52-
53-
54- @pytest .mark .parametrize (
55- "source_plane_coordinate" ,
56- [
57- (0.0 , 0.0 ),
58- (0.0 , 1.0 ),
59- (1.0 , 0.0 ),
60- (1.0 , 1.0 ),
61- (0.5 , 0.5 ),
62- (0.1 , 0.1 ),
63- (- 1.0 , - 1.0 ),
64- ],
65- )
66- def test_trivial (
67- source_plane_coordinate : Tuple [float , float ],
68- grid ,
69- solver ,
70- ):
71- coordinates = solver .solve (
72- NullTracer (),
73- source_plane_coordinate = source_plane_coordinate ,
74- )
75- coordinates = coordinates .array [~ np .isnan (coordinates .array ).any (axis = 1 )]
76- assert coordinates [0 ] == pytest .approx (source_plane_coordinate , abs = 1.0e-1 )
77-
78-
79- def test_real_example (grid , tracer ):
80- solver = PointSolver .for_grid (
81- grid = grid ,
82- pixel_scale_precision = 0.001 ,
83- array_triangles_cls = CoordinateArrayTriangles ,
84- )
85-
86- result = solver .solve (tracer , (0.07 , 0.07 ))
87- assert len (result ) == 5
88-
89-
90- def _test_jax (grid ):
91- sizes = (5 , 10 , 15 , 20 , 25 , 30 , 35 , 40 , 45 , 50 )
92- run_times = []
93- init_times = []
94-
95- for size in sizes :
96- start = time .time ()
97- solver = PointSolver .for_grid (
98- grid = grid ,
99- pixel_scale_precision = 0.001 ,
100- array_triangles_cls = CoordinateArrayTriangles ,
101- max_containing_size = size ,
102- )
103-
104- solver .solve (NullTracer (), (0.07 , 0.07 ))
105-
106- repeats = 100
107-
108- done_init_time = time .time ()
109- init_time = done_init_time - start
110- for _ in range (repeats ):
111- _ = solver .solve (NullTracer (), (0.07 , 0.07 ))
112-
113- # print(result)
114-
115- init_times .append (init_time )
116-
117- run_time = (time .time () - done_init_time ) / repeats
118- run_times .append (run_time )
119-
120- print (f"Time taken for { size } : { run_time } ({ init_time } to init)" )
121-
122- from matplotlib import pyplot as plt
123-
124- plt .plot (sizes , run_times )
125- plt .show ()
10+ #
11+ # try:
12+ # from autoarray.structures.triangles.coordinate_array.jax_coordinate_array import (
13+ # CoordinateArrayTriangles,
14+ # )
15+ #
16+ # except ImportError:
17+ # from autoarray.structures.triangles.coordinate_array import CoordinateArrayTriangles
18+ #
19+ # from autolens.mock import NullTracer
20+ #
21+ # pytest.importorskip("jax")
22+ #
23+ #
24+ # @pytest.fixture(autouse=True)
25+ # def register(tracer):
26+ # af.Model.from_instance(tracer)
27+ #
28+ #
29+ # @pytest.fixture
30+ # def solver(grid):
31+ # return PointSolver.for_grid(
32+ # grid=grid,
33+ # pixel_scale_precision=0.01,
34+ # array_triangles_cls=CoordinateArrayTriangles,
35+ # )
36+ #
37+ #
38+ # def test_solver(solver):
39+ # mass_profile = ag.mp.Isothermal(
40+ # centre=(0.0, 0.0),
41+ # einstein_radius=1.0,
42+ # )
43+ # tracer = Tracer(
44+ # galaxies=[ag.Galaxy(redshift=0.5, mass=mass_profile)],
45+ # )
46+ # result = solver.solve(
47+ # tracer,
48+ # source_plane_coordinate=(0.0, 0.0),
49+ # )
50+ # print(result)
51+ # assert result
52+ #
53+ #
54+ # @pytest.mark.parametrize(
55+ # "source_plane_coordinate",
56+ # [
57+ # (0.0, 0.0),
58+ # (0.0, 1.0),
59+ # (1.0, 0.0),
60+ # (1.0, 1.0),
61+ # (0.5, 0.5),
62+ # (0.1, 0.1),
63+ # (-1.0, -1.0),
64+ # ],
65+ # )
66+ # def test_trivial(
67+ # source_plane_coordinate: Tuple[float, float],
68+ # grid,
69+ # solver,
70+ # ):
71+ # coordinates = solver.solve(
72+ # NullTracer(),
73+ # source_plane_coordinate=source_plane_coordinate,
74+ # )
75+ # coordinates = coordinates.array[~np.isnan(coordinates.array).any(axis=1)]
76+ # assert coordinates[0] == pytest.approx(source_plane_coordinate, abs=1.0e-1)
77+ #
78+ #
79+ # def test_real_example(grid, tracer):
80+ # solver = PointSolver.for_grid(
81+ # grid=grid,
82+ # pixel_scale_precision=0.001,
83+ # array_triangles_cls=CoordinateArrayTriangles,
84+ # )
85+ #
86+ # result = solver.solve(tracer, (0.07, 0.07))
87+ # assert len(result) == 5
88+ #
89+ #
90+ # def _test_jax(grid):
91+ # sizes = (5, 10, 15, 20, 25, 30, 35, 40, 45, 50)
92+ # run_times = []
93+ # init_times = []
94+ #
95+ # for size in sizes:
96+ # start = time.time()
97+ # solver = PointSolver.for_grid(
98+ # grid=grid,
99+ # pixel_scale_precision=0.001,
100+ # array_triangles_cls=CoordinateArrayTriangles,
101+ # max_containing_size=size,
102+ # )
103+ #
104+ # solver.solve(NullTracer(), (0.07, 0.07))
105+ #
106+ # repeats = 100
107+ #
108+ # done_init_time = time.time()
109+ # init_time = done_init_time - start
110+ # for _ in range(repeats):
111+ # _ = solver.solve(NullTracer(), (0.07, 0.07))
112+ #
113+ # # print(result)
114+ #
115+ # init_times.append(init_time)
116+ #
117+ # run_time = (time.time() - done_init_time) / repeats
118+ # run_times.append(run_time)
119+ #
120+ # print(f"Time taken for {size}: {run_time} ({init_time} to init)")
121+ #
122+ # from matplotlib import pyplot as plt
123+ #
124+ # plt.plot(sizes, run_times)
125+ # plt.show()
126+ #
127+ #
128+ # def test_real_example_jax(grid, tracer):
129+ # jax_solver = PointSolver.for_grid(
130+ # grid=grid,
131+ # pixel_scale_precision=0.001,
132+ # array_triangles_cls=CoordinateArrayTriangles,
133+ # )
134+ #
135+ # result = jax_solver.solve(
136+ # tracer=tracer,
137+ # source_plane_coordinate=(0.07, 0.07),
138+ # )
139+ #
140+ # assert len(result) == 5
0 commit comments