Skip to content

Commit 7bbd681

Browse files
mempool: adapt for generalized validation
1 parent d6a9748 commit 7bbd681

File tree

1 file changed

+73
-50
lines changed
  • ouroboros-network/src/Ouroboros/Network/TxSubmission/Mempool

1 file changed

+73
-50
lines changed

ouroboros-network/src/Ouroboros/Network/TxSubmission/Mempool/Simple.hs

Lines changed: 73 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1-
{-# LANGUAGE DerivingStrategies #-}
2-
{-# LANGUAGE GADTs #-}
3-
{-# LANGUAGE NamedFieldPuns #-}
4-
{-# LANGUAGE ScopedTypeVariables #-}
5-
{-# LANGUAGE StandaloneDeriving #-}
1+
{-# LANGUAGE BangPatterns #-}
2+
{-# LANGUAGE DerivingStrategies #-}
3+
{-# LANGUAGE DisambiguateRecordFields #-}
4+
{-# LANGUAGE FlexibleContexts #-}
5+
{-# LANGUAGE GADTs #-}
6+
{-# LANGUAGE NamedFieldPuns #-}
7+
{-# LANGUAGE ScopedTypeVariables #-}
8+
{-# LANGUAGE StandaloneDeriving #-}
69

710
-- | The module should be imported qualified.
811
--
@@ -19,24 +22,22 @@ module Ouroboros.Network.TxSubmission.Mempool.Simple
1922
import Prelude hiding (read, seq)
2023

2124
import Control.Concurrent.Class.MonadSTM.Strict
22-
import Control.Monad (when)
2325
import Control.Monad.Class.MonadThrow
24-
26+
import Control.Monad.Trans.Except
2527
import Data.Bifunctor (bimap)
26-
import Data.Either (partitionEithers)
28+
import Data.Either
2729
import Data.Foldable (toList)
2830
import Data.Foldable qualified as Foldable
29-
import Data.Function (on)
30-
import Data.List (find, nubBy)
31-
import Data.Maybe (isJust)
31+
import Data.List (find)
32+
import Data.Maybe (isJust, isNothing, fromJust)
3233
import Data.Sequence (Seq)
3334
import Data.Sequence qualified as Seq
3435
import Data.Set (Set)
3536
import Data.Set qualified as Set
3637
import Data.Typeable (Typeable)
3738

39+
import Ouroboros.Network.Protocol.LocalTxSubmission.Type (SubmitResult (..))
3840
import Ouroboros.Network.SizeInBytes
39-
import Ouroboros.Network.TxSubmission.Inbound.V2.Types
4041
import Ouroboros.Network.TxSubmission.Mempool.Reader
4142

4243

@@ -120,11 +121,35 @@ deriving instance Show InvalidTxsError
120121
instance Exception InvalidTxsError
121122

122123

123-
-- | A simple mempool writer.
124+
-- | A mempool writer which generalizes the tx submission mempool writer
125+
-- TODO: We could replace TxSubmissionMempoolWriter with this at some point
126+
--
127+
data MempoolWriter txid tx failure idx m =
128+
MempoolWriter {
129+
130+
-- | Compute the transaction id from a transaction.
131+
--
132+
-- This is used in the protocol handler to verify a full transaction
133+
-- matches a previously given transaction id.
134+
--
135+
txId :: tx -> txid,
136+
137+
-- | Supply a batch of transactions to the mempool. They are either
138+
-- accepted or rejected individually, but in the order supplied.
139+
--
140+
-- The 'txid's of all transactions that were added successfully are
141+
-- returned.
142+
mempoolAddTxs :: [tx] -> m [SubmitResult failure]
143+
}
144+
145+
146+
-- | A mempool writer with validation harness
147+
-- PRECONDITION: no duplicates given to mempoolAddTxs
124148
--
125149
getWriter :: forall tx txid ctx failure m.
126150
( MonadSTM m
127-
, MonadThrow m
151+
, Exception failure
152+
, MonadThrow (STM m)
128153
, Ord txid
129154
, Typeable txid
130155
, Typeable failure
@@ -134,40 +159,38 @@ getWriter :: forall tx txid ctx failure m.
134159
=> (tx -> txid)
135160
-- ^ get txid of a tx
136161
-> m ctx
137-
-- ^ monadic validation ctx
138-
-> (ctx -> tx -> Either failure ())
139-
-- ^ validate a tx, any failing `tx` throws an exception.
140-
-> (failure -> Bool)
141-
-- ^ return `True` when a failure should throw an exception
162+
-- ^ acquire validation context
163+
-> ([tx] -> ctx -> [Except failure (Either failure ())])
164+
-- ^ validation function
165+
-> Maybe failure
166+
-- ^ replace duplicates if Just
142167
-> Mempool m txid tx
143-
-> TxSubmissionMempoolWriter txid tx Int m
144-
getWriter getTxId getValidationCtx validateTx failureFilterFn (Mempool mempool) =
145-
TxSubmissionMempoolWriter {
146-
txId = getTxId,
147-
148-
mempoolAddTxs = \txs -> do
149-
ctx <- getValidationCtx
150-
(invalidTxIds, validTxs) <- atomically $ do
151-
MempoolSeq { mempoolSet, mempoolSeq } <- readTVar mempool
152-
let (invalidTxIds, validTxs) =
153-
bimap (filter (failureFilterFn . snd))
154-
(nubBy (on (==) getTxId))
155-
. partitionEithers
156-
. map (\tx -> case validateTx ctx tx of
157-
Left e -> Left (getTxId tx, e)
158-
Right _ -> Right tx
159-
)
160-
. filter (\tx -> getTxId tx `Set.notMember` mempoolSet)
161-
$ txs
162-
mempoolTxs' = MempoolSeq {
163-
mempoolSet = Foldable.foldl' (\s tx -> getTxId tx `Set.insert` s)
164-
mempoolSet
165-
validTxs,
166-
mempoolSeq = Foldable.foldl' (Seq.|>) mempoolSeq validTxs
167-
}
168-
writeTVar mempool mempoolTxs'
169-
return (invalidTxIds, map getTxId validTxs)
170-
when (not (null invalidTxIds)) $
171-
throwIO (InvalidTxsError invalidTxIds)
172-
return validTxs
173-
}
168+
-> MempoolWriter txid tx failure Int m
169+
getWriter getTxId acquireCtx validateTxs mDuplicate (Mempool mempool) =
170+
MempoolWriter {
171+
txId = getTxId,
172+
173+
mempoolAddTxs = \txs -> do
174+
ctx <- acquireCtx
175+
let !vTxs = validateTxs txs ctx
176+
atomically $ do
177+
MempoolSeq { mempoolSet, mempoolSeq } <- readTVar mempool
178+
result <- sequence
179+
[if duplicate then
180+
pure . Left . SubmitFail . fromJust $ mDuplicate
181+
else case runExcept vtx of
182+
Left e -> throwSTM e
183+
Right result -> pure $ bimap SubmitFail (const (txid, tx)) result
184+
| (tx, vtx) <- zip txs vTxs
185+
, let txid = getTxId tx
186+
duplicate = txid `Set.member` mempoolSet
187+
, not (duplicate && isNothing mDuplicate)
188+
]
189+
let (validIds, validTxs) = unzip . rights $ result
190+
mempoolTxs' = MempoolSeq {
191+
mempoolSet = Set.union mempoolSet (Set.fromList validIds),
192+
mempoolSeq = Foldable.foldl' (Seq.|>) mempoolSeq validTxs
193+
}
194+
writeTVar mempool mempoolTxs'
195+
return $ fromLeft SubmitSuccess <$> result
196+
}

0 commit comments

Comments
 (0)