From abfa29864ddde4c1302917a2da879246e1990610 Mon Sep 17 00:00:00 2001 From: Wenying Dong Date: Mon, 19 Jun 2023 18:59:20 -0700 Subject: [PATCH] Bugfix: unmarshal issue with byte slice in packetIn2 message (#40) Use "copy" function to set the byte slice fields in a message. Signed-off-by: wenyingd --- VERSION | 2 +- openflow13/nx_action.go | 9 +++-- openflow13/nx_match.go | 3 +- openflow13/nxt_message.go | 7 ++-- openflow15/nx_match.go | 3 +- openflow15/nxt_message.go | 6 ++-- protocol/arp.go | 12 ++++--- protocol/icmpv6.go | 9 +++-- protocol/igmp.go | 9 +++-- protocol/ip.go | 6 ++-- protocol/ipv6.go | 6 ++-- stream_test.go | 69 +++++++++++++++++++++++++++++++++++---- 12 files changed, 109 insertions(+), 32 deletions(-) diff --git a/VERSION b/VERSION index 3123ff9..b4e3d2a 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -v0.10.2 +v0.10.3 diff --git a/openflow13/nx_action.go b/openflow13/nx_action.go index 7a275be..9b3bf49 100644 --- a/openflow13/nx_action.go +++ b/openflow13/nx_action.go @@ -1188,7 +1188,8 @@ func (s *NXLearnSpec) UnmarshalBinary(data []byte) error { n := s.Header.Len() if s.Header.src { srcDataLength := 2 * ((s.Header.nBits + 15) / 16) - s.SrcValue = data[n : n+srcDataLength] + s.SrcValue = make([]byte, srcDataLength) + copy(s.SrcValue, data[n:n+srcDataLength]) n += srcDataLength } else { s.SrcField = new(NXLearnSpecField) @@ -1350,7 +1351,8 @@ func (a *NXActionNote) UnmarshalBinary(data []byte) error { return errors.New("the []byte is too short to unmarshal a full NXActionNote message") } n := a.NXActionHeader.Len() - a.Note = data[n:a.Length] + a.Note = make([]byte, int(a.Length-n)) + copy(a.Note, data[n:a.Length]) return nil } @@ -1674,7 +1676,8 @@ func (a *NXActionController2PropUserdata) UnmarshalBinary(data []byte) error { } n += int(a.PropHeader.Len()) - a.Userdata = data[n:a.Length] + a.Userdata = make([]byte, int(a.Length)-n) + copy(a.Userdata, data[n:a.Length]) return nil } diff --git a/openflow13/nx_match.go b/openflow13/nx_match.go index 9a65ca7..416bea1 100644 --- a/openflow13/nx_match.go +++ b/openflow13/nx_match.go @@ -80,7 +80,8 @@ func (m *ByteArrayField) UnmarshalBinary(data []byte) error { if len(data) < int(expectLength) { return errors.New("The byte array has wrong size to unmarshal ByteArrayField message") } - m.Data = data[:expectLength] + m.Data = make([]byte, expectLength) + copy(m.Data, data[:expectLength]) return nil } diff --git a/openflow13/nxt_message.go b/openflow13/nxt_message.go index 6e24d58..a290dc4 100644 --- a/openflow13/nxt_message.go +++ b/openflow13/nxt_message.go @@ -1126,7 +1126,8 @@ func (p *PacketIn2PropUserdata) UnmarshalBinary(data []byte) error { } n += int(p.PropHeader.Len()) - p.Userdata = data[n:p.Length] + p.Userdata = make([]byte, int(p.Length)-n) + copy(p.Userdata, data[n:p.Length]) return nil } @@ -1168,8 +1169,8 @@ func (p *PacketIn2PropContinuation) UnmarshalBinary(data []byte) error { return errors.New("the []byte is too short to unmarshal a full PacketIn2PropContinuation message") } n += int(p.PropHeader.Len()) - - p.Continuation = data[n:p.Length] + p.Continuation = make([]byte, int(p.Length)-n) + copy(p.Continuation, data[n:p.Length]) return nil } diff --git a/openflow15/nx_match.go b/openflow15/nx_match.go index bb88986..f4885d8 100644 --- a/openflow15/nx_match.go +++ b/openflow15/nx_match.go @@ -80,7 +80,8 @@ func (m *ByteArrayField) UnmarshalBinary(data []byte) error { if len(data) < int(expectLength) { return errors.New("The byte array has wrong size to unmarshal ByteArrayField message") } - m.Data = data[:expectLength] + m.Data = make([]byte, expectLength) + copy(m.Data, data[:expectLength]) return nil } diff --git a/openflow15/nxt_message.go b/openflow15/nxt_message.go index 86f65ac..f7f188d 100644 --- a/openflow15/nxt_message.go +++ b/openflow15/nxt_message.go @@ -1140,7 +1140,8 @@ func (p *PacketIn2PropUserdata) UnmarshalBinary(data []byte) error { } n += int(p.PropHeader.Len()) - p.Userdata = data[n:p.Length] + p.Userdata = make([]byte, int(p.Length)-n) + copy(p.Userdata, data[n:p.Length]) return nil } @@ -1183,7 +1184,8 @@ func (p *PacketIn2PropContinuation) UnmarshalBinary(data []byte) error { } n += int(p.PropHeader.Len()) - p.Continuation = data[n:p.Length] + p.Continuation = make([]byte, int(p.Length)-n) + copy(p.Continuation, data[n:p.Length]) return nil } diff --git a/protocol/arp.go b/protocol/arp.go index f92bf2e..59f1d61 100755 --- a/protocol/arp.go +++ b/protocol/arp.go @@ -80,12 +80,16 @@ func (a *ARP) UnmarshalBinary(data []byte) error { if len(data[n:]) < (int(a.HWLength)*2 + int(a.ProtoLength)*2) { return errors.New("The []byte is too short to unmarshal a full ARP message.") } - a.HWSrc = data[n : n+int(a.HWLength)] + a.HWSrc = make([]byte, a.HWLength) + copy(a.HWSrc, data[n:n+int(a.HWLength)]) n += int(a.HWLength) - a.IPSrc = data[n : n+int(a.ProtoLength)] + a.IPSrc = make([]byte, a.ProtoLength) + copy(a.IPSrc, data[n:n+int(a.ProtoLength)]) n += int(a.ProtoLength) - a.HWDst = data[n : n+int(a.HWLength)] + a.HWDst = make([]byte, a.HWLength) + copy(a.HWDst, data[n:n+int(a.HWLength)]) n += int(a.HWLength) - a.IPDst = data[n : n+int(a.ProtoLength)] + a.IPDst = make([]byte, a.ProtoLength) + copy(a.IPDst, data[n:n+int(a.ProtoLength)]) return nil } diff --git a/protocol/icmpv6.go b/protocol/icmpv6.go index c87b42e..f2ac7b1 100644 --- a/protocol/icmpv6.go +++ b/protocol/icmpv6.go @@ -220,7 +220,8 @@ func (m *MLD) UnmarshalBinary(data []byte) error { m.MaxResponse = binary.BigEndian.Uint16(data[n:]) n += 2 n += 2 - m.MulticastAddress = data[n : n+16] + m.MulticastAddress = make([]byte, 16) + copy(m.MulticastAddress, data[n:n+16]) return nil } @@ -358,7 +359,8 @@ func (q *MLDQuery) UnmarshalBinary(data []byte) error { q.MaxResponse = binary.BigEndian.Uint16(data[n:]) n += 2 n += 2 - q.MulticastAddress = data[n : n+16] + q.MulticastAddress = make([]byte, 16) + copy(q.MulticastAddress, data[n:n+16]) n += 16 q.Version2 = len(data) > 28 if q.Version2 { @@ -603,7 +605,8 @@ func (r *MLDv2Record) UnmarshalBinary(data []byte) error { n += 1 r.NumberOfSources = binary.BigEndian.Uint16(data[n:]) n += 2 - r.MulticastAddress = data[n : n+16] + r.MulticastAddress = make([]byte, 16) + copy(r.MulticastAddress, data[n:n+16]) n += 16 if len(data) < int(r.Len()) { return fmt.Errorf("The []byte is too short to unmarshal a full MLDv2Record message.") diff --git a/protocol/igmp.go b/protocol/igmp.go index 6209bfa..c9a9c0e 100644 --- a/protocol/igmp.go +++ b/protocol/igmp.go @@ -76,7 +76,8 @@ func (p *IGMPv1or2) UnmarshalBinary(data []byte) error { p.Type = data[0] p.MaxResponseTime = data[1] p.Checksum = binary.BigEndian.Uint16(data[2:4]) - p.GroupAddress = data[4:8] + p.GroupAddress = make([]byte, 4) + copy(p.GroupAddress, data[4:8]) return nil } @@ -190,7 +191,8 @@ func (p *IGMPv3Query) UnmarshalBinary(data []byte) error { n += 1 p.Checksum = binary.BigEndian.Uint16(data[n:]) n += 2 - p.GroupAddress = data[n : n+4] + p.GroupAddress = make([]byte, 4) + copy(p.GroupAddress, data[n:n+4]) n += 4 p.SuppressRouterProcessing = data[n]&0x8 != 0 p.RobustnessValue = data[n] & 0x7 @@ -295,7 +297,8 @@ func (p *IGMPv3GroupRecord) UnmarshalBinary(data []byte) error { n += 1 p.NumberOfSources = binary.BigEndian.Uint16(data[n:]) n += 2 - p.MulticastAddress = data[n : n+4] + p.MulticastAddress = make([]byte, 4) + copy(p.MulticastAddress, data[n:n+4]) n += 4 if len(data) < int(p.Len()) { return fmt.Errorf("The []byte is too short to unmarshal a full IGMPv3GroupRecord message.") diff --git a/protocol/ip.go b/protocol/ip.go index ff3523d..652f61c 100644 --- a/protocol/ip.go +++ b/protocol/ip.go @@ -135,9 +135,11 @@ func (i *IPv4) UnmarshalBinary(data []byte) error { n += 1 i.Checksum = binary.BigEndian.Uint16(data[n:]) n += 2 - i.NWSrc = data[n : n+4] + i.NWSrc = make([]byte, 4) + copy(i.NWSrc, data[n:n+4]) n += 4 - i.NWDst = data[n : n+4] + i.NWDst = make([]byte, 4) + copy(i.NWDst, data[n:n+4]) n += 4 err := i.Options.UnmarshalBinary(data[n:int(i.IHL*4)]) diff --git a/protocol/ipv6.go b/protocol/ipv6.go index e079f33..6fd69ed 100644 --- a/protocol/ipv6.go +++ b/protocol/ipv6.go @@ -128,9 +128,11 @@ func (i *IPv6) UnmarshalBinary(data []byte) error { n += 1 i.HopLimit = data[n] n += 1 - i.NWSrc = data[n : n+16] + i.NWSrc = make([]byte, 16) + copy(i.NWSrc, data[n:n+16]) n += 16 - i.NWDst = data[n : n+16] + i.NWDst = make([]byte, 16) + copy(i.NWDst, data[n:n+16]) n += 16 checkExtHeader := true diff --git a/stream_test.go b/stream_test.go index c382f43..b7f71fb 100644 --- a/stream_test.go +++ b/stream_test.go @@ -8,9 +8,11 @@ import ( "time" "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" "antrea.io/libOpenflow/common" "antrea.io/libOpenflow/openflow13" + "antrea.io/libOpenflow/openflow15" "antrea.io/libOpenflow/util" ) @@ -18,8 +20,9 @@ var helloMessage *common.Hello var binaryMessage []byte type fakeConn struct { - count int - max int + count int + max int + bytesGenerator func() []byte } func (f *fakeConn) Close() error { @@ -31,7 +34,7 @@ func (f *fakeConn) Read(b []byte) (int, error) { return 0, io.EOF } f.count++ - regenerateMessage() + binaryMessage = f.bytesGenerator() copy(b, binaryMessage) return len(binaryMessage), nil } @@ -60,6 +63,13 @@ func (f *fakeConn) SetWriteDeadline(t time.Time) error { return nil } +func newFakeConn(max int, generator func() []byte) net.Conn { + return &fakeConn{ + max: max, + bytesGenerator: generator, + } +} + type parserIntf struct { } @@ -67,22 +77,25 @@ func (p parserIntf) Parse(b []byte) (message util.Message, err error) { switch b[0] { case openflow13.VERSION: message, err = openflow13.Parse(b) + case openflow15.VERSION: + message, err = openflow15.Parse(b) default: } return } -func regenerateMessage() { +func regenerateMessage() []byte { helloMessage, _ = common.NewHello(4) - binaryMessage, _ = helloMessage.MarshalBinary() - + msgBytes, _ := helloMessage.MarshalBinary() + return msgBytes } func TestMessageStream(t *testing.T) { var ( c = &fakeConn{ - max: 5000000, + max: 5000000, + bytesGenerator: regenerateMessage, } p = parserIntf{} goroutineCountStart = runtime.NumGoroutine() @@ -102,3 +115,45 @@ func TestMessageStream(t *testing.T) { t.Fatalf("found more goroutines: %v before, %v after", goroutineCountStart, goroutineCountEnd) } } + +func TestStreamInbound(t *testing.T) { + msgBytes := [][]byte{ + {6, 4, 1, 32, 0, 0, 0, 0, 0, 0, 35, 32, 0, 0, 0, 30, 0, 0, 0, 146, 18, 140, 235, 64, 244, 97, 250, 225, 185, 29, 98, 76, 8, 0, 69, 0, 0, 128, 81, 197, 0, 0, 64, 17, 165, 78, 192, 168, 1, 5, 192, 168, 1, 4, 74, 57, 20, 82, 0, 108, 39, 22, 38, 140, 4, 111, 143, 183, 249, 172, 140, 17, 90, 252, 24, 153, 45, 23, 130, 161, 238, 104, 89, 18, 12, 49, 241, 43, 100, 179, 102, 188, 140, 42, 221, 93, 185, 100, 143, 105, 135, 253, 204, 36, 247, 68, 5, 239, 57, 213, 97, 86, 73, 13, 73, 247, 250, 181, 202, 140, 158, 63, 190, 231, 49, 20, 242, 192, 121, 129, 5, 81, 253, 104, 171, 241, 45, 46, 189, 211, 37, 123, 31, 187, 181, 253, 60, 109, 192, 144, 230, 234, 108, 149, 104, 131, 163, 221, 165, 41, 249, 138, 0, 0, 0, 0, 0, 0, 0, 3, 0, 5, 28, 0, 0, 0, 0, 4, 0, 16, 0, 0, 0, 0, 0, 35, 2, 0, 0, 0, 0, 0, 0, 5, 0, 5, 0, 0, 0, 0, 0, 6, 0, 76, 128, 0, 0, 4, 0, 0, 0, 6, 128, 1, 0, 8, 2, 64, 0, 3, 0, 0, 0, 5, 128, 1, 3, 16, 0, 0, 0, 25, 0, 0, 0, 0, 255, 255, 255, 255, 0, 0, 0, 0, 128, 1, 4, 8, 0, 1, 0, 0, 0, 0, 0, 3, 128, 1, 7, 16, 0, 0, 0, 2, 0, 0, 0, 0, 255, 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 0, 6, 1, 1, 0, 0}, + {6, 4, 0, 144, 0, 0, 0, 2, 0, 0, 35, 32, 0, 0, 0, 30, 0, 0, 0, 50, 1, 0, 94, 20, 50, 173, 34, 101, 235, 44, 251, 123, 8, 0, 70, 192, 0, 32, 0, 0, 64, 0, 1, 2, 15, 169, 192, 168, 0, 5, 225, 20, 50, 173, 148, 4, 0, 0, 18, 0, 218, 61, 225, 20, 50, 173, 0, 0, 0, 0, 0, 0, 0, 3, 0, 5, 33, 0, 0, 0, 0, 4, 0, 16, 0, 0, 0, 0, 0, 3, 5, 0, 0, 0, 0, 0, 0, 5, 0, 5, 0, 0, 0, 0, 0, 6, 0, 32, 128, 0, 0, 4, 0, 0, 0, 6, 128, 1, 1, 16, 0, 0, 0, 3, 0, 0, 0, 0, 255, 255, 255, 255, 0, 0, 0, 0, 0, 7, 0, 5, 3, 0, 0, 0}, + } + expectedMessages := make([]util.Message, 2) + for i := range msgBytes { + venderMsg := new(openflow15.VendorHeader) + err := venderMsg.UnmarshalBinary(msgBytes[i]) + assert.NoError(t, err) + expectedMessages[i] = venderMsg + } + countCh := make(chan int) + msgCount := 10000 + c := newFakeConn(msgCount, func() []byte { + count := <-countCh + if count < 25 { + return msgBytes[0] + } else { + return msgBytes[1] + } + }) + stream := util.NewMessageStream(c, parserIntf{}) + go func() { + _ = <-stream.Error + }() + + msgs := make([]util.Message, msgCount) + for i := 0; i < msgCount; i++ { + countCh <- i + msg := <-stream.Inbound + msgs[i] = msg + } + for i := range msgs { + if i < 25 { + assert.Equal(t, expectedMessages[0], msgs[i]) + } else { + assert.Equal(t, expectedMessages[1], msgs[i]) + } + } +}