Skip to content

Commit

Permalink
Define == and isapprox for transforms (#14)
Browse files Browse the repository at this point in the history
* Define '==' and 'isapprox' for transforms

* Update CI.yml

---------

Co-authored-by: Júlio Hoffimann <[email protected]>
  • Loading branch information
eliascarv and juliohm authored Aug 15, 2024
1 parent c7ddf28 commit 6a923af
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
fail-fast: false
matrix:
version:
- '1.6'
- '1.9'
- '1'
os:
- ubuntu-latest
Expand Down
13 changes: 13 additions & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,19 @@ reapply(transform::Transform, object, cache) = apply(transform, object) |> first

(transform::Transform)(object) = apply(transform, object) |> first

Base.:(==)(t₁::Transform, t₂::Transform) = nameof(typeof(t₁)) == nameof(typeof(t₂)) && parameters(t₁) == parameters(t₂)

Base.isapprox(t₁::Transform, t₂::Transform; kwargs...) =
nameof(typeof(t₁)) == nameof(typeof(t₂)) && _isapprox(parameters(t₁), parameters(t₂); kwargs...)

_isapprox(tup₁::NamedTuple, tup₂::NamedTuple; kwargs...) =
propertynames(tup₁) == propertynames(tup₂) && _isapprox(Tuple(tup₁), Tuple(tup₂); kwargs...)

_isapprox(tup₁::Tuple, tup₂::Tuple; kwargs...) =
length(tup₁) == length(tup₂) && all(_isapprox(x₁, x₂; kwargs...) for (x₁, x₂) in zip(tup₁, tup₂))

_isapprox(x₁, x₂; kwargs...) = isapprox(x₁, x₂; kwargs...)

# -----------
# IO METHODS
# -----------
Expand Down
14 changes: 14 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,20 @@ using Test
@test T[begin] == TestTransform()
@test T[end] == Identity()

# equality and approximation
struct TestParamTransform <: TransformsBase.Transform
param::Float64
end
TransformsBase.apply(t::TestParamTransform, x) = x * t.param, nothing
TransformsBase.parameters(t::TestParamTransform) = (; param=t.param)
T1 = TestParamTransform(1.0)
T2 = TestParamTransform(1.0f0)
T3 = TestTransform()
@test T1 == T2
@test T1 T3
@test T1 T2
@test T1 T3

T1 = Identity()
T2 = TestTransform()
T3 = TestTransform() TestTransform()
Expand Down

0 comments on commit 6a923af

Please sign in to comment.