Skip to content

json: fix cycle detection #153

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions json/codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"sort"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"unicode"
Expand All @@ -32,13 +33,24 @@ type codec struct {

type encoder struct {
flags AppendFlags
// ptrDepth tracks the depth of pointer cycles, when it reaches the value
// refDepth tracks the depth of pointer cycles, when it reaches the value
// of startDetectingCyclesAfter, the ptrSeen map is allocated and the
// encoder starts tracking pointers it has seen as an attempt to detect
// whether it has entered a pointer cycle and needs to error before the
// goroutine runs out of stack space.
ptrDepth uint32
ptrSeen map[unsafe.Pointer]struct{}
//
// This relies on encoder being passed as a value,
// and encoder methods calling each other in a traditional stack
// (not using trampoline techniques),
// since refDepth is never decremented.
refDepth uint32
refSeen cycleMap
}

type cycleMap map[unsafe.Pointer]struct{}

var cycleMapPool = sync.Pool{
New: func() any { return make(cycleMap) },
}

type decoder struct {
Expand Down
105 changes: 81 additions & 24 deletions json/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ func (e encoder) encodeToString(b []byte, p unsafe.Pointer, encode encodeFunc) (
func (e encoder) encodeBytes(b []byte, p unsafe.Pointer) ([]byte, error) {
v := *(*[]byte)(p)
if v == nil {
return append(b, "null"...), nil
return e.encodeNull(b, nil)
}

n := base64.StdEncoding.EncodedLen(len(v)) + 2
Expand Down Expand Up @@ -299,7 +299,7 @@ func (e encoder) encodeSlice(b []byte, p unsafe.Pointer, size uintptr, t reflect
s := (*slice)(p)

if s.data == nil && s.len == 0 && s.cap == 0 {
return append(b, "null"...), nil
return e.encodeNull(b, nil)
}

return e.encodeArray(b, s.data, s.len, size, t, encode)
Expand All @@ -308,7 +308,7 @@ func (e encoder) encodeSlice(b []byte, p unsafe.Pointer, size uintptr, t reflect
func (e encoder) encodeMap(b []byte, p unsafe.Pointer, t reflect.Type, encodeKey, encodeValue encodeFunc, sortKeys sortFunc) ([]byte, error) {
m := reflect.NewAt(t, p).Elem()
if m.IsNil() {
return append(b, "null"...), nil
return e.encodeNull(b, nil)
}

keys := m.MapKeys()
Expand Down Expand Up @@ -363,7 +363,7 @@ var mapslicePool = sync.Pool{
func (e encoder) encodeMapStringInterface(b []byte, p unsafe.Pointer) ([]byte, error) {
m := *(*map[string]any)(p)
if m == nil {
return append(b, "null"...), nil
return e.encodeNull(b, nil)
}

if (e.flags & SortMapKeys) == 0 {
Expand Down Expand Up @@ -441,7 +441,7 @@ func (e encoder) encodeMapStringInterface(b []byte, p unsafe.Pointer) ([]byte, e
func (e encoder) encodeMapStringRawMessage(b []byte, p unsafe.Pointer) ([]byte, error) {
m := *(*map[string]RawMessage)(p)
if m == nil {
return append(b, "null"...), nil
return e.encodeNull(b, nil)
}

if (e.flags & SortMapKeys) == 0 {
Expand Down Expand Up @@ -520,7 +520,7 @@ func (e encoder) encodeMapStringRawMessage(b []byte, p unsafe.Pointer) ([]byte,
func (e encoder) encodeMapStringString(b []byte, p unsafe.Pointer) ([]byte, error) {
m := *(*map[string]string)(p)
if m == nil {
return append(b, "null"...), nil
return e.encodeNull(b, nil)
}

if (e.flags & SortMapKeys) == 0 {
Expand Down Expand Up @@ -586,7 +586,7 @@ func (e encoder) encodeMapStringString(b []byte, p unsafe.Pointer) ([]byte, erro
func (e encoder) encodeMapStringStringSlice(b []byte, p unsafe.Pointer) ([]byte, error) {
m := *(*map[string][]string)(p)
if m == nil {
return append(b, "null"...), nil
return e.encodeNull(b, nil)
}

stringSize := unsafe.Sizeof("")
Expand Down Expand Up @@ -667,7 +667,7 @@ func (e encoder) encodeMapStringStringSlice(b []byte, p unsafe.Pointer) ([]byte,
func (e encoder) encodeMapStringBool(b []byte, p unsafe.Pointer) ([]byte, error) {
m := *(*map[string]bool)(p)
if m == nil {
return append(b, "null"...), nil
return e.encodeNull(b, nil)
}

if (e.flags & SortMapKeys) == 0 {
Expand Down Expand Up @@ -794,22 +794,23 @@ func (e encoder) encodeEmbeddedStructPointer(b []byte, p unsafe.Pointer, t refle
}

func (e encoder) encodePointer(b []byte, p unsafe.Pointer, t reflect.Type, encode encodeFunc) ([]byte, error) {
if p = *(*unsafe.Pointer)(p); p != nil {
if e.ptrDepth++; e.ptrDepth >= startDetectingCyclesAfter {
if _, seen := e.ptrSeen[p]; seen {
// TODO: reconstruct the reflect.Value from p + t so we can set
// the erorr's Value field?
return b, &UnsupportedValueError{Str: fmt.Sprintf("encountered a cycle via %s", t)}
}
if e.ptrSeen == nil {
e.ptrSeen = make(map[unsafe.Pointer]struct{})
}
e.ptrSeen[p] = struct{}{}
defer delete(e.ptrSeen, p)
// p was a pointer to the actual user data pointer:
// dereference it to operate on the user data pointer.
p = *(*unsafe.Pointer)(p)
if p == nil {
return e.encodeNull(b, nil)
}

if shouldCheckForRefCycle(&e) {
err := checkRefCycle(&e, t, p)
if err != nil {
return b, err
}
return encode(e, b, p)

defer freeRefCycleInfo(&e, p)
}
return e.encodeNull(b, nil)

return encode(e, b, p)
}

func (e encoder) encodeInterface(b []byte, p unsafe.Pointer) ([]byte, error) {
Expand All @@ -828,7 +829,7 @@ func (e encoder) encodeRawMessage(b []byte, p unsafe.Pointer) ([]byte, error) {
v := *(*RawMessage)(p)

if v == nil {
return append(b, "null"...), nil
return e.encodeNull(b, nil)
}

var s []byte
Expand Down Expand Up @@ -862,7 +863,7 @@ func (e encoder) encodeJSONMarshaler(b []byte, p unsafe.Pointer, t reflect.Type,
switch v.Kind() {
case reflect.Ptr, reflect.Interface:
if v.IsNil() {
return append(b, "null"...), nil
return e.encodeNull(b, nil)
}
}

Expand Down Expand Up @@ -968,3 +969,59 @@ func appendCompactEscapeHTML(dst []byte, src []byte) []byte {

return dst
}

// shouldCheckForRefCycle determines whether checking for reference cycles
// is reasonable to do at this time.
//
// When true, checkRefCycle should be called and any error handled,
// and then a deferred call to freeRefCycleInfo should be made.
//
// This should only be called from encoder methods that are possible points
// that could directly contribute to a reference cycle.
func shouldCheckForRefCycle(e *encoder) bool {
// Note: do not combine this with checkRefCycle,
// because checkRefCycle is too large to be inlined,
// and a non-inlined depth check leads to ~5%+ benchmark degradation.
e.refDepth++
return e.refDepth >= startDetectingCyclesAfter
}

// checkRefCycle returns an error if a reference cycle was detected.
// The data pointer passed in should be equivalent to one of:
//
// - A normal Go pointer, e.g. `unsafe.Pointer(&T)`
// - The pointer to a map header, e.g. `*(*unsafe.Pointer)(&map[K]V)`
//
// Many [encoder] methods accept a pointer-to-a-pointer,
// and so those may need to be derenced in order to safely pass them here.
func checkRefCycle(e *encoder, t reflect.Type, p unsafe.Pointer) error {
_, seen := e.refSeen[p]
if seen {
v := reflect.NewAt(t, p)
return &UnsupportedValueError{
Value: v,
Str: fmt.Sprintf("encountered a cycle via %s", t),
}
}

if e.refSeen == nil {
e.refSeen = cycleMapPool.Get().(cycleMap)
}

e.refSeen[p] = struct{}{}

return nil
}

// freeRefCycle performs the cleanup operation for [checkRefCycle].
// p must be the same value passed into a prior call to checkRefCycle.
func freeRefCycleInfo(e *encoder, p unsafe.Pointer) {
delete(e.refSeen, p)
if len(e.refSeen) == 0 {
// There are no remaining elements,
// so we can release this map for later reuse.
m := e.refSeen
e.refSeen = nil
cycleMapPool.Put(m)
}
}
86 changes: 75 additions & 11 deletions json/golang_encode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,21 +136,85 @@ func TestEncodeRenamedByteSlice(t *testing.T) {
}
}

var unsupportedValues = []any{
math.NaN(),
math.Inf(-1),
math.Inf(1),
type SamePointerNoCycle struct {
Ptr1, Ptr2 *SamePointerNoCycle
}

var samePointerNoCycle = &SamePointerNoCycle{}

type PointerCycle struct {
Ptr *PointerCycle
}

var pointerCycle = &PointerCycle{}

type PointerCycleIndirect struct {
Ptrs []any
}

type RecursiveSlice []RecursiveSlice

var (
pointerCycleIndirect = &PointerCycleIndirect{}
mapCycle = make(map[string]any)
sliceCycle = []any{nil}
sliceNoCycle = []any{nil, nil}
recursiveSliceCycle = []RecursiveSlice{nil}
)

func init() {
ptr := &SamePointerNoCycle{}
samePointerNoCycle.Ptr1 = ptr
samePointerNoCycle.Ptr2 = ptr

pointerCycle.Ptr = pointerCycle
pointerCycleIndirect.Ptrs = []any{pointerCycleIndirect}

mapCycle["x"] = mapCycle
sliceCycle[0] = sliceCycle
sliceNoCycle[1] = sliceNoCycle[:1]
for i := startDetectingCyclesAfter; i > 0; i-- {
sliceNoCycle = []any{sliceNoCycle}
}
recursiveSliceCycle[0] = recursiveSliceCycle
}

func TestSamePointerNoCycle(t *testing.T) {
if _, err := Marshal(samePointerNoCycle); err != nil {
t.Fatalf("Marshal error: %v", err)
}
}

func TestSliceNoCycle(t *testing.T) {
if _, err := Marshal(sliceNoCycle); err != nil {
t.Fatalf("Marshal error: %v", err)
}
}

func TestUnsupportedValues(t *testing.T) {
for _, v := range unsupportedValues {
if _, err := Marshal(v); err != nil {
if _, ok := err.(*UnsupportedValueError); !ok {
t.Errorf("for %v, got %T want UnsupportedValueError", v, err)
tests := []struct {
CaseName
in any
}{
{Name(""), math.NaN()},
{Name(""), math.Inf(-1)},
{Name(""), math.Inf(1)},
{Name(""), pointerCycle},
{Name(""), pointerCycleIndirect},
{Name(""), mapCycle},
{Name(""), sliceCycle},
{Name(""), recursiveSliceCycle},
}
for _, tt := range tests {
t.Run(tt.Name, func(t *testing.T) {
if _, err := Marshal(tt.in); err != nil {
if _, ok := err.(*UnsupportedValueError); !ok {
t.Errorf("%s: Marshal error:\n\tgot: %T\n\twant: %T", tt.Where, err, new(UnsupportedValueError))
}
} else {
t.Errorf("%s: Marshal error: got nil, want non-nil", tt.Where)
}
} else {
t.Errorf("for %v, expected error", v)
}
})
}
}

Expand Down
30 changes: 30 additions & 0 deletions json/golang_shim_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ package json

import (
"bytes"
"fmt"
"path"
"reflect"
"runtime"
"sync"
"testing"
)
Expand Down Expand Up @@ -68,3 +71,30 @@ func errorWithPrefixes(t *testing.T, prefixes []any, format string, elements ...
}
t.Errorf(fullFormat, allElements...)
}

// =============================================================================
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

// CaseName is a case name annotated with a file and line.
type CaseName struct {
Name string
Where CasePos
}

// Name annotates a case name with the file and line of the caller.
func Name(s string) (c CaseName) {
c.Name = s
runtime.Callers(2, c.Where.pc[:])
return c
}

// CasePos represents a file and line number.
type CasePos struct{ pc [1]uintptr }

func (pos CasePos) String() string {
frames := runtime.CallersFrames(pos.pc[:])
frame, _ := frames.Next()
return fmt.Sprintf("%s:%d", path.Base(frame.File), frame.Line)
}
Loading