Skip to content

Commit fe2e528

Browse files
committed
Initial OpenACC port of the mpas_atm_update_bdy_tend subroutine
This commit enables the GPU execution of the mpas_atm_update_bdy_tend routine using OpenACC directives for data movement and loops. A new timer has been added to time the host-device data transfers in this subroutine, with the label 'mpas_atm_update_bdy_tend [ACC_data_xfer]' This commit also introduces some integers for loop bounds, so as to dereference scalar integer pointers which the OpenACC parallel regions do not correctly copy to device memory.
1 parent fe766bf commit fe2e528

File tree

1 file changed

+125
-40
lines changed

1 file changed

+125
-40
lines changed

src/core_atmosphere/dynamics/mpas_atm_boundaries.F

Lines changed: 125 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,12 @@ subroutine mpas_atm_update_bdy_tend(clock, streamManager, block, firstCall, ierr
9999
type (mpas_pool_type), pointer :: lbc
100100
real (kind=RKIND) :: dt
101101

102-
integer, pointer :: nCells
103-
integer, pointer :: nEdges
104-
integer, pointer :: index_qv
102+
integer, pointer :: nCells_ptr
103+
integer, pointer :: nEdges_ptr
104+
integer, pointer :: nVertLevels_ptr
105+
integer, pointer :: index_qv_ptr
106+
integer, pointer :: nScalars_ptr
107+
integer :: nCells, nEdges, nVertLevels, index_qv, nScalars
105108

106109
real (kind=RKIND), dimension(:,:), pointer :: u
107110
real (kind=RKIND), dimension(:,:), pointer :: ru
@@ -129,7 +132,7 @@ subroutine mpas_atm_update_bdy_tend(clock, streamManager, block, firstCall, ierr
129132
type (MPAS_Time_Type) :: currTime
130133
type (MPAS_TimeInterval_Type) :: lbc_interval
131134
character(len=StrKIND) :: read_time
132-
integer :: iEdge
135+
integer :: iEdge, iCell, k, j
133136
integer :: cell1, cell2
134137

135138

@@ -169,71 +172,142 @@ subroutine mpas_atm_update_bdy_tend(clock, streamManager, block, firstCall, ierr
169172
call mpas_pool_get_array(lbc, 'lbc_u', u, 2)
170173
call mpas_pool_get_array(lbc, 'lbc_ru', ru, 2)
171174
call mpas_pool_get_array(lbc, 'lbc_rho_edge', rho_edge, 2)
175+
call mpas_pool_get_array(lbc, 'lbc_w', w, 2)
172176
call mpas_pool_get_array(lbc, 'lbc_theta', theta, 2)
173177
call mpas_pool_get_array(lbc, 'lbc_rtheta_m', rtheta_m, 2)
174178
call mpas_pool_get_array(lbc, 'lbc_rho_zz', rho_zz, 2)
175179
call mpas_pool_get_array(lbc, 'lbc_rho', rho, 2)
176180
call mpas_pool_get_array(lbc, 'lbc_scalars', scalars, 2)
177181

178182
call mpas_pool_get_array(mesh, 'cellsOnEdge', cellsOnEdge)
179-
call mpas_pool_get_dimension(mesh, 'nCells', nCells)
180-
call mpas_pool_get_dimension(mesh, 'nEdges', nEdges)
181-
call mpas_pool_get_dimension(lbc, 'index_qv', index_qv)
183+
call mpas_pool_get_dimension(mesh, 'nCells', nCells_ptr)
184+
call mpas_pool_get_dimension(mesh, 'nEdges', nEdges_ptr)
185+
call mpas_pool_get_dimension(mesh, 'nVertLevels', nVertLevels_ptr)
186+
call mpas_pool_get_dimension(state, 'num_scalars', nScalars_ptr)
187+
call mpas_pool_get_dimension(lbc, 'index_qv', index_qv_ptr)
182188
call mpas_pool_get_array(mesh, 'zz', zz)
183189

190+
MPAS_ACC_TIMER_START('mpas_atm_update_bdy_tend [ACC_data_xfer]')
191+
if (.not. firstCall) then
192+
call mpas_pool_get_array(lbc, 'lbc_u', lbc_tend_u, 1)
193+
call mpas_pool_get_array(lbc, 'lbc_ru', lbc_tend_ru, 1)
194+
call mpas_pool_get_array(lbc, 'lbc_rho_edge', lbc_tend_rho_edge, 1)
195+
call mpas_pool_get_array(lbc, 'lbc_w', lbc_tend_w, 1)
196+
call mpas_pool_get_array(lbc, 'lbc_theta', lbc_tend_theta, 1)
197+
call mpas_pool_get_array(lbc, 'lbc_rtheta_m', lbc_tend_rtheta_m, 1)
198+
call mpas_pool_get_array(lbc, 'lbc_rho_zz', lbc_tend_rho_zz, 1)
199+
call mpas_pool_get_array(lbc, 'lbc_rho', lbc_tend_rho, 1)
200+
call mpas_pool_get_array(lbc, 'lbc_scalars', lbc_tend_scalars, 1)
201+
202+
!$acc enter data copyin(lbc_tend_u, lbc_tend_ru, lbc_tend_rho_edge, lbc_tend_w, &
203+
!$acc lbc_tend_theta, lbc_tend_rtheta_m, lbc_tend_rho_zz, &
204+
!$acc lbc_tend_rho, lbc_tend_scalars)
205+
end if
206+
!$acc enter data copyin(u, w, theta, rho, scalars)
207+
!$acc enter data create(ru, rho_edge, rtheta_m, rho_zz)
208+
MPAS_ACC_TIMER_STOP('mpas_atm_update_bdy_tend [ACC_data_xfer]')
209+
210+
! Dereference the pointers to avoid non-array pointer for OpenACC
211+
nCells = nCells_ptr
212+
nEdges = nEdges_ptr
213+
nVertLevels = nVertLevels_ptr
214+
nScalars = nScalars_ptr
215+
index_qv = index_qv_ptr
216+
184217
! Compute lbc_rho_zz
185-
zz(:,nCells+1) = 1.0_RKIND ! Avoid potential division by zero in the following line
186-
rho_zz(:,:) = rho(:,:) / zz(:,:)
218+
!$acc parallel default(present)
219+
!$acc loop vector
220+
do k=1,nVertLevels
221+
zz(k,nCells+1) = 1.0_RKIND ! Avoid potential division by zero in the following line
222+
end do
223+
!$acc end parallel
224+
225+
!$acc parallel default(present)
226+
!$acc loop gang vector collapse(2)
227+
do iCell=1,nCells+1
228+
do k=1,nVertLevels
229+
rho_zz(k,iCell) = rho(k,iCell) / zz(k,iCell)
230+
end do
231+
end do
232+
!$acc end parallel
187233

188234
! Average lbc_rho_zz to edges
235+
!$acc parallel default(present)
236+
!$acc loop gang worker
189237
do iEdge=1,nEdges
190238
cell1 = cellsOnEdge(1,iEdge)
191239
cell2 = cellsOnEdge(2,iEdge)
192240
if (cell1 > 0 .and. cell2 > 0) then
193-
rho_edge(:,iEdge) = 0.5_RKIND * (rho_zz(:,cell1) + rho_zz(:,cell2))
241+
!$acc loop vector
242+
do k = 1, nVertLevels
243+
rho_edge(k,iEdge) = 0.5_RKIND * (rho_zz(k,cell1) + rho_zz(k,cell2))
244+
end do
194245
end if
195246
end do
247+
!$acc end parallel
196248

197-
ru(:,:) = u(:,:) * rho_edge(:,:)
198-
rtheta_m(:,:) = theta(:,:) * rho_zz(:,:) * (1.0_RKIND + rvord * scalars(index_qv,:,:))
249+
!$acc parallel default(present)
250+
!$acc loop gang vector collapse(2)
251+
do iEdge=1,nEdges+1
252+
do k=1,nVertLevels
253+
ru(k,iEdge) = u(k,iEdge) * rho_edge(k,iEdge)
254+
end do
255+
end do
256+
257+
!$acc loop gang vector collapse(2)
258+
do iCell=1,nCells+1
259+
do k=1,nVertLevels
260+
rtheta_m(k,iCell) = theta(k,iCell) * rho_zz(k,iCell) * (1.0_RKIND + rvord * scalars(index_qv,k,iCell))
261+
end do
262+
end do
263+
!$acc end parallel
199264

200265
if (.not. firstCall) then
201266
lbc_interval = currTime - LBC_intv_end
202267
call mpas_get_timeInterval(interval=lbc_interval, DD=dd_intv, S=s_intv, S_n=sn_intv, S_d=sd_intv, ierr=ierr)
203268
dt = 86400.0_RKIND * real(dd_intv, kind=RKIND) + real(s_intv, kind=RKIND) &
204269
+ (real(sn_intv, kind=RKIND) / real(sd_intv, kind=RKIND))
205270

206-
call mpas_pool_get_array(lbc, 'lbc_u', u, 2)
207-
call mpas_pool_get_array(lbc, 'lbc_ru', ru, 2)
208-
call mpas_pool_get_array(lbc, 'lbc_rho_edge', rho_edge, 2)
209-
call mpas_pool_get_array(lbc, 'lbc_w', w, 2)
210-
call mpas_pool_get_array(lbc, 'lbc_theta', theta, 2)
211-
call mpas_pool_get_array(lbc, 'lbc_rtheta_m', rtheta_m, 2)
212-
call mpas_pool_get_array(lbc, 'lbc_rho_zz', rho_zz, 2)
213-
call mpas_pool_get_array(lbc, 'lbc_rho', rho, 2)
214-
call mpas_pool_get_array(lbc, 'lbc_scalars', scalars, 2)
215271

216-
call mpas_pool_get_array(lbc, 'lbc_u', lbc_tend_u, 1)
217-
call mpas_pool_get_array(lbc, 'lbc_ru', lbc_tend_ru, 1)
218-
call mpas_pool_get_array(lbc, 'lbc_rho_edge', lbc_tend_rho_edge, 1)
219-
call mpas_pool_get_array(lbc, 'lbc_w', lbc_tend_w, 1)
220-
call mpas_pool_get_array(lbc, 'lbc_theta', lbc_tend_theta, 1)
221-
call mpas_pool_get_array(lbc, 'lbc_rtheta_m', lbc_tend_rtheta_m, 1)
222-
call mpas_pool_get_array(lbc, 'lbc_rho_zz', lbc_tend_rho_zz, 1)
223-
call mpas_pool_get_array(lbc, 'lbc_rho', lbc_tend_rho, 1)
224-
call mpas_pool_get_array(lbc, 'lbc_scalars', lbc_tend_scalars, 1)
272+
dt = 1.0_RKIND / dt
273+
274+
!$acc parallel default(present)
275+
!$acc loop gang vector collapse(2)
276+
do iEdge=1,nEdges+1
277+
do k=1,nVertLevels
278+
lbc_tend_u(k,iEdge) = (u(k,iEdge) - lbc_tend_u(k,iEdge)) * dt
279+
lbc_tend_ru(k,iEdge) = (ru(k,iEdge) - lbc_tend_ru(k,iEdge)) * dt
280+
lbc_tend_rho_edge(k,iEdge) = (rho_edge(k,iEdge) - lbc_tend_rho_edge(k,iEdge)) * dt
281+
end do
282+
end do
225283

284+
!$acc loop gang vector collapse(2)
285+
do iCell=1,nCells+1
286+
do k=1,nVertLevels+1
287+
lbc_tend_w(k,iCell) = (w(k,iCell) - lbc_tend_w(k,iCell)) * dt
288+
end do
289+
end do
226290

227-
dt = 1.0_RKIND / dt
228-
lbc_tend_u(:,:) = (u(:,:) - lbc_tend_u(:,:)) * dt
229-
lbc_tend_ru(:,:) = (ru(:,:) - lbc_tend_ru(:,:)) * dt
230-
lbc_tend_rho_edge(:,:) = (rho_edge(:,:) - lbc_tend_rho_edge(:,:)) * dt
231-
lbc_tend_w(:,:) = (w(:,:) - lbc_tend_w(:,:)) * dt
232-
lbc_tend_theta(:,:) = (theta(:,:) - lbc_tend_theta(:,:)) * dt
233-
lbc_tend_rtheta_m(:,:) = (rtheta_m(:,:) - lbc_tend_rtheta_m(:,:)) * dt
234-
lbc_tend_rho_zz(:,:) = (rho_zz(:,:) - lbc_tend_rho_zz(:,:)) * dt
235-
lbc_tend_rho(:,:) = (rho(:,:) - lbc_tend_rho(:,:)) * dt
236-
lbc_tend_scalars(:,:,:) = (scalars(:,:,:) - lbc_tend_scalars(:,:,:)) * dt
291+
!$acc loop gang vector collapse(2)
292+
do iCell=1,nCells+1
293+
do k=1,nVertLevels
294+
lbc_tend_theta(k,iCell) = (theta(k,iCell) - lbc_tend_theta(k,iCell)) * dt
295+
lbc_tend_rtheta_m(k,iCell) = (rtheta_m(k,iCell) - lbc_tend_rtheta_m(k,iCell)) * dt
296+
lbc_tend_rho_zz(k,iCell) = (rho_zz(k,iCell) - lbc_tend_rho_zz(k,iCell)) * dt
297+
lbc_tend_rho(k,iCell) = (rho(k,iCell) - lbc_tend_rho(k,iCell)) * dt
298+
end do
299+
end do
300+
301+
!$acc loop gang
302+
do iCell=1,nCells+1
303+
!$acc loop vector collapse(2)
304+
do k=1,nVertLevels
305+
do j = 1,nScalars
306+
lbc_tend_scalars(j,k,iCell) = (scalars(j,k,iCell) - lbc_tend_scalars(j,k,iCell)) * dt
307+
end do
308+
end do
309+
end do
310+
!$acc end parallel
237311

238312
!
239313
! Logging the lbc start and end times appears to be backwards, but
@@ -249,6 +323,17 @@ subroutine mpas_atm_update_bdy_tend(clock, streamManager, block, firstCall, ierr
249323

250324
end if
251325

326+
MPAS_ACC_TIMER_START('mpas_atm_update_bdy_tend [ACC_data_xfer]')
327+
if (.not. firstCall) then
328+
!$acc exit data copyout(lbc_tend_u, lbc_tend_ru, lbc_tend_rho_edge, lbc_tend_w, &
329+
!$acc lbc_tend_theta, lbc_tend_rtheta_m, lbc_tend_rho_zz, &
330+
!$acc lbc_tend_rho, lbc_tend_scalars)
331+
end if
332+
333+
!$acc exit data copyout(ru, rho_edge, rtheta_m, rho_zz)
334+
!$acc exit data delete(u, w, theta, rho, scalars)
335+
MPAS_ACC_TIMER_STOP('mpas_atm_update_bdy_tend [ACC_data_xfer]')
336+
252337
LBC_intv_end = currTime
253338

254339
end subroutine mpas_atm_update_bdy_tend

0 commit comments

Comments
 (0)