Skip to content

Commit 045b62b

Browse files
maxkozlovskyjhunsaker
authored andcommitted
Optimize MPT node allocation with custom allocator and SharedPtr
Replace UniquePtr→SharedPtr conversions with direct SharedPtr creation using std::allocate_shared and custom variable_size_allocator. This eliminates double memory allocations (separate control block and object) by allocating both in a single block. - Add variable_size_allocator for single-allocation SharedPtr creation - Create make_shared Node creation functions returning SharedPtr - Modify existing Node creation function to return SharedPtr Some functions now have both shared and unique ptr versions for optimal usage Partially generated using Claude Sonnet 4.5
1 parent 034e382 commit 045b62b

15 files changed

+171
-60
lines changed

category/core/mem/allocators.hpp

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,100 @@ namespace allocators
292292
throw;
293293
}
294294
}
295+
296+
/**************************************************************************/
297+
//! \brief A STL allocator for use with `std::allocate_shared` that
298+
//! allocates extra bytes beyond sizeof(T) for trailing variable-length
299+
//! data.
300+
//!
301+
//! This allocator is designed for types with flexible array members or
302+
//! trailing data. When used with `std::allocate_shared`, it ensures the
303+
//! control block and object with trailing data are allocated together.
304+
//!
305+
//! \tparam T The type to allocate
306+
template <typename T>
307+
struct variable_size_allocator
308+
{
309+
using value_type = T;
310+
using size_type = std::size_t;
311+
using difference_type = std::ptrdiff_t;
312+
313+
//! \brief Construct allocator with total storage size
314+
//! \param storage_bytes Total bytes needed (sizeof(T) + trailing data)
315+
//!
316+
//! The extra_bytes_ member stores the additional bytes beyond sizeof(T)
317+
//! needed for trailing data (path, value, child data, etc.)
318+
explicit variable_size_allocator(size_t storage_bytes) noexcept
319+
: extra_bytes_(storage_bytes - sizeof(T))
320+
{
321+
MONAD_ASSERT(storage_bytes >= sizeof(T));
322+
}
323+
324+
//! \brief Rebind constructor for allocator conversion
325+
template <typename U>
326+
// NOLINTNEXTLINE(google-explicit-constructor)
327+
variable_size_allocator(
328+
variable_size_allocator<U> const &other) noexcept
329+
: extra_bytes_(other.extra_bytes_)
330+
{
331+
}
332+
333+
//! \brief Allocate memory for n objects of type T plus extra bytes
334+
//! \param n Number of objects to allocate (must be 1)
335+
//!
336+
//! For std::allocate_shared:
337+
//! - If T is the object type: allocates sizeof(T) + extra_bytes
338+
//! - If T is control block: allocates sizeof(control_block) +
339+
//! extra_bytes
340+
//! (control block already includes sizeof(object), so this gives
341+
//! control block + object trailing data)
342+
[[nodiscard]] T *allocate(size_type n)
343+
{
344+
MONAD_ASSERT(n == 1);
345+
size_t const bytes = sizeof(T) + extra_bytes_;
346+
347+
if constexpr (alignof(T) > alignof(max_align_t)) {
348+
return reinterpret_cast<T *>(
349+
::operator new(bytes, std::align_val_t{alignof(T)}));
350+
}
351+
return reinterpret_cast<T *>(::operator new(bytes));
352+
}
353+
354+
//! \brief Deallocate memory
355+
void deallocate(T *p, size_type) noexcept
356+
{
357+
if constexpr (alignof(T) > alignof(max_align_t)) {
358+
::operator delete(p, std::align_val_t{alignof(T)});
359+
}
360+
else {
361+
::operator delete(p);
362+
}
363+
}
364+
365+
//! \brief Rebind to allocate different types
366+
template <typename U>
367+
struct rebind
368+
{
369+
using other = variable_size_allocator<U>;
370+
};
371+
372+
bool operator==(variable_size_allocator const &other) const noexcept
373+
{
374+
return extra_bytes_ == other.extra_bytes_;
375+
}
376+
377+
bool operator!=(variable_size_allocator const &other) const noexcept
378+
{
379+
return !(*this == other);
380+
}
381+
382+
template <typename U>
383+
friend struct variable_size_allocator;
384+
385+
private:
386+
//! Extra bytes beyond sizeof(T) for trailing data
387+
size_t extra_bytes_;
388+
};
295389
}
296390

297391
MONAD_NAMESPACE_END

category/mpt/copy_trie.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@
3333

3434
MONAD_MPT_NAMESPACE_BEGIN
3535

36-
Node::UniquePtr create_node_add_new_branch(
36+
Node::SharedPtr create_node_add_new_branch(
3737
UpdateAuxImpl &aux, Node *const node, unsigned char const new_branch,
38-
Node::UniquePtr new_child, uint64_t const new_version,
38+
Node::SharedPtr new_child, uint64_t const new_version,
3939
std::optional<byte_string_view> opt_value)
4040
{
4141
uint16_t const mask =
@@ -80,9 +80,9 @@ Node::UniquePtr create_node_add_new_branch(
8080
static_cast<int64_t>(new_version));
8181
}
8282

83-
Node::UniquePtr create_node_with_two_children(
83+
Node::SharedPtr create_node_with_two_children(
8484
UpdateAuxImpl &aux, NibblesView const path, unsigned char const branch0,
85-
Node::UniquePtr child0, unsigned char const branch1, Node::UniquePtr child1,
85+
Node::SharedPtr child0, unsigned char const branch1, Node::SharedPtr child1,
8686
uint64_t const new_version, std::optional<byte_string_view> opt_value)
8787
{
8888
// mismatch: split node's path: turn node to a branch node with two
@@ -158,7 +158,7 @@ Node::SharedPtr copy_trie_impl(
158158
Node *parent = nullptr;
159159
unsigned char branch = INVALID_BRANCH;
160160
Node::SharedPtr node = dest_root;
161-
Node::UniquePtr new_node{};
161+
Node::SharedPtr new_node{};
162162
unsigned prefix_index = 0;
163163
unsigned node_prefix_index = 0;
164164

@@ -216,7 +216,7 @@ Node::SharedPtr copy_trie_impl(
216216
if (node->mask & (1u << nibble)) {
217217
auto const index = node->to_child_index(nibble);
218218
if (node->next(index) == nullptr) {
219-
Node::UniquePtr next_node_ondisk =
219+
auto next_node_ondisk =
220220
read_node_blocking(aux, node->fnext(index), dest_version);
221221
MONAD_ASSERT(next_node_ondisk != nullptr);
222222
node->set_next(index, std::move(next_node_ondisk));

category/mpt/deserialize_node_from_receiver_result.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,12 @@ namespace detail
5454
}
5555

5656
template <class NodeType, class ResultType>
57-
inline NodeType::UniquePtr deserialize_node_from_receiver_result(
57+
inline NodeType::SharedPtr deserialize_node_from_receiver_result(
5858
ResultType buffer_, uint16_t buffer_off,
5959
MONAD_ASYNC_NAMESPACE::erased_connected_operation *io_state)
6060
{
6161
MONAD_ASSERT(buffer_);
62-
typename NodeType::UniquePtr node;
62+
typename NodeType::SharedPtr node;
6363
if constexpr (std::is_same_v<
6464
std::decay_t<ResultType>,
6565
typename monad::async::read_single_buffer_sender::

category/mpt/find.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ find_cursor_result_type find_blocking(
5252
MONAD_ASSERT(aux.is_on_disk());
5353
auto g2(g.upgrade());
5454
if (g2.upgrade_was_atomic() || !node->next(idx)) {
55-
Node::UniquePtr next_node_ondisk =
55+
auto next_node_ondisk =
5656
read_node_blocking(aux, node->fnext(idx), version);
5757
if (!next_node_ondisk) {
5858
return {

category/mpt/node.cpp

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,7 @@ void ChildData::erase()
456456
}
457457

458458
void ChildData::finalize(
459-
Node::UniquePtr node, Compute &compute, bool const cache)
459+
Node::SharedPtr node, Compute &compute, bool const cache)
460460
{
461461
MONAD_DEBUG_ASSERT(is_valid());
462462
ptr = std::move(node);
@@ -488,13 +488,13 @@ void ChildData::copy_old_child(Node *const old, unsigned const i)
488488
MONAD_DEBUG_ASSERT(is_valid());
489489
}
490490

491-
Node::UniquePtr make_node(
491+
Node::SharedPtr make_node(
492492
Node &from, NibblesView const path,
493493
std::optional<byte_string_view> const value, int64_t const version)
494494
{
495495
auto const value_size =
496496
value.transform(&byte_string_view::size).value_or(0);
497-
auto node = Node::make(
497+
auto node = Node::make_shared(
498498
calculate_node_size(
499499
from.number_of_children(),
500500
from.child_data_len(),
@@ -530,17 +530,11 @@ Node::UniquePtr make_node(
530530
return node;
531531
}
532532

533-
Node::UniquePtr make_node(
533+
Node::SharedPtr make_node(
534534
uint16_t const mask, std::span<ChildData> const children,
535535
NibblesView const path, std::optional<byte_string_view> const value,
536536
size_t const data_size, int64_t const version)
537537
{
538-
MONAD_DEBUG_ASSERT(data_size <= KECCAK256_SIZE);
539-
if (value.has_value()) {
540-
MONAD_DEBUG_ASSERT(
541-
value->size() <=
542-
std::numeric_limits<decltype(Node::value_len)>::max());
543-
}
544538
for (size_t i = 0; i < 16; ++i) {
545539
MONAD_DEBUG_ASSERT(
546540
!std::ranges::contains(children, i, &ChildData::branch) ||
@@ -558,7 +552,7 @@ Node::UniquePtr make_node(
558552
}
559553
}
560554

561-
auto node = Node::make(
555+
auto node = Node::make_shared(
562556
calculate_node_size(
563557
number_of_children,
564558
total_child_data_size,
@@ -596,7 +590,7 @@ Node::UniquePtr make_node(
596590
return node;
597591
}
598592

599-
Node::UniquePtr make_node(
593+
Node::SharedPtr make_node(
600594
uint16_t const mask, std::span<ChildData> const children,
601595
NibblesView const path, std::optional<byte_string_view> const value,
602596
byte_string_view const data, int64_t const version)
@@ -608,7 +602,7 @@ Node::UniquePtr make_node(
608602

609603
// all children's offset are set before creating parent
610604
// create node with at least one child
611-
Node::UniquePtr create_node_with_children(
605+
Node::SharedPtr create_node_with_children(
612606
Compute &comp, uint16_t const mask, std::span<ChildData> const children,
613607
NibblesView const path, std::optional<byte_string_view> const value,
614608
int64_t const version)

category/mpt/node.hpp

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ class NodeBase
158158
max_disk_size + max_number_of_children * KECCAK256_SIZE;
159159

160160
template <node_type DestNodeType, node_type SrcNodeType>
161-
friend DestNodeType::UniquePtr copy_node(SrcNodeType const *const from);
161+
friend DestNodeType::SharedPtr copy_node(SrcNodeType const *const from);
162162

163163
/* 16-bit mask for children */
164164
uint16_t mask{0};
@@ -327,6 +327,16 @@ class Node final : public NodeBase
327327
std::forward<Args>(args)...);
328328
}
329329

330+
template <class... Args>
331+
static SharedPtr make_shared(size_t bytes, Args &&...args)
332+
{
333+
allocators::variable_size_allocator<Node> alloc(bytes);
334+
return std::allocate_shared<Node>(
335+
alloc,
336+
prevent_public_construction_tag{},
337+
std::forward<Args>(args)...);
338+
}
339+
330340
SharedPtr *child_ptr(unsigned index) noexcept;
331341
SharedPtr const *child_ptr(unsigned index) const noexcept;
332342

@@ -351,6 +361,7 @@ class CacheNode final : public NodeBase
351361
using Deleter = allocators::unique_ptr_aliasing_allocator_deleter<
352362
&allocators::aliasing_allocator_pair<CacheNode>>;
353363
using UniquePtr = std::unique_ptr<CacheNode, Deleter>;
364+
using SharedPtr = std::shared_ptr<CacheNode>;
354365

355366
CacheNode(prevent_public_construction_tag)
356367
: NodeBase()
@@ -368,6 +379,16 @@ class CacheNode final : public NodeBase
368379
std::forward<Args>(args)...);
369380
}
370381

382+
template <class... Args>
383+
static SharedPtr make_shared(size_t bytes, Args &&...args)
384+
{
385+
allocators::variable_size_allocator<CacheNode> alloc(bytes);
386+
return std::allocate_shared<CacheNode>(
387+
alloc,
388+
prevent_public_construction_tag{},
389+
std::forward<Args>(args)...);
390+
}
391+
371392
void *next(size_t const index) const noexcept;
372393
void set_next(unsigned const index, void *const ptr) noexcept;
373394

@@ -398,7 +419,7 @@ struct ChildData
398419

399420
bool is_valid() const;
400421
void erase();
401-
void finalize(Node::UniquePtr, Compute &, bool cache);
422+
void finalize(Node::SharedPtr, Compute &, bool cache);
402423
void copy_old_child(Node *old, unsigned i);
403424
};
404425

@@ -431,21 +452,21 @@ constexpr size_t MAX_VALUE_LEN_OF_LEAF =
431452
0 /* number_of_children */, 0 /* child_data_size */, 0 /* value_size */,
432453
KECCAK256_SIZE /* path_size */, KECCAK256_SIZE /* data_size*/);
433454

434-
Node::UniquePtr make_node(
455+
Node::SharedPtr make_node(
435456
Node &from, NibblesView path, std::optional<byte_string_view> value,
436457
int64_t version);
437458

438-
Node::UniquePtr make_node(
459+
Node::SharedPtr make_node(
439460
uint16_t mask, std::span<ChildData>, NibblesView path,
440461
std::optional<byte_string_view> value, size_t data_size, int64_t version);
441462

442-
Node::UniquePtr make_node(
463+
Node::SharedPtr make_node(
443464
uint16_t mask, std::span<ChildData>, NibblesView path,
444465
std::optional<byte_string_view> value, byte_string_view data,
445466
int64_t version);
446467

447468
// create node: either branch/extension, with or without leaf
448-
Node::UniquePtr create_node_with_children(
469+
Node::SharedPtr create_node_with_children(
449470
Compute &, uint16_t mask, std::span<ChildData> children, NibblesView path,
450471
std::optional<byte_string_view> value, int64_t version);
451472

@@ -454,7 +475,7 @@ void serialize_node_to_buffer(
454475
uint32_t disk_size, unsigned offset = 0);
455476

456477
template <node_type NodeType>
457-
inline NodeType::UniquePtr
478+
inline NodeType::SharedPtr
458479
deserialize_node_from_buffer(unsigned char const *read_pos, size_t max_bytes)
459480
{
460481
for (size_t n = 0; n < max_bytes; n += 64) {
@@ -473,7 +494,7 @@ deserialize_node_from_buffer(unsigned char const *read_pos, size_t max_bytes)
473494
if constexpr (std::same_as<NodeType, Node>) {
474495
auto const alloc_size = round_up_align<3>(base_size) +
475496
number_of_children * sizeof(Node::SharedPtr);
476-
auto node = NodeType::make(alloc_size);
497+
auto node = NodeType::make_shared(alloc_size);
477498
std::copy_n(read_pos, base_size, (unsigned char *)node.get());
478499
for (unsigned i = 0; i < node->number_of_children(); ++i) {
479500
new (node->child_ptr(i)) Node::SharedPtr();
@@ -484,7 +505,7 @@ deserialize_node_from_buffer(unsigned char const *read_pos, size_t max_bytes)
484505
else {
485506
auto const alloc_size = round_up_align<3>(base_size) +
486507
number_of_children * sizeof(NodeType *);
487-
auto node = NodeType::make(alloc_size);
508+
auto node = NodeType::make_shared(alloc_size);
488509
std::copy_n(read_pos, base_size, (unsigned char *)node.get());
489510
std::memset(
490511
node->next_data_aligned(),
@@ -496,7 +517,7 @@ deserialize_node_from_buffer(unsigned char const *read_pos, size_t max_bytes)
496517
}
497518

498519
template <node_type DestNodeType, node_type SrcNodeType>
499-
DestNodeType::UniquePtr copy_node(SrcNodeType const *const from)
520+
DestNodeType::SharedPtr copy_node(SrcNodeType const *const from)
500521
{
501522
auto const number_of_children = from->number_of_children();
502523
auto const base_size = static_cast<unsigned>(
@@ -505,7 +526,7 @@ DestNodeType::UniquePtr copy_node(SrcNodeType const *const from)
505526
if constexpr (std::same_as<DestNodeType, Node>) {
506527
auto const alloc_size = round_up_align<3>(base_size) +
507528
number_of_children * sizeof(Node::SharedPtr);
508-
auto node_copy = DestNodeType::make(alloc_size);
529+
auto node_copy = DestNodeType::make_shared(alloc_size);
509530
std::copy_n(
510531
(unsigned char *)from, base_size, (unsigned char *)node_copy.get());
511532
for (unsigned i = 0; i < number_of_children; ++i) {
@@ -516,7 +537,7 @@ DestNodeType::UniquePtr copy_node(SrcNodeType const *const from)
516537
else {
517538
auto const next_ptrs_size = number_of_children * sizeof(void *);
518539
auto const alloc_size = round_up_align<3>(base_size) + next_ptrs_size;
519-
auto node_copy = DestNodeType::make(alloc_size);
540+
auto node_copy = DestNodeType::make_shared(alloc_size);
520541
std::copy_n(
521542
(unsigned char *)from, base_size, (unsigned char *)node_copy.get());
522543
// reset all in memory children

category/mpt/read_node_blocking.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
MONAD_MPT_NAMESPACE_BEGIN
3131

32-
Node::UniquePtr read_node_blocking(
32+
Node::SharedPtr read_node_blocking(
3333
UpdateAuxImpl const &aux, chunk_offset_t const node_offset,
3434
uint64_t const version)
3535
{
@@ -69,7 +69,7 @@ Node::UniquePtr read_node_blocking(
6969
return aux.version_is_valid_ondisk(version)
7070
? deserialize_node_from_buffer<Node>(
7171
buffer + buffer_off, size_t(bytes_read) - buffer_off)
72-
: Node::UniquePtr{};
72+
: Node::SharedPtr{};
7373
}
7474

7575
MONAD_MPT_NAMESPACE_END

0 commit comments

Comments
 (0)