Skip to content

Commit 0676468

Browse files
committed
Continuing migration of updated pygeosx tools
1 parent 19033e1 commit 0676468

File tree

8 files changed

+693
-528
lines changed

8 files changed

+693
-528
lines changed

pygeosx_tools_package/pygeosx_tools/geophysics/fiber.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77

88

99
class Fiber():
10-
def __init__(self):
10+
11+
def __init__( self ):
1112
self.time = []
1213
self.channel_position = []
1314
self.gage_length = -1.0
@@ -19,9 +20,9 @@ def __init__(self):
1920

2021

2122
class FiberAnalysis():
22-
def __init__(self):
23+
24+
def __init__( self ):
2325
"""
2426
InSAR Analysis class
2527
"""
2628
self.set_names = []
27-

pygeosx_tools_package/pygeosx_tools/geophysics/insar.py

Lines changed: 179 additions & 186 deletions
Large diffs are not rendered by default.

pygeosx_tools_package/pygeosx_tools/geophysics/microseismic.py

Lines changed: 246 additions & 262 deletions
Large diffs are not rendered by default.
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
from mpi4py import MPI
2+
import numpy as np
3+
4+
# Get the MPI rank
5+
comm = MPI.COMM_WORLD
6+
rank = comm.Get_rank()
7+
8+
9+
def get_global_array_range( local_values ):
10+
# 1D arrays will return a scalar, ND arrays an array
11+
N = np.shape( local_values )
12+
local_min = 1e100
13+
local_max = -1e100
14+
if ( len( N ) > 1 ):
15+
local_min = np.zeros( N[ 1 ] ) + 1e100
16+
local_max = np.zeros( N[ 1 ] ) - 1e100
17+
18+
# For >1D arrays, keep the last dimension
19+
query_axis = 0
20+
if ( len( N ) > 2 ):
21+
query_axis = tuple( [ ii for ii in range( 0, len( N ) - 1 ) ] )
22+
23+
# Ignore zero-length results
24+
if len( local_values ):
25+
local_min = np.amin( local_values, axis=query_axis )
26+
local_max = np.amax( local_values, axis=query_axis )
27+
28+
# Gather the results onto rank 0
29+
all_min = comm.gather( local_min, root=0 )
30+
all_max = comm.gather( local_max, root=0 )
31+
global_min = 1e100
32+
global_max = -1e100
33+
if ( rank == 0 ):
34+
global_min = np.amin( np.array( all_min ), axis=0 )
35+
global_max = np.amax( np.array( all_max ), axis=0 )
36+
37+
# Broadcast
38+
global_min = comm.bcast( global_min, root=0 )
39+
global_max = comm.bcast( global_max, root=0 )
40+
41+
return global_min, global_max
42+
43+
44+
def gather_array( local_values, allgather=False, concatenate=True ):
45+
# Find buffer size
46+
N = np.shape( local_values )
47+
M = np.prod( N )
48+
all_M = []
49+
max_M = 0
50+
if allgather:
51+
all_M = comm.allgather( M )
52+
max_M = np.amax( all_M )
53+
else:
54+
all_M = comm.gather( M, root=0 )
55+
if ( rank == 0 ):
56+
max_M = np.amax( all_M )
57+
max_M = comm.bcast( max_M, root=0 )
58+
59+
# Pack the array into a buffer
60+
send_buff = np.zeros( max_M )
61+
send_buff[ :M ] = np.reshape( local_values, ( -1 ) )
62+
receive_buff = np.zeros( ( comm.size, max_M ) )
63+
64+
# Gather the buffers
65+
if allgather:
66+
comm.Allgather( [ send_buff, MPI.DOUBLE ], [ receive_buff, MPI.DOUBLE ] )
67+
else:
68+
comm.Gather( [ send_buff, MPI.DOUBLE ], [ receive_buff, MPI.DOUBLE ], root=0 )
69+
70+
# Unpack the buffers
71+
all_values = []
72+
R = list( N )
73+
R[ 0 ] = -1
74+
if ( ( rank == 0 ) | allgather ):
75+
# Reshape each rank's contribution
76+
for ii in range( comm.size ):
77+
if ( all_M[ ii ] > 0 ):
78+
tmp = np.reshape( receive_buff[ ii, :all_M[ ii ] ], R )
79+
all_values.append( tmp )
80+
81+
# Concatenate into a single array
82+
if concatenate:
83+
if ( len( all_values ) ):
84+
all_values = np.concatenate( all_values, axis=0 )
85+
86+
return all_values
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
def patch_numpy_where():
2+
"""
3+
Some packages (numpy, etc.) use np.where in a way that isn't
4+
compliant with the API. In most cases this is fine, but when the
5+
geosx environment is initialized this can cause critical errors with
6+
nan inputs. This patch checks and updates the inputs to where
7+
"""
8+
import numpy as np
9+
np.where_original_fn = np.where
10+
print( 'patching np.where', flush=True )
11+
12+
def variable_as_array_type( x, target_shape ):
13+
if x is None:
14+
return
15+
16+
if isinstance( x, ( np.ndarray, list ) ):
17+
return x
18+
19+
y = np.empty( target_shape )
20+
y[...] = x
21+
return y
22+
23+
def flexible_where( condition, x=None, y=None ):
24+
s = np.shape( condition )
25+
x = variable_as_array_type( x, s )
26+
y = variable_as_array_type( y, s )
27+
return np.where_original_fn( condition, x, y )
28+
29+
np.where = flexible_where
30+
31+
32+
patch_numpy_where()
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import os
2+
import warnings
3+
import matplotlib.pyplot as plt
4+
from matplotlib import cm
5+
import matplotlib.colors as mcolors
6+
7+
8+
class HighResPlot():
9+
10+
def __init__( self ):
11+
self.handle = plt.figure()
12+
self.isAdjusted = False
13+
self.colorspace = 'RGB'
14+
self.compress = False
15+
self.output_format = 'png'
16+
self.font_weight = 'normal'
17+
self.axis_font_size = 8
18+
self.legend_font_size = 8
19+
self.size = ( 3.5, 2.625 )
20+
self.dpi = 200
21+
self.apply_tight_layout = True
22+
self.disable_antialiasing = False
23+
self.use_imagemagick = False
24+
self.apply_font_size()
25+
26+
def __del__( self ):
27+
try:
28+
self.handle.clf()
29+
plt.close( self.handle.number )
30+
except:
31+
pass
32+
33+
def set_active( self ):
34+
plt.figure( self.handle.number )
35+
36+
def reset( self ):
37+
tmp = self.handle.number
38+
self.set_active()
39+
self.handle.clf()
40+
del self.handle
41+
self.handle = plt.figure( tmp )
42+
self.apply_font_size()
43+
self.isAdjusted = False
44+
45+
def apply_font_size( self ):
46+
self.font = { 'weight': self.font_weight, 'size': self.axis_font_size }
47+
plt.rc( 'font', **self.font )
48+
49+
def apply_format( self ):
50+
self.set_active()
51+
self.handle.set_size_inches( self.size )
52+
self.apply_font_size()
53+
54+
for ax in self.handle.get_axes():
55+
tmp = ax.get_legend()
56+
if tmp:
57+
ax.legend( loc=tmp._loc, fontsize=self.legend_font_size )
58+
59+
if self.apply_tight_layout:
60+
with warnings.catch_warnings():
61+
warnings.simplefilter( 'ignore' )
62+
self.handle.tight_layout( pad=1.5 )
63+
self.apply_tight_layout = False
64+
65+
def save( self, fname ):
66+
self.apply_format()
67+
if plt.isinteractive():
68+
plt.draw()
69+
plt.pause( 0.1 )
70+
71+
if self.use_imagemagick:
72+
# Matplotlib's non-pdf layout is sometimes odd,
73+
# so use the imagemagick convert as a back-up option
74+
plt.savefig( '%s.pdf' % ( fname ), format='pdf', dpi=self.dpi )
75+
76+
if ( self.output_format != 'pdf' ):
77+
cmd = 'convert -density %i %s.pdf' % ( self.dpi, fname )
78+
79+
if ( self.colorspace == 'CMYK' ):
80+
cmd += ' -colorspace CMYK'
81+
if ( self.colorspace == 'gray' ):
82+
cmd += ' -colorspace gray'
83+
if ( self.compress ):
84+
cmd += ' -compress lzw'
85+
if ( self.disable_antialiasing ):
86+
cmd += ' +antialias'
87+
88+
cmd += ' %s.%s' % ( fname, self.output_format )
89+
os.system( cmd )
90+
os.system( 'rm %s.pdf' % fname )
91+
else:
92+
plt.savefig( '%s.%s' % ( fname, self.output_format ), format=self.output_format, dpi=self.dpi )
93+
94+
95+
def get_modified_jet_colormap():
96+
# Add an updated colormap
97+
cdict = cm.get_cmap( 'jet' ).__dict__[ '_segmentdata' ].copy()
98+
for k in cdict.keys():
99+
tmp_seq = list( cdict[ k ] )
100+
tmp_final = list( tmp_seq[ 0 ] )
101+
tmp_final[ 1 ] = 1.0
102+
tmp_final[ 2 ] = 1.0
103+
tmp_seq[ 0 ] = tuple( tmp_final )
104+
cdict[ k ] = tuple( tmp_seq )
105+
return mcolors.LinearSegmentedColormap( 'jet_mod', cdict )
106+
107+
108+
def get_periodic_line_style( ii ):
109+
col = [ 'k', 'b', 'c', 'g', 'y', 'r', 'm' ]
110+
mark = [ '-', '--', ':', '-.', 'o', '*', '^', '+' ]
111+
112+
jj = ii % ( len( col ) * len( mark ) )
113+
style = col[ jj % len( col ) ] + mark[ int( jj / len( col ) ) ]
114+
115+
return style

pygeosx_tools_package/pygeosx_tools/wrapper.py

Lines changed: 8 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
11
import sys
22
import numpy as np
3-
from mpi4py import MPI
43
import matplotlib.pyplot as plt
54
import pylvarray
65
import pygeosx
7-
8-
# Get the MPI rank
9-
comm = MPI.COMM_WORLD
10-
rank = comm.Get_rank()
6+
from pygeosx_tools import parallel_io
117

128

139
def get_wrapper( problem, target_key, write_flag=False ):
@@ -52,7 +48,7 @@ def get_wrapper_par( problem, target_key, allgather=False, ghost_key='' ):
5248
Returns:
5349
np.ndarray: The wrapper as a numpy ndarray
5450
"""
55-
if ( comm.size == 1 ):
51+
if ( parallel_io.comm.size == 1 ):
5652
# This is a serial problem
5753
return get_wrapper( problem, target_key )
5854

@@ -66,45 +62,7 @@ def get_wrapper_par( problem, target_key, allgather=False, ghost_key='' ):
6662
ghost_values = get_wrapper( problem, ghost_key )
6763
local_values = local_values[ ghost_values < -0.5 ]
6864

69-
# Find buffer size
70-
N = np.shape( local_values )
71-
M = np.prod( N )
72-
all_M = []
73-
max_M = 0
74-
if allgather:
75-
all_M = comm.allgather( M )
76-
max_M = np.amax( all_M )
77-
else:
78-
all_M = comm.gather( M, root=0 )
79-
if ( rank == 0 ):
80-
max_M = np.amax( all_M )
81-
max_M = comm.bcast( max_M, root=0 )
82-
83-
# Pack the array into a buffer
84-
send_buff = np.zeros( max_M )
85-
send_buff[ :M ] = np.reshape( local_values, ( -1 ) )
86-
receive_buff = np.zeros( ( comm.size, max_M ) )
87-
88-
# Gather the buffers
89-
if allgather:
90-
comm.Allgather( [ send_buff, MPI.DOUBLE ], [ receive_buff, MPI.DOUBLE ] )
91-
else:
92-
comm.Gather( [ send_buff, MPI.DOUBLE ], [ receive_buff, MPI.DOUBLE ], root=0 )
93-
94-
# Unpack the buffers
95-
all_values = []
96-
R = list( N )
97-
R[ 0 ] = -1
98-
if ( ( rank == 0 ) | allgather ):
99-
# Reshape each rank's contribution
100-
for ii in range( comm.size ):
101-
if ( all_M[ ii ] > 0 ):
102-
tmp = np.reshape( receive_buff[ ii, :all_M[ ii ] ], R )
103-
all_values.append( tmp )
104-
105-
# Concatenate into a single array
106-
all_values = np.concatenate( all_values, axis=0 )
107-
return all_values
65+
return parallel_io.gather_array( local_values, allgather=allgather )
10866

10967

11068
def gather_wrapper( problem, key, ghost_key='' ):
@@ -147,34 +105,7 @@ def get_global_value_range( problem, key ):
147105
tuple: The global min/max of the target
148106
"""
149107
local_values = get_wrapper( problem, key )
150-
151-
# 1D arrays will return a scalar, ND arrays an array
152-
N = np.shape( local_values )
153-
local_min = 1e100
154-
local_max = -1e100
155-
if ( len( N ) > 1 ):
156-
local_min = np.zeros( N[ 1 ] ) + 1e100
157-
local_max = np.zeros( N[ 1 ] ) - 1e100
158-
159-
# For >1D arrays, keep the last dimension
160-
query_axis = 0
161-
if ( len( N ) > 2 ):
162-
query_axis = tuple( [ ii for ii in range( 0, len( N ) - 1 ) ] )
163-
164-
# Ignore zero-length results
165-
if len( local_values ):
166-
local_min = np.amin( local_values, axis=query_axis )
167-
local_max = np.amax( local_values, axis=query_axis )
168-
169-
# Gather the results onto rank 0
170-
all_min = comm.gather( local_min, root=0 )
171-
all_max = comm.gather( local_max, root=0 )
172-
global_min = 1e100
173-
global_max = -1e100
174-
if ( rank == 0 ):
175-
global_min = np.amin( np.array( all_min ), axis=0 )
176-
global_max = np.amax( np.array( all_max ), axis=0 )
177-
return global_min, global_max
108+
return parallel_io.get_global_array_range( local_values )
178109

179110

180111
def print_global_value_range( problem, key, header, scale=1.0, precision='%1.4f' ):
@@ -195,7 +126,7 @@ def print_global_value_range( problem, key, header, scale=1.0, precision='%1.4f'
195126
global_min *= scale
196127
global_max *= scale
197128

198-
if ( rank == 0 ):
129+
if ( parallel_io.rank == 0 ):
199130
if isinstance( global_min, np.ndarray ):
200131
min_str = ', '.join( [ precision % ( x ) for x in global_min ] )
201132
max_str = ', '.join( [ precision % ( x ) for x in global_max ] )
@@ -313,12 +244,12 @@ def get_matching_wrapper_path( problem, filters ):
313244
search_datastructure_wrappers_recursive( problem, filters, matching_paths )
314245

315246
if ( len( matching_paths ) == 1 ):
316-
if ( rank == 0 ):
247+
if ( parallel_io.rank == 0 ):
317248
print( 'Found matching wrapper: %s' % ( matching_paths[ 0 ] ) )
318249
return matching_paths[ 0 ]
319250

320251
else:
321-
if ( rank == 0 ):
252+
if ( parallel_io.rank == 0 ):
322253
print( 'Error occured while looking for wrappers:' )
323254
print( 'Filters: [%s]' % ( ', '.join( filters ) ) )
324255
print( 'Matching wrappers: [%s]' % ( ', '.join( matching_paths ) ) )
@@ -360,7 +291,7 @@ def plot_history( records, output_root='.', save_figures=True, show_figures=True
360291
save_figures (bool): Flag to indicate whether figures should be saved (default = True)
361292
show_figures (bool): Flag to indicate whether figures should be drawn (default = False)
362293
"""
363-
if ( rank == 0 ):
294+
if ( parallel_io.rank == 0 ):
364295
for k in records.keys():
365296
if ( k != 'time' ):
366297
# Set the active figure

0 commit comments

Comments
 (0)