@@ -2,13 +2,73 @@ package tunneld_test
22
33import (
44 "context"
5+ "encoding/hex"
6+ "encoding/json"
7+ "net/http"
8+ "strings"
59 "testing"
610
711 "github.com/stretchr/testify/require"
812
13+ "github.com/coder/wgtunnel/tunneld"
914 "github.com/coder/wgtunnel/tunnelsdk"
1015)
1116
17+ // Test for the compatibility endpoint which allows old tunnels to connect to
18+ // the new server.
19+ func Test_postTun (t * testing.T ) {
20+ t .Parallel ()
21+
22+ td , client := createTestTunneld (t , nil )
23+
24+ key , err := tunnelsdk .GeneratePrivateKey ()
25+ require .NoError (t , err )
26+
27+ expectedIP , expectedURLs := td .WireguardPublicKeyToIPAndURLs (key .NoisePublicKey (), tunnelsdk .TunnelVersion1 )
28+ require .Len (t , expectedURLs , 2 )
29+ require .Len (t , strings .Split (expectedURLs [0 ].Host , "." )[0 ], 32 )
30+ expectedHostname := expectedURLs [0 ].Host
31+
32+ // First request should return a 201.
33+ resp , err := client .Request (context .Background (), http .MethodPost , "/tun" , tunneld.LegacyPostTunRequest {
34+ PublicKey : key .NoisePublicKey (),
35+ })
36+ require .NoError (t , err )
37+ defer resp .Body .Close ()
38+ require .Equal (t , http .StatusCreated , resp .StatusCode )
39+
40+ var legacyRes tunneld.LegacyPostTunResponse
41+ require .NoError (t , json .NewDecoder (resp .Body ).Decode (& legacyRes ))
42+ require .Equal (t , expectedIP , legacyRes .ClientIP )
43+ require .Equal (t , expectedHostname , legacyRes .Hostname )
44+
45+ // Register on the new endpoint so we can compare the values to the legacy
46+ // endpoint.
47+ newRes , err := client .ClientRegister (context .Background (), tunnelsdk.ClientRegisterRequest {
48+ Version : tunnelsdk .TunnelVersion1 ,
49+ PublicKey : key .NoisePublicKey (),
50+ })
51+ require .NoError (t , err )
52+ require .Equal (t , tunnelsdk .TunnelVersion1 , newRes .Version )
53+
54+ require .Equal (t , legacyRes .ServerEndpoint , newRes .ServerEndpoint )
55+ require .Equal (t , legacyRes .ServerIP , newRes .ServerIP )
56+ require .Equal (t , legacyRes .ServerPublicKey , hex .EncodeToString (newRes .ServerPublicKey [:]))
57+ require .Equal (t , legacyRes .ClientIP , newRes .ClientIP )
58+
59+ // Second request should return a 200.
60+ resp , err = client .Request (context .Background (), http .MethodPost , "/tun" , tunneld.LegacyPostTunRequest {
61+ PublicKey : key .NoisePublicKey (),
62+ })
63+ require .NoError (t , err )
64+ defer resp .Body .Close ()
65+ require .Equal (t , http .StatusOK , resp .StatusCode )
66+
67+ var legacyRes2 tunneld.LegacyPostTunResponse
68+ require .NoError (t , json .NewDecoder (resp .Body ).Decode (& legacyRes2 ))
69+ require .Equal (t , legacyRes , legacyRes2 )
70+ }
71+
1272func Test_postClients (t * testing.T ) {
1373 t .Parallel ()
1474
@@ -17,16 +77,22 @@ func Test_postClients(t *testing.T) {
1777 key , err := tunnelsdk .GeneratePrivateKey ()
1878 require .NoError (t , err )
1979
20- expectedIP := td .WireguardPublicKeyToIP (key .NoisePublicKey ())
21- expectedURL := td .WireguardIPToTunnelURL (expectedIP )
80+ expectedIP , expectedURLs := td .WireguardPublicKeyToIPAndURLs (key .NoisePublicKey (), tunnelsdk .TunnelVersion2 )
81+
82+ expectedURLsStr := make ([]string , len (expectedURLs ))
83+ for i , u := range expectedURLs {
84+ expectedURLsStr [i ] = u .String ()
85+ }
2286
2387 // Register a client.
2488 res , err := client .ClientRegister (context .Background (), tunnelsdk.ClientRegisterRequest {
89+ // No version should default to 2.
2590 PublicKey : key .NoisePublicKey (),
2691 })
2792 require .NoError (t , err )
2893
29- require .Equal (t , expectedURL .String (), res .TunnelURL )
94+ require .Equal (t , tunnelsdk .TunnelVersion2 , res .Version )
95+ require .Equal (t , expectedURLsStr , res .TunnelURLs )
3096 require .Equal (t , expectedIP , res .ClientIP )
3197 require .Equal (t , td .WireguardEndpoint , res .ServerEndpoint )
3298 require .Equal (t , td .WireguardServerIP , res .ServerIP )
@@ -35,8 +101,22 @@ func Test_postClients(t *testing.T) {
35101
36102 // Register the same client again.
37103 res2 , err := client .ClientRegister (context .Background (), tunnelsdk.ClientRegisterRequest {
104+ Version : tunnelsdk .TunnelVersion2 ,
38105 PublicKey : key .NoisePublicKey (),
39106 })
40107 require .NoError (t , err )
41108 require .Equal (t , res , res2 )
109+
110+ // Register the same client with the old version.
111+ res3 , err := client .ClientRegister (context .Background (), tunnelsdk.ClientRegisterRequest {
112+ Version : tunnelsdk .TunnelVersion1 ,
113+ PublicKey : key .NoisePublicKey (),
114+ })
115+ require .NoError (t , err )
116+
117+ // Should be equal after reversing the URL list.
118+ require .Equal (t , tunnelsdk .TunnelVersion1 , res3 .Version )
119+ res3 .TunnelURLs [0 ], res3 .TunnelURLs [1 ] = res3 .TunnelURLs [1 ], res3 .TunnelURLs [0 ]
120+ res3 .Version = tunnelsdk .TunnelVersion2
121+ require .Equal (t , res , res3 )
42122}
0 commit comments