@@ -17,10 +17,10 @@ import (
17
17
"sync"
18
18
"time"
19
19
20
- "github.com/gorilla/mux"
21
20
"github.com/rs/xid"
22
21
"github.com/rs/zerolog"
23
22
"github.com/rs/zerolog/hlog"
23
+ "go.mau.fi/util/exhttp"
24
24
"go.mau.fi/util/jsontime"
25
25
"go.mau.fi/util/requestlog"
26
26
@@ -38,7 +38,7 @@ type matrixAuthCacheEntry struct {
38
38
}
39
39
40
40
type ProvisioningAPI struct {
41
- Router * mux. Router
41
+ Router * http. ServeMux
42
42
43
43
br * Connector
44
44
log zerolog.Logger
@@ -82,12 +82,12 @@ func (prov *ProvisioningAPI) GetUser(r *http.Request) *bridgev2.User {
82
82
return r .Context ().Value (provisioningUserKey ).(* bridgev2.User )
83
83
}
84
84
85
- func (prov * ProvisioningAPI ) GetRouter () * mux. Router {
85
+ func (prov * ProvisioningAPI ) GetRouter () * http. ServeMux {
86
86
return prov .Router
87
87
}
88
88
89
89
type IProvisioningAPI interface {
90
- GetRouter () * mux. Router
90
+ GetRouter () * http. ServeMux
91
91
GetUser (r * http.Request ) * bridgev2.User
92
92
}
93
93
@@ -106,50 +106,44 @@ func (prov *ProvisioningAPI) Init() {
106
106
tp .Dialer .Timeout = 10 * time .Second
107
107
tp .Transport .ResponseHeaderTimeout = 10 * time .Second
108
108
tp .Transport .TLSHandshakeTimeout = 10 * time .Second
109
- prov .Router = prov .br .AS .Router .PathPrefix (prov .br .Config .Provisioning .Prefix ).Subrouter ()
110
- prov .Router .Use (hlog .NewHandler (prov .log ))
111
- prov .Router .Use (hlog .RequestIDHandler ("request_id" , "Request-Id" ))
112
- prov .Router .Use (corsMiddleware )
113
- prov .Router .Use (requestlog .AccessLogger (false ))
114
- prov .Router .Use (prov .AuthMiddleware )
115
- prov .Router .Path ("/v3/whoami" ).Methods (http .MethodGet , http .MethodOptions ).HandlerFunc (prov .GetWhoami )
116
- prov .Router .Path ("/v3/login/flows" ).Methods (http .MethodGet , http .MethodOptions ).HandlerFunc (prov .GetLoginFlows )
117
- prov .Router .Path ("/v3/login/start/{flowID}" ).Methods (http .MethodPost , http .MethodOptions ).HandlerFunc (prov .PostLoginStart )
118
- prov .Router .Path ("/v3/login/step/{loginProcessID}/{stepID}/{stepType:user_input|cookies}" ).Methods (http .MethodPost , http .MethodOptions ).HandlerFunc (prov .PostLoginSubmitInput )
119
- prov .Router .Path ("/v3/login/step/{loginProcessID}/{stepID}/{stepType:display_and_wait}" ).Methods (http .MethodPost , http .MethodOptions ).HandlerFunc (prov .PostLoginWait )
120
- prov .Router .Path ("/v3/logout/{loginID}" ).Methods (http .MethodPost , http .MethodOptions ).HandlerFunc (prov .PostLogout )
121
- prov .Router .Path ("/v3/logins" ).Methods (http .MethodGet , http .MethodOptions ).HandlerFunc (prov .GetLogins )
122
- prov .Router .Path ("/v3/contacts" ).Methods (http .MethodGet , http .MethodOptions ).HandlerFunc (prov .GetContactList )
123
- prov .Router .Path ("/v3/search_users" ).Methods (http .MethodPost , http .MethodOptions ).HandlerFunc (prov .PostSearchUsers )
124
- prov .Router .Path ("/v3/resolve_identifier/{identifier}" ).Methods (http .MethodGet , http .MethodOptions ).HandlerFunc (prov .GetResolveIdentifier )
125
- prov .Router .Path ("/v3/create_dm/{identifier}" ).Methods (http .MethodPost , http .MethodOptions ).HandlerFunc (prov .PostCreateDM )
126
- prov .Router .Path ("/v3/create_group" ).Methods (http .MethodPost , http .MethodOptions ).HandlerFunc (prov .PostCreateGroup )
109
+
110
+ provRouter := http .NewServeMux ()
111
+
112
+ provRouter .HandleFunc ("GET /v3/whoami" , prov .GetWhoami )
113
+ provRouter .HandleFunc ("GET /v3/whoami/flows" , prov .GetLoginFlows )
114
+
115
+ provRouter .HandleFunc ("POST /v3/login/start/{flowID}" , prov .PostLoginStart )
116
+ provRouter .HandleFunc ("POST /v3/login/step/{loginProcessID}/{stepID}/{stepType}" , prov .PostLogin )
117
+ provRouter .HandleFunc ("POST /v3/logout/{loginID}" , prov .PostLogout )
118
+ provRouter .HandleFunc ("GET /v3/logins" , prov .GetLogins )
119
+ provRouter .HandleFunc ("GET /v3/contacts" , prov .GetContactList )
120
+ provRouter .HandleFunc ("POST /v3/search_users" , prov .PostSearchUsers )
121
+ provRouter .HandleFunc ("GET /v3/resolve_identifier/{identifier}" , prov .GetResolveIdentifier )
122
+ provRouter .HandleFunc ("POST /v3/create_dm/{identifier}" , prov .PostCreateDM )
123
+ provRouter .HandleFunc ("POST /v3/create_group" , prov .PostCreateGroup )
124
+
125
+ var provHandler http.Handler = prov .Router
126
+ provHandler = prov .AuthMiddleware (provHandler )
127
+ provHandler = requestlog .AccessLogger (false )(provHandler )
128
+ provHandler = exhttp .CORSMiddleware (provHandler )
129
+ provHandler = hlog .RequestIDHandler ("request_id" , "Request-Id" )(provHandler )
130
+ provHandler = hlog .NewHandler (prov .log )(provHandler )
131
+ provHandler = http .StripPrefix (prov .br .Config .Provisioning .Prefix , provHandler )
132
+ prov .br .AS .Router .Handle (prov .br .Config .Provisioning .Prefix , provHandler )
127
133
128
134
if prov .br .Config .Provisioning .DebugEndpoints {
129
135
prov .log .Debug ().Msg ("Enabling debug API at /debug" )
130
- r := prov .br .AS .Router .PathPrefix ("/debug" ).Subrouter ()
131
- r .Use (prov .DebugAuthMiddleware )
132
- r .HandleFunc ("/pprof/cmdline" , pprof .Cmdline ).Methods (http .MethodGet )
133
- r .HandleFunc ("/pprof/profile" , pprof .Profile ).Methods (http .MethodGet )
134
- r .HandleFunc ("/pprof/symbol" , pprof .Symbol ).Methods (http .MethodGet )
135
- r .HandleFunc ("/pprof/trace" , pprof .Trace ).Methods (http .MethodGet )
136
- r .PathPrefix ("/pprof/" ).HandlerFunc (pprof .Index )
136
+ debugRouter := http .NewServeMux ()
137
+ // TODO do we need to strip prefix here?
138
+ debugRouter .HandleFunc ("/debug/pprof" , pprof .Index )
139
+ debugRouter .HandleFunc ("GET /debug/pprof/trace" , pprof .Trace )
140
+ debugRouter .HandleFunc ("GET /debug/pprof/symbol" , pprof .Symbol )
141
+ debugRouter .HandleFunc ("GET /debug/pprof/profile" , pprof .Profile )
142
+ debugRouter .HandleFunc ("GET /debug/pprof/cmdline" , pprof .Cmdline )
143
+ prov .br .AS .Router .Handle ("/debug" , prov .AuthMiddleware (debugRouter ))
137
144
}
138
145
}
139
146
140
- func corsMiddleware (handler http.Handler ) http.Handler {
141
- return http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
142
- w .Header ().Set ("Access-Control-Allow-Origin" , "*" )
143
- w .Header ().Set ("Access-Control-Allow-Methods" , "GET, POST, PUT, DELETE, OPTIONS" )
144
- w .Header ().Set ("Access-Control-Allow-Headers" , "X-Requested-With, Content-Type, Authorization" )
145
- if r .Method == http .MethodOptions {
146
- w .WriteHeader (http .StatusOK )
147
- return
148
- }
149
- handler .ServeHTTP (w , r )
150
- })
151
- }
152
-
153
147
func jsonResponse (w http.ResponseWriter , status int , response any ) {
154
148
w .Header ().Add ("Content-Type" , "application/json" )
155
149
w .WriteHeader (status )
@@ -270,7 +264,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
270
264
}
271
265
272
266
ctx := context .WithValue (r .Context (), provisioningUserKey , user )
273
- if loginID , ok := mux . Vars ( r )[ "loginProcessID" ]; ok {
267
+ if loginID := r . PathValue ( "loginProcessID" ); loginID != "" {
274
268
prov .loginsLock .RLock ()
275
269
login , ok := prov .logins [loginID ]
276
270
prov .loginsLock .RUnlock ()
@@ -285,7 +279,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
285
279
login .Lock .Lock ()
286
280
// This will only unlock after the handler runs
287
281
defer login .Lock .Unlock ()
288
- stepID := mux . Vars ( r )[ "stepID" ]
282
+ stepID := r . PathValue ( "stepID" )
289
283
if login .NextStep .StepID != stepID {
290
284
zerolog .Ctx (r .Context ()).Warn ().
291
285
Str ("request_step_id" , stepID ).
@@ -297,7 +291,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
297
291
})
298
292
return
299
293
}
300
- stepType := mux . Vars ( r )[ "stepType" ]
294
+ stepType := r . PathValue ( "stepType" )
301
295
if login .NextStep .Type != bridgev2 .LoginStepType (stepType ) {
302
296
zerolog .Ctx (r .Context ()).Warn ().
303
297
Str ("request_step_type" , stepType ).
@@ -401,7 +395,7 @@ func (prov *ProvisioningAPI) PostLoginStart(w http.ResponseWriter, r *http.Reque
401
395
login , err := prov .net .CreateLogin (
402
396
r .Context (),
403
397
prov .GetUser (r ),
404
- mux . Vars ( r )[ "flowID" ] ,
398
+ r . PathValue ( "flowID" ) ,
405
399
)
406
400
if err != nil {
407
401
zerolog .Ctx (r .Context ()).Err (err ).Msg ("Failed to create login process" )
@@ -440,6 +434,17 @@ func (prov *ProvisioningAPI) handleCompleteStep(ctx context.Context, login *Prov
440
434
}, bridgev2.DeleteOpts {LogoutRemote : true })
441
435
}
442
436
437
+ func (prov * ProvisioningAPI ) PostLogin (w http.ResponseWriter , r * http.Request ) {
438
+ switch r .PathValue ("stepType" ) {
439
+ case "user_input" , "cookies" :
440
+ prov .PostLoginSubmitInput (w , r )
441
+ case "display_and_wait" :
442
+ prov .PostLoginWait (w , r )
443
+ default :
444
+ panic ("Impossible state" ) // checked by the AuthMiddleware
445
+ }
446
+ }
447
+
443
448
func (prov * ProvisioningAPI ) PostLoginSubmitInput (w http.ResponseWriter , r * http.Request ) {
444
449
var params map [string ]string
445
450
err := json .NewDecoder (r .Body ).Decode (& params )
@@ -493,7 +498,7 @@ func (prov *ProvisioningAPI) PostLoginWait(w http.ResponseWriter, r *http.Reques
493
498
494
499
func (prov * ProvisioningAPI ) PostLogout (w http.ResponseWriter , r * http.Request ) {
495
500
user := prov .GetUser (r )
496
- userLoginID := networkid .UserLoginID (mux . Vars ( r )[ "loginID" ] )
501
+ userLoginID := networkid .UserLoginID (r . PathValue ( "loginID" ) )
497
502
if userLoginID == "all" {
498
503
for {
499
504
login := user .GetDefaultLogin ()
@@ -596,7 +601,7 @@ func (prov *ProvisioningAPI) doResolveIdentifier(w http.ResponseWriter, r *http.
596
601
})
597
602
return
598
603
}
599
- resp , err := api .ResolveIdentifier (r .Context (), mux . Vars ( r )[ "identifier" ] , createChat )
604
+ resp , err := api .ResolveIdentifier (r .Context (), r . PathValue ( "identifier" ) , createChat )
600
605
if err != nil {
601
606
zerolog .Ctx (r .Context ()).Err (err ).Msg ("Failed to resolve identifier" )
602
607
RespondWithError (w , err , "Internal error resolving identifier" )
0 commit comments