diff --git a/src/Database/Redis/ConnectionContext.hs b/src/Database/Redis/ConnectionContext.hs index e82796a..a21672b 100644 --- a/src/Database/Redis/ConnectionContext.hs +++ b/src/Database/Redis/ConnectionContext.hs @@ -22,6 +22,7 @@ import qualified Data.ByteString as B import qualified Data.ByteString.Lazy as LB import qualified Data.IORef as IOR import qualified Data.Time as Time +import Data.List (drop, length, take) import Data.Maybe (fromMaybe) import Control.Concurrent.MVar(newMVar, readMVar, swapMVar) import Control.Exception(bracketOnError, Exception, throwIO, try) @@ -34,6 +35,7 @@ import System.IO(Handle, hSetBinaryMode, hClose, IOMode(..), hFlush, hIsOpen) import System.IO.Error(catchIOError) import Text.Read (readMaybe) import System.Timeout (timeout) +import System.Random (randomRIO) data ConnectionContext = NormalHandle Handle | TLSContext TLS.Context @@ -109,14 +111,16 @@ errConnectTimeout phase = throwIO $ ConnectTimeout phase connectSocket :: [NS.AddrInfo] -> IO NS.Socket connectSocket [] = error "connectSocket: unexpected empty list" -connectSocket (addr:rest) = tryConnect >>= \case - Right sock -> return sock - Left err -> if null rest - then throwIO err - else connectSocket rest +connectSocket addrs = do + (addr, rest) <- pickRandomAddr + sockM <- tryConnect addr + case (sockM, rest) of + (Right sock, _) -> return sock + (Left err, []) -> throwIO err + (Left _err, _) -> connectSocket rest where - tryConnect :: IO (Either IOError NS.Socket) - tryConnect = bracketOnError createSock NS.close $ \sock -> + tryConnect :: NS.AddrInfo -> IO (Either IOError NS.Socket) + tryConnect addr = bracketOnError createSock NS.close $ \sock -> try (NS.connect sock $ NS.addrAddress addr) >>= \case Right () -> return (Right sock) Left err -> NS.close sock >> return (Left err) @@ -124,6 +128,11 @@ connectSocket (addr:rest) = tryConnect >>= \case createSock = NS.socket (NS.addrFamily addr) (NS.addrSocketType addr) (NS.addrProtocol addr) + pickRandomAddr = do + i <- randomRIO (0, length addrs - 1) + let front = take (if i<2 then 0 else i-1) addrs + back = drop (i+1) addrs + pure (addrs !! i, front ++ back) send :: ConnectionContext -> B.ByteString -> IO () send (NormalHandle h) requestData =