Skip to content

Commit a1c8171

Browse files
committedAug 31, 2017
Add websocket
1 parent a58fae9 commit a1c8171

28 files changed

+4489
-0
lines changed
 

‎README.md

+2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ The library provides packages about network and multiple media processing:
2424
- [x] [flv](flv/example_test.go): The FLV muxer and demuxer, for oryx.
2525
- [x] [errors](errors/example_test.go): Fork from [pkg/errors](https://github.com/pkg/errors), a complex error with message and stack, read [article](https://gocn.io/article/348).
2626
- [x] [aac](aac/example_test.go): The AAC utilities to demux and mux AAC RAW data, for oryx.
27+
- [x] [websocket](https://golang.org/x/net/websocket): Fork from [websocket](https://github.com/gorilla/websocket/tree/v1.2.0).
2728
- [ ] [sip](sip/example_test.go): A [sip](https://en.wikipedia.org/wiki/Session_Initiation_Protocol) [RFC3261](https://www.ietf.org/rfc/rfc3261.txt) library, like [pjsip](http://pjsip.org/), [freeswitch](https://freeswitch.org/) and [kamailio](https://www.kamailio.org/).
2829
- [ ] [avc](avc/example_test.go): The AVC utilities to demux and mux AVC RAW data, for oryx.
2930
- [ ] [rtmp](rtmp/example_test.go): The RTMP protocol stack, for oryx.
@@ -47,5 +48,6 @@ while all the licenses are liberal:
4748
1. [acme](https/acme/LICENSE) uses [MIT License](https://github.com/xenolf/lego/blob/master/LICENSE).
4849
1. [jose](https/jose/LICENSE) uses [Apache License 2.0](https://github.com/square/go-jose/blob/v1.1.0/LICENSE).
4950
1. [letsencrypt](https/letsencrypt/LICENSE) uses [BSD 3-clause "New" or "Revised" License](https://github.com/rsc/letsencrypt/blob/master/LICENSE).
51+
1. [websocket](https://github.com/gorilla/websocket) uses [BSD 2-clause "Simplified" License](https://github.com/gorilla/websocket/blob/master/LICENSE).
5052

5153
Winlin 2016

‎websocket/AUTHORS

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# This is the official list of Gorilla WebSocket authors for copyright
2+
# purposes.
3+
#
4+
# Please keep the list sorted.
5+
6+
Gary Burd <gary@beagledreams.com>
7+
Joachim Bauch <mail@joachim-bauch.de>
8+

‎websocket/LICENSE

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
Copyright (c) 2013 The Gorilla WebSocket Authors. All rights reserved.
2+
3+
Redistribution and use in source and binary forms, with or without
4+
modification, are permitted provided that the following conditions are met:
5+
6+
Redistributions of source code must retain the above copyright notice, this
7+
list of conditions and the following disclaimer.
8+
9+
Redistributions in binary form must reproduce the above copyright notice,
10+
this list of conditions and the following disclaimer in the documentation
11+
and/or other materials provided with the distribution.
12+
13+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
14+
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
15+
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
16+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
17+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
18+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
19+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
20+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
21+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

‎websocket/client.go

+393
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,393 @@
1+
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
// fork from https://github.com/gorilla/websocket
6+
package websocket
7+
8+
import (
9+
"bufio"
10+
"bytes"
11+
"crypto/tls"
12+
"encoding/base64"
13+
"errors"
14+
"io"
15+
"io/ioutil"
16+
"net"
17+
"net/http"
18+
"net/url"
19+
"strings"
20+
"time"
21+
)
22+
23+
// ErrBadHandshake is returned when the server response to opening handshake is
24+
// invalid.
25+
var ErrBadHandshake = errors.New("websocket: bad handshake")
26+
27+
var errInvalidCompression = errors.New("websocket: invalid compression negotiation")
28+
29+
// NewClient creates a new client connection using the given net connection.
30+
// The URL u specifies the host and request URI. Use requestHeader to specify
31+
// the origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies
32+
// (Cookie). Use the response.Header to get the selected subprotocol
33+
// (Sec-WebSocket-Protocol) and cookies (Set-Cookie).
34+
//
35+
// If the WebSocket handshake fails, ErrBadHandshake is returned along with a
36+
// non-nil *http.Response so that callers can handle redirects, authentication,
37+
// etc.
38+
//
39+
// Deprecated: Use Dialer instead.
40+
func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufSize, writeBufSize int) (c *Conn, response *http.Response, err error) {
41+
d := Dialer{
42+
ReadBufferSize: readBufSize,
43+
WriteBufferSize: writeBufSize,
44+
NetDial: func(net, addr string) (net.Conn, error) {
45+
return netConn, nil
46+
},
47+
}
48+
return d.Dial(u.String(), requestHeader)
49+
}
50+
51+
// A Dialer contains options for connecting to WebSocket server.
52+
type Dialer struct {
53+
// NetDial specifies the dial function for creating TCP connections. If
54+
// NetDial is nil, net.Dial is used.
55+
NetDial func(network, addr string) (net.Conn, error)
56+
57+
// Proxy specifies a function to return a proxy for a given
58+
// Request. If the function returns a non-nil error, the
59+
// request is aborted with the provided error.
60+
// If Proxy is nil or returns a nil *URL, no proxy is used.
61+
Proxy func(*http.Request) (*url.URL, error)
62+
63+
// TLSClientConfig specifies the TLS configuration to use with tls.Client.
64+
// If nil, the default configuration is used.
65+
TLSClientConfig *tls.Config
66+
67+
// HandshakeTimeout specifies the duration for the handshake to complete.
68+
HandshakeTimeout time.Duration
69+
70+
// ReadBufferSize and WriteBufferSize specify I/O buffer sizes. If a buffer
71+
// size is zero, then a useful default size is used. The I/O buffer sizes
72+
// do not limit the size of the messages that can be sent or received.
73+
ReadBufferSize, WriteBufferSize int
74+
75+
// Subprotocols specifies the client's requested subprotocols.
76+
Subprotocols []string
77+
78+
// EnableCompression specifies if the client should attempt to negotiate
79+
// per message compression (RFC 7692). Setting this value to true does not
80+
// guarantee that compression will be supported. Currently only "no context
81+
// takeover" modes are supported.
82+
EnableCompression bool
83+
84+
// Jar specifies the cookie jar.
85+
// If Jar is nil, cookies are not sent in requests and ignored
86+
// in responses.
87+
Jar http.CookieJar
88+
}
89+
90+
var errMalformedURL = errors.New("malformed ws or wss URL")
91+
92+
// parseURL parses the URL.
93+
//
94+
// This function is a replacement for the standard library url.Parse function.
95+
// In Go 1.4 and earlier, url.Parse loses information from the path.
96+
func parseURL(s string) (*url.URL, error) {
97+
// From the RFC:
98+
//
99+
// ws-URI = "ws:" "//" host [ ":" port ] path [ "?" query ]
100+
// wss-URI = "wss:" "//" host [ ":" port ] path [ "?" query ]
101+
var u url.URL
102+
switch {
103+
case strings.HasPrefix(s, "ws://"):
104+
u.Scheme = "ws"
105+
s = s[len("ws://"):]
106+
case strings.HasPrefix(s, "wss://"):
107+
u.Scheme = "wss"
108+
s = s[len("wss://"):]
109+
default:
110+
return nil, errMalformedURL
111+
}
112+
113+
if i := strings.Index(s, "?"); i >= 0 {
114+
u.RawQuery = s[i+1:]
115+
s = s[:i]
116+
}
117+
118+
if i := strings.Index(s, "/"); i >= 0 {
119+
u.Opaque = s[i:]
120+
s = s[:i]
121+
} else {
122+
u.Opaque = "/"
123+
}
124+
125+
u.Host = s
126+
127+
if strings.Contains(u.Host, "@") {
128+
// Don't bother parsing user information because user information is
129+
// not allowed in websocket URIs.
130+
return nil, errMalformedURL
131+
}
132+
133+
return &u, nil
134+
}
135+
136+
func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) {
137+
hostPort = u.Host
138+
hostNoPort = u.Host
139+
if i := strings.LastIndex(u.Host, ":"); i > strings.LastIndex(u.Host, "]") {
140+
hostNoPort = hostNoPort[:i]
141+
} else {
142+
switch u.Scheme {
143+
case "wss":
144+
hostPort += ":443"
145+
case "https":
146+
hostPort += ":443"
147+
default:
148+
hostPort += ":80"
149+
}
150+
}
151+
return hostPort, hostNoPort
152+
}
153+
154+
// DefaultDialer is a dialer with all fields set to the default zero values.
155+
var DefaultDialer = &Dialer{
156+
Proxy: http.ProxyFromEnvironment,
157+
}
158+
159+
// Dial creates a new client connection. Use requestHeader to specify the
160+
// origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie).
161+
// Use the response.Header to get the selected subprotocol
162+
// (Sec-WebSocket-Protocol) and cookies (Set-Cookie).
163+
//
164+
// If the WebSocket handshake fails, ErrBadHandshake is returned along with a
165+
// non-nil *http.Response so that callers can handle redirects, authentication,
166+
// etcetera. The response body may not contain the entire response and does not
167+
// need to be closed by the application.
168+
func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
169+
170+
if d == nil {
171+
d = &Dialer{
172+
Proxy: http.ProxyFromEnvironment,
173+
}
174+
}
175+
176+
challengeKey, err := generateChallengeKey()
177+
if err != nil {
178+
return nil, nil, err
179+
}
180+
181+
u, err := parseURL(urlStr)
182+
if err != nil {
183+
return nil, nil, err
184+
}
185+
186+
switch u.Scheme {
187+
case "ws":
188+
u.Scheme = "http"
189+
case "wss":
190+
u.Scheme = "https"
191+
default:
192+
return nil, nil, errMalformedURL
193+
}
194+
195+
if u.User != nil {
196+
// User name and password are not allowed in websocket URIs.
197+
return nil, nil, errMalformedURL
198+
}
199+
200+
req := &http.Request{
201+
Method: "GET",
202+
URL: u,
203+
Proto: "HTTP/1.1",
204+
ProtoMajor: 1,
205+
ProtoMinor: 1,
206+
Header: make(http.Header),
207+
Host: u.Host,
208+
}
209+
210+
// Set the cookies present in the cookie jar of the dialer
211+
if d.Jar != nil {
212+
for _, cookie := range d.Jar.Cookies(u) {
213+
req.AddCookie(cookie)
214+
}
215+
}
216+
217+
// Set the request headers using the capitalization for names and values in
218+
// RFC examples. Although the capitalization shouldn't matter, there are
219+
// servers that depend on it. The Header.Set method is not used because the
220+
// method canonicalizes the header names.
221+
req.Header["Upgrade"] = []string{"websocket"}
222+
req.Header["Connection"] = []string{"Upgrade"}
223+
req.Header["Sec-WebSocket-Key"] = []string{challengeKey}
224+
req.Header["Sec-WebSocket-Version"] = []string{"13"}
225+
if len(d.Subprotocols) > 0 {
226+
req.Header["Sec-WebSocket-Protocol"] = []string{strings.Join(d.Subprotocols, ", ")}
227+
}
228+
for k, vs := range requestHeader {
229+
switch {
230+
case k == "Host":
231+
if len(vs) > 0 {
232+
req.Host = vs[0]
233+
}
234+
case k == "Upgrade" ||
235+
k == "Connection" ||
236+
k == "Sec-Websocket-Key" ||
237+
k == "Sec-Websocket-Version" ||
238+
k == "Sec-Websocket-Extensions" ||
239+
(k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0):
240+
return nil, nil, errors.New("websocket: duplicate header not allowed: " + k)
241+
default:
242+
req.Header[k] = vs
243+
}
244+
}
245+
246+
if d.EnableCompression {
247+
req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover")
248+
}
249+
250+
hostPort, hostNoPort := hostPortNoPort(u)
251+
252+
var proxyURL *url.URL
253+
// Check wether the proxy method has been configured
254+
if d.Proxy != nil {
255+
proxyURL, err = d.Proxy(req)
256+
}
257+
if err != nil {
258+
return nil, nil, err
259+
}
260+
261+
var targetHostPort string
262+
if proxyURL != nil {
263+
targetHostPort, _ = hostPortNoPort(proxyURL)
264+
} else {
265+
targetHostPort = hostPort
266+
}
267+
268+
var deadline time.Time
269+
if d.HandshakeTimeout != 0 {
270+
deadline = time.Now().Add(d.HandshakeTimeout)
271+
}
272+
273+
netDial := d.NetDial
274+
if netDial == nil {
275+
netDialer := &net.Dialer{Deadline: deadline}
276+
netDial = netDialer.Dial
277+
}
278+
279+
netConn, err := netDial("tcp", targetHostPort)
280+
if err != nil {
281+
return nil, nil, err
282+
}
283+
284+
defer func() {
285+
if netConn != nil {
286+
netConn.Close()
287+
}
288+
}()
289+
290+
if err := netConn.SetDeadline(deadline); err != nil {
291+
return nil, nil, err
292+
}
293+
294+
if proxyURL != nil {
295+
connectHeader := make(http.Header)
296+
if user := proxyURL.User; user != nil {
297+
proxyUser := user.Username()
298+
if proxyPassword, passwordSet := user.Password(); passwordSet {
299+
credential := base64.StdEncoding.EncodeToString([]byte(proxyUser + ":" + proxyPassword))
300+
connectHeader.Set("Proxy-Authorization", "Basic "+credential)
301+
}
302+
}
303+
connectReq := &http.Request{
304+
Method: "CONNECT",
305+
URL: &url.URL{Opaque: hostPort},
306+
Host: hostPort,
307+
Header: connectHeader,
308+
}
309+
310+
connectReq.Write(netConn)
311+
312+
// Read response.
313+
// Okay to use and discard buffered reader here, because
314+
// TLS server will not speak until spoken to.
315+
br := bufio.NewReader(netConn)
316+
resp, err := http.ReadResponse(br, connectReq)
317+
if err != nil {
318+
return nil, nil, err
319+
}
320+
if resp.StatusCode != 200 {
321+
f := strings.SplitN(resp.Status, " ", 2)
322+
return nil, nil, errors.New(f[1])
323+
}
324+
}
325+
326+
if u.Scheme == "https" {
327+
cfg := cloneTLSConfig(d.TLSClientConfig)
328+
if cfg.ServerName == "" {
329+
cfg.ServerName = hostNoPort
330+
}
331+
tlsConn := tls.Client(netConn, cfg)
332+
netConn = tlsConn
333+
if err := tlsConn.Handshake(); err != nil {
334+
return nil, nil, err
335+
}
336+
if !cfg.InsecureSkipVerify {
337+
if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
338+
return nil, nil, err
339+
}
340+
}
341+
}
342+
343+
conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize)
344+
345+
if err := req.Write(netConn); err != nil {
346+
return nil, nil, err
347+
}
348+
349+
resp, err := http.ReadResponse(conn.br, req)
350+
if err != nil {
351+
return nil, nil, err
352+
}
353+
354+
if d.Jar != nil {
355+
if rc := resp.Cookies(); len(rc) > 0 {
356+
d.Jar.SetCookies(u, rc)
357+
}
358+
}
359+
360+
if resp.StatusCode != 101 ||
361+
!strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") ||
362+
!strings.EqualFold(resp.Header.Get("Connection"), "upgrade") ||
363+
resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) {
364+
// Before closing the network connection on return from this
365+
// function, slurp up some of the response to aid application
366+
// debugging.
367+
buf := make([]byte, 1024)
368+
n, _ := io.ReadFull(resp.Body, buf)
369+
resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n]))
370+
return nil, resp, ErrBadHandshake
371+
}
372+
373+
for _, ext := range parseExtensions(resp.Header) {
374+
if ext[""] != "permessage-deflate" {
375+
continue
376+
}
377+
_, snct := ext["server_no_context_takeover"]
378+
_, cnct := ext["client_no_context_takeover"]
379+
if !snct || !cnct {
380+
return nil, resp, errInvalidCompression
381+
}
382+
conn.newCompressionWriter = compressNoContextTakeover
383+
conn.newDecompressionReader = decompressNoContextTakeover
384+
break
385+
}
386+
387+
resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{}))
388+
conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol")
389+
390+
netConn.SetDeadline(time.Time{})
391+
netConn = nil // to avoid close in defer.
392+
return conn, resp, nil
393+
}

‎websocket/client_clone.go

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
// +build go1.8
6+
7+
// fork from https://github.com/gorilla/websocket
8+
package websocket
9+
10+
import "crypto/tls"
11+
12+
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
13+
if cfg == nil {
14+
return &tls.Config{}
15+
}
16+
return cfg.Clone()
17+
}

‎websocket/client_clone_legacy.go

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
// +build !go1.8
6+
7+
// fork from https://github.com/gorilla/websocket
8+
package websocket
9+
10+
import "crypto/tls"
11+
12+
// cloneTLSConfig clones all public fields except the fields
13+
// SessionTicketsDisabled and SessionTicketKey. This avoids copying the
14+
// sync.Mutex in the sync.Once and makes it safe to call cloneTLSConfig on a
15+
// config in active use.
16+
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
17+
if cfg == nil {
18+
return &tls.Config{}
19+
}
20+
return &tls.Config{
21+
Rand: cfg.Rand,
22+
Time: cfg.Time,
23+
Certificates: cfg.Certificates,
24+
NameToCertificate: cfg.NameToCertificate,
25+
GetCertificate: cfg.GetCertificate,
26+
RootCAs: cfg.RootCAs,
27+
NextProtos: cfg.NextProtos,
28+
ServerName: cfg.ServerName,
29+
ClientAuth: cfg.ClientAuth,
30+
ClientCAs: cfg.ClientCAs,
31+
InsecureSkipVerify: cfg.InsecureSkipVerify,
32+
CipherSuites: cfg.CipherSuites,
33+
PreferServerCipherSuites: cfg.PreferServerCipherSuites,
34+
ClientSessionCache: cfg.ClientSessionCache,
35+
MinVersion: cfg.MinVersion,
36+
MaxVersion: cfg.MaxVersion,
37+
CurvePreferences: cfg.CurvePreferences,
38+
}
39+
}

‎websocket/client_server_test.go

+513
Large diffs are not rendered by default.

‎websocket/client_test.go

+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
// Copyright 2014 The Gorilla WebSocket Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
// fork from https://github.com/gorilla/websocket
6+
package websocket
7+
8+
import (
9+
"net/url"
10+
"reflect"
11+
"testing"
12+
)
13+
14+
var parseURLTests = []struct {
15+
s string
16+
u *url.URL
17+
rui string
18+
}{
19+
{"ws://example.com/", &url.URL{Scheme: "ws", Host: "example.com", Opaque: "/"}, "/"},
20+
{"ws://example.com", &url.URL{Scheme: "ws", Host: "example.com", Opaque: "/"}, "/"},
21+
{"ws://example.com:7777/", &url.URL{Scheme: "ws", Host: "example.com:7777", Opaque: "/"}, "/"},
22+
{"wss://example.com/", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/"}, "/"},
23+
{"wss://example.com/a/b", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/a/b"}, "/a/b"},
24+
{"ss://example.com/a/b", nil, ""},
25+
{"ws://webmaster@example.com/", nil, ""},
26+
{"wss://example.com/a/b?x=y", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/a/b", RawQuery: "x=y"}, "/a/b?x=y"},
27+
{"wss://example.com?x=y", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/", RawQuery: "x=y"}, "/?x=y"},
28+
}
29+
30+
func TestParseURL(t *testing.T) {
31+
for _, tt := range parseURLTests {
32+
u, err := parseURL(tt.s)
33+
if tt.u != nil && err != nil {
34+
t.Errorf("parseURL(%q) returned error %v", tt.s, err)
35+
continue
36+
}
37+
if tt.u == nil {
38+
if err == nil {
39+
t.Errorf("parseURL(%q) did not return error", tt.s)
40+
}
41+
continue
42+
}
43+
if !reflect.DeepEqual(u, tt.u) {
44+
t.Errorf("parseURL(%q) = %v, want %v", tt.s, u, tt.u)
45+
continue
46+
}
47+
if u.RequestURI() != tt.rui {
48+
t.Errorf("parseURL(%q).RequestURI() = %v, want %v", tt.s, u.RequestURI(), tt.rui)
49+
}
50+
}
51+
}
52+
53+
var hostPortNoPortTests = []struct {
54+
u *url.URL
55+
hostPort, hostNoPort string
56+
}{
57+
{&url.URL{Scheme: "ws", Host: "example.com"}, "example.com:80", "example.com"},
58+
{&url.URL{Scheme: "wss", Host: "example.com"}, "example.com:443", "example.com"},
59+
{&url.URL{Scheme: "ws", Host: "example.com:7777"}, "example.com:7777", "example.com"},
60+
{&url.URL{Scheme: "wss", Host: "example.com:7777"}, "example.com:7777", "example.com"},
61+
}
62+
63+
func TestHostPortNoPort(t *testing.T) {
64+
for _, tt := range hostPortNoPortTests {
65+
hostPort, hostNoPort := hostPortNoPort(tt.u)
66+
if hostPort != tt.hostPort {
67+
t.Errorf("hostPortNoPort(%v) returned hostPort %q, want %q", tt.u, hostPort, tt.hostPort)
68+
}
69+
if hostNoPort != tt.hostNoPort {
70+
t.Errorf("hostPortNoPort(%v) returned hostNoPort %q, want %q", tt.u, hostNoPort, tt.hostNoPort)
71+
}
72+
}
73+
}

‎websocket/compression.go

+149
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
// fork from https://github.com/gorilla/websocket
6+
package websocket
7+
8+
import (
9+
"compress/flate"
10+
"errors"
11+
"io"
12+
"strings"
13+
"sync"
14+
)
15+
16+
const (
17+
minCompressionLevel = -2 // flate.HuffmanOnly not defined in Go < 1.6
18+
maxCompressionLevel = flate.BestCompression
19+
defaultCompressionLevel = 1
20+
)
21+
22+
var (
23+
flateWriterPools [maxCompressionLevel - minCompressionLevel + 1]sync.Pool
24+
flateReaderPool = sync.Pool{New: func() interface{} {
25+
return flate.NewReader(nil)
26+
}}
27+
)
28+
29+
func decompressNoContextTakeover(r io.Reader) io.ReadCloser {
30+
const tail =
31+
// Add four bytes as specified in RFC
32+
"\x00\x00\xff\xff" +
33+
// Add final block to squelch unexpected EOF error from flate reader.
34+
"\x01\x00\x00\xff\xff"
35+
36+
fr, _ := flateReaderPool.Get().(io.ReadCloser)
37+
fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil)
38+
return &flateReadWrapper{fr}
39+
}
40+
41+
func isValidCompressionLevel(level int) bool {
42+
return minCompressionLevel <= level && level <= maxCompressionLevel
43+
}
44+
45+
func compressNoContextTakeover(w io.WriteCloser, level int) io.WriteCloser {
46+
p := &flateWriterPools[level-minCompressionLevel]
47+
tw := &truncWriter{w: w}
48+
fw, _ := p.Get().(*flate.Writer)
49+
if fw == nil {
50+
fw, _ = flate.NewWriter(tw, level)
51+
} else {
52+
fw.Reset(tw)
53+
}
54+
return &flateWriteWrapper{fw: fw, tw: tw, p: p}
55+
}
56+
57+
// truncWriter is an io.Writer that writes all but the last four bytes of the
58+
// stream to another io.Writer.
59+
type truncWriter struct {
60+
w io.WriteCloser
61+
n int
62+
p [4]byte
63+
}
64+
65+
func (w *truncWriter) Write(p []byte) (int, error) {
66+
n := 0
67+
68+
// fill buffer first for simplicity.
69+
if w.n < len(w.p) {
70+
n = copy(w.p[w.n:], p)
71+
p = p[n:]
72+
w.n += n
73+
if len(p) == 0 {
74+
return n, nil
75+
}
76+
}
77+
78+
m := len(p)
79+
if m > len(w.p) {
80+
m = len(w.p)
81+
}
82+
83+
if nn, err := w.w.Write(w.p[:m]); err != nil {
84+
return n + nn, err
85+
}
86+
87+
copy(w.p[:], w.p[m:])
88+
copy(w.p[len(w.p)-m:], p[len(p)-m:])
89+
nn, err := w.w.Write(p[:len(p)-m])
90+
return n + nn, err
91+
}
92+
93+
type flateWriteWrapper struct {
94+
fw *flate.Writer
95+
tw *truncWriter
96+
p *sync.Pool
97+
}
98+
99+
func (w *flateWriteWrapper) Write(p []byte) (int, error) {
100+
if w.fw == nil {
101+
return 0, errWriteClosed
102+
}
103+
return w.fw.Write(p)
104+
}
105+
106+
func (w *flateWriteWrapper) Close() error {
107+
if w.fw == nil {
108+
return errWriteClosed
109+
}
110+
err1 := w.fw.Flush()
111+
w.p.Put(w.fw)
112+
w.fw = nil
113+
if w.tw.p != [4]byte{0, 0, 0xff, 0xff} {
114+
return errors.New("websocket: internal error, unexpected bytes at end of flate stream")
115+
}
116+
err2 := w.tw.w.Close()
117+
if err1 != nil {
118+
return err1
119+
}
120+
return err2
121+
}
122+
123+
type flateReadWrapper struct {
124+
fr io.ReadCloser
125+
}
126+
127+
func (r *flateReadWrapper) Read(p []byte) (int, error) {
128+
if r.fr == nil {
129+
return 0, io.ErrClosedPipe
130+
}
131+
n, err := r.fr.Read(p)
132+
if err == io.EOF {
133+
// Preemptively place the reader back in the pool. This helps with
134+
// scenarios where the application does not call NextReader() soon after
135+
// this final read.
136+
r.Close()
137+
}
138+
return n, err
139+
}
140+
141+
func (r *flateReadWrapper) Close() error {
142+
if r.fr == nil {
143+
return io.ErrClosedPipe
144+
}
145+
err := r.fr.Close()
146+
flateReaderPool.Put(r.fr)
147+
r.fr = nil
148+
return err
149+
}

‎websocket/compression_test.go

+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
// fork from https://github.com/gorilla/websocket
2+
package websocket
3+
4+
import (
5+
"bytes"
6+
"fmt"
7+
"io"
8+
"io/ioutil"
9+
"testing"
10+
)
11+
12+
type nopCloser struct{ io.Writer }
13+
14+
func (nopCloser) Close() error { return nil }
15+
16+
func TestTruncWriter(t *testing.T) {
17+
const data = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijlkmnopqrstuvwxyz987654321"
18+
for n := 1; n <= 10; n++ {
19+
var b bytes.Buffer
20+
w := &truncWriter{w: nopCloser{&b}}
21+
p := []byte(data)
22+
for len(p) > 0 {
23+
m := len(p)
24+
if m > n {
25+
m = n
26+
}
27+
w.Write(p[:m])
28+
p = p[m:]
29+
}
30+
if b.String() != data[:len(data)-len(w.p)] {
31+
t.Errorf("%d: %q", n, b.String())
32+
}
33+
}
34+
}
35+
36+
func textMessages(num int) [][]byte {
37+
messages := make([][]byte, num)
38+
for i := 0; i < num; i++ {
39+
msg := fmt.Sprintf("planet: %d, country: %d, city: %d, street: %d", i, i, i, i)
40+
messages[i] = []byte(msg)
41+
}
42+
return messages
43+
}
44+
45+
func BenchmarkWriteNoCompression(b *testing.B) {
46+
w := ioutil.Discard
47+
c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024)
48+
messages := textMessages(100)
49+
b.ResetTimer()
50+
for i := 0; i < b.N; i++ {
51+
c.WriteMessage(TextMessage, messages[i%len(messages)])
52+
}
53+
b.ReportAllocs()
54+
}
55+
56+
func BenchmarkWriteWithCompression(b *testing.B) {
57+
w := ioutil.Discard
58+
c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024)
59+
messages := textMessages(100)
60+
c.enableWriteCompression = true
61+
c.newCompressionWriter = compressNoContextTakeover
62+
b.ResetTimer()
63+
for i := 0; i < b.N; i++ {
64+
c.WriteMessage(TextMessage, messages[i%len(messages)])
65+
}
66+
b.ReportAllocs()
67+
}
68+
69+
func TestValidCompressionLevel(t *testing.T) {
70+
c := newConn(fakeNetConn{}, false, 1024, 1024)
71+
for _, level := range []int{minCompressionLevel - 1, maxCompressionLevel + 1} {
72+
if err := c.SetCompressionLevel(level); err == nil {
73+
t.Errorf("no error for level %d", level)
74+
}
75+
}
76+
for _, level := range []int{minCompressionLevel, maxCompressionLevel} {
77+
if err := c.SetCompressionLevel(level); err != nil {
78+
t.Errorf("error for level %d", level)
79+
}
80+
}
81+
}

‎websocket/conn.go

+1,150
Large diffs are not rendered by default.

‎websocket/conn_broadcast_test.go

+135
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
// +build go1.7
6+
7+
// fork from https://github.com/gorilla/websocket
8+
package websocket
9+
10+
import (
11+
"io"
12+
"io/ioutil"
13+
"sync/atomic"
14+
"testing"
15+
)
16+
17+
// broadcastBench allows to run broadcast benchmarks.
18+
// In every broadcast benchmark we create many connections, then send the same
19+
// message into every connection and wait for all writes complete. This emulates
20+
// an application where many connections listen to the same data - i.e. PUB/SUB
21+
// scenarios with many subscribers in one channel.
22+
type broadcastBench struct {
23+
w io.Writer
24+
message *broadcastMessage
25+
closeCh chan struct{}
26+
doneCh chan struct{}
27+
count int32
28+
conns []*broadcastConn
29+
compression bool
30+
usePrepared bool
31+
}
32+
33+
type broadcastMessage struct {
34+
payload []byte
35+
prepared *PreparedMessage
36+
}
37+
38+
type broadcastConn struct {
39+
conn *Conn
40+
msgCh chan *broadcastMessage
41+
}
42+
43+
func newBroadcastConn(c *Conn) *broadcastConn {
44+
return &broadcastConn{
45+
conn: c,
46+
msgCh: make(chan *broadcastMessage, 1),
47+
}
48+
}
49+
50+
func newBroadcastBench(usePrepared, compression bool) *broadcastBench {
51+
bench := &broadcastBench{
52+
w: ioutil.Discard,
53+
doneCh: make(chan struct{}),
54+
closeCh: make(chan struct{}),
55+
usePrepared: usePrepared,
56+
compression: compression,
57+
}
58+
msg := &broadcastMessage{
59+
payload: textMessages(1)[0],
60+
}
61+
if usePrepared {
62+
pm, _ := NewPreparedMessage(TextMessage, msg.payload)
63+
msg.prepared = pm
64+
}
65+
bench.message = msg
66+
bench.makeConns(10000)
67+
return bench
68+
}
69+
70+
func (b *broadcastBench) makeConns(numConns int) {
71+
conns := make([]*broadcastConn, numConns)
72+
73+
for i := 0; i < numConns; i++ {
74+
c := newConn(fakeNetConn{Reader: nil, Writer: b.w}, true, 1024, 1024)
75+
if b.compression {
76+
c.enableWriteCompression = true
77+
c.newCompressionWriter = compressNoContextTakeover
78+
}
79+
conns[i] = newBroadcastConn(c)
80+
go func(c *broadcastConn) {
81+
for {
82+
select {
83+
case msg := <-c.msgCh:
84+
if b.usePrepared {
85+
c.conn.WritePreparedMessage(msg.prepared)
86+
} else {
87+
c.conn.WriteMessage(TextMessage, msg.payload)
88+
}
89+
val := atomic.AddInt32(&b.count, 1)
90+
if val%int32(numConns) == 0 {
91+
b.doneCh <- struct{}{}
92+
}
93+
case <-b.closeCh:
94+
return
95+
}
96+
}
97+
}(conns[i])
98+
}
99+
b.conns = conns
100+
}
101+
102+
func (b *broadcastBench) close() {
103+
close(b.closeCh)
104+
}
105+
106+
func (b *broadcastBench) runOnce() {
107+
for _, c := range b.conns {
108+
c.msgCh <- b.message
109+
}
110+
<-b.doneCh
111+
}
112+
113+
func BenchmarkBroadcast(b *testing.B) {
114+
benchmarks := []struct {
115+
name string
116+
usePrepared bool
117+
compression bool
118+
}{
119+
{"NoCompression", false, false},
120+
{"WithCompression", false, true},
121+
{"NoCompressionPrepared", true, false},
122+
{"WithCompressionPrepared", true, true},
123+
}
124+
for _, bm := range benchmarks {
125+
b.Run(bm.name, func(b *testing.B) {
126+
bench := newBroadcastBench(bm.usePrepared, bm.compression)
127+
defer bench.close()
128+
b.ResetTimer()
129+
for i := 0; i < b.N; i++ {
130+
bench.runOnce()
131+
}
132+
b.ReportAllocs()
133+
})
134+
}
135+
}

‎websocket/conn_read.go

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
// +build go1.5
6+
7+
// fork from https://github.com/gorilla/websocket
8+
package websocket
9+
10+
import "io"
11+
12+
func (c *Conn) read(n int) ([]byte, error) {
13+
p, err := c.br.Peek(n)
14+
if err == io.EOF {
15+
err = errUnexpectedEOF
16+
}
17+
c.br.Discard(len(p))
18+
return p, err
19+
}

‎websocket/conn_read_legacy.go

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
// +build !go1.5
6+
7+
// fork from https://github.com/gorilla/websocket
8+
package websocket
9+
10+
import "io"
11+
12+
func (c *Conn) read(n int) ([]byte, error) {
13+
p, err := c.br.Peek(n)
14+
if err == io.EOF {
15+
err = errUnexpectedEOF
16+
}
17+
if len(p) > 0 {
18+
// advance over the bytes just read
19+
io.ReadFull(c.br, p)
20+
}
21+
return p, err
22+
}

‎websocket/conn_test.go

+498
Large diffs are not rendered by default.

‎websocket/doc.go

+180
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
// Package websocket implements the WebSocket protocol defined in RFC 6455.
6+
//
7+
// Overview
8+
//
9+
// The Conn type represents a WebSocket connection. A server application calls
10+
// the Upgrader.Upgrade method from an HTTP request handler to get a *Conn:
11+
//
12+
// var upgrader = websocket.Upgrader{
13+
// ReadBufferSize: 1024,
14+
// WriteBufferSize: 1024,
15+
// }
16+
//
17+
// func handler(w http.ResponseWriter, r *http.Request) {
18+
// conn, err := upgrader.Upgrade(w, r, nil)
19+
// if err != nil {
20+
// log.Println(err)
21+
// return
22+
// }
23+
// ... Use conn to send and receive messages.
24+
// }
25+
//
26+
// Call the connection's WriteMessage and ReadMessage methods to send and
27+
// receive messages as a slice of bytes. This snippet of code shows how to echo
28+
// messages using these methods:
29+
//
30+
// for {
31+
// messageType, p, err := conn.ReadMessage()
32+
// if err != nil {
33+
// return
34+
// }
35+
// if err := conn.WriteMessage(messageType, p); err != nil {
36+
// return err
37+
// }
38+
// }
39+
//
40+
// In above snippet of code, p is a []byte and messageType is an int with value
41+
// websocket.BinaryMessage or websocket.TextMessage.
42+
//
43+
// An application can also send and receive messages using the io.WriteCloser
44+
// and io.Reader interfaces. To send a message, call the connection NextWriter
45+
// method to get an io.WriteCloser, write the message to the writer and close
46+
// the writer when done. To receive a message, call the connection NextReader
47+
// method to get an io.Reader and read until io.EOF is returned. This snippet
48+
// shows how to echo messages using the NextWriter and NextReader methods:
49+
//
50+
// for {
51+
// messageType, r, err := conn.NextReader()
52+
// if err != nil {
53+
// return
54+
// }
55+
// w, err := conn.NextWriter(messageType)
56+
// if err != nil {
57+
// return err
58+
// }
59+
// if _, err := io.Copy(w, r); err != nil {
60+
// return err
61+
// }
62+
// if err := w.Close(); err != nil {
63+
// return err
64+
// }
65+
// }
66+
//
67+
// Data Messages
68+
//
69+
// The WebSocket protocol distinguishes between text and binary data messages.
70+
// Text messages are interpreted as UTF-8 encoded text. The interpretation of
71+
// binary messages is left to the application.
72+
//
73+
// This package uses the TextMessage and BinaryMessage integer constants to
74+
// identify the two data message types. The ReadMessage and NextReader methods
75+
// return the type of the received message. The messageType argument to the
76+
// WriteMessage and NextWriter methods specifies the type of a sent message.
77+
//
78+
// It is the application's responsibility to ensure that text messages are
79+
// valid UTF-8 encoded text.
80+
//
81+
// Control Messages
82+
//
83+
// The WebSocket protocol defines three types of control messages: close, ping
84+
// and pong. Call the connection WriteControl, WriteMessage or NextWriter
85+
// methods to send a control message to the peer.
86+
//
87+
// Connections handle received close messages by sending a close message to the
88+
// peer and returning a *CloseError from the the NextReader, ReadMessage or the
89+
// message Read method.
90+
//
91+
// Connections handle received ping and pong messages by invoking callback
92+
// functions set with SetPingHandler and SetPongHandler methods. The callback
93+
// functions are called from the NextReader, ReadMessage and the message Read
94+
// methods.
95+
//
96+
// The default ping handler sends a pong to the peer. The application's reading
97+
// goroutine can block for a short time while the handler writes the pong data
98+
// to the connection.
99+
//
100+
// The application must read the connection to process ping, pong and close
101+
// messages sent from the peer. If the application is not otherwise interested
102+
// in messages from the peer, then the application should start a goroutine to
103+
// read and discard messages from the peer. A simple example is:
104+
//
105+
// func readLoop(c *websocket.Conn) {
106+
// for {
107+
// if _, _, err := c.NextReader(); err != nil {
108+
// c.Close()
109+
// break
110+
// }
111+
// }
112+
// }
113+
//
114+
// Concurrency
115+
//
116+
// Connections support one concurrent reader and one concurrent writer.
117+
//
118+
// Applications are responsible for ensuring that no more than one goroutine
119+
// calls the write methods (NextWriter, SetWriteDeadline, WriteMessage,
120+
// WriteJSON, EnableWriteCompression, SetCompressionLevel) concurrently and
121+
// that no more than one goroutine calls the read methods (NextReader,
122+
// SetReadDeadline, ReadMessage, ReadJSON, SetPongHandler, SetPingHandler)
123+
// concurrently.
124+
//
125+
// The Close and WriteControl methods can be called concurrently with all other
126+
// methods.
127+
//
128+
// Origin Considerations
129+
//
130+
// Web browsers allow Javascript applications to open a WebSocket connection to
131+
// any host. It's up to the server to enforce an origin policy using the Origin
132+
// request header sent by the browser.
133+
//
134+
// The Upgrader calls the function specified in the CheckOrigin field to check
135+
// the origin. If the CheckOrigin function returns false, then the Upgrade
136+
// method fails the WebSocket handshake with HTTP status 403.
137+
//
138+
// If the CheckOrigin field is nil, then the Upgrader uses a safe default: fail
139+
// the handshake if the Origin request header is present and not equal to the
140+
// Host request header.
141+
//
142+
// An application can allow connections from any origin by specifying a
143+
// function that always returns true:
144+
//
145+
// var upgrader = websocket.Upgrader{
146+
// CheckOrigin: func(r *http.Request) bool { return true },
147+
// }
148+
//
149+
// The deprecated package-level Upgrade function does not perform origin
150+
// checking. The application is responsible for checking the Origin header
151+
// before calling the Upgrade function.
152+
//
153+
// Compression EXPERIMENTAL
154+
//
155+
// Per message compression extensions (RFC 7692) are experimentally supported
156+
// by this package in a limited capacity. Setting the EnableCompression option
157+
// to true in Dialer or Upgrader will attempt to negotiate per message deflate
158+
// support.
159+
//
160+
// var upgrader = websocket.Upgrader{
161+
// EnableCompression: true,
162+
// }
163+
//
164+
// If compression was successfully negotiated with the connection's peer, any
165+
// message received in compressed form will be automatically decompressed.
166+
// All Read methods will return uncompressed bytes.
167+
//
168+
// Per message compression of messages written to a connection can be enabled
169+
// or disabled by calling the corresponding Conn method:
170+
//
171+
// conn.EnableWriteCompression(false)
172+
//
173+
// Currently this package does not support compression with "context takeover".
174+
// This means that messages must be compressed and decompressed in isolation,
175+
// without retaining sliding window or dictionary state across messages. For
176+
// more details refer to RFC 7692.
177+
//
178+
// Use of compression is experimental and may result in decreased performance.
179+
// fork from https://github.com/gorilla/websocket
180+
package websocket

‎websocket/example_test.go

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// Copyright 2015 The Gorilla WebSocket Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
// fork from https://github.com/gorilla/websocket
6+
package websocket_test
7+
8+
import (
9+
"log"
10+
"net/http"
11+
"testing"
12+
13+
"github.com/ossrs/go-oryx-lib/websocket"
14+
)
15+
16+
var (
17+
c *websocket.Conn
18+
req *http.Request
19+
)
20+
21+
// The websocket.IsUnexpectedCloseError function is useful for identifying
22+
// application and protocol errors.
23+
//
24+
// This server application works with a client application running in the
25+
// browser. The client application does not explicitly close the websocket. The
26+
// only expected close message from the client has the code
27+
// websocket.CloseGoingAway. All other other close messages are likely the
28+
// result of an application or protocol error and are logged to aid debugging.
29+
func ExampleIsUnexpectedCloseError() {
30+
31+
for {
32+
messageType, p, err := c.ReadMessage()
33+
if err != nil {
34+
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) {
35+
log.Printf("error: %v, user-agent: %v", err, req.Header.Get("User-Agent"))
36+
}
37+
return
38+
}
39+
processMesage(messageType, p)
40+
}
41+
}
42+
43+
func processMesage(mt int, p []byte) {}
44+
45+
// TestX prevents godoc from showing this entire file in the example. Remove
46+
// this function when a second example is added.
47+
func TestX(t *testing.T) {}

‎websocket/json.go

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
// fork from https://github.com/gorilla/websocket
6+
package websocket
7+
8+
import (
9+
"encoding/json"
10+
"io"
11+
)
12+
13+
// WriteJSON writes the JSON encoding of v as a message.
14+
//
15+
// Deprecated: Use c.WriteJSON instead.
16+
func WriteJSON(c *Conn, v interface{}) error {
17+
return c.WriteJSON(v)
18+
}
19+
20+
// WriteJSON writes the JSON encoding of v as a message.
21+
//
22+
// See the documentation for encoding/json Marshal for details about the
23+
// conversion of Go values to JSON.
24+
func (c *Conn) WriteJSON(v interface{}) error {
25+
w, err := c.NextWriter(TextMessage)
26+
if err != nil {
27+
return err
28+
}
29+
err1 := json.NewEncoder(w).Encode(v)
30+
err2 := w.Close()
31+
if err1 != nil {
32+
return err1
33+
}
34+
return err2
35+
}
36+
37+
// ReadJSON reads the next JSON-encoded message from the connection and stores
38+
// it in the value pointed to by v.
39+
//
40+
// Deprecated: Use c.ReadJSON instead.
41+
func ReadJSON(c *Conn, v interface{}) error {
42+
return c.ReadJSON(v)
43+
}
44+
45+
// ReadJSON reads the next JSON-encoded message from the connection and stores
46+
// it in the value pointed to by v.
47+
//
48+
// See the documentation for the encoding/json Unmarshal function for details
49+
// about the conversion of JSON to a Go value.
50+
func (c *Conn) ReadJSON(v interface{}) error {
51+
_, r, err := c.NextReader()
52+
if err != nil {
53+
return err
54+
}
55+
err = json.NewDecoder(r).Decode(v)
56+
if err == io.EOF {
57+
// One value is expected in the message.
58+
err = io.ErrUnexpectedEOF
59+
}
60+
return err
61+
}

‎websocket/json_test.go

+120
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
// fork from https://github.com/gorilla/websocket
6+
package websocket
7+
8+
import (
9+
"bytes"
10+
"encoding/json"
11+
"io"
12+
"reflect"
13+
"testing"
14+
)
15+
16+
func TestJSON(t *testing.T) {
17+
var buf bytes.Buffer
18+
c := fakeNetConn{&buf, &buf}
19+
wc := newConn(c, true, 1024, 1024)
20+
rc := newConn(c, false, 1024, 1024)
21+
22+
var actual, expect struct {
23+
A int
24+
B string
25+
}
26+
expect.A = 1
27+
expect.B = "hello"
28+
29+
if err := wc.WriteJSON(&expect); err != nil {
30+
t.Fatal("write", err)
31+
}
32+
33+
if err := rc.ReadJSON(&actual); err != nil {
34+
t.Fatal("read", err)
35+
}
36+
37+
if !reflect.DeepEqual(&actual, &expect) {
38+
t.Fatal("equal", actual, expect)
39+
}
40+
}
41+
42+
func TestPartialJSONRead(t *testing.T) {
43+
var buf bytes.Buffer
44+
c := fakeNetConn{&buf, &buf}
45+
wc := newConn(c, true, 1024, 1024)
46+
rc := newConn(c, false, 1024, 1024)
47+
48+
var v struct {
49+
A int
50+
B string
51+
}
52+
v.A = 1
53+
v.B = "hello"
54+
55+
messageCount := 0
56+
57+
// Partial JSON values.
58+
59+
data, err := json.Marshal(v)
60+
if err != nil {
61+
t.Fatal(err)
62+
}
63+
for i := len(data) - 1; i >= 0; i-- {
64+
if err := wc.WriteMessage(TextMessage, data[:i]); err != nil {
65+
t.Fatal(err)
66+
}
67+
messageCount++
68+
}
69+
70+
// Whitespace.
71+
72+
if err := wc.WriteMessage(TextMessage, []byte(" ")); err != nil {
73+
t.Fatal(err)
74+
}
75+
messageCount++
76+
77+
// Close.
78+
79+
if err := wc.WriteMessage(CloseMessage, FormatCloseMessage(CloseNormalClosure, "")); err != nil {
80+
t.Fatal(err)
81+
}
82+
83+
for i := 0; i < messageCount; i++ {
84+
err := rc.ReadJSON(&v)
85+
if err != io.ErrUnexpectedEOF {
86+
t.Error("read", i, err)
87+
}
88+
}
89+
90+
err = rc.ReadJSON(&v)
91+
if _, ok := err.(*CloseError); !ok {
92+
t.Error("final", err)
93+
}
94+
}
95+
96+
func TestDeprecatedJSON(t *testing.T) {
97+
var buf bytes.Buffer
98+
c := fakeNetConn{&buf, &buf}
99+
wc := newConn(c, true, 1024, 1024)
100+
rc := newConn(c, false, 1024, 1024)
101+
102+
var actual, expect struct {
103+
A int
104+
B string
105+
}
106+
expect.A = 1
107+
expect.B = "hello"
108+
109+
if err := WriteJSON(wc, &expect); err != nil {
110+
t.Fatal("write", err)
111+
}
112+
113+
if err := ReadJSON(rc, &actual); err != nil {
114+
t.Fatal("read", err)
115+
}
116+
117+
if !reflect.DeepEqual(&actual, &expect) {
118+
t.Fatal("equal", actual, expect)
119+
}
120+
}

‎websocket/mask.go

+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. Use of
2+
// this source code is governed by a BSD-style license that can be found in the
3+
// LICENSE file.
4+
5+
// +build !appengine
6+
7+
// fork from https://github.com/gorilla/websocket
8+
package websocket
9+
10+
import "unsafe"
11+
12+
const wordSize = int(unsafe.Sizeof(uintptr(0)))
13+
14+
func maskBytes(key [4]byte, pos int, b []byte) int {
15+
16+
// Mask one byte at a time for small buffers.
17+
if len(b) < 2*wordSize {
18+
for i := range b {
19+
b[i] ^= key[pos&3]
20+
pos++
21+
}
22+
return pos & 3
23+
}
24+
25+
// Mask one byte at a time to word boundary.
26+
if n := int(uintptr(unsafe.Pointer(&b[0]))) % wordSize; n != 0 {
27+
n = wordSize - n
28+
for i := range b[:n] {
29+
b[i] ^= key[pos&3]
30+
pos++
31+
}
32+
b = b[n:]
33+
}
34+
35+
// Create aligned word size key.
36+
var k [wordSize]byte
37+
for i := range k {
38+
k[i] = key[(pos+i)&3]
39+
}
40+
kw := *(*uintptr)(unsafe.Pointer(&k))
41+
42+
// Mask one word at a time.
43+
n := (len(b) / wordSize) * wordSize
44+
for i := 0; i < n; i += wordSize {
45+
*(*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(&b[0])) + uintptr(i))) ^= kw
46+
}
47+
48+
// Mask one byte at a time for remaining bytes.
49+
b = b[n:]
50+
for i := range b {
51+
b[i] ^= key[pos&3]
52+
pos++
53+
}
54+
55+
return pos & 3
56+
}

‎websocket/mask_safe.go

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. Use of
2+
// this source code is governed by a BSD-style license that can be found in the
3+
// LICENSE file.
4+
5+
// +build appengine
6+
7+
// fork from https://github.com/gorilla/websocket
8+
package websocket
9+
10+
func maskBytes(key [4]byte, pos int, b []byte) int {
11+
for i := range b {
12+
b[i] ^= key[pos&3]
13+
pos++
14+
}
15+
return pos & 3
16+
}

‎websocket/mask_test.go

+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. Use of
2+
// this source code is governed by a BSD-style license that can be found in the
3+
// LICENSE file.
4+
5+
// Require 1.7 for sub-bencmarks
6+
// +build go1.7,!appengine
7+
8+
// fork from https://github.com/gorilla/websocket
9+
package websocket
10+
11+
import (
12+
"fmt"
13+
"testing"
14+
)
15+
16+
func maskBytesByByte(key [4]byte, pos int, b []byte) int {
17+
for i := range b {
18+
b[i] ^= key[pos&3]
19+
pos++
20+
}
21+
return pos & 3
22+
}
23+
24+
func notzero(b []byte) int {
25+
for i := range b {
26+
if b[i] != 0 {
27+
return i
28+
}
29+
}
30+
return -1
31+
}
32+
33+
func TestMaskBytes(t *testing.T) {
34+
key := [4]byte{1, 2, 3, 4}
35+
for size := 1; size <= 1024; size++ {
36+
for align := 0; align < wordSize; align++ {
37+
for pos := 0; pos < 4; pos++ {
38+
b := make([]byte, size+align)[align:]
39+
maskBytes(key, pos, b)
40+
maskBytesByByte(key, pos, b)
41+
if i := notzero(b); i >= 0 {
42+
t.Errorf("size:%d, align:%d, pos:%d, offset:%d", size, align, pos, i)
43+
}
44+
}
45+
}
46+
}
47+
}
48+
49+
func BenchmarkMaskBytes(b *testing.B) {
50+
for _, size := range []int{2, 4, 8, 16, 32, 512, 1024} {
51+
b.Run(fmt.Sprintf("size-%d", size), func(b *testing.B) {
52+
for _, align := range []int{wordSize / 2} {
53+
b.Run(fmt.Sprintf("align-%d", align), func(b *testing.B) {
54+
for _, fn := range []struct {
55+
name string
56+
fn func(key [4]byte, pos int, b []byte) int
57+
}{
58+
{"byte", maskBytesByByte},
59+
{"word", maskBytes},
60+
} {
61+
b.Run(fn.name, func(b *testing.B) {
62+
key := newMaskKey()
63+
data := make([]byte, size+align)[align:]
64+
for i := 0; i < b.N; i++ {
65+
fn.fn(key, 0, data)
66+
}
67+
b.SetBytes(int64(len(data)))
68+
})
69+
}
70+
})
71+
}
72+
})
73+
}
74+
}

‎websocket/prepared.go

+104
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
// fork from https://github.com/gorilla/websocket
6+
package websocket
7+
8+
import (
9+
"bytes"
10+
"net"
11+
"sync"
12+
"time"
13+
)
14+
15+
// PreparedMessage caches on the wire representations of a message payload.
16+
// Use PreparedMessage to efficiently send a message payload to multiple
17+
// connections. PreparedMessage is especially useful when compression is used
18+
// because the CPU and memory expensive compression operation can be executed
19+
// once for a given set of compression options.
20+
type PreparedMessage struct {
21+
messageType int
22+
data []byte
23+
err error
24+
mu sync.Mutex
25+
frames map[prepareKey]*preparedFrame
26+
}
27+
28+
// prepareKey defines a unique set of options to cache prepared frames in PreparedMessage.
29+
type prepareKey struct {
30+
isServer bool
31+
compress bool
32+
compressionLevel int
33+
}
34+
35+
// preparedFrame contains data in wire representation.
36+
type preparedFrame struct {
37+
once sync.Once
38+
data []byte
39+
}
40+
41+
// NewPreparedMessage returns an initialized PreparedMessage. You can then send
42+
// it to connection using WritePreparedMessage method. Valid wire
43+
// representation will be calculated lazily only once for a set of current
44+
// connection options.
45+
func NewPreparedMessage(messageType int, data []byte) (*PreparedMessage, error) {
46+
pm := &PreparedMessage{
47+
messageType: messageType,
48+
frames: make(map[prepareKey]*preparedFrame),
49+
data: data,
50+
}
51+
52+
// Prepare a plain server frame.
53+
_, frameData, err := pm.frame(prepareKey{isServer: true, compress: false})
54+
if err != nil {
55+
return nil, err
56+
}
57+
58+
// To protect against caller modifying the data argument, remember the data
59+
// copied to the plain server frame.
60+
pm.data = frameData[len(frameData)-len(data):]
61+
return pm, nil
62+
}
63+
64+
func (pm *PreparedMessage) frame(key prepareKey) (int, []byte, error) {
65+
pm.mu.Lock()
66+
frame, ok := pm.frames[key]
67+
if !ok {
68+
frame = &preparedFrame{}
69+
pm.frames[key] = frame
70+
}
71+
pm.mu.Unlock()
72+
73+
var err error
74+
frame.once.Do(func() {
75+
// Prepare a frame using a 'fake' connection.
76+
// TODO: Refactor code in conn.go to allow more direct construction of
77+
// the frame.
78+
mu := make(chan bool, 1)
79+
mu <- true
80+
var nc prepareConn
81+
c := &Conn{
82+
conn: &nc,
83+
mu: mu,
84+
isServer: key.isServer,
85+
compressionLevel: key.compressionLevel,
86+
enableWriteCompression: true,
87+
writeBuf: make([]byte, defaultWriteBufferSize+maxFrameHeaderSize),
88+
}
89+
if key.compress {
90+
c.newCompressionWriter = compressNoContextTakeover
91+
}
92+
err = c.WriteMessage(pm.messageType, pm.data)
93+
frame.data = nc.buf.Bytes()
94+
})
95+
return pm.messageType, frame.data, err
96+
}
97+
98+
type prepareConn struct {
99+
buf bytes.Buffer
100+
net.Conn
101+
}
102+
103+
func (pc *prepareConn) Write(p []byte) (int, error) { return pc.buf.Write(p) }
104+
func (pc *prepareConn) SetWriteDeadline(t time.Time) error { return nil }

‎websocket/prepared_test.go

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
// fork from https://github.com/gorilla/websocket
6+
package websocket
7+
8+
import (
9+
"bytes"
10+
"compress/flate"
11+
"math/rand"
12+
"testing"
13+
)
14+
15+
var preparedMessageTests = []struct {
16+
messageType int
17+
isServer bool
18+
enableWriteCompression bool
19+
compressionLevel int
20+
}{
21+
// Server
22+
{TextMessage, true, false, flate.BestSpeed},
23+
{TextMessage, true, true, flate.BestSpeed},
24+
{TextMessage, true, true, flate.BestCompression},
25+
{PingMessage, true, false, flate.BestSpeed},
26+
{PingMessage, true, true, flate.BestSpeed},
27+
28+
// Client
29+
{TextMessage, false, false, flate.BestSpeed},
30+
{TextMessage, false, true, flate.BestSpeed},
31+
{TextMessage, false, true, flate.BestCompression},
32+
{PingMessage, false, false, flate.BestSpeed},
33+
{PingMessage, false, true, flate.BestSpeed},
34+
}
35+
36+
func TestPreparedMessage(t *testing.T) {
37+
for _, tt := range preparedMessageTests {
38+
var data = []byte("this is a test")
39+
var buf bytes.Buffer
40+
c := newConn(fakeNetConn{Reader: nil, Writer: &buf}, tt.isServer, 1024, 1024)
41+
if tt.enableWriteCompression {
42+
c.newCompressionWriter = compressNoContextTakeover
43+
}
44+
c.SetCompressionLevel(tt.compressionLevel)
45+
46+
// Seed random number generator for consistent frame mask.
47+
rand.Seed(1234)
48+
49+
if err := c.WriteMessage(tt.messageType, data); err != nil {
50+
t.Fatal(err)
51+
}
52+
want := buf.String()
53+
54+
pm, err := NewPreparedMessage(tt.messageType, data)
55+
if err != nil {
56+
t.Fatal(err)
57+
}
58+
59+
// Scribble on data to ensure that NewPreparedMessage takes a snapshot.
60+
copy(data, "hello world")
61+
62+
// Seed random number generator for consistent frame mask.
63+
rand.Seed(1234)
64+
65+
buf.Reset()
66+
if err := c.WritePreparedMessage(pm); err != nil {
67+
t.Fatal(err)
68+
}
69+
got := buf.String()
70+
71+
if got != want {
72+
t.Errorf("write message != prepared message for %+v", tt)
73+
}
74+
}
75+
}

‎websocket/server.go

+293
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,293 @@
1+
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
// fork from https://github.com/gorilla/websocket
6+
package websocket
7+
8+
import (
9+
"bufio"
10+
"errors"
11+
"net"
12+
"net/http"
13+
"net/url"
14+
"strings"
15+
"time"
16+
)
17+
18+
// HandshakeError describes an error with the handshake from the peer.
19+
type HandshakeError struct {
20+
message string
21+
}
22+
23+
func (e HandshakeError) Error() string { return e.message }
24+
25+
// Upgrader specifies parameters for upgrading an HTTP connection to a
26+
// WebSocket connection.
27+
type Upgrader struct {
28+
// HandshakeTimeout specifies the duration for the handshake to complete.
29+
HandshakeTimeout time.Duration
30+
31+
// ReadBufferSize and WriteBufferSize specify I/O buffer sizes. If a buffer
32+
// size is zero, then buffers allocated by the HTTP server are used. The
33+
// I/O buffer sizes do not limit the size of the messages that can be sent
34+
// or received.
35+
ReadBufferSize, WriteBufferSize int
36+
37+
// Subprotocols specifies the server's supported protocols in order of
38+
// preference. If this field is set, then the Upgrade method negotiates a
39+
// subprotocol by selecting the first match in this list with a protocol
40+
// requested by the client.
41+
Subprotocols []string
42+
43+
// Error specifies the function for generating HTTP error responses. If Error
44+
// is nil, then http.Error is used to generate the HTTP response.
45+
Error func(w http.ResponseWriter, r *http.Request, status int, reason error)
46+
47+
// CheckOrigin returns true if the request Origin header is acceptable. If
48+
// CheckOrigin is nil, the host in the Origin header must not be set or
49+
// must match the host of the request.
50+
CheckOrigin func(r *http.Request) bool
51+
52+
// EnableCompression specify if the server should attempt to negotiate per
53+
// message compression (RFC 7692). Setting this value to true does not
54+
// guarantee that compression will be supported. Currently only "no context
55+
// takeover" modes are supported.
56+
EnableCompression bool
57+
}
58+
59+
func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status int, reason string) (*Conn, error) {
60+
err := HandshakeError{reason}
61+
if u.Error != nil {
62+
u.Error(w, r, status, err)
63+
} else {
64+
w.Header().Set("Sec-Websocket-Version", "13")
65+
http.Error(w, http.StatusText(status), status)
66+
}
67+
return nil, err
68+
}
69+
70+
// checkSameOrigin returns true if the origin is not set or is equal to the request host.
71+
func checkSameOrigin(r *http.Request) bool {
72+
origin := r.Header["Origin"]
73+
if len(origin) == 0 {
74+
return true
75+
}
76+
u, err := url.Parse(origin[0])
77+
if err != nil {
78+
return false
79+
}
80+
return u.Host == r.Host
81+
}
82+
83+
func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string {
84+
if u.Subprotocols != nil {
85+
clientProtocols := Subprotocols(r)
86+
for _, serverProtocol := range u.Subprotocols {
87+
for _, clientProtocol := range clientProtocols {
88+
if clientProtocol == serverProtocol {
89+
return clientProtocol
90+
}
91+
}
92+
}
93+
} else if responseHeader != nil {
94+
return responseHeader.Get("Sec-Websocket-Protocol")
95+
}
96+
return ""
97+
}
98+
99+
// Upgrade upgrades the HTTP server connection to the WebSocket protocol.
100+
//
101+
// The responseHeader is included in the response to the client's upgrade
102+
// request. Use the responseHeader to specify cookies (Set-Cookie) and the
103+
// application negotiated subprotocol (Sec-Websocket-Protocol).
104+
//
105+
// If the upgrade fails, then Upgrade replies to the client with an HTTP error
106+
// response.
107+
func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) {
108+
if r.Method != "GET" {
109+
return u.returnError(w, r, http.StatusMethodNotAllowed, "websocket: not a websocket handshake: request method is not GET")
110+
}
111+
112+
if _, ok := responseHeader["Sec-Websocket-Extensions"]; ok {
113+
return u.returnError(w, r, http.StatusInternalServerError, "websocket: application specific 'Sec-Websocket-Extensions' headers are unsupported")
114+
}
115+
116+
if !tokenListContainsValue(r.Header, "Connection", "upgrade") {
117+
return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'upgrade' token not found in 'Connection' header")
118+
}
119+
120+
if !tokenListContainsValue(r.Header, "Upgrade", "websocket") {
121+
return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'websocket' token not found in 'Upgrade' header")
122+
}
123+
124+
if !tokenListContainsValue(r.Header, "Sec-Websocket-Version", "13") {
125+
return u.returnError(w, r, http.StatusBadRequest, "websocket: unsupported version: 13 not found in 'Sec-Websocket-Version' header")
126+
}
127+
128+
checkOrigin := u.CheckOrigin
129+
if checkOrigin == nil {
130+
checkOrigin = checkSameOrigin
131+
}
132+
if !checkOrigin(r) {
133+
return u.returnError(w, r, http.StatusForbidden, "websocket: 'Origin' header value not allowed")
134+
}
135+
136+
challengeKey := r.Header.Get("Sec-Websocket-Key")
137+
if challengeKey == "" {
138+
return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: `Sec-Websocket-Key' header is missing or blank")
139+
}
140+
141+
subprotocol := u.selectSubprotocol(r, responseHeader)
142+
143+
// Negotiate PMCE
144+
var compress bool
145+
if u.EnableCompression {
146+
for _, ext := range parseExtensions(r.Header) {
147+
if ext[""] != "permessage-deflate" {
148+
continue
149+
}
150+
compress = true
151+
break
152+
}
153+
}
154+
155+
var (
156+
netConn net.Conn
157+
err error
158+
)
159+
160+
h, ok := w.(http.Hijacker)
161+
if !ok {
162+
return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker")
163+
}
164+
var brw *bufio.ReadWriter
165+
netConn, brw, err = h.Hijack()
166+
if err != nil {
167+
return u.returnError(w, r, http.StatusInternalServerError, err.Error())
168+
}
169+
170+
if brw.Reader.Buffered() > 0 {
171+
netConn.Close()
172+
return nil, errors.New("websocket: client sent data before handshake is complete")
173+
}
174+
175+
c := newConnBRW(netConn, true, u.ReadBufferSize, u.WriteBufferSize, brw)
176+
c.subprotocol = subprotocol
177+
178+
if compress {
179+
c.newCompressionWriter = compressNoContextTakeover
180+
c.newDecompressionReader = decompressNoContextTakeover
181+
}
182+
183+
p := c.writeBuf[:0]
184+
p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...)
185+
p = append(p, computeAcceptKey(challengeKey)...)
186+
p = append(p, "\r\n"...)
187+
if c.subprotocol != "" {
188+
p = append(p, "Sec-Websocket-Protocol: "...)
189+
p = append(p, c.subprotocol...)
190+
p = append(p, "\r\n"...)
191+
}
192+
if compress {
193+
p = append(p, "Sec-Websocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...)
194+
}
195+
for k, vs := range responseHeader {
196+
if k == "Sec-Websocket-Protocol" {
197+
continue
198+
}
199+
for _, v := range vs {
200+
p = append(p, k...)
201+
p = append(p, ": "...)
202+
for i := 0; i < len(v); i++ {
203+
b := v[i]
204+
if b <= 31 {
205+
// prevent response splitting.
206+
b = ' '
207+
}
208+
p = append(p, b)
209+
}
210+
p = append(p, "\r\n"...)
211+
}
212+
}
213+
p = append(p, "\r\n"...)
214+
215+
// Clear deadlines set by HTTP server.
216+
netConn.SetDeadline(time.Time{})
217+
218+
if u.HandshakeTimeout > 0 {
219+
netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout))
220+
}
221+
if _, err = netConn.Write(p); err != nil {
222+
netConn.Close()
223+
return nil, err
224+
}
225+
if u.HandshakeTimeout > 0 {
226+
netConn.SetWriteDeadline(time.Time{})
227+
}
228+
229+
return c, nil
230+
}
231+
232+
// Upgrade upgrades the HTTP server connection to the WebSocket protocol.
233+
//
234+
// Deprecated: Use websocket.Upgrader instead.
235+
//
236+
// Upgrade does not perform origin checking. The application is responsible for
237+
// checking the Origin header before calling Upgrade. An example implementation
238+
// of the same origin policy check is:
239+
//
240+
// if req.Header.Get("Origin") != "http://"+req.Host {
241+
// http.Error(w, "Origin not allowed", 403)
242+
// return
243+
// }
244+
//
245+
// If the endpoint supports subprotocols, then the application is responsible
246+
// for negotiating the protocol used on the connection. Use the Subprotocols()
247+
// function to get the subprotocols requested by the client. Use the
248+
// Sec-Websocket-Protocol response header to specify the subprotocol selected
249+
// by the application.
250+
//
251+
// The responseHeader is included in the response to the client's upgrade
252+
// request. Use the responseHeader to specify cookies (Set-Cookie) and the
253+
// negotiated subprotocol (Sec-Websocket-Protocol).
254+
//
255+
// The connection buffers IO to the underlying network connection. The
256+
// readBufSize and writeBufSize parameters specify the size of the buffers to
257+
// use. Messages can be larger than the buffers.
258+
//
259+
// If the request is not a valid WebSocket handshake, then Upgrade returns an
260+
// error of type HandshakeError. Applications should handle this error by
261+
// replying to the client with an HTTP error response.
262+
func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header, readBufSize, writeBufSize int) (*Conn, error) {
263+
u := Upgrader{ReadBufferSize: readBufSize, WriteBufferSize: writeBufSize}
264+
u.Error = func(w http.ResponseWriter, r *http.Request, status int, reason error) {
265+
// don't return errors to maintain backwards compatibility
266+
}
267+
u.CheckOrigin = func(r *http.Request) bool {
268+
// allow all connections by default
269+
return true
270+
}
271+
return u.Upgrade(w, r, responseHeader)
272+
}
273+
274+
// Subprotocols returns the subprotocols requested by the client in the
275+
// Sec-Websocket-Protocol header.
276+
func Subprotocols(r *http.Request) []string {
277+
h := strings.TrimSpace(r.Header.Get("Sec-Websocket-Protocol"))
278+
if h == "" {
279+
return nil
280+
}
281+
protocols := strings.Split(h, ",")
282+
for i := range protocols {
283+
protocols[i] = strings.TrimSpace(protocols[i])
284+
}
285+
return protocols
286+
}
287+
288+
// IsWebSocketUpgrade returns true if the client requested upgrade to the
289+
// WebSocket protocol.
290+
func IsWebSocketUpgrade(r *http.Request) bool {
291+
return tokenListContainsValue(r.Header, "Connection", "upgrade") &&
292+
tokenListContainsValue(r.Header, "Upgrade", "websocket")
293+
}

‎websocket/server_test.go

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
// fork from https://github.com/gorilla/websocket
6+
package websocket
7+
8+
import (
9+
"net/http"
10+
"reflect"
11+
"testing"
12+
)
13+
14+
var subprotocolTests = []struct {
15+
h string
16+
protocols []string
17+
}{
18+
{"", nil},
19+
{"foo", []string{"foo"}},
20+
{"foo,bar", []string{"foo", "bar"}},
21+
{"foo, bar", []string{"foo", "bar"}},
22+
{" foo, bar", []string{"foo", "bar"}},
23+
{" foo, bar ", []string{"foo", "bar"}},
24+
}
25+
26+
func TestSubprotocols(t *testing.T) {
27+
for _, st := range subprotocolTests {
28+
r := http.Request{Header: http.Header{"Sec-Websocket-Protocol": {st.h}}}
29+
protocols := Subprotocols(&r)
30+
if !reflect.DeepEqual(st.protocols, protocols) {
31+
t.Errorf("SubProtocols(%q) returned %#v, want %#v", st.h, protocols, st.protocols)
32+
}
33+
}
34+
}
35+
36+
var isWebSocketUpgradeTests = []struct {
37+
ok bool
38+
h http.Header
39+
}{
40+
{false, http.Header{"Upgrade": {"websocket"}}},
41+
{false, http.Header{"Connection": {"upgrade"}}},
42+
{true, http.Header{"Connection": {"upgRade"}, "Upgrade": {"WebSocket"}}},
43+
}
44+
45+
func TestIsWebSocketUpgrade(t *testing.T) {
46+
for _, tt := range isWebSocketUpgradeTests {
47+
ok := IsWebSocketUpgrade(&http.Request{Header: tt.h})
48+
if tt.ok != ok {
49+
t.Errorf("IsWebSocketUpgrade(%v) returned %v, want %v", tt.h, ok, tt.ok)
50+
}
51+
}
52+
}

‎websocket/util.go

+215
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
// fork from https://github.com/gorilla/websocket
6+
package websocket
7+
8+
import (
9+
"crypto/rand"
10+
"crypto/sha1"
11+
"encoding/base64"
12+
"io"
13+
"net/http"
14+
"strings"
15+
)
16+
17+
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
18+
19+
func computeAcceptKey(challengeKey string) string {
20+
h := sha1.New()
21+
h.Write([]byte(challengeKey))
22+
h.Write(keyGUID)
23+
return base64.StdEncoding.EncodeToString(h.Sum(nil))
24+
}
25+
26+
func generateChallengeKey() (string, error) {
27+
p := make([]byte, 16)
28+
if _, err := io.ReadFull(rand.Reader, p); err != nil {
29+
return "", err
30+
}
31+
return base64.StdEncoding.EncodeToString(p), nil
32+
}
33+
34+
// Octet types from RFC 2616.
35+
var octetTypes [256]byte
36+
37+
const (
38+
isTokenOctet = 1 << iota
39+
isSpaceOctet
40+
)
41+
42+
func init() {
43+
// From RFC 2616
44+
//
45+
// OCTET = <any 8-bit sequence of data>
46+
// CHAR = <any US-ASCII character (octets 0 - 127)>
47+
// CTL = <any US-ASCII control character (octets 0 - 31) and DEL (127)>
48+
// CR = <US-ASCII CR, carriage return (13)>
49+
// LF = <US-ASCII LF, linefeed (10)>
50+
// SP = <US-ASCII SP, space (32)>
51+
// HT = <US-ASCII HT, horizontal-tab (9)>
52+
// <"> = <US-ASCII double-quote mark (34)>
53+
// CRLF = CR LF
54+
// LWS = [CRLF] 1*( SP | HT )
55+
// TEXT = <any OCTET except CTLs, but including LWS>
56+
// separators = "(" | ")" | "<" | ">" | "@" | "," | ";" | ":" | "\" | <">
57+
// | "/" | "[" | "]" | "?" | "=" | "{" | "}" | SP | HT
58+
// token = 1*<any CHAR except CTLs or separators>
59+
// qdtext = <any TEXT except <">>
60+
61+
for c := 0; c < 256; c++ {
62+
var t byte
63+
isCtl := c <= 31 || c == 127
64+
isChar := 0 <= c && c <= 127
65+
isSeparator := strings.IndexRune(" \t\"(),/:;<=>?@[]\\{}", rune(c)) >= 0
66+
if strings.IndexRune(" \t\r\n", rune(c)) >= 0 {
67+
t |= isSpaceOctet
68+
}
69+
if isChar && !isCtl && !isSeparator {
70+
t |= isTokenOctet
71+
}
72+
octetTypes[c] = t
73+
}
74+
}
75+
76+
func skipSpace(s string) (rest string) {
77+
i := 0
78+
for ; i < len(s); i++ {
79+
if octetTypes[s[i]]&isSpaceOctet == 0 {
80+
break
81+
}
82+
}
83+
return s[i:]
84+
}
85+
86+
func nextToken(s string) (token, rest string) {
87+
i := 0
88+
for ; i < len(s); i++ {
89+
if octetTypes[s[i]]&isTokenOctet == 0 {
90+
break
91+
}
92+
}
93+
return s[:i], s[i:]
94+
}
95+
96+
func nextTokenOrQuoted(s string) (value string, rest string) {
97+
if !strings.HasPrefix(s, "\"") {
98+
return nextToken(s)
99+
}
100+
s = s[1:]
101+
for i := 0; i < len(s); i++ {
102+
switch s[i] {
103+
case '"':
104+
return s[:i], s[i+1:]
105+
case '\\':
106+
p := make([]byte, len(s)-1)
107+
j := copy(p, s[:i])
108+
escape := true
109+
for i = i + 1; i < len(s); i++ {
110+
b := s[i]
111+
switch {
112+
case escape:
113+
escape = false
114+
p[j] = b
115+
j++
116+
case b == '\\':
117+
escape = true
118+
case b == '"':
119+
return string(p[:j]), s[i+1:]
120+
default:
121+
p[j] = b
122+
j++
123+
}
124+
}
125+
return "", ""
126+
}
127+
}
128+
return "", ""
129+
}
130+
131+
// tokenListContainsValue returns true if the 1#token header with the given
132+
// name contains token.
133+
func tokenListContainsValue(header http.Header, name string, value string) bool {
134+
headers:
135+
for _, s := range header[name] {
136+
for {
137+
var t string
138+
t, s = nextToken(skipSpace(s))
139+
if t == "" {
140+
continue headers
141+
}
142+
s = skipSpace(s)
143+
if s != "" && s[0] != ',' {
144+
continue headers
145+
}
146+
if strings.EqualFold(t, value) {
147+
return true
148+
}
149+
if s == "" {
150+
continue headers
151+
}
152+
s = s[1:]
153+
}
154+
}
155+
return false
156+
}
157+
158+
// parseExtensiosn parses WebSocket extensions from a header.
159+
func parseExtensions(header http.Header) []map[string]string {
160+
161+
// From RFC 6455:
162+
//
163+
// Sec-WebSocket-Extensions = extension-list
164+
// extension-list = 1#extension
165+
// extension = extension-token *( ";" extension-param )
166+
// extension-token = registered-token
167+
// registered-token = token
168+
// extension-param = token [ "=" (token | quoted-string) ]
169+
// ;When using the quoted-string syntax variant, the value
170+
// ;after quoted-string unescaping MUST conform to the
171+
// ;'token' ABNF.
172+
173+
var result []map[string]string
174+
headers:
175+
for _, s := range header["Sec-Websocket-Extensions"] {
176+
for {
177+
var t string
178+
t, s = nextToken(skipSpace(s))
179+
if t == "" {
180+
continue headers
181+
}
182+
ext := map[string]string{"": t}
183+
for {
184+
s = skipSpace(s)
185+
if !strings.HasPrefix(s, ";") {
186+
break
187+
}
188+
var k string
189+
k, s = nextToken(skipSpace(s[1:]))
190+
if k == "" {
191+
continue headers
192+
}
193+
s = skipSpace(s)
194+
var v string
195+
if strings.HasPrefix(s, "=") {
196+
v, s = nextTokenOrQuoted(skipSpace(s[1:]))
197+
s = skipSpace(s)
198+
}
199+
if s != "" && s[0] != ',' && s[0] != ';' {
200+
continue headers
201+
}
202+
ext[k] = v
203+
}
204+
if s != "" && s[0] != ',' {
205+
continue headers
206+
}
207+
result = append(result, ext)
208+
if s == "" {
209+
continue headers
210+
}
211+
s = s[1:]
212+
}
213+
}
214+
return result
215+
}

‎websocket/util_test.go

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
// Copyright 2014 The Gorilla WebSocket Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
// fork from https://github.com/gorilla/websocket
6+
package websocket
7+
8+
import (
9+
"net/http"
10+
"reflect"
11+
"testing"
12+
)
13+
14+
var tokenListContainsValueTests = []struct {
15+
value string
16+
ok bool
17+
}{
18+
{"WebSocket", true},
19+
{"WEBSOCKET", true},
20+
{"websocket", true},
21+
{"websockets", false},
22+
{"x websocket", false},
23+
{"websocket x", false},
24+
{"other,websocket,more", true},
25+
{"other, websocket, more", true},
26+
}
27+
28+
func TestTokenListContainsValue(t *testing.T) {
29+
for _, tt := range tokenListContainsValueTests {
30+
h := http.Header{"Upgrade": {tt.value}}
31+
ok := tokenListContainsValue(h, "Upgrade", "websocket")
32+
if ok != tt.ok {
33+
t.Errorf("tokenListContainsValue(h, n, %q) = %v, want %v", tt.value, ok, tt.ok)
34+
}
35+
}
36+
}
37+
38+
var parseExtensionTests = []struct {
39+
value string
40+
extensions []map[string]string
41+
}{
42+
{`foo`, []map[string]string{map[string]string{"": "foo"}}},
43+
{`foo, bar; baz=2`, []map[string]string{
44+
map[string]string{"": "foo"},
45+
map[string]string{"": "bar", "baz": "2"}}},
46+
{`foo; bar="b,a;z"`, []map[string]string{
47+
map[string]string{"": "foo", "bar": "b,a;z"}}},
48+
{`foo , bar; baz = 2`, []map[string]string{
49+
map[string]string{"": "foo"},
50+
map[string]string{"": "bar", "baz": "2"}}},
51+
{`foo, bar; baz=2 junk`, []map[string]string{
52+
map[string]string{"": "foo"}}},
53+
{`foo junk, bar; baz=2 junk`, nil},
54+
{`mux; max-channels=4; flow-control, deflate-stream`, []map[string]string{
55+
map[string]string{"": "mux", "max-channels": "4", "flow-control": ""},
56+
map[string]string{"": "deflate-stream"}}},
57+
{`permessage-foo; x="10"`, []map[string]string{
58+
map[string]string{"": "permessage-foo", "x": "10"}}},
59+
{`permessage-foo; use_y, permessage-foo`, []map[string]string{
60+
map[string]string{"": "permessage-foo", "use_y": ""},
61+
map[string]string{"": "permessage-foo"}}},
62+
{`permessage-deflate; client_max_window_bits; server_max_window_bits=10 , permessage-deflate; client_max_window_bits`, []map[string]string{
63+
map[string]string{"": "permessage-deflate", "client_max_window_bits": "", "server_max_window_bits": "10"},
64+
map[string]string{"": "permessage-deflate", "client_max_window_bits": ""}}},
65+
}
66+
67+
func TestParseExtensions(t *testing.T) {
68+
for _, tt := range parseExtensionTests {
69+
h := http.Header{http.CanonicalHeaderKey("Sec-WebSocket-Extensions"): {tt.value}}
70+
extensions := parseExtensions(h)
71+
if !reflect.DeepEqual(extensions, tt.extensions) {
72+
t.Errorf("parseExtensions(%q)\n = %v,\nwant %v", tt.value, extensions, tt.extensions)
73+
}
74+
}
75+
}

0 commit comments

Comments
 (0)
Please sign in to comment.