Skip to content

Commit efe09c2

Browse files
authored
feat: add stdArray bounds checking (#3873)
1 parent 298cc45 commit efe09c2

File tree

3 files changed

+117
-3
lines changed

3 files changed

+117
-3
lines changed

src/coreComponents/codingUtilities/tests/testGeosxTraits.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ TEST( testGeosxTraits, HasMemberFunction_insert )
7070
static_assert( HasMemberFunction_insert< std::list< stdVector< int > > >, "Should be true." );
7171

7272
static_assert( !HasMemberFunction_insert< int >, "Should be false." );
73-
static_assert( !HasMemberFunction_insert< std::array< int, 5 > >, "Should be false." );
73+
static_assert( !HasMemberFunction_insert< stdArray< int, 5 > >, "Should be false." );
7474
}
7575

7676

src/coreComponents/common/StdContainerWrappers.hpp

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#ifndef GEOS_COMMON_STD_CONTAINER_WRAPPERS_HPP
22
#define GEOS_COMMON_STD_CONTAINER_WRAPPERS_HPP
33

4+
#include <array>
5+
#include <cstddef>
46
#include <vector>
57
#include <map>
68
#include <unordered_map>
@@ -346,6 +348,118 @@ class mapBase< TKEY, TVAL, std::integral_constant< bool, false > > : public std:
346348
};
347349
/// @endcond
348350

351+
namespace internal
352+
{
353+
354+
/**
355+
* Wrapper for the underlying std::aray that allows toggling between bounds-checked access
356+
* (using at()) and unchecked access (using operator[]).
357+
* @tparam T Type of elements in stdArray
358+
* @tparam N The number of fixed element in the array
359+
* @tparam USE_STD_CONTAINER_BOUNDS_CHECKING A boolean flag to enable or disable bounds checking.
360+
*/
361+
template< class T,
362+
std::size_t N,
363+
bool USE_BOUNDS_CHECKING = false >
364+
struct StdArrayWrapper : public std::array< T, N >
365+
{
366+
public:
367+
/// Type alias for the base class (i.e., stdArray)
368+
using Base = std::array< T, N >;
369+
370+
/**
371+
* Access element at index with bounds checking if USE_STD_CONTAINER_BOUNDS_CHECKING is true.
372+
* Otherwise, uses operator[] for unchecked access.
373+
* @param index Index of the element to access.
374+
* @return Const reference to the element at the specified index.
375+
* @throws std::out_of_range if index is out of bounds.
376+
*/
377+
constexpr T & operator[]( size_t const index )
378+
{
379+
if constexpr (USE_BOUNDS_CHECKING)
380+
{
381+
return this->at( index );
382+
}
383+
else
384+
{
385+
return Base::operator[]( index );
386+
}
387+
}
388+
389+
/**
390+
* Access element at index with bounds checking if USE_STD_CONTAINER_BOUNDS_CHECKING is true.
391+
* Otherwise, uses operator[] for unchecked access.
392+
* @param index Index of the element to access.
393+
* @return Const reference to the element at the specified index.
394+
* @throws std::out_of_range if index is out of bounds.
395+
*/
396+
constexpr T const & operator[]( size_t const index ) const
397+
{
398+
if constexpr (USE_BOUNDS_CHECKING)
399+
{
400+
return this->at( index );
401+
}
402+
else
403+
{
404+
return Base::operator[]( index );
405+
}
406+
}
407+
408+
};
409+
} //namespace internal
410+
411+
/**
412+
* @tparam T Type of elements in the array.
413+
* @tparam N The number of fixed element in the array.
414+
* @note we use a struct rather than an alias for taking advantage of deduction guide
415+
*/
416+
template< class T, std::size_t N >
417+
struct stdArray : public internal::StdArrayWrapper< T, N, true >
418+
{};
419+
420+
/**
421+
* @brief Deduction guide for stdArray
422+
* Allows the element type and array size to be automatically deduced from the initialization list.
423+
* @code
424+
* stdArray a{1, 2, 3}; // deduces stdArray<int, 3>
425+
* @endcode
426+
* @tparam _Tp Type of the first element provided.
427+
* @tparam _Up Types of the other elements provided.
428+
*/
429+
template< typename _Tp, typename ... _Up >
430+
stdArray( _Tp, _Up ... )
431+
->stdArray< std::enable_if_t< (std::is_same_v< _Tp, _Up > && ...), _Tp >,
432+
1 + sizeof...(_Up) >;
433+
434+
/**
435+
* @brief Helper function that convert a std::array into a stdArray by expanding its elements
436+
* @tparam T The type of the elements
437+
* @tparam N The number of fixed element in the array
438+
* @tparam Is An integer parameter pack representing the indices
439+
* @param arr The input std::array to be converted.
440+
* @return A constexpr stdArray< T, N >
441+
*/
442+
template< typename T, std::size_t N, std::size_t... Is >
443+
constexpr stdArray< T, N > to_stdArray_impl( std::array< T, N > const & arr, std::index_sequence< Is... > )
444+
{
445+
return {{arr[Is] ...}};
446+
}
447+
448+
/**
449+
* @brief Convert an std::array to an stdArray.
450+
* @tparam T Type of elements in stdArray
451+
* @tparam N The number of fixed element in the array
452+
* @param arr The std::array to convert
453+
* @return A stdArray< T, N >
454+
@note we don't want implicitly convert an std::array into a stdArray.
455+
*/
456+
template< typename T, std::size_t N >
457+
constexpr stdArray< T, N > to_stdArray( std::array< T, N > const & arr )
458+
{
459+
return to_stdArray_impl( arr, std::make_index_sequence< N >{} );
460+
}
461+
462+
349463
} // namespace geos
350464

351465
#endif /* GEOS_COMMON_STD_CONTAINER_WRAPPERS_HPP */

src/coreComponents/dataRepository/wrapperHelpers.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -425,7 +425,7 @@ pushDataToConduitNode( Array< T, NDIM, PERMUTATION > const & var,
425425
node[ "__dimensions__" ].set( dimensionType, temp );
426426

427427
// Create a copy of the permutation
428-
constexpr std::array< camp::idx_t, NDIM > const perm = RAJA::as_array< PERMUTATION >::get();
428+
constexpr std::array< camp::idx_t, NDIM > const perm = to_stdArray( RAJA::as_array< PERMUTATION >::get());
429429
for( int i = 0; i < NDIM; ++i )
430430
{
431431
temp[ i ] = perm[ i ];
@@ -454,7 +454,7 @@ pullDataFromConduitNode( Array< T, NDIM, PERMUTATION > & var,
454454
conduit::Node const & permutationNode = node.fetch_existing( "__permutation__" );
455455
GEOS_ERROR_IF_NE( permutationNode.dtype().number_of_elements(), totalNumDimensions );
456456

457-
constexpr std::array< camp::idx_t, NDIM > const perm = RAJA::as_array< PERMUTATION >::get();
457+
constexpr std::array< camp::idx_t, NDIM > const perm = to_stdArray( RAJA::as_array< PERMUTATION >::get());
458458
camp::idx_t const * const permFromConduit = permutationNode.value();
459459
for( int i = 0; i < NDIM; ++i )
460460
{

0 commit comments

Comments
 (0)