Skip to content

Commit

Permalink
update to pythoncall , judi v4
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Nov 1, 2024
1 parent f511360 commit 41b22eb
Show file tree
Hide file tree
Showing 7 changed files with 31 additions and 25 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
name = "ImageGather"
uuid = "355d8124-6b2e-49b5-aab5-cdbc0a5fccbe"
authors = ["mloubout <[email protected]>"]
version = "0.3.0"
version = "0.4.0"

[deps]
JUDI = "f3b833dc-6b2e-5b9c-b940-873ed6319979"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[compat]
JUDI = "3.4.6"
JUDI = "4"
julia = "1"

[extras]
Expand Down
Binary file modified docs/img/cig_cdp.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/img/cig_line.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
11 changes: 7 additions & 4 deletions src/ImageGather.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
module ImageGather

using JUDI
using JUDI.DSP, JUDI.PyCall
using JUDI.DSP, JUDI.PythonCall

import Base: getindex, *
import JUDI: judiAbstractJacobian, judiMultiSourceVector, judiComposedPropagator, judiJacobian, make_input, propagate
import JUDI.LinearAlgebra: adjoint

const impl = PyNULL()
const impl = PythonCall.pynew()

IGPath = pathof(ImageGather)

function __init__()
pushfirst!(PyVector(pyimport("sys")."path"),dirname(pathof(ImageGather)))
copy!(impl, pyimport("implementation"))
pyimport("sys").path.append(dirname(IGPath))
PythonCall.pycopy!(impl, pyimport("implementation"))
set_devito_config("autopadding", false)
end
# Utility functions
include("utils.jl")
Expand Down
11 changes: 8 additions & 3 deletions src/implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@
from kernels import wave_kernel
from utils import opt_op

try:
from devitopro import * # noqa
except ImportError:
pass


def double_rtm(model, wavelet, src_coords, res, res_o, rec_coords, ic="as"):
"""
Expand All @@ -24,12 +29,12 @@ def double_rtm(model, wavelet, src_coords, res, res_o, rec_coords, ic="as"):


def cig_grad(model, src_coords, wavelet, rec_coords, res, offsets, ic="as",
space_order=8, dims=None, illum=False):
space_order=8, dims=None, illum=False, t_sub=1):
"""
"""
so = max(space_order, np.max(np.abs(offsets)) // model.grid.spacing[0])
u = forward(model, src_coords, None, wavelet,
space_order=(space_order, so, so), save=True)[1]
_, u, _, _ = forward(model, src_coords, None, wavelet, t_sub=t_sub,
illum=illum, space_order=(space_order, so, so), save=True)
# Setting adjoint wavefield
v = wavefield(model, space_order, fw=False, tfull=True)

Expand Down
19 changes: 9 additions & 10 deletions src/subsurface_gather.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ function propagate(J::judiExtendedJacobian{T, :born, O}, q::AbstractArray{T}, il
modelPy = devito_model(J.model, J.options)
nh = [length(J.offsets) for _=1:length(J.dims)]
dmd = reshape(dm, nh..., J.model.n...)
dtComp = convert(Float32, modelPy."critical_dt")
dtComp = pyconvert(Float32, modelPy.critical_dt)

# Extrapolate input data to computational grid
qIn = time_resample(srcData, srcGeometry, dtComp)
Expand All @@ -92,9 +92,9 @@ function propagate(J::judiExtendedJacobian{T, :born, O}, q::AbstractArray{T}, il
rec_coords = setup_grid(recGeometry, J.model.n) # shifts rec coordinates by origin

# Devito interface
dD = JUDI.wrapcall_data(impl."cig_lin", modelPy, src_coords, qIn, rec_coords,
dmd, J.offsets, ic=J.options.IC, space_order=J.options.space_order, dims=J.dims)
dD = time_resample(dD, dtComp, recGeometry)
dD = impl.cig_lin(modelPy, src_coords, qIn, rec_coords, dmd, J.offsets,
ic=J.options.IC, space_order=J.options.space_order, dims=J.dims)
dD = time_resample(PyArray(dD), dtComp, recGeometry)
# Output shot record as judiVector
return judiVector{Float32, Matrix{Float32}}(1, recGeometry, [dD])
end
Expand All @@ -107,7 +107,7 @@ function propagate(J::judiExtendedJacobian{T, :adjoint_born, O}, q::AbstractArra

# Set up Python model
modelPy = devito_model(J.model, J.options)
dtComp = convert(Float32, modelPy."critical_dt")
dtComp = pyconvert(Float32, modelPy.critical_dt)

# Extrapolate input data to computational grid
qIn = time_resample(srcData, srcGeometry, dtComp)
Expand All @@ -118,11 +118,10 @@ function propagate(J::judiExtendedJacobian{T, :adjoint_born, O}, q::AbstractArra
rec_coords = setup_grid(recGeometry, J.model.n) # shifts rec coordinates by origin

# Devito
g = JUDI.pylock() do
pycall(impl."cig_grad", PyArray, modelPy, src_coords, qIn, rec_coords, dObserved, J.offsets,
illum=false, ic=J.options.IC, space_order=J.options.space_order, dims=J.dims)
end
g = remove_padding_cig(g, modelPy.padsizes; true_adjoint=J.options.sum_padding)
g = impl.cig_grad(modelPy, src_coords, qIn, rec_coords, dObserved, J.offsets,
illum=false, ic=J.options.IC, space_order=J.options.space_order, dims=J.dims,
t_sub=J.options.subsampling_factor)
g = remove_padding_cig(PyArray(g), pyconvert(Tuple, modelPy.padsizes); true_adjoint=J.options.sum_padding)
return g
end

Expand Down
11 changes: 5 additions & 6 deletions src/surface_gather.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import JUDI: AbstractModel, rlock_pycall, devito
import JUDI: AbstractModel, devito, wrapcall_data

export surface_gather, double_rtm_cig

Expand Down Expand Up @@ -60,7 +60,7 @@ function double_rtm_cig(model_full, q::judiVector, data::judiVector, offs, optio

# Set up Python model
modelPy = devito_model(model, options)
dtComp = convert(Float32, modelPy."critical_dt")
dtComp = pyconvert(Float32, modelPy.critical_dt)

# Extrapolate input data to computational grid
qIn = time_resample(make_input(q), q.geometry, dtComp)
Expand All @@ -71,7 +71,7 @@ function double_rtm_cig(model_full, q::judiVector, data::judiVector, offs, optio
rec_coords = setup_grid(data.geometry, model.n) # shifts rec coordinates by origin

# Src-rec offsets
scale = 1f1
scale = 1f2
off_r = log.(abs.(data.geometry.xloc[1] .- q.geometry.xloc[1]) .+ scale)
inv_off(x) = exp.(x) .- scale

Expand All @@ -82,9 +82,8 @@ function double_rtm_cig(model_full, q::judiVector, data::judiVector, offs, optio
res_o = res .* off_r'
# Double rtm

rtm, rtmo, illum = rlock_pycall(impl."double_rtm", Tuple{PyArray, PyArray, PyArray},
modelPy, qIn, src_coords, res, res_o, rec_coords,
ic=options.IC)
rtm, rtmo, illum = wrapcall_data(impl."double_rtm", modelPy, qIn, src_coords, res, res_o, rec_coords,
ic=options.IC)

rtm = remove_padding(rtm, modelPy.padsizes)
rtmo = remove_padding(rtmo, modelPy.padsizes)
Expand Down

0 comments on commit 41b22eb

Please sign in to comment.