Skip to content

Commit

Permalink
Add gomega matcher to compare and diff protoreflect.Value objects
Browse files Browse the repository at this point in the history
  • Loading branch information
kralicky committed Nov 21, 2023
1 parent 5b58282 commit 4f2374e
Showing 1 changed file with 80 additions and 0 deletions.
80 changes: 80 additions & 0 deletions pkg/test/testutil/equal.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package testutil

import (
"fmt"
"slices"
"strings"

"github.com/google/go-cmp/cmp"
Expand All @@ -11,6 +12,7 @@ import (
"google.golang.org/grpc/status"
"google.golang.org/protobuf/encoding/prototext"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/testing/protocmp"
)

Expand Down Expand Up @@ -65,6 +67,84 @@ func (matcher *ProtoMatcher) Matches(x interface{}) bool {
return success
}

func ProtoValueEqual(expected protoreflect.Value) *ProtoValueMatcher {
return &ProtoValueMatcher{
Expected: expected,
}
}

type ProtoValueMatcher struct {
Expected protoreflect.Value
}

func (matcher *ProtoValueMatcher) Match(actual any) (success bool, err error) {
if actual == nil && !matcher.Expected.IsValid() {
return false, fmt.Errorf("Refusing to compare <nil> to <nil>.\nBe explicit and use BeNil() instead. This is to avoid mistakes where both sides of an assertion are erroneously uninitialized")
}
if actual, ok := actual.(protoreflect.Value); !ok {
return false, fmt.Errorf("ProtoValueMatcher expects a protoreflect.Value. Got:\n%s", format.Object(actual, 1))
} else {
return matcher.Expected.Equal(actual), nil
}
}

func (matcher *ProtoValueMatcher) FailureMessage(actual any) (message string) {
return fmt.Sprintf("Expected\n%s\n%s\n%s",
format.IndentString(humanizeProtoreflectValue(actual.(protoreflect.Value)), 1),
"to equal",
format.IndentString(humanizeProtoreflectValue(matcher.Expected), 1),
)
}

func (matcher *ProtoValueMatcher) NegatedFailureMessage(actual any) (message string) {
return fmt.Sprintf("Expected\n%s\n%s\n%s",
format.IndentString(humanizeProtoreflectValue(actual.(protoreflect.Value)), 1),
"not to equal",
format.IndentString(humanizeProtoreflectValue(matcher.Expected), 1),
)
}

func humanizeProtoreflectValue(v protoreflect.Value) string {
switch v := v.Interface().(type) {
case nil:
return "nil"
case bool, int32, int64, uint32, uint64, float32, float64, string:
return fmt.Sprintf("%v", v)
case []byte:
return string(v)
case protoreflect.EnumNumber:
return fmt.Sprintf("%v", v)
case protoreflect.Message:
return prototext.Format(v.Interface())
case protoreflect.List:
builder := strings.Builder{}
builder.WriteString("[\n")
for i := 0; i < v.Len(); i++ {
fmt.Fprintf(&builder, " %s\n", humanizeProtoreflectValue(v.Get(i)))
}
builder.WriteString("]")
return builder.String()
case protoreflect.Map:
builder := strings.Builder{}
builder.WriteString("{\n")
kvs := []string{}
v.Range(func(mk protoreflect.MapKey, v protoreflect.Value) bool {
kvs = append(kvs, fmt.Sprintf("%s: %s", mk.String(), humanizeProtoreflectValue(v)))
return true
})
slices.Sort(kvs)
for _, kv := range kvs {
builder.WriteString(kv)
}
builder.WriteString("}")
return builder.String()
case protoreflect.ProtoMessage:
panic(fmt.Sprintf("invalid proto.Message(%T) type, expected a protoreflect.Message type", v))
default:
panic(fmt.Sprintf("invalid type: %T", v))
}
}

type StatusCodeMatcher struct {
Expected any
matchMsg types.GomegaMatcher
Expand Down

0 comments on commit 4f2374e

Please sign in to comment.