@@ -92,36 +92,48 @@ func WriteTCPRequest(addr string, payload []byte) *buf.Buffer {
92
92
// Padding length (QUIC varint)
93
93
// Padding (bytes)
94
94
95
- func ReadTCPResponse (r io.Reader ) (bool , string , error ) {
95
+ func ReadTCPResponse (r io.Reader ) (ok bool , message string , err error ) {
96
96
var status [1 ]byte
97
- if _ , err := io .ReadFull (r , status [:]); err != nil {
98
- return false , "" , err
97
+ _ , err = io .ReadFull (r , status [:])
98
+ if err != nil {
99
+ return
99
100
}
101
+ ok = status [0 ] == 0
100
102
bReader := quicvarint .NewReader (r )
101
- msg , err := ReadVString (bReader )
103
+ messageLen , err := quicvarint .Read (bReader )
104
+ if err != nil {
105
+ return
106
+ }
107
+ if messageLen > MaxMessageLength {
108
+ return false , "" , E .New ("invalid message length" )
109
+ }
110
+ message , err = rw .ReadString (r , int (messageLen ))
102
111
if err != nil {
103
- return false , "" , err
112
+ return
104
113
}
105
114
paddingLen , err := quicvarint .Read (bReader )
106
115
if err != nil {
107
- return false , "" , err
116
+ return
108
117
}
109
118
if paddingLen > MaxPaddingLength {
110
119
return false , "" , E .New ("invalid padding length" )
111
120
}
112
121
if paddingLen > 0 {
113
122
_ , err = io .CopyN (io .Discard , r , int64 (paddingLen ))
114
123
if err != nil {
115
- return false , "" , err
124
+ return
116
125
}
117
126
}
118
- return status [ 0 ] == 0 , msg , nil
127
+ return
119
128
}
120
129
121
130
func WriteTCPResponse (ok bool , msg string , payload []byte ) * buf.Buffer {
122
131
padding := tcpResponsePadding .String ()
123
132
paddingLen := len (padding )
124
133
msgLen := len (msg )
134
+ if msgLen > MaxMessageLength {
135
+ msgLen = MaxMessageLength
136
+ }
125
137
sz := 1 + int (quicvarint .Len (uint64 (msgLen ))) + msgLen +
126
138
int (quicvarint .Len (uint64 (paddingLen ))) + paddingLen
127
139
buffer := buf .NewSize (sz + len (payload ))
@@ -198,7 +210,7 @@ func ParseUDPMessage(msg []byte) (*UDPMessage, error) {
198
210
if err != nil {
199
211
return nil , err
200
212
}
201
- if lAddr == 0 || lAddr > MaxMessageLength {
213
+ if lAddr == 0 || lAddr > MaxAddressLength {
202
214
return nil , E .New ("invalid address length" )
203
215
}
204
216
bs := buf .Bytes ()
@@ -212,6 +224,9 @@ func ReadVString(reader io.Reader) (string, error) {
212
224
if err != nil {
213
225
return "" , err
214
226
}
227
+ if length > MaxAddressLength {
228
+ return "" , E .New ("invalid address length" )
229
+ }
215
230
value , err := rw .ReadBytes (reader , int (length ))
216
231
if err != nil {
217
232
return "" , err
0 commit comments