@@ -22,7 +22,9 @@ import qualified Data.ByteString.Char8 as BS
22
22
import qualified Data.ByteString.Lazy as BL
23
23
import qualified Data.HashMap.Strict as M
24
24
import qualified Data.Text.Encoding.Error as T
25
- import Data.Time.Clock.POSIX (getPOSIXTime )
25
+ import Data.Time.Clock (UTCTime )
26
+ import Data.Time.Clock.POSIX (utcTimeToPOSIXSeconds , posixSecondsToUTCTime )
27
+ import Control.Concurrent.AlarmClock (newAlarmClock , setAlarm )
26
28
import PostgresWebsockets.Broadcast (Multiplexer , onMessage )
27
29
import qualified PostgresWebsockets.Broadcast as B
28
30
import PostgresWebsockets.Claims
@@ -38,19 +40,21 @@ data Message = Message
38
40
instance A. ToJSON Message
39
41
40
42
-- | Given a secret, a function to fetch the system time, a Hasql Pool and a Multiplexer this will give you a WAI middleware.
41
- postgresWsMiddleware :: Text -> ByteString -> H. Pool -> Multiplexer -> Wai. Application -> Wai. Application
43
+ postgresWsMiddleware :: IO UTCTime -> Text -> ByteString -> H. Pool -> Multiplexer -> Wai. Application -> Wai. Application
42
44
postgresWsMiddleware =
43
45
WS. websocketsOr WS. defaultConnectionOptions `compose` wsApp
44
46
where
45
- compose = (.) . (.) . (.) . (.)
47
+ compose = (.) . (.) . (.) . (.) . (.)
46
48
47
49
-- private functions
50
+ jwtExpirationStatusCode :: Word16
51
+ jwtExpirationStatusCode = 3001
48
52
49
53
-- when the websocket is closed a ConnectionClosed Exception is triggered
50
54
-- this kills all children and frees resources for us
51
- wsApp :: Text -> ByteString -> H. Pool -> Multiplexer -> WS. ServerApp
52
- wsApp dbChannel secret pool multi pendingConn =
53
- validateClaims requestChannel secret (toS jwtToken) >>= either rejectRequest forkSessions
55
+ wsApp :: IO UTCTime -> Text -> ByteString -> H. Pool -> Multiplexer -> WS. ServerApp
56
+ wsApp getTime dbChannel secret pool multi pendingConn =
57
+ getTime >>= validateClaims requestChannel secret (toS jwtToken) >>= either rejectRequest forkSessions
54
58
where
55
59
hasRead m = m == (" r" :: ByteString ) || m == (" rw" :: ByteString )
56
60
hasWrite m = m == (" w" :: ByteString ) || m == (" rw" :: ByteString )
@@ -68,27 +72,34 @@ wsApp dbChannel secret pool multi pendingConn =
68
72
-- We should accept only after verifying JWT
69
73
conn <- WS. acceptRequest pendingConn
70
74
-- Fork a pinging thread to ensure browser connections stay alive
71
- WS. forkPingThread conn 30
75
+ WS. withPingThread conn 30 (pure () ) $ do
76
+ case M. lookup " exp" validClaims of
77
+ Just (A. Number expClaim) -> do
78
+ connectionExpirer <- newAlarmClock $ const (WS. sendCloseCode conn jwtExpirationStatusCode (" JWT expired" :: ByteString ))
79
+ setAlarm connectionExpirer (posixSecondsToUTCTime $ realToFrac expClaim)
80
+ Just _ -> pure ()
81
+ Nothing -> pure ()
72
82
73
- when (hasRead mode) $
74
- onMessage multi ch $ WS. sendTextData conn . B. payload
83
+ when (hasRead mode) $
84
+ onMessage multi ch $ WS. sendTextData conn . B. payload
75
85
76
- when (hasWrite mode) $
77
- let sendNotifications = void . H. notifyPool pool dbChannel . toS
78
- in notifySession validClaims (toS ch) conn sendNotifications
86
+ when (hasWrite mode) $
87
+ let sendNotifications = void . H. notifyPool pool dbChannel . toS
88
+ in notifySession validClaims (toS ch) conn getTime sendNotifications
79
89
80
- waitForever <- newEmptyMVar
81
- void $ takeMVar waitForever
90
+ waitForever <- newEmptyMVar
91
+ void $ takeMVar waitForever
82
92
83
93
-- Having both channel and claims as parameters seem redundant
84
94
-- But it allows the function to ignore the claims structure and the source
85
95
-- of the channel, so all claims decoding can be coded in the caller
86
96
notifySession :: A. Object
87
97
-> Text
88
98
-> WS. Connection
99
+ -> IO UTCTime
89
100
-> (ByteString -> IO () )
90
101
-> IO ()
91
- notifySession claimsToSend ch wsCon send =
102
+ notifySession claimsToSend ch wsCon getTime send =
92
103
withAsync (forever relayData) wait
93
104
where
94
105
relayData = jsonMsgWithTime >>= send
@@ -102,5 +113,5 @@ notifySession claimsToSend ch wsCon send =
102
113
claimsWithChannel = M. insert " channel" (A. String ch) claimsToSend
103
114
claimsWithTime :: IO (M. HashMap Text A. Value )
104
115
claimsWithTime = do
105
- time <- getPOSIXTime
106
- return $ M. insert " message_delivered_at" (A. Number $ fromRational $ toRational time) claimsWithChannel
116
+ time <- utcTimeToPOSIXSeconds <$> getTime
117
+ return $ M. insert " message_delivered_at" (A. Number $ realToFrac time) claimsWithChannel
0 commit comments