1- {-# LANGUAGE DerivingStrategies #-}
2- {-# LANGUAGE GADTs #-}
3- {-# LANGUAGE NamedFieldPuns #-}
4- {-# LANGUAGE ScopedTypeVariables #-}
5- {-# LANGUAGE StandaloneDeriving #-}
1+ {-# LANGUAGE BangPatterns #-}
2+ {-# LANGUAGE DerivingStrategies #-}
3+ {-# LANGUAGE DisambiguateRecordFields #-}
4+ {-# LANGUAGE FlexibleContexts #-}
5+ {-# LANGUAGE GADTs #-}
6+ {-# LANGUAGE NamedFieldPuns #-}
7+ {-# LANGUAGE ScopedTypeVariables #-}
8+ {-# LANGUAGE StandaloneDeriving #-}
69
710-- | The module should be imported qualified.
811--
@@ -19,24 +22,22 @@ module Ouroboros.Network.TxSubmission.Mempool.Simple
1922import Prelude hiding (read , seq )
2023
2124import Control.Concurrent.Class.MonadSTM.Strict
22- import Control.Monad (when )
2325import Control.Monad.Class.MonadThrow
24-
26+ import Control.Monad.Trans.Except
2527import Data.Bifunctor (bimap )
26- import Data.Either ( partitionEithers )
28+ import Data.Either
2729import Data.Foldable (toList )
2830import Data.Foldable qualified as Foldable
29- import Data.Function (on )
30- import Data.List (find , nubBy )
31- import Data.Maybe (isJust )
31+ import Data.List (find )
32+ import Data.Maybe (isJust , isNothing , fromJust )
3233import Data.Sequence (Seq )
3334import Data.Sequence qualified as Seq
3435import Data.Set (Set )
3536import Data.Set qualified as Set
3637import Data.Typeable (Typeable )
3738
39+ import Ouroboros.Network.Protocol.LocalTxSubmission.Type (SubmitResult (.. ))
3840import Ouroboros.Network.SizeInBytes
39- import Ouroboros.Network.TxSubmission.Inbound.V2.Types
4041import Ouroboros.Network.TxSubmission.Mempool.Reader
4142
4243
@@ -120,11 +121,35 @@ deriving instance Show InvalidTxsError
120121instance Exception InvalidTxsError
121122
122123
123- -- | A simple mempool writer.
124+ -- | A mempool writer which generalizes the tx submission mempool writer
125+ -- TODO: We could replace TxSubmissionMempoolWriter with this at some point
126+ --
127+ data MempoolWriter txid tx failure idx m =
128+ MempoolWriter {
129+
130+ -- | Compute the transaction id from a transaction.
131+ --
132+ -- This is used in the protocol handler to verify a full transaction
133+ -- matches a previously given transaction id.
134+ --
135+ txId :: tx -> txid ,
136+
137+ -- | Supply a batch of transactions to the mempool. They are either
138+ -- accepted or rejected individually, but in the order supplied.
139+ --
140+ -- The 'txid's of all transactions that were added successfully are
141+ -- returned.
142+ mempoolAddTxs :: [tx ] -> m [SubmitResult failure ]
143+ }
144+
145+
146+ -- | A mempool writer with validation harness
147+ -- PRECONDITION: no duplicates given to mempoolAddTxs
124148--
125149getWriter :: forall tx txid ctx failure m .
126150 ( MonadSTM m
127- , MonadThrow m
151+ , Exception failure
152+ , MonadThrow (STM m )
128153 , Ord txid
129154 , Typeable txid
130155 , Typeable failure
@@ -134,40 +159,38 @@ getWriter :: forall tx txid ctx failure m.
134159 => (tx -> txid )
135160 -- ^ get txid of a tx
136161 -> m ctx
137- -- ^ monadic validation ctx
138- -> (ctx -> tx -> Either failure () )
139- -- ^ validate a tx, any failing `tx` throws an exception.
140- -> ( failure -> Bool )
141- -- ^ return `True` when a failure should throw an exception
162+ -- ^ acquire validation context
163+ -> ([ tx ] -> ctx -> [ Except failure ( Either failure () )] )
164+ -- ^ validation function
165+ -> Maybe failure
166+ -- ^ replace duplicates if Just
142167 -> Mempool m txid tx
143- -> TxSubmissionMempoolWriter txid tx Int m
144- getWriter getTxId getValidationCtx validateTx failureFilterFn (Mempool mempool) =
145- TxSubmissionMempoolWriter {
146- txId = getTxId,
147-
148- mempoolAddTxs = \ txs -> do
149- ctx <- getValidationCtx
150- (invalidTxIds, validTxs) <- atomically $ do
151- MempoolSeq { mempoolSet, mempoolSeq } <- readTVar mempool
152- let (invalidTxIds, validTxs) =
153- bimap (filter (failureFilterFn . snd ))
154- (nubBy (on (==) getTxId))
155- . partitionEithers
156- . map (\ tx -> case validateTx ctx tx of
157- Left e -> Left (getTxId tx, e)
158- Right _ -> Right tx
159- )
160- . filter (\ tx -> getTxId tx `Set.notMember` mempoolSet)
161- $ txs
162- mempoolTxs' = MempoolSeq {
163- mempoolSet = Foldable. foldl' (\ s tx -> getTxId tx `Set.insert` s)
164- mempoolSet
165- validTxs,
166- mempoolSeq = Foldable. foldl' (Seq. |>) mempoolSeq validTxs
167- }
168- writeTVar mempool mempoolTxs'
169- return (invalidTxIds, map getTxId validTxs)
170- when (not (null invalidTxIds)) $
171- throwIO (InvalidTxsError invalidTxIds)
172- return validTxs
173- }
168+ -> MempoolWriter txid tx failure Int m
169+ getWriter getTxId acquireCtx validateTxs mDuplicate (Mempool mempool) =
170+ MempoolWriter {
171+ txId = getTxId,
172+
173+ mempoolAddTxs = \ txs -> do
174+ ctx <- acquireCtx
175+ let ! vTxs = validateTxs txs ctx
176+ atomically $ do
177+ MempoolSeq { mempoolSet, mempoolSeq } <- readTVar mempool
178+ result <- sequence
179+ [if duplicate then
180+ pure . Left . SubmitFail . fromJust $ mDuplicate
181+ else case runExcept vtx of
182+ Left e -> throwSTM e
183+ Right result -> pure $ bimap SubmitFail (const (txid, tx)) result
184+ | (tx, vtx) <- zip txs vTxs
185+ , let txid = getTxId tx
186+ duplicate = txid `Set.member` mempoolSet
187+ , not (duplicate && isNothing mDuplicate)
188+ ]
189+ let (validIds, validTxs) = unzip . rights $ result
190+ mempoolTxs' = MempoolSeq {
191+ mempoolSet = Set. union mempoolSet (Set. fromList validIds),
192+ mempoolSeq = Foldable. foldl' (Seq. |>) mempoolSeq validTxs
193+ }
194+ writeTVar mempool mempoolTxs'
195+ return $ fromLeft SubmitSuccess <$> result
196+ }
0 commit comments