diff --git a/internal/libyaml/loader.go b/internal/libyaml/loader.go index ab150ce7..f59eed84 100644 --- a/internal/libyaml/loader.go +++ b/internal/libyaml/loader.go @@ -139,6 +139,10 @@ func (l *Loader) Load(v any) (err error) { // Stage 2: Resolve - determine implicit types for untagged scalars l.resolver.Resolve(node) + // Propagate a snapshot of loader options onto every node so that Node.Decode called + // inside custom UnmarshalYAML implementations inherits settings like KnownFields + propagateLoadOptions(node, l.options) + // Stage 3: Construct - convert node tree to Go values out := reflect.ValueOf(v) if out.Kind() == reflect.Pointer && !out.IsNil() { @@ -153,6 +157,24 @@ func (l *Loader) Load(v any) (err error) { return nil } +// propagateLoadOptions stamps n and every node reachable through Content with opts. +// Alias pointers are not followed: in valid YAML, anchors are defined before +// their aliases, so the anchor node is always reachable through n.Content traversal. +func propagateLoadOptions(n *Node, opts *Options) { + if n == nil || opts == nil { + return + } + snapshot := *opts // pass a snapshot to avoid data race + propagateLoadOptionsRecursion(n, &snapshot) +} + +func propagateLoadOptionsRecursion(n *Node, opts *Options) { + n.options = opts + for _, child := range n.Content { + propagateLoadOptionsRecursion(child, opts) + } +} + // loadAll loads all documents from the input into a slice. // The out parameter must be a non-nil pointer to a slice. // Each document is appended to the slice as an element. @@ -262,6 +284,7 @@ func loadSingle(in []byte, out any, opts *Options) error { // This is used by the legacy Decoder.KnownFields() method. func (l *Loader) SetKnownFields(enable bool) { l.constructor.KnownFields = enable + l.options.KnownFields = enable } // ComposeAndResolve composes and resolves the next document from the input @@ -282,6 +305,10 @@ func (l *Loader) ComposeAndResolve() *Node { // Stage 2: Resolve - determine implicit types for untagged scalars l.resolver.Resolve(node) + // Propagate a snapshot of loader options onto every node so that Node.Decode called + // inside custom UnmarshalYAML implementations inherits settings like KnownFields + propagateLoadOptions(node, l.options) + return node } diff --git a/internal/libyaml/loader_test.go b/internal/libyaml/loader_test.go index 1ebe1351..80c00d20 100644 --- a/internal/libyaml/loader_test.go +++ b/internal/libyaml/loader_test.go @@ -275,3 +275,26 @@ func TestLoad_MultipleDocuments(t *testing.T) { assert.NotNil(t, err) assert.ErrorMatches(t, ".*expected single document, found multiple.*", err) } + +// TestComposeAndResolvePropagatesOptions tests that ComposeAndResolve propagates +// a snapshot of the loader options onto the returned node tree so that +// Node.Decode inside custom UnmarshalYAML implementations respects settings +// like KnownFields. +func TestComposeAndResolvePropagatesOptions(t *testing.T) { + type target struct { + Name string `yaml:"name"` + } + + input := []byte("name: Alice\nunknown_field: oops\n") + loader, err := NewLoader(bytes.NewReader(input), WithKnownFields()) + assert.NoError(t, err) + + node := loader.ComposeAndResolve() + assert.NotNil(t, node) + assert.NotNil(t, node.options) + + var v target + err = node.Decode(&v) + assert.NotNil(t, err) + assert.ErrorMatches(t, ".*unknown_field.*", err) +} diff --git a/internal/libyaml/node.go b/internal/libyaml/node.go index 4fee73b3..36763110 100644 --- a/internal/libyaml/node.go +++ b/internal/libyaml/node.go @@ -189,9 +189,16 @@ type Node struct { // Stream holds stream metadata (non-nil only when Kind == StreamNode). Stream *Stream + + // options is set by propagateLoadOptions when a Loader produces this node. It carries + // the loader options so that Decode can inherit them in custom UnmarshalYAML functions. + // Is typically nil for user-constructed nodes. + options *Options } -// IsZero returns whether the node has all of its fields unset. +// IsZero returns whether the node has all of its user-visible fields unset. +// The unexported options field is intentionally excluded: it is set by loader +// infrastructure and does not represent user-visible content. func (n *Node) IsZero() bool { return n.Kind == 0 && n.Style == 0 && n.Tag == "" && n.Value == "" && n.Anchor == "" && n.Alias == nil && n.Content == nil && n.HeadComment == "" && n.LineComment == "" && n.FootComment == "" && n.Line == 0 && n.Column == 0 && @@ -280,8 +287,12 @@ func (n *Node) SetString(s string) { // See the documentation for Unmarshal for details about the // conversion of YAML into a Go value. func (n *Node) Decode(v any) (err error) { - d := NewConstructor(DefaultOptions) defer handleErr(&err) + opts := DefaultOptions + if n.options != nil { + opts = n.options + } + d := NewConstructor(opts) out := reflect.ValueOf(v) if out.Kind() == reflect.Pointer && !out.IsNil() { out = out.Elem() @@ -296,8 +307,10 @@ func (n *Node) Decode(v any) (err error) { // Load decodes the node and stores its data into the value pointed to by v, // applying the given options. // -// This method is useful when you need to preserve options like WithKnownFields() -// inside custom UnmarshalYAML implementations. +// Unlike Decode, Load does not inherit options from the loader that produced +// this node; the caller must supply all required options explicitly. +// This method is useful when you need explicit control over options like +// WithKnownFields() inside custom UnmarshalYAML implementations. // // Maps and pointers (to a struct, string, int, etc) are accepted as v // values. If an internal pointer within a struct is not initialized, diff --git a/node_test.go b/node_test.go index 36c9af11..4a93f7e9 100644 --- a/node_test.go +++ b/node_test.go @@ -11,6 +11,7 @@ import ( "bytes" "fmt" "reflect" + "sync" "testing" "go.yaml.in/yaml/v4" @@ -784,3 +785,157 @@ func TestNodeDumpInvalidOptions(t *testing.T) { assert.NotNil(t, err) assert.ErrorMatches(t, ".*indent must be.*", err) } + +type nodeDecodeTarget struct { + Name string `yaml:"name"` +} + +func (t *nodeDecodeTarget) UnmarshalYAML(node *yaml.Node) error { + type plain nodeDecodeTarget + return node.Decode((*plain)(t)) +} + +type nodeDecodeChildInner struct { + Name string `yaml:"name"` +} + +type nodeDecodeChildOuter struct { + Inner nodeDecodeChildInner +} + +func (o *nodeDecodeChildOuter) UnmarshalYAML(node *yaml.Node) error { + for i := 0; i+1 < len(node.Content); i += 2 { + if node.Content[i].Value == "inner" { + return node.Content[i+1].Decode(&o.Inner) + } + } + return nil +} + +func TestNodeDecodeInheritsKnownFields(t *testing.T) { + t.Run("known fields rejected", func(t *testing.T) { + input := "name: Alice\nunknown_field: oops\n" + var v nodeDecodeTarget + err := yaml.Load([]byte(input), &v, yaml.WithKnownFields()) + assert.NotNil(t, err) + assert.ErrorMatches(t, ".*unknown_field.*", err) + }) + + t.Run("unknown fields ignored without option", func(t *testing.T) { + input := "name: Alice\nunknown_field: oops\n" + var v nodeDecodeTarget + err := yaml.Unmarshal([]byte(input), &v) + assert.NoError(t, err) + assert.Equal(t, "Alice", v.Name) + }) + + t.Run("user-constructed node uses default options", func(t *testing.T) { + node := &yaml.Node{ + Kind: yaml.MappingNode, + Tag: "!!map", + Content: []*yaml.Node{ + {Kind: yaml.ScalarNode, Tag: "!!str", Value: "name"}, + {Kind: yaml.ScalarNode, Tag: "!!str", Value: "Bob"}, + {Kind: yaml.ScalarNode, Tag: "!!str", Value: "extra"}, + {Kind: yaml.ScalarNode, Tag: "!!str", Value: "ignored"}, + }, + } + var v nodeDecodeChildInner + err := node.Decode(&v) + assert.NoError(t, err) + assert.Equal(t, "Bob", v.Name) + }) + + t.Run("child node inherits known fields", func(t *testing.T) { + input := "inner:\n name: Carol\n unknown_field: oops\n" + var v nodeDecodeChildOuter + err := yaml.Load([]byte(input), &v, yaml.WithKnownFields()) + assert.NotNil(t, err) + assert.ErrorMatches(t, ".*unknown_field.*", err) + }) + + t.Run("known fields via decoder api", func(t *testing.T) { + input := "name: Alice\nunknown_field: oops\n" + var v nodeDecodeTarget + dec := yaml.NewDecoder(bytes.NewReader([]byte(input))) + dec.KnownFields(true) + err := dec.Decode(&v) + assert.NotNil(t, err) + assert.ErrorMatches(t, ".*unknown_field.*", err) + }) + + t.Run("known fields enforced on all documents", func(t *testing.T) { + input := "name: Alice\n---\nname: Bob\nunknown_field: oops\n" + dec := yaml.NewDecoder(bytes.NewReader([]byte(input))) + dec.KnownFields(true) + + var v1 nodeDecodeTarget + err := dec.Decode(&v1) + assert.NoError(t, err) + assert.Equal(t, "Alice", v1.Name) + + var v2 nodeDecodeTarget + err = dec.Decode(&v2) + assert.NotNil(t, err) + assert.ErrorMatches(t, ".*unknown_field.*", err) + }) + + t.Run("known fields can be disabled between documents", func(t *testing.T) { + input := "name: Alice\n---\nname: Bob\nunknown_field: ok\n" + dec := yaml.NewDecoder(bytes.NewReader([]byte(input))) + dec.KnownFields(true) + + var v1 nodeDecodeTarget + err := dec.Decode(&v1) + assert.NoError(t, err) + assert.Equal(t, "Alice", v1.Name) + + dec.KnownFields(false) + + var v2 nodeDecodeTarget + err = dec.Decode(&v2) + assert.NoError(t, err) + assert.Equal(t, "Bob", v2.Name) + }) +} + +type raceDecodeTarget struct { + Name string `yaml:"name"` + onDecode func() +} + +func (t *raceDecodeTarget) UnmarshalYAML(node *yaml.Node) error { + if t.onDecode != nil { + t.onDecode() + } + type plain struct { + Name string `yaml:"name"` + } + var p plain + if err := node.Decode(&p); err != nil { + return err + } + t.Name = p.Name + return nil +} + +// TestSetKnownFieldsRaceWithNodeDecode checks for a data race between Decoder.KnownFields() +// and Node.Decode() inside UnmarshalYAML (run with '-race') +func TestSetKnownFieldsRaceWithNodeDecode(t *testing.T) { + input := "name: Alice\n" + dec := yaml.NewDecoder(bytes.NewReader([]byte(input))) + dec.KnownFields(true) + + var wg sync.WaitGroup + wg.Add(1) + v := &raceDecodeTarget{ + onDecode: func() { + go func() { + defer wg.Done() + dec.KnownFields(false) + }() + }, + } + _ = dec.Decode(v) + wg.Wait() +} diff --git a/yaml_bench_test.go b/yaml_bench_test.go new file mode 100644 index 00000000..8131c5bf --- /dev/null +++ b/yaml_bench_test.go @@ -0,0 +1,103 @@ +// Copyright 2025 The go-yaml Project Contributors +// SPDX-License-Identifier: Apache-2.0 + +package yaml_test + +import ( + "bytes" + "fmt" + "strings" + "testing" + + yaml "go.yaml.in/yaml/v4" +) + +// benchPlain is a decode target with no custom UnmarshalYAML. +type benchPlain struct { + Fields map[string]string `yaml:",inline"` +} + +// benchCustom is a decode target whose UnmarshalYAML calls node.Decode — +// the path that will inherit loader options after the tree-stamp fix. +type benchCustom struct { + Fields map[string]string `yaml:",inline"` +} + +func (b *benchCustom) UnmarshalYAML(node *yaml.Node) error { + type plain benchCustom + return node.Decode((*plain)(b)) +} + +func makeKVDoc(n int) []byte { + var sb strings.Builder + for i := 0; i < n; i++ { + fmt.Fprintf(&sb, "key%d: value%d\n", i, i) + } + return []byte(sb.String()) +} + +var ( + benchSmallDoc = makeKVDoc(10) + benchMediumDoc = makeKVDoc(100) + benchLargeDoc = makeKVDoc(1000) +) + +func BenchmarkDecode(b *testing.B) { + targets := []struct { + name string + decode func(data []byte, known bool) error + }{ + { + name: "plain", + decode: func(data []byte, known bool) error { + var v benchPlain + dec := yaml.NewDecoder(bytes.NewReader(data)) + dec.KnownFields(known) + return dec.Decode(&v) + }, + }, + { + name: "custom", + decode: func(data []byte, known bool) error { + var v benchCustom + dec := yaml.NewDecoder(bytes.NewReader(data)) + dec.KnownFields(known) + return dec.Decode(&v) + }, + }, + } + + options := []struct { + name string + knownFields bool + }{ + {"default", false}, + {"known-fields", true}, + } + + sizes := []struct { + name string + data []byte + }{ + {"small", benchSmallDoc}, + {"medium", benchMediumDoc}, + {"large", benchLargeDoc}, + } + + for _, size := range sizes { + for _, target := range targets { + for _, opt := range options { + size, target, opt := size, target, opt + name := fmt.Sprintf("target=%s/option=%s/size=%s", target.name, opt.name, size.name) + b.Run(name, func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + if err := target.decode(size.data, opt.knownFields); err != nil { + b.Fatal(err) + } + } + }) + } + } + } +}