Skip to content

[Draft] Consolidating OpenACC device-host memory transfers #1307

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 88 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
88 commits
Select commit Hold shift + click to select a range
5c2fa59
Add parallel and loop directives to atm_bdy_adjust_scalars_work
gdicker1 Jan 16, 2025
bd8f074
Copy invariant fields used in atm_bdy_adjust_scalars_work
gdicker1 Jan 16, 2025
143585d
Add acc data movement to atm_bdy_adjust_scalars_work
gdicker1 Jan 16, 2025
80fa8f7
initial OpenACC port of atm_rk_dynamics_substep_finish
abishekg7 Aug 1, 2024
8c48f75
Moving parallel directives to inside if-else conditions
abishekg7 Jan 17, 2025
ed2db57
Initializing the garbage cells for theta_m_1 in atm_rk_dynamics_subst…
abishekg7 Jan 17, 2025
d45b96a
Clean up to address review comments
abishekg7 Jan 22, 2025
852fa34
Adding default(present) clauses to the parallel regions
abishekg7 Jan 31, 2025
8ae0a8f
First working version
abishekg7 Jan 17, 2025
cef6ed8
Removing scalars from private clause and reverting exponent operation
abishekg7 Jan 22, 2025
5c7d198
adding comment
abishekg7 Jan 31, 2025
a6f0b7f
fixup! Add parallel and loop directives to atm_bdy_adjust_scalars_work
gdicker1 Feb 3, 2025
8d55069
fixup! Add parallel and loop directives to atm_bdy_adjust_scalars_work
gdicker1 Feb 3, 2025
5c41e0e
Prepare for OpenACC porting in mpas_reconstruct_2d
gdicker1 Feb 7, 2025
e8cf29b
Initial OpenACC directives added to mpas_vector_reconstruct_2d
gdicker1 Feb 14, 2025
b7a3cbc
Add OpenACC transfer of invariant fields within mpas_reconstruct_2d
gdicker1 Feb 18, 2025
96fe04e
Add OpenACC data movement to mpas_reconstruct_2d
gdicker1 Feb 15, 2025
1b2dae1
WIP: porting compute_dyn_tend_work
mgduda Feb 26, 2025
f5ada69
WIP: checkpoint compute_dyn_tend work, 2 Jan 2025, 6:05pm
mgduda Jan 3, 2025
5ed529e
WIP: checkpoint compute_dyn_tend work
mgduda Jan 3, 2025
b4fa456
WIP: checkpoint changes in time integration
mgduda Jan 6, 2025
8335cfd
WIP: finished porting w tendencies in mpas_atm_time_integration
mgduda Jan 7, 2025
7be1861
WIP: completed porting of theta tendency terms
mgduda Jan 7, 2025
6e25210
Fuse data movement for tend_u_euler
mgduda Jan 17, 2025
0c80909
Fuse data movement for kdiff
mgduda Jan 17, 2025
a92e78f
Fuse data movement for u, v, h_divergence, ru
mgduda Jan 17, 2025
44e84f0
Fuse data movement for tend_rho and tend_rho_physics
mgduda Jan 17, 2025
ad7c88c
Fuse data movement for rw
mgduda Jan 17, 2025
ca61366
FIXUP for 167dd4de75 (remove redundant rdzu,rdzw)
mgduda Jan 17, 2025
32f71f5
fixup: remove redundant (after rebase) code for moving 'zxu'
mgduda Feb 26, 2025
54fecf5
data movement: rdzw
mgduda Feb 27, 2025
571c52d
data merge: rb, qtot, rr_save
mgduda Feb 27, 2025
5dc6856
data movement: dpdz
mgduda Feb 27, 2025
3ee0699
data movement: tend_u
mgduda Feb 27, 2025
7e76e54
data movement: cqu, pp, u, w, pv_edge, rho_edge, ke
mgduda Feb 27, 2025
8e2d056
data movement: divergence, vorticity
mgduda Mar 4, 2025
b25cf7e
data movement: delsq_u, delsq_vorticity, delsq_divergence
mgduda Mar 4, 2025
acd3e52
data movement: u_init, v_init, rayleigh_damp_coef, tend_ru_physics
mgduda Mar 4, 2025
cf7c2ff
data movement: delsq_w, tend_w_euler, tend_w_euler, delsq_w, tend_w, …
mgduda Mar 4, 2025
4af60d9
data movement: tend_w_euler
mgduda Mar 8, 2025
c8908f8
data movement: tend_w
mgduda Mar 8, 2025
d422801
data movement: rho_zz
mgduda Mar 8, 2025
d819324
data movement: tend_theta
mgduda Mar 8, 2025
ad181fa
data movement: theta_m, ru_save, theta_m_save
mgduda Mar 8, 2025
93df11a
data movement: tend_theta_euler
mgduda Mar 8, 2025
e623406
Dereference single-value integer pointers to integers in summarize_ti…
gdicker1 Mar 12, 2025
cb2ba6a
Move mpas_log_write outside Check for NaNs loops in summarize_timestep
gdicker1 Mar 20, 2025
7c3b1fe
Use 2 loop search for min/max with location in summarize_timestep
gdicker1 Mar 17, 2025
6802185
Add OpenACC for config_print_global_minmax_sca in summarize_timestep
gdicker1 Mar 20, 2025
ee1d77b
broken wip: trying to develop checkpointing infrastructure
mgduda Mar 21, 2025
d7d40bf
WIP: just copy most fields to work around answer differences
mgduda Mar 22, 2025
ed9b10f
WIP: add more data directives within (rk_step == 1) test
mgduda Mar 22, 2025
46fd13f
Add OpenACC for config_print_global_minmax_vel in summarize_timestep
gdicker1 Mar 25, 2025
a21db2c
Add OpenACC for config_print_detailed_minmax_vel in summarize_timestep
gdicker1 Mar 25, 2025
3f1965e
Clean up changes for initial profiling
mgduda Mar 26, 2025
8414225
Remove code related to checkpointing
mgduda Mar 26, 2025
c912af2
Merge remote-tracking branch 'dylan/atmosphere/acc_atm_bdy_adjust_sca…
abishekg7 Apr 9, 2025
8e876cc
Merge remote-tracking branch 'mine/atmosphere/port_atm_rk_dynamics_su…
abishekg7 Apr 9, 2025
e6e74a6
Merge remote-tracking branch 'dylan/atmosphere/acc_atm_bdy_adjust_dyn…
abishekg7 Apr 9, 2025
cacb8ab
Merge remote-tracking branch 'mine/atmosphere/port_atm_bdy_adjust_dyn…
abishekg7 Apr 9, 2025
8fbb9f5
Merge remote-tracking branch 'mine/atmosphere/port_atm_bdy_reset_spec…
abishekg7 Apr 9, 2025
621879d
Merge remote-tracking branch 'jim/atmosphere/mpas_atm_get_bdy_tend' i…
abishekg7 Apr 9, 2025
312afd0
Merge remote-tracking branch 'dylan/framework/acc_mpas_reconstruct_2d…
abishekg7 Apr 9, 2025
29c72b6
Merge remote-tracking branch 'dylan/atmosphere/acc_summarize_timestep…
abishekg7 Apr 9, 2025
b497a71
first attempt
abishekg7 Apr 17, 2025
b9de954
first working
abishekg7 Apr 18, 2025
08b212f
mesh fields need to be uploaded to device after atm_mpas_init_block
abishekg7 Apr 18, 2025
2654783
second working
abishekg7 Apr 18, 2025
4416bee
still working jw
abishekg7 Apr 18, 2025
01001e4
some changes
abishekg7 Apr 23, 2025
e8d480b
Lam no physics is correct on 1mpi rank
abishekg7 Apr 23, 2025
3f31419
data movement around halo exchanges
abishekg7 Apr 24, 2025
63a88b3
restart files producing correct results for no_physics case with 1 rank
abishekg7 Apr 24, 2025
b3bc23d
Initial OpenACC port of mpas_atm_update_bdy_tend
abishekg7 Apr 29, 2025
8bc1e96
Adding timers mpas_atm_update_bdy_tend [ACC_data_xfer]
abishekg7 Apr 29, 2025
f08acca
Merge remote-tracking branch 'origin/atmosphere/port_mpas_atm_update_…
abishekg7 Apr 29, 2025
2720fcc
consolidating data movements used in mpas_atm_boundaries
abishekg7 May 1, 2025
508f551
working lam with physics, need to check radiation
abishekg7 May 6, 2025
e7a9142
some more fixes + cleanup
abishekg7 May 6, 2025
936765b
more cleanup
abishekg7 May 7, 2025
e99150c
functions to subroutines
abishekg7 May 7, 2025
aa7aa85
data creates for some module variables
abishekg7 May 7, 2025
2bdcee7
more data movements around compute solve diagnostics
abishekg7 May 7, 2025
6c3963c
moving copies of tendency terms consumed in atm_compute_dyn_tend
abishekg7 May 8, 2025
ca1d38e
using delete for lbc state and tend in post
abishekg7 May 8, 2025
d4de3f2
using delete instead of copyout - 1
abishekg7 May 8, 2025
a8561cc
using delete instead of copyout - 2
abishekg7 May 8, 2025
81815ba
add new ACC_xfer timers and remove empty ones
abishekg7 May 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 35 additions & 45 deletions src/core_atmosphere/dynamics/mpas_atm_boundaries.F
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ end subroutine mpas_atm_update_bdy_tend
!> tend_scalars(1,:,:) = mpas_atm_get_bdy_tend(clock, domain % blocklist, nVertLevels, nCells, 'qv', 0.0_RKIND)
!
!-----------------------------------------------------------------------
function mpas_atm_get_bdy_tend(clock, block, vertDim, horizDim, field, delta_t) result(return_tend)
subroutine mpas_atm_get_bdy_tend(clock, block, vertDim, horizDim, field, delta_t, return_tend)

implicit none

Expand All @@ -296,31 +296,49 @@ function mpas_atm_get_bdy_tend(clock, block, vertDim, horizDim, field, delta_t)
integer, intent(in) :: vertDim, horizDim
character(len=*), intent(in) :: field
real (kind=RKIND), intent(in) :: delta_t

real (kind=RKIND), dimension(vertDim,horizDim+1) :: return_tend
real (kind=RKIND), dimension(vertDim,horizDim+1), intent(out) :: return_tend

type (mpas_pool_type), pointer :: lbc
integer, pointer :: idx
integer, pointer :: idx_ptr
real (kind=RKIND), dimension(:,:), pointer :: tend
real (kind=RKIND), dimension(:,:,:), pointer :: tend_scalars
integer :: ierr
integer :: idx, i, j


call mpas_pool_get_subpool(block % structs, 'lbc', lbc)

nullify(tend)
call mpas_pool_get_array(lbc, 'lbc_'//trim(field), tend, 1)

if (associated(tend)) then
return_tend(:,:) = tend(:,:)
else
call mpas_pool_get_array(lbc, 'lbc_scalars', tend_scalars, 1)
call mpas_pool_get_dimension(lbc, 'index_'//trim(field), idx)

return_tend(:,:) = tend_scalars(idx,:,:)
! Ensure the integer pointed to by idx_ptr is copied to the gpu device
call mpas_pool_get_dimension(lbc, 'index_'//trim(field), idx_ptr)
idx = idx_ptr
end if

!$acc parallel default(present)
if (associated(tend)) then
!$acc loop gang vector collapse(2)
do j=1,horizDim+1
do i=1,vertDim
return_tend(i,j) = tend(i,j)
end do
end do
else
!$acc loop gang vector collapse(2)
do j=1,horizDim+1
do i=1,vertDim
return_tend(i,j) = tend_scalars(idx,i,j)
end do
end do
end if
!$acc end parallel


end function mpas_atm_get_bdy_tend
end subroutine mpas_atm_get_bdy_tend


!***********************************************************************
Expand Down Expand Up @@ -356,10 +374,11 @@ end function mpas_atm_get_bdy_tend
!> scalars(1,:,:) = mpas_atm_get_bdy_state(clock, domain % blocklist, nVertLevels, nCells, 'qv', 0.0_RKIND)
!
!-----------------------------------------------------------------------
function mpas_atm_get_bdy_state_2d(clock, block, vertDim, horizDim, field, delta_t) result(return_state)
subroutine mpas_atm_get_bdy_state_2d(clock, block, vertDim, horizDim, field, delta_t, return_state)

use mpas_pool_routines, only : mpas_pool_get_error_level, mpas_pool_set_error_level
use mpas_derived_types, only : MPAS_POOL_SILENT
use mpas_log, only : mpas_log_write

implicit none

Expand All @@ -368,8 +387,7 @@ function mpas_atm_get_bdy_state_2d(clock, block, vertDim, horizDim, field, delta
integer, intent(in) :: vertDim, horizDim
character(len=*), intent(in) :: field
real (kind=RKIND), intent(in) :: delta_t

real (kind=RKIND), dimension(vertDim,horizDim+1) :: return_state
real (kind=RKIND), dimension(vertDim,horizDim+1), intent(out) :: return_state

type (mpas_pool_type), pointer :: lbc
integer, pointer :: idx_ptr
Expand Down Expand Up @@ -420,10 +438,6 @@ function mpas_atm_get_bdy_state_2d(clock, block, vertDim, horizDim, field, delta
! query the field as a scalar constituent
!
if (associated(tend) .and. associated(state)) then
MPAS_ACC_TIMER_START('mpas_atm_get_bdy_state_2d [ACC_data_xfer]')
!$acc enter data create(return_state) &
!$acc copyin(tend, state)
MPAS_ACC_TIMER_STOP('mpas_atm_get_bdy_state_2d [ACC_data_xfer]')

!$acc parallel default(present)
!$acc loop gang vector collapse(2)
Expand All @@ -434,22 +448,13 @@ function mpas_atm_get_bdy_state_2d(clock, block, vertDim, horizDim, field, delta
end do
!$acc end parallel

MPAS_ACC_TIMER_START('mpas_atm_get_bdy_state_2d [ACC_data_xfer]')
!$acc exit data copyout(return_state) &
!$acc delete(tend, state)
MPAS_ACC_TIMER_STOP('mpas_atm_get_bdy_state_2d [ACC_data_xfer]')
else
call mpas_pool_get_array(lbc, 'lbc_scalars', tend_scalars, 1)
call mpas_pool_get_array(lbc, 'lbc_scalars', state_scalars, 2)
call mpas_pool_get_dimension(lbc, 'index_'//trim(field), idx_ptr)

idx=idx_ptr ! Avoid non-array pointer for OpenACC

MPAS_ACC_TIMER_START('mpas_atm_get_bdy_state_2d [ACC_data_xfer]')
!$acc enter data create(return_state) &
!$acc copyin(tend_scalars, state_scalars)
MPAS_ACC_TIMER_STOP('mpas_atm_get_bdy_state_2d [ACC_data_xfer]')

!$acc parallel default(present)
!$acc loop gang vector collapse(2)
do i=1, horizDim+1
Expand All @@ -459,13 +464,9 @@ function mpas_atm_get_bdy_state_2d(clock, block, vertDim, horizDim, field, delta
end do
!$acc end parallel

MPAS_ACC_TIMER_START('mpas_atm_get_bdy_state_2d [ACC_data_xfer]')
!$acc exit data copyout(return_state) &
!$acc delete(tend_scalars, state_scalars)
MPAS_ACC_TIMER_STOP('mpas_atm_get_bdy_state_2d [ACC_data_xfer]')
end if

end function mpas_atm_get_bdy_state_2d
end subroutine mpas_atm_get_bdy_state_2d


!***********************************************************************
Expand Down Expand Up @@ -498,7 +499,7 @@ end function mpas_atm_get_bdy_state_2d
!> num_scalars, nVertLevels, nCells, 'scalars', 0.0_RKIND)
!
!-----------------------------------------------------------------------
function mpas_atm_get_bdy_state_3d(clock, block, innerDim, vertDim, horizDim, field, delta_t) result(return_state)
subroutine mpas_atm_get_bdy_state_3d(clock, block, innerDim, vertDim, horizDim, field, delta_t, return_state)

use mpas_pool_routines, only : mpas_pool_get_error_level, mpas_pool_set_error_level
use mpas_derived_types, only : MPAS_POOL_SILENT
Expand All @@ -510,8 +511,7 @@ function mpas_atm_get_bdy_state_3d(clock, block, innerDim, vertDim, horizDim, fi
integer, intent(in) :: innerDim, vertDim, horizDim
character(len=*), intent(in) :: field
real (kind=RKIND), intent(in) :: delta_t

real (kind=RKIND), dimension(innerDim,vertDim,horizDim+1) :: return_state
real (kind=RKIND), dimension(innerDim,vertDim,horizDim+1), intent(out) :: return_state

type (mpas_pool_type), pointer :: lbc
real (kind=RKIND), dimension(:,:,:), pointer :: tend
Expand Down Expand Up @@ -543,11 +543,6 @@ function mpas_atm_get_bdy_state_3d(clock, block, innerDim, vertDim, horizDim, fi
call mpas_pool_get_array(lbc, 'lbc_'//trim(field), tend, 1)
call mpas_pool_get_array(lbc, 'lbc_'//trim(field), state, 2)

MPAS_ACC_TIMER_START('mpas_atm_get_bdy_state_3d [ACC_data_xfer]')
!$acc enter data create(return_state) &
!$acc copyin(tend, state)
MPAS_ACC_TIMER_STOP('mpas_atm_get_bdy_state_3d [ACC_data_xfer]')

!$acc parallel default(present)
!$acc loop gang vector collapse(3)
do i=1, horizDim+1
Expand All @@ -559,12 +554,7 @@ function mpas_atm_get_bdy_state_3d(clock, block, innerDim, vertDim, horizDim, fi
end do
!$acc end parallel

MPAS_ACC_TIMER_START('mpas_atm_get_bdy_state_3d [ACC_data_xfer]')
!$acc exit data copyout(return_state) &
!$acc delete(tend, state)
MPAS_ACC_TIMER_STOP('mpas_atm_get_bdy_state_3d [ACC_data_xfer]')

end function mpas_atm_get_bdy_state_3d
end subroutine mpas_atm_get_bdy_state_3d


!***********************************************************************
Expand Down
2 changes: 2 additions & 0 deletions src/core_atmosphere/dynamics/mpas_atm_iau.F
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ subroutine atm_add_tend_anal_incr (configs, structs, itimestep, dt, tend_ru, ten
call mpas_pool_get_array(state, 'scalars', scalars, 1)
call mpas_pool_get_array(state, 'rho_zz', rho_zz, 2)
call mpas_pool_get_array(diag , 'rho_edge', rho_edge)
!$acc update self(theta_m, scalars, rho_zz, rho_edge)

call mpas_pool_get_dimension(state, 'moist_start', moist_start)
call mpas_pool_get_dimension(state, 'moist_end', moist_end)
Expand All @@ -149,6 +150,7 @@ subroutine atm_add_tend_anal_incr (configs, structs, itimestep, dt, tend_ru, ten
! call mpas_pool_get_array(tend, 'rho_zz', tend_rho)
! call mpas_pool_get_array(tend, 'theta_m', tend_theta)
call mpas_pool_get_array(tend, 'scalars_tend', tend_scalars)
!$acc update self(tend_scalars)

call mpas_pool_get_array(tend_iau, 'theta', theta_amb)
call mpas_pool_get_array(tend_iau, 'rho', rho_amb)
Expand Down
Loading