From e96084501ac2eebd6503d8c870e1b15a29169075 Mon Sep 17 00:00:00 2001 From: Kevin Gillette Date: Tue, 29 Jul 2025 11:14:35 -0600 Subject: [PATCH 1/3] json: consistent use of encodeNull (should inline well) --- json/encode.go | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/json/encode.go b/json/encode.go index 2a6da07..8fe2867 100644 --- a/json/encode.go +++ b/json/encode.go @@ -241,7 +241,7 @@ func (e encoder) encodeToString(b []byte, p unsafe.Pointer, encode encodeFunc) ( func (e encoder) encodeBytes(b []byte, p unsafe.Pointer) ([]byte, error) { v := *(*[]byte)(p) if v == nil { - return append(b, "null"...), nil + return e.encodeNull(b, nil) } n := base64.StdEncoding.EncodedLen(len(v)) + 2 @@ -299,7 +299,7 @@ func (e encoder) encodeSlice(b []byte, p unsafe.Pointer, size uintptr, t reflect s := (*slice)(p) if s.data == nil && s.len == 0 && s.cap == 0 { - return append(b, "null"...), nil + return e.encodeNull(b, nil) } return e.encodeArray(b, s.data, s.len, size, t, encode) @@ -308,7 +308,7 @@ func (e encoder) encodeSlice(b []byte, p unsafe.Pointer, size uintptr, t reflect func (e encoder) encodeMap(b []byte, p unsafe.Pointer, t reflect.Type, encodeKey, encodeValue encodeFunc, sortKeys sortFunc) ([]byte, error) { m := reflect.NewAt(t, p).Elem() if m.IsNil() { - return append(b, "null"...), nil + return e.encodeNull(b, nil) } keys := m.MapKeys() @@ -363,7 +363,7 @@ var mapslicePool = sync.Pool{ func (e encoder) encodeMapStringInterface(b []byte, p unsafe.Pointer) ([]byte, error) { m := *(*map[string]any)(p) if m == nil { - return append(b, "null"...), nil + return e.encodeNull(b, nil) } if (e.flags & SortMapKeys) == 0 { @@ -441,7 +441,7 @@ func (e encoder) encodeMapStringInterface(b []byte, p unsafe.Pointer) ([]byte, e func (e encoder) encodeMapStringRawMessage(b []byte, p unsafe.Pointer) ([]byte, error) { m := *(*map[string]RawMessage)(p) if m == nil { - return append(b, "null"...), nil + return e.encodeNull(b, nil) } if (e.flags & SortMapKeys) == 0 { @@ -520,7 +520,7 @@ func (e encoder) encodeMapStringRawMessage(b []byte, p unsafe.Pointer) ([]byte, func (e encoder) encodeMapStringString(b []byte, p unsafe.Pointer) ([]byte, error) { m := *(*map[string]string)(p) if m == nil { - return append(b, "null"...), nil + return e.encodeNull(b, nil) } if (e.flags & SortMapKeys) == 0 { @@ -586,7 +586,7 @@ func (e encoder) encodeMapStringString(b []byte, p unsafe.Pointer) ([]byte, erro func (e encoder) encodeMapStringStringSlice(b []byte, p unsafe.Pointer) ([]byte, error) { m := *(*map[string][]string)(p) if m == nil { - return append(b, "null"...), nil + return e.encodeNull(b, nil) } stringSize := unsafe.Sizeof("") @@ -667,7 +667,7 @@ func (e encoder) encodeMapStringStringSlice(b []byte, p unsafe.Pointer) ([]byte, func (e encoder) encodeMapStringBool(b []byte, p unsafe.Pointer) ([]byte, error) { m := *(*map[string]bool)(p) if m == nil { - return append(b, "null"...), nil + return e.encodeNull(b, nil) } if (e.flags & SortMapKeys) == 0 { @@ -828,7 +828,7 @@ func (e encoder) encodeRawMessage(b []byte, p unsafe.Pointer) ([]byte, error) { v := *(*RawMessage)(p) if v == nil { - return append(b, "null"...), nil + return e.encodeNull(b, nil) } var s []byte @@ -862,7 +862,7 @@ func (e encoder) encodeJSONMarshaler(b []byte, p unsafe.Pointer, t reflect.Type, switch v.Kind() { case reflect.Ptr, reflect.Interface: if v.IsNil() { - return append(b, "null"...), nil + return e.encodeNull(b, nil) } } From e7e3ee19442e6f9c35c4b43329df13fdae6a63a4 Mon Sep 17 00:00:00 2001 From: Kevin Gillette Date: Sun, 27 Jul 2025 22:33:13 -0600 Subject: [PATCH 2/3] json: update cycle checking aspects from stdlib golang_encode_test.go --- json/golang_encode_test.go | 86 +++++++++++++++++++++++++++++++++----- json/golang_shim_test.go | 30 +++++++++++++ 2 files changed, 105 insertions(+), 11 deletions(-) diff --git a/json/golang_encode_test.go b/json/golang_encode_test.go index 5e334a6..86f4f8d 100644 --- a/json/golang_encode_test.go +++ b/json/golang_encode_test.go @@ -136,21 +136,85 @@ func TestEncodeRenamedByteSlice(t *testing.T) { } } -var unsupportedValues = []any{ - math.NaN(), - math.Inf(-1), - math.Inf(1), +type SamePointerNoCycle struct { + Ptr1, Ptr2 *SamePointerNoCycle +} + +var samePointerNoCycle = &SamePointerNoCycle{} + +type PointerCycle struct { + Ptr *PointerCycle +} + +var pointerCycle = &PointerCycle{} + +type PointerCycleIndirect struct { + Ptrs []any +} + +type RecursiveSlice []RecursiveSlice + +var ( + pointerCycleIndirect = &PointerCycleIndirect{} + mapCycle = make(map[string]any) + sliceCycle = []any{nil} + sliceNoCycle = []any{nil, nil} + recursiveSliceCycle = []RecursiveSlice{nil} +) + +func init() { + ptr := &SamePointerNoCycle{} + samePointerNoCycle.Ptr1 = ptr + samePointerNoCycle.Ptr2 = ptr + + pointerCycle.Ptr = pointerCycle + pointerCycleIndirect.Ptrs = []any{pointerCycleIndirect} + + mapCycle["x"] = mapCycle + sliceCycle[0] = sliceCycle + sliceNoCycle[1] = sliceNoCycle[:1] + for i := startDetectingCyclesAfter; i > 0; i-- { + sliceNoCycle = []any{sliceNoCycle} + } + recursiveSliceCycle[0] = recursiveSliceCycle +} + +func TestSamePointerNoCycle(t *testing.T) { + if _, err := Marshal(samePointerNoCycle); err != nil { + t.Fatalf("Marshal error: %v", err) + } +} + +func TestSliceNoCycle(t *testing.T) { + if _, err := Marshal(sliceNoCycle); err != nil { + t.Fatalf("Marshal error: %v", err) + } } func TestUnsupportedValues(t *testing.T) { - for _, v := range unsupportedValues { - if _, err := Marshal(v); err != nil { - if _, ok := err.(*UnsupportedValueError); !ok { - t.Errorf("for %v, got %T want UnsupportedValueError", v, err) + tests := []struct { + CaseName + in any + }{ + {Name(""), math.NaN()}, + {Name(""), math.Inf(-1)}, + {Name(""), math.Inf(1)}, + {Name(""), pointerCycle}, + {Name(""), pointerCycleIndirect}, + {Name(""), mapCycle}, + {Name(""), sliceCycle}, + {Name(""), recursiveSliceCycle}, + } + for _, tt := range tests { + t.Run(tt.Name, func(t *testing.T) { + if _, err := Marshal(tt.in); err != nil { + if _, ok := err.(*UnsupportedValueError); !ok { + t.Errorf("%s: Marshal error:\n\tgot: %T\n\twant: %T", tt.Where, err, new(UnsupportedValueError)) + } + } else { + t.Errorf("%s: Marshal error: got nil, want non-nil", tt.Where) } - } else { - t.Errorf("for %v, expected error", v) - } + }) } } diff --git a/json/golang_shim_test.go b/json/golang_shim_test.go index 5a19b7f..90e4fa9 100644 --- a/json/golang_shim_test.go +++ b/json/golang_shim_test.go @@ -4,7 +4,10 @@ package json import ( "bytes" + "fmt" + "path" "reflect" + "runtime" "sync" "testing" ) @@ -68,3 +71,30 @@ func errorWithPrefixes(t *testing.T, prefixes []any, format string, elements ... } t.Errorf(fullFormat, allElements...) } + +// ============================================================================= +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// CaseName is a case name annotated with a file and line. +type CaseName struct { + Name string + Where CasePos +} + +// Name annotates a case name with the file and line of the caller. +func Name(s string) (c CaseName) { + c.Name = s + runtime.Callers(2, c.Where.pc[:]) + return c +} + +// CasePos represents a file and line number. +type CasePos struct{ pc [1]uintptr } + +func (pos CasePos) String() string { + frames := runtime.CallersFrames(pos.pc[:]) + frame, _ := frames.Next() + return fmt.Sprintf("%s:%d", path.Base(frame.File), frame.Line) +} From 54f54bd6d45327813f1cdd52bf1f042b70f6cde8 Mon Sep 17 00:00:00 2001 From: Kevin Gillette Date: Sun, 27 Jul 2025 22:01:09 -0600 Subject: [PATCH 3/3] json: refactor ref-cycle handling Also set UnsupportedValueError.Value (better stdlib compat). --- json/codec.go | 18 +++++++++-- json/encode.go | 85 +++++++++++++++++++++++++++++++++++++++++--------- 2 files changed, 86 insertions(+), 17 deletions(-) diff --git a/json/codec.go b/json/codec.go index 77fe264..eb73f71 100644 --- a/json/codec.go +++ b/json/codec.go @@ -10,6 +10,7 @@ import ( "sort" "strconv" "strings" + "sync" "sync/atomic" "time" "unicode" @@ -32,13 +33,24 @@ type codec struct { type encoder struct { flags AppendFlags - // ptrDepth tracks the depth of pointer cycles, when it reaches the value + // refDepth tracks the depth of pointer cycles, when it reaches the value // of startDetectingCyclesAfter, the ptrSeen map is allocated and the // encoder starts tracking pointers it has seen as an attempt to detect // whether it has entered a pointer cycle and needs to error before the // goroutine runs out of stack space. - ptrDepth uint32 - ptrSeen map[unsafe.Pointer]struct{} + // + // This relies on encoder being passed as a value, + // and encoder methods calling each other in a traditional stack + // (not using trampoline techniques), + // since refDepth is never decremented. + refDepth uint32 + refSeen cycleMap +} + +type cycleMap map[unsafe.Pointer]struct{} + +var cycleMapPool = sync.Pool{ + New: func() any { return make(cycleMap) }, } type decoder struct { diff --git a/json/encode.go b/json/encode.go index 8fe2867..576271c 100644 --- a/json/encode.go +++ b/json/encode.go @@ -794,22 +794,23 @@ func (e encoder) encodeEmbeddedStructPointer(b []byte, p unsafe.Pointer, t refle } func (e encoder) encodePointer(b []byte, p unsafe.Pointer, t reflect.Type, encode encodeFunc) ([]byte, error) { - if p = *(*unsafe.Pointer)(p); p != nil { - if e.ptrDepth++; e.ptrDepth >= startDetectingCyclesAfter { - if _, seen := e.ptrSeen[p]; seen { - // TODO: reconstruct the reflect.Value from p + t so we can set - // the erorr's Value field? - return b, &UnsupportedValueError{Str: fmt.Sprintf("encountered a cycle via %s", t)} - } - if e.ptrSeen == nil { - e.ptrSeen = make(map[unsafe.Pointer]struct{}) - } - e.ptrSeen[p] = struct{}{} - defer delete(e.ptrSeen, p) + // p was a pointer to the actual user data pointer: + // dereference it to operate on the user data pointer. + p = *(*unsafe.Pointer)(p) + if p == nil { + return e.encodeNull(b, nil) + } + + if shouldCheckForRefCycle(&e) { + err := checkRefCycle(&e, t, p) + if err != nil { + return b, err } - return encode(e, b, p) + + defer freeRefCycleInfo(&e, p) } - return e.encodeNull(b, nil) + + return encode(e, b, p) } func (e encoder) encodeInterface(b []byte, p unsafe.Pointer) ([]byte, error) { @@ -968,3 +969,59 @@ func appendCompactEscapeHTML(dst []byte, src []byte) []byte { return dst } + +// shouldCheckForRefCycle determines whether checking for reference cycles +// is reasonable to do at this time. +// +// When true, checkRefCycle should be called and any error handled, +// and then a deferred call to freeRefCycleInfo should be made. +// +// This should only be called from encoder methods that are possible points +// that could directly contribute to a reference cycle. +func shouldCheckForRefCycle(e *encoder) bool { + // Note: do not combine this with checkRefCycle, + // because checkRefCycle is too large to be inlined, + // and a non-inlined depth check leads to ~5%+ benchmark degradation. + e.refDepth++ + return e.refDepth >= startDetectingCyclesAfter +} + +// checkRefCycle returns an error if a reference cycle was detected. +// The data pointer passed in should be equivalent to one of: +// +// - A normal Go pointer, e.g. `unsafe.Pointer(&T)` +// - The pointer to a map header, e.g. `*(*unsafe.Pointer)(&map[K]V)` +// +// Many [encoder] methods accept a pointer-to-a-pointer, +// and so those may need to be derenced in order to safely pass them here. +func checkRefCycle(e *encoder, t reflect.Type, p unsafe.Pointer) error { + _, seen := e.refSeen[p] + if seen { + v := reflect.NewAt(t, p) + return &UnsupportedValueError{ + Value: v, + Str: fmt.Sprintf("encountered a cycle via %s", t), + } + } + + if e.refSeen == nil { + e.refSeen = cycleMapPool.Get().(cycleMap) + } + + e.refSeen[p] = struct{}{} + + return nil +} + +// freeRefCycle performs the cleanup operation for [checkRefCycle]. +// p must be the same value passed into a prior call to checkRefCycle. +func freeRefCycleInfo(e *encoder, p unsafe.Pointer) { + delete(e.refSeen, p) + if len(e.refSeen) == 0 { + // There are no remaining elements, + // so we can release this map for later reuse. + m := e.refSeen + e.refSeen = nil + cycleMapPool.Put(m) + } +}