Skip to content

Commit 481f435

Browse files
committed
treewide: replace gorilla/mux with http.ServeMux
Signed-off-by: Sumner Evans <[email protected]>
1 parent 7a9269e commit 481f435

File tree

11 files changed

+175
-188
lines changed

11 files changed

+175
-188
lines changed

appservice/appservice.go

+8-9
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ import (
1919
"syscall"
2020
"time"
2121

22-
"github.com/gorilla/mux"
2322
"github.com/gorilla/websocket"
2423
"github.com/rs/zerolog"
2524
"golang.org/x/net/publicsuffix"
@@ -43,7 +42,7 @@ func Create() *AppService {
4342
intents: make(map[id.UserID]*IntentAPI),
4443
HTTPClient: &http.Client{Timeout: 180 * time.Second, Jar: jar},
4544
StateStore: mautrix.NewMemoryStateStore().(StateStore),
46-
Router: mux.NewRouter(),
45+
Router: http.NewServeMux(),
4746
UserAgent: mautrix.DefaultUserAgent,
4847
txnIDC: NewTransactionIDCache(128),
4948
Live: true,
@@ -61,12 +60,12 @@ func Create() *AppService {
6160
DefaultHTTPRetries: 4,
6261
}
6362

64-
as.Router.HandleFunc("/_matrix/app/v1/transactions/{txnID}", as.PutTransaction).Methods(http.MethodPut)
65-
as.Router.HandleFunc("/_matrix/app/v1/rooms/{roomAlias}", as.GetRoom).Methods(http.MethodGet)
66-
as.Router.HandleFunc("/_matrix/app/v1/users/{userID}", as.GetUser).Methods(http.MethodGet)
67-
as.Router.HandleFunc("/_matrix/app/v1/ping", as.PostPing).Methods(http.MethodPost)
68-
as.Router.HandleFunc("/_matrix/mau/live", as.GetLive).Methods(http.MethodGet)
69-
as.Router.HandleFunc("/_matrix/mau/ready", as.GetReady).Methods(http.MethodGet)
63+
as.Router.HandleFunc("PUT /_matrix/app/v1/transactions/{txnID}", as.PutTransaction)
64+
as.Router.HandleFunc("GET /_matrix/app/v1/rooms/{roomAlias}", as.GetRoom)
65+
as.Router.HandleFunc("GET /_matrix/app/v1/users/{userID}", as.GetUser)
66+
as.Router.HandleFunc("POST /_matrix/app/v1/ping", as.PostPing)
67+
as.Router.HandleFunc("GET /_matrix/mau/live", as.GetLive)
68+
as.Router.HandleFunc("GET /_matrix/mau/ready", as.GetReady)
7069

7170
return as
7271
}
@@ -160,7 +159,7 @@ type AppService struct {
160159
QueryHandler QueryHandler
161160
StateStore StateStore
162161

163-
Router *mux.Router
162+
Router *http.ServeMux
164163
UserAgent string
165164
server *http.Server
166165
HTTPClient *http.Client

appservice/http.go

+3-9
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ import (
1717
"syscall"
1818
"time"
1919

20-
"github.com/gorilla/mux"
2120
"github.com/rs/zerolog"
2221

2322
"maunium.net/go/mautrix"
@@ -101,8 +100,7 @@ func (as *AppService) PutTransaction(w http.ResponseWriter, r *http.Request) {
101100
return
102101
}
103102

104-
vars := mux.Vars(r)
105-
txnID := vars["txnID"]
103+
txnID := r.PathValue("txnID")
106104
if len(txnID) == 0 {
107105
Error{
108106
ErrorCode: ErrNoTransactionID,
@@ -258,9 +256,7 @@ func (as *AppService) GetRoom(w http.ResponseWriter, r *http.Request) {
258256
return
259257
}
260258

261-
vars := mux.Vars(r)
262-
roomAlias := vars["roomAlias"]
263-
ok := as.QueryHandler.QueryAlias(roomAlias)
259+
ok := as.QueryHandler.QueryAlias(r.PathValue("roomAlias"))
264260
if ok {
265261
WriteBlankOK(w)
266262
} else {
@@ -277,9 +273,7 @@ func (as *AppService) GetUser(w http.ResponseWriter, r *http.Request) {
277273
return
278274
}
279275

280-
vars := mux.Vars(r)
281-
userID := id.UserID(vars["userID"])
282-
ok := as.QueryHandler.QueryUser(userID)
276+
ok := as.QueryHandler.QueryUser(id.UserID(r.PathValue("userID")))
283277
if ok {
284278
WriteBlankOK(w)
285279
} else {

bridgev2/matrix/connector.go

+3-2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"encoding/json"
1414
"errors"
1515
"fmt"
16+
"net/http"
1617
"net/url"
1718
"os"
1819
"regexp"
@@ -21,7 +22,6 @@ import (
2122
"time"
2223
"unsafe"
2324

24-
"github.com/gorilla/mux"
2525
_ "github.com/lib/pq"
2626
"github.com/rs/zerolog"
2727
"go.mau.fi/util/dbutil"
@@ -222,7 +222,8 @@ func (br *Connector) GetPublicAddress() string {
222222
return br.Config.AppService.PublicAddress
223223
}
224224

225-
func (br *Connector) GetRouter() *mux.Router {
225+
// TODO switch to http.ServeMux
226+
func (br *Connector) GetRouter() *http.ServeMux {
226227
if br.GetPublicAddress() != "" {
227228
return br.AS.Router
228229
}

bridgev2/matrix/provisioning.go

+53-48
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@ import (
1717
"sync"
1818
"time"
1919

20-
"github.com/gorilla/mux"
2120
"github.com/rs/xid"
2221
"github.com/rs/zerolog"
2322
"github.com/rs/zerolog/hlog"
23+
"go.mau.fi/util/exhttp"
2424
"go.mau.fi/util/jsontime"
2525
"go.mau.fi/util/requestlog"
2626

@@ -38,7 +38,7 @@ type matrixAuthCacheEntry struct {
3838
}
3939

4040
type ProvisioningAPI struct {
41-
Router *mux.Router
41+
Router *http.ServeMux
4242

4343
br *Connector
4444
log zerolog.Logger
@@ -82,12 +82,12 @@ func (prov *ProvisioningAPI) GetUser(r *http.Request) *bridgev2.User {
8282
return r.Context().Value(provisioningUserKey).(*bridgev2.User)
8383
}
8484

85-
func (prov *ProvisioningAPI) GetRouter() *mux.Router {
85+
func (prov *ProvisioningAPI) GetRouter() *http.ServeMux {
8686
return prov.Router
8787
}
8888

8989
type IProvisioningAPI interface {
90-
GetRouter() *mux.Router
90+
GetRouter() *http.ServeMux
9191
GetUser(r *http.Request) *bridgev2.User
9292
}
9393

@@ -106,50 +106,44 @@ func (prov *ProvisioningAPI) Init() {
106106
tp.Dialer.Timeout = 10 * time.Second
107107
tp.Transport.ResponseHeaderTimeout = 10 * time.Second
108108
tp.Transport.TLSHandshakeTimeout = 10 * time.Second
109-
prov.Router = prov.br.AS.Router.PathPrefix(prov.br.Config.Provisioning.Prefix).Subrouter()
110-
prov.Router.Use(hlog.NewHandler(prov.log))
111-
prov.Router.Use(hlog.RequestIDHandler("request_id", "Request-Id"))
112-
prov.Router.Use(corsMiddleware)
113-
prov.Router.Use(requestlog.AccessLogger(false))
114-
prov.Router.Use(prov.AuthMiddleware)
115-
prov.Router.Path("/v3/whoami").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetWhoami)
116-
prov.Router.Path("/v3/login/flows").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetLoginFlows)
117-
prov.Router.Path("/v3/login/start/{flowID}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLoginStart)
118-
prov.Router.Path("/v3/login/step/{loginProcessID}/{stepID}/{stepType:user_input|cookies}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLoginSubmitInput)
119-
prov.Router.Path("/v3/login/step/{loginProcessID}/{stepID}/{stepType:display_and_wait}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLoginWait)
120-
prov.Router.Path("/v3/logout/{loginID}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLogout)
121-
prov.Router.Path("/v3/logins").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetLogins)
122-
prov.Router.Path("/v3/contacts").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetContactList)
123-
prov.Router.Path("/v3/search_users").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostSearchUsers)
124-
prov.Router.Path("/v3/resolve_identifier/{identifier}").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetResolveIdentifier)
125-
prov.Router.Path("/v3/create_dm/{identifier}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostCreateDM)
126-
prov.Router.Path("/v3/create_group").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostCreateGroup)
109+
110+
provRouter := http.NewServeMux()
111+
112+
provRouter.HandleFunc("GET /v3/whoami", prov.GetWhoami)
113+
provRouter.HandleFunc("GET /v3/whoami/flows", prov.GetLoginFlows)
114+
115+
provRouter.HandleFunc("POST /v3/login/start/{flowID}", prov.PostLoginStart)
116+
provRouter.HandleFunc("POST /v3/login/step/{loginProcessID}/{stepID}/{stepType}", prov.PostLogin)
117+
provRouter.HandleFunc("POST /v3/logout/{loginID}", prov.PostLogout)
118+
provRouter.HandleFunc("GET /v3/logins", prov.GetLogins)
119+
provRouter.HandleFunc("GET /v3/contacts", prov.GetContactList)
120+
provRouter.HandleFunc("POST /v3/search_users", prov.PostSearchUsers)
121+
provRouter.HandleFunc("GET /v3/resolve_identifier/{identifier}", prov.GetResolveIdentifier)
122+
provRouter.HandleFunc("POST /v3/create_dm/{identifier}", prov.PostCreateDM)
123+
provRouter.HandleFunc("POST /v3/create_group", prov.PostCreateGroup)
124+
125+
var provHandler http.Handler = prov.Router
126+
provHandler = prov.AuthMiddleware(provHandler)
127+
provHandler = requestlog.AccessLogger(false)(provHandler)
128+
provHandler = exhttp.CORSMiddleware(provHandler)
129+
provHandler = hlog.RequestIDHandler("request_id", "Request-Id")(provHandler)
130+
provHandler = hlog.NewHandler(prov.log)(provHandler)
131+
provHandler = http.StripPrefix(prov.br.Config.Provisioning.Prefix, provHandler)
132+
prov.br.AS.Router.Handle(prov.br.Config.Provisioning.Prefix, provHandler)
127133

128134
if prov.br.Config.Provisioning.DebugEndpoints {
129135
prov.log.Debug().Msg("Enabling debug API at /debug")
130-
r := prov.br.AS.Router.PathPrefix("/debug").Subrouter()
131-
r.Use(prov.DebugAuthMiddleware)
132-
r.HandleFunc("/pprof/cmdline", pprof.Cmdline).Methods(http.MethodGet)
133-
r.HandleFunc("/pprof/profile", pprof.Profile).Methods(http.MethodGet)
134-
r.HandleFunc("/pprof/symbol", pprof.Symbol).Methods(http.MethodGet)
135-
r.HandleFunc("/pprof/trace", pprof.Trace).Methods(http.MethodGet)
136-
r.PathPrefix("/pprof/").HandlerFunc(pprof.Index)
136+
debugRouter := http.NewServeMux()
137+
// TODO do we need to strip prefix here?
138+
debugRouter.HandleFunc("/debug/pprof", pprof.Index)
139+
debugRouter.HandleFunc("GET /debug/pprof/trace", pprof.Trace)
140+
debugRouter.HandleFunc("GET /debug/pprof/symbol", pprof.Symbol)
141+
debugRouter.HandleFunc("GET /debug/pprof/profile", pprof.Profile)
142+
debugRouter.HandleFunc("GET /debug/pprof/cmdline", pprof.Cmdline)
143+
prov.br.AS.Router.Handle("/debug", prov.AuthMiddleware(debugRouter))
137144
}
138145
}
139146

140-
func corsMiddleware(handler http.Handler) http.Handler {
141-
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
142-
w.Header().Set("Access-Control-Allow-Origin", "*")
143-
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
144-
w.Header().Set("Access-Control-Allow-Headers", "X-Requested-With, Content-Type, Authorization")
145-
if r.Method == http.MethodOptions {
146-
w.WriteHeader(http.StatusOK)
147-
return
148-
}
149-
handler.ServeHTTP(w, r)
150-
})
151-
}
152-
153147
func jsonResponse(w http.ResponseWriter, status int, response any) {
154148
w.Header().Add("Content-Type", "application/json")
155149
w.WriteHeader(status)
@@ -270,7 +264,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
270264
}
271265

272266
ctx := context.WithValue(r.Context(), provisioningUserKey, user)
273-
if loginID, ok := mux.Vars(r)["loginProcessID"]; ok {
267+
if loginID := r.PathValue("loginProcessID"); loginID != "" {
274268
prov.loginsLock.RLock()
275269
login, ok := prov.logins[loginID]
276270
prov.loginsLock.RUnlock()
@@ -285,7 +279,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
285279
login.Lock.Lock()
286280
// This will only unlock after the handler runs
287281
defer login.Lock.Unlock()
288-
stepID := mux.Vars(r)["stepID"]
282+
stepID := r.PathValue("stepID")
289283
if login.NextStep.StepID != stepID {
290284
zerolog.Ctx(r.Context()).Warn().
291285
Str("request_step_id", stepID).
@@ -297,7 +291,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
297291
})
298292
return
299293
}
300-
stepType := mux.Vars(r)["stepType"]
294+
stepType := r.PathValue("stepType")
301295
if login.NextStep.Type != bridgev2.LoginStepType(stepType) {
302296
zerolog.Ctx(r.Context()).Warn().
303297
Str("request_step_type", stepType).
@@ -401,7 +395,7 @@ func (prov *ProvisioningAPI) PostLoginStart(w http.ResponseWriter, r *http.Reque
401395
login, err := prov.net.CreateLogin(
402396
r.Context(),
403397
prov.GetUser(r),
404-
mux.Vars(r)["flowID"],
398+
r.PathValue("flowID"),
405399
)
406400
if err != nil {
407401
zerolog.Ctx(r.Context()).Err(err).Msg("Failed to create login process")
@@ -440,6 +434,17 @@ func (prov *ProvisioningAPI) handleCompleteStep(ctx context.Context, login *Prov
440434
}, bridgev2.DeleteOpts{LogoutRemote: true})
441435
}
442436

437+
func (prov *ProvisioningAPI) PostLogin(w http.ResponseWriter, r *http.Request) {
438+
switch r.PathValue("stepType") {
439+
case "user_input", "cookies":
440+
prov.PostLoginSubmitInput(w, r)
441+
case "display_and_wait":
442+
prov.PostLoginWait(w, r)
443+
default:
444+
panic("Impossible state") // checked by the AuthMiddleware
445+
}
446+
}
447+
443448
func (prov *ProvisioningAPI) PostLoginSubmitInput(w http.ResponseWriter, r *http.Request) {
444449
var params map[string]string
445450
err := json.NewDecoder(r.Body).Decode(&params)
@@ -493,7 +498,7 @@ func (prov *ProvisioningAPI) PostLoginWait(w http.ResponseWriter, r *http.Reques
493498

494499
func (prov *ProvisioningAPI) PostLogout(w http.ResponseWriter, r *http.Request) {
495500
user := prov.GetUser(r)
496-
userLoginID := networkid.UserLoginID(mux.Vars(r)["loginID"])
501+
userLoginID := networkid.UserLoginID(r.PathValue("loginID"))
497502
if userLoginID == "all" {
498503
for {
499504
login := user.GetDefaultLogin()
@@ -596,7 +601,7 @@ func (prov *ProvisioningAPI) doResolveIdentifier(w http.ResponseWriter, r *http.
596601
})
597602
return
598603
}
599-
resp, err := api.ResolveIdentifier(r.Context(), mux.Vars(r)["identifier"], createChat)
604+
resp, err := api.ResolveIdentifier(r.Context(), r.PathValue("identifier"), createChat)
600605
if err != nil {
601606
zerolog.Ctx(r.Context()).Err(err).Msg("Failed to resolve identifier")
602607
RespondWithError(w, err, "Internal error resolving identifier")

bridgev2/matrix/publicmedia.go

+4-7
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@ import (
1616
"net/http"
1717
"time"
1818

19-
"github.com/gorilla/mux"
20-
2119
"maunium.net/go/mautrix/bridgev2"
2220
"maunium.net/go/mautrix/id"
2321
)
@@ -35,7 +33,7 @@ func (br *Connector) initPublicMedia() error {
3533
return fmt.Errorf("public media hash length is negative")
3634
}
3735
br.pubMediaSigKey = []byte(br.Config.PublicMedia.SigningKey)
38-
br.AS.Router.HandleFunc("/_mautrix/publicmedia/{server}/{mediaID}/{checksum}", br.servePublicMedia).Methods(http.MethodGet)
36+
br.AS.Router.HandleFunc("GET /_mautrix/publicmedia/{server}/{mediaID}/{checksum}", br.servePublicMedia)
3937
return nil
4038
}
4139

@@ -76,16 +74,15 @@ var proxyHeadersToCopy = []string{
7674
}
7775

7876
func (br *Connector) servePublicMedia(w http.ResponseWriter, r *http.Request) {
79-
vars := mux.Vars(r)
8077
contentURI := id.ContentURI{
81-
Homeserver: vars["server"],
82-
FileID: vars["mediaID"],
78+
Homeserver: r.PathValue("server"),
79+
FileID: r.PathValue("mediaID"),
8380
}
8481
if !contentURI.IsValid() {
8582
http.Error(w, "invalid content URI", http.StatusBadRequest)
8683
return
8784
}
88-
checksum, err := base64.RawURLEncoding.DecodeString(vars["checksum"])
85+
checksum, err := base64.RawURLEncoding.DecodeString(r.PathValue("checksum"))
8986
if err != nil || !hmac.Equal(checksum, br.makePublicMediaChecksum(contentURI)) {
9087
http.Error(w, "invalid base64 in checksum", http.StatusBadRequest)
9188
return

bridgev2/matrixinterface.go

+2-3
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,10 @@ import (
1010
"context"
1111
"fmt"
1212
"io"
13+
"net/http"
1314
"os"
1415
"time"
1516

16-
"github.com/gorilla/mux"
17-
1817
"maunium.net/go/mautrix"
1918
"maunium.net/go/mautrix/bridge/status"
2019
"maunium.net/go/mautrix/bridgev2/database"
@@ -58,7 +57,7 @@ type MatrixConnector interface {
5857

5958
type MatrixConnectorWithServer interface {
6059
GetPublicAddress() string
61-
GetRouter() *mux.Router
60+
GetRouter() *http.ServeMux
6261
}
6362

6463
type MatrixConnectorWithPublicMedia interface {

0 commit comments

Comments
 (0)