Skip to content

Commit bb23a92

Browse files
committed
Fix Jacobian-free Newton Krylov method on GPUs
1 parent c19e001 commit bb23a92

File tree

2 files changed

+21
-16
lines changed

2 files changed

+21
-16
lines changed

ext/KrylovExt.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
module KrylovExt
22

3-
import ClimaComms
4-
import ClimaCore: Fields
3+
import ClimaCore: DataLayouts, Fields
54
import Krylov
65

7-
Krylov.ktypeof(x::Fields.FieldVector) =
8-
ClimaComms.array_type(x){eltype(parent(x)), 1}
6+
function Krylov.ktypeof(x::Fields.FieldVector)
7+
representative_data = Fields.field_values(Fields.representative_field(x))
8+
array_type_with_N_var = DataLayouts.parent_array_type(representative_data)
9+
return typeintersect(array_type_with_N_var, AbstractVector) # Set N = 1.
10+
end
911

1012
end

src/Fields/fieldvector.jl

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -456,20 +456,23 @@ end
456456

457457
import ClimaComms
458458

459-
ClimaComms.array_type(x::FieldVector) =
460-
promote_type(unrolled_map(ClimaComms.array_type, _values(x))...)
461-
462-
ClimaComms.device(x::FieldVector) = ClimaComms.device(ClimaComms.context(x))
463-
function ClimaComms.context(x::FieldVector)
464-
isempty(_values(x)) && error("Empty FieldVector has no device or context")
465-
# We don't have promotion for devices or contexts, so we use the first value
466-
# that isn't a PointField (a PointField's data can be stored on a different
467-
# device from other Fields to avoid scalar indexing on GPUs). If there is no
468-
# such value, fall back to using the first PointField.
469-
index = unrolled_findfirst(Base.Fix1(!isa, PointField), _values(x))
470-
return ClimaComms.context(_values(x)[isnothing(index) ? 1 : index])
459+
# To infer the ClimaComms device and its properties, use the first Field in a
460+
# FieldVector that isn't a PointField, since a PointField's data can be stored
461+
# on a different device from other Fields to avoid scalar indexing on GPUs. If
462+
# the FieldVector only contains PointFields, fall back to using the first one.
463+
function representative_field(x)
464+
all_fields = _values(x)
465+
isempty(all_fields) && error("Empty FieldVector has no ClimaComms device")
466+
field_index = unrolled_findfirst(Base.Fix2(!isa, PointField), all_fields)
467+
return all_fields[isnothing(field_index) ? 1 : field_index]
471468
end
472469

470+
ClimaComms.array_type(x::FieldVector) =
471+
ClimaComms.array_type(representative_field(x))
472+
ClimaComms.device(x::FieldVector) = ClimaComms.device(representative_field(x))
473+
ClimaComms.context(x::FieldVector) = ClimaComms.context(representative_field(x))
474+
475+
473476
function __rprint_diff(
474477
io::IO,
475478
x::T,

0 commit comments

Comments
 (0)