diff --git a/json/codec.go b/json/codec.go index 381b0e7..908c3f6 100644 --- a/json/codec.go +++ b/json/codec.go @@ -16,13 +16,32 @@ import ( "github.com/segmentio/asm/keyset" ) +const ( + // 1000 is the value used by the standard encoding/json package. + // + // https://cs.opensource.google/go/go/+/refs/tags/go1.17.3:src/encoding/json/encode.go;drc=refs%2Ftags%2Fgo1.17.3;l=300 + startDetectingCyclesAfter = 1000 +) + type codec struct { encode encodeFunc decode decodeFunc } -type encoder struct{ flags AppendFlags } -type decoder struct{ flags ParseFlags } +type encoder struct { + flags AppendFlags + // ptrDepth tracks the depth of pointer cycles, when it reaches the value + // of startDetectingCyclesAfter, the ptrSeen map is allocated and the + // encoder starts tracking pointers it has seen as an attempt to detect + // whether it has entered a pointer cycle and needs to error before the + // goroutine runs out of stack space. + ptrDepth uint32 + ptrSeen map[unsafe.Pointer]struct{} +} + +type decoder struct { + flags ParseFlags +} type encodeFunc func(encoder, []byte, unsafe.Pointer) ([]byte, error) type decodeFunc func(decoder, []byte, unsafe.Pointer) ([]byte, error) diff --git a/json/encode.go b/json/encode.go index 8dd0e4b..acb3b67 100644 --- a/json/encode.go +++ b/json/encode.go @@ -2,6 +2,7 @@ package json import ( "encoding" + "fmt" "math" "reflect" "sort" @@ -815,6 +816,18 @@ func (e encoder) encodeEmbeddedStructPointer(b []byte, p unsafe.Pointer, t refle func (e encoder) encodePointer(b []byte, p unsafe.Pointer, t reflect.Type, encode encodeFunc) ([]byte, error) { if p = *(*unsafe.Pointer)(p); p != nil { + if e.ptrDepth++; e.ptrDepth >= startDetectingCyclesAfter { + if _, seen := e.ptrSeen[p]; seen { + // TODO: reconstruct the reflect.Value from p + t so we can set + // the erorr's Value field? + return b, &UnsupportedValueError{Str: fmt.Sprintf("encountered a cycle via %s", t)} + } + if e.ptrSeen == nil { + e.ptrSeen = make(map[unsafe.Pointer]struct{}) + } + e.ptrSeen[p] = struct{}{} + defer delete(e.ptrSeen, p) + } return encode(e, b, p) } return e.encodeNull(b, nil) diff --git a/json/json.go b/json/json.go index c3e710b..9d259cc 100644 --- a/json/json.go +++ b/json/json.go @@ -55,7 +55,7 @@ type UnsupportedValueError = json.UnsupportedValueError // AppendFlags is a type used to represent configuration options that can be // applied when formatting json output. -type AppendFlags uint +type AppendFlags uint32 const ( // EscapeHTML is a formatting flag used to to escape HTML in json strings. @@ -74,7 +74,7 @@ const ( // ParseFlags is a type used to represent configuration options that can be // applied when parsing json input. -type ParseFlags uint +type ParseFlags uint32 func (flags ParseFlags) has(f ParseFlags) bool { return (flags & f) != 0 diff --git a/json/json_test.go b/json/json_test.go index 5f3f8a2..63ec861 100644 --- a/json/json_test.go +++ b/json/json_test.go @@ -1591,6 +1591,28 @@ func TestGithubIssue44(t *testing.T) { } } +type issue107Foo struct { + Bar *issue107Bar +} + +type issue107Bar struct { + Foo *issue107Foo +} + +func TestGithubIssue107(t *testing.T) { + f := &issue107Foo{} + b := &issue107Bar{} + f.Bar = b + b.Foo = f + + _, err := Marshal(f) // must not crash + switch err.(type) { + case *UnsupportedValueError: + default: + t.Errorf("marshaling a cycling data structure was expected to return an unsupported value error but got %T", err) + } +} + type rawJsonString string func (r *rawJsonString) UnmarshalJSON(b []byte) error {