Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions src/Database/Redis/ConnectionContext.hs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as LB
import qualified Data.IORef as IOR
import qualified Data.Time as Time
import Data.List (drop, length, take)
import Data.Maybe (fromMaybe)
import Control.Concurrent.MVar(newMVar, readMVar, swapMVar)
import Control.Exception(bracketOnError, Exception, throwIO, try)
Expand All @@ -34,6 +35,7 @@ import System.IO(Handle, hSetBinaryMode, hClose, IOMode(..), hFlush, hIsOpen)
import System.IO.Error(catchIOError)
import Text.Read (readMaybe)
import System.Timeout (timeout)
import System.Random (randomRIO)

data ConnectionContext = NormalHandle Handle | TLSContext TLS.Context

Expand Down Expand Up @@ -109,21 +111,28 @@ errConnectTimeout phase = throwIO $ ConnectTimeout phase

connectSocket :: [NS.AddrInfo] -> IO NS.Socket
connectSocket [] = error "connectSocket: unexpected empty list"
connectSocket (addr:rest) = tryConnect >>= \case
Right sock -> return sock
Left err -> if null rest
then throwIO err
else connectSocket rest
connectSocket addrs = do
(addr, rest) <- pickRandomAddr
sockM <- tryConnect addr
case (sockM, rest) of
(Right sock, _) -> return sock
(Left err, []) -> throwIO err
(Left _err, _) -> connectSocket rest
where
tryConnect :: IO (Either IOError NS.Socket)
tryConnect = bracketOnError createSock NS.close $ \sock ->
tryConnect :: NS.AddrInfo -> IO (Either IOError NS.Socket)
tryConnect addr = bracketOnError createSock NS.close $ \sock ->
try (NS.connect sock $ NS.addrAddress addr) >>= \case
Right () -> return (Right sock)
Left err -> NS.close sock >> return (Left err)
where
createSock = NS.socket (NS.addrFamily addr)
(NS.addrSocketType addr)
(NS.addrProtocol addr)
pickRandomAddr = do
i <- randomRIO (0, length addrs - 1)
let front = take (if i<2 then 0 else i-1) addrs
back = drop (i+1) addrs
pure (addrs !! i, front ++ back)

send :: ConnectionContext -> B.ByteString -> IO ()
send (NormalHandle h) requestData =
Expand Down