@@ -439,16 +439,70 @@ namespace xt
439439 using requested_value_type = detail::conditional_promote_to_complex_t <e1_value_type, e2_requested_value_type>;
440440 };
441441
442+ /* *********************************
443+ * Expression Order Optimizations *
444+ **********************************/
445+
446+ class optimize_expression
447+ {
448+ private:
449+
450+ template <class E1 , class E2 >
451+ struct equal_rank
452+ {
453+ static constexpr bool value = get_rank<E1 >::value == get_rank<E2 >::value;
454+ };
455+
456+ template <class E1 , class ... E>
457+ struct all_equal_rank
458+ {
459+ static constexpr bool value = xtl::conjunction<equal_rank<E1 , E>...>::value
460+ && (get_rank<E1 >::value != SIZE_MAX);
461+ };
462+
463+ template <class F , class ... CT, class ... S, size_t ... I, size_t ... J>
464+ inline auto
465+ impl_reorder_function (const xfunction<F, CT...>& e, std::tuple<S...> slices, std::index_sequence<I...>, std::index_sequence<J...>)
466+ {
467+ return make_lambda_xfunction (F (), view (std::get<I>(e.arguments ()), std::get<J>(slices)...)...);
468+ }
469+
470+ public:
471+
472+ // when we have a view of a function where the closures of the functions are of equal rank (i.e no
473+ // broadcasting) we can flip the order of the function and the view such that we have a function of
474+ // views of containers which can be linearly assigned unlike the inverse.
475+ template <class F , class ... CT, class ... S, class = std::enable_if_t <all_equal_rank<std::decay_t <CT>...>::value>>
476+ inline auto reorder (const xview<xfunction<F, CT...>, S...>& e)
477+ {
478+ return impl_reorder_function (
479+ e.expression (),
480+ e.slices (),
481+ std::make_index_sequence<sizeof ...(CT)>(),
482+ std::make_index_sequence<sizeof ...(S)>()
483+ );
484+ }
485+
486+ // base case no applicable optimization
487+ template <class E >
488+ inline auto & reorder (E&& e)
489+ {
490+ return std::forward<E>(e);
491+ }
492+ };
493+
442494 template <class E1 , class E2 >
443495 inline void xexpression_assigner_base<xtensor_expression_tag>::assign_data(
444496 xexpression<E1 >& e1 ,
445497 const xexpression<E2 >& e2 ,
446498 bool trivial
447499 )
448500 {
449- E1 & de1 = e1 .derived_cast ();
450- const E2 & de2 = e2 .derived_cast ();
451- using traits = xassign_traits<E1 , E2 >;
501+ auto & de1 = e1 .derived_cast ();
502+ const auto & de2 = optimize_expression ().reorder (e2 .derived_cast ());
503+ using dst_type = typename std::decay_t <decltype (de1)>;
504+ using src_type = typename std::decay_t <decltype (de2)>;
505+ using traits = xassign_traits<dst_type, src_type>;
452506
453507 bool linear_assign = traits::linear_assign (de1, de2, trivial);
454508 constexpr bool simd_assign = traits::simd_assign ();
0 commit comments