Skip to content

Commit 981c818

Browse files
authored
Merge pull request #433 from obackhouse/thc
Tensor hypercontraction
2 parents 01de514 + 99e0e42 commit 981c818

File tree

8 files changed

+187
-11
lines changed

8 files changed

+187
-11
lines changed

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,8 @@ set(SeQuant_src
381381
SeQuant/domain/mbpt/rules/csv.hpp
382382
SeQuant/domain/mbpt/rules/df.cpp
383383
SeQuant/domain/mbpt/rules/df.hpp
384+
SeQuant/domain/mbpt/rules/thc.cpp
385+
SeQuant/domain/mbpt/rules/thc.hpp
384386
SeQuant/domain/mbpt/space_qns.hpp
385387
SeQuant/domain/mbpt/spin.cpp
386388
SeQuant/domain/mbpt/spin.hpp

SeQuant/domain/mbpt/convention.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,11 @@ void add_batching_spaces(std::shared_ptr<IndexSpaceRegistry>& isr) {
107107
isr->add(IndexSpace{L"z", 0b100000, BatchingQNS::batch}); // Batching Space
108108
}
109109

110+
void add_thc_spaces(std::shared_ptr<IndexSpaceRegistry>& isr) {
111+
isr->add(IndexSpace{L"L", 0b000001, TensorFactorizationQNS::thc}) // THC AO
112+
;
113+
}
114+
110115
std::shared_ptr<IndexSpaceRegistry> make_min_sr_spaces(SpinConvention spconv) {
111116
auto isr = std::make_shared<IndexSpaceRegistry>();
112117

SeQuant/domain/mbpt/convention.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ void add_ao_spaces(std::shared_ptr<IndexSpaceRegistry>& isr, bool vbs = false,
5252
/// @brief add DF spaces to registry
5353
void add_df_spaces(std::shared_ptr<IndexSpaceRegistry>& isr);
5454

55+
/// @brief add THC spaces to registry
56+
void add_thc_spaces(std::shared_ptr<IndexSpaceRegistry>& isr);
57+
5558
/// @brief add PAO spaces to registry
5659

5760
/// expects \p isr to have a defined particle space

SeQuant/domain/mbpt/rules/df.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,7 @@ ExprPtr density_fit(ExprPtr const& expr, IndexSpace aux_space,
5252
if (tensor.label() == tensor_label //
5353
&& tensor.bra_rank() == 2 //
5454
&& tensor.ket_rank() == 2)
55-
return density_fit_impl(
56-
tensor, Index(aux_space.base_key() + L"_1", aux_space), factor_label);
55+
return density_fit_impl(tensor, Index(aux_space, 1), factor_label);
5756
else
5857
return expr;
5958
} else if (expr->is<Product>()) {
@@ -63,10 +62,10 @@ ExprPtr density_fit(ExprPtr const& expr, IndexSpace aux_space,
6362
result.scale(prod.scalar());
6463
size_t aux_ix = 0;
6564
for (auto&& f : prod.factors())
66-
if (f.is<Tensor>() && f.as<Tensor>().label() == L"g") {
65+
if (f.is<Tensor>() && f.as<Tensor>().label() == tensor_label) {
6766
auto const& g = f->as<Tensor>();
68-
auto g_df = density_fit_impl(
69-
g, Index(std::to_wstring(++aux_ix), aux_space), factor_label);
67+
auto g_df =
68+
density_fit_impl(g, Index(aux_space, ++aux_ix), factor_label);
7069
result.append(1, std::move(g_df), Product::Flatten::Yes);
7170
} else {
7271
result.append(1, f, Product::Flatten::No);

SeQuant/domain/mbpt/rules/thc.cpp

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
//
2+
// Created by Oliver Backhouse on 22/11/2025.
3+
//
4+
5+
#include <SeQuant/domain/mbpt/rules/thc.hpp>
6+
7+
#include <SeQuant/core/expr.hpp>
8+
#include <SeQuant/core/space.hpp>
9+
#include <SeQuant/core/utility/macros.hpp>
10+
11+
#include <range/v3/view.hpp>
12+
13+
#include <string_view>
14+
15+
namespace sequant::mbpt {
16+
17+
ExprPtr tensor_hypercontract_impl(Tensor const& tnsr, Index const& aux_idx_1,
18+
Index const& aux_idx_2,
19+
std::wstring_view factor_label,
20+
std::wstring_view aux_label) {
21+
SEQUANT_ASSERT(tnsr.bra_rank() == 2 //
22+
&& tnsr.ket_rank() == 2 //
23+
&& tnsr.aux_rank() == 0);
24+
25+
auto t1 = ex<Tensor>(factor_label, bra({ranges::front(tnsr.bra())}), ket(),
26+
aux({aux_idx_1}));
27+
auto t2 = ex<Tensor>(factor_label, bra(), ket({ranges::front(tnsr.ket())}),
28+
aux({aux_idx_1}));
29+
auto t3 = ex<Tensor>(factor_label, bra({ranges::back(tnsr.bra())}), ket(),
30+
aux({aux_idx_2}));
31+
auto t4 = ex<Tensor>(factor_label, bra(), ket({ranges::back(tnsr.ket())}),
32+
aux({aux_idx_2}));
33+
auto z = ex<Tensor>(aux_label, bra(), ket(), aux({aux_idx_1, aux_idx_2}));
34+
35+
if (tnsr.symmetry() == Symmetry::Antisymm) {
36+
auto t1a = ex<Tensor>(factor_label, bra({ranges::back(tnsr.bra())}), ket(),
37+
aux({aux_idx_1}));
38+
auto t3a = ex<Tensor>(factor_label, bra({ranges::front(tnsr.bra())}), ket(),
39+
aux({aux_idx_2}));
40+
41+
return (t1 * t2 * z * t3 * t4) - (t1a * t2 * z * t3a * t4);
42+
}
43+
44+
return t1 * t2 * z * t3 * t4;
45+
}
46+
47+
ExprPtr tensor_hypercontract(ExprPtr const& expr, IndexSpace aux_space,
48+
std::wstring_view tensor_label,
49+
std::wstring_view outer_factor_label,
50+
std::wstring_view core_tensor_label) {
51+
using ranges::views::transform;
52+
53+
if (expr->is<Sum>())
54+
return ex<Sum>(*expr | transform([&](auto&& x) {
55+
return tensor_hypercontract(x, aux_space, tensor_label,
56+
outer_factor_label, core_tensor_label);
57+
}));
58+
59+
else if (expr->is<Tensor>()) {
60+
auto const& tensor = expr->as<Tensor>();
61+
if (tensor.label() == tensor_label //
62+
&& tensor.bra_rank() == 2 //
63+
&& tensor.ket_rank() == 2)
64+
return tensor_hypercontract_impl(tensor, Index(aux_space, 1),
65+
Index(aux_space, 2), outer_factor_label,
66+
core_tensor_label);
67+
else
68+
return expr;
69+
} else if (expr->is<Product>()) {
70+
auto const& prod = expr->as<Product>();
71+
72+
Product result;
73+
result.scale(prod.scalar());
74+
size_t aux_ix = 0;
75+
for (auto&& f : prod.factors()) {
76+
if (f->is<Tensor>() && f->as<Tensor>().label() == tensor_label) {
77+
auto const& g = f->as<Tensor>();
78+
auto index1 = Index(aux_space, ++aux_ix);
79+
auto index2 = Index(aux_space, ++aux_ix);
80+
auto g_thc = tensor_hypercontract_impl(
81+
g, index1, index2, outer_factor_label, core_tensor_label);
82+
result.append(1, std::move(g_thc), Product::Flatten::Yes);
83+
} else {
84+
result.append(1, f, Product::Flatten::No);
85+
}
86+
}
87+
return ex<Product>(std::move(result));
88+
} else {
89+
return expr;
90+
}
91+
}
92+
93+
} // namespace sequant::mbpt

SeQuant/domain/mbpt/rules/thc.hpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
//
2+
// Created by Oliver Backhouse on 22/11/2025.
3+
//
4+
5+
#ifndef SEQUANT_DOMAIN_MBPT_RULES_THC_HPP
6+
#define SEQUANT_DOMAIN_MBPT_RULES_THC_HPP
7+
8+
#include <SeQuant/core/expr_fwd.hpp>
9+
#include <SeQuant/core/space.hpp>
10+
11+
#include <string_view>
12+
13+
namespace sequant::mbpt {
14+
15+
// clang-format off
16+
///
17+
/// Factorizes 2-particle tensors into five rank-2 tensors using
18+
/// the ``tensor hypercontraction'' topology (see [DOI 10.1063/1.4732310](https://doi.org/10.1063/1.4732310);
19+
/// also see [DOI 10.1021/acs.jctc.0c01310](https://doi.org/10.1021/acs.jctc.0c01310) for discussion
20+
/// in the context of related pseudospectral and CP factorizations).
21+
/// Namely, \f$ g_{b_1 b_2}^{k_1 k_2} \f$ is factorized into
22+
/// \f$ B_{b_1}[r_1] B^{k_1}[r_1] B_{b_2}[r_2] B^{k_2}[r_2] C[r_1,r_2] \f$
23+
///
24+
/// \param expr The expression to be tensor-hyper-contracted.
25+
/// \param aux_space The index space representing the auxiliary indices (\f$ r_1, r_2 \f$ in the example above)
26+
/// introduced through the decomposition.
27+
/// \param tensor_label The label of the tensor that shall be decomposed
28+
/// \param outer_factor_label The label of the outer factor tensors (\f$ B \f$ in the example above).
29+
/// \param core_tensor_label The label of the core tensor (\f$ C \f$ in the example above).
30+
// clang-format on
31+
[[nodiscard]] ExprPtr tensor_hypercontract(ExprPtr const& expr,
32+
IndexSpace aux_space,
33+
std::wstring_view tensor_label,
34+
std::wstring_view outer_factor_label,
35+
std::wstring_view core_tensor_label);
36+
37+
} // namespace sequant::mbpt
38+
39+
#endif // SEQUANT_DOMAIN_MBPT_RULES_THC_HPP

SeQuant/domain/mbpt/space_qns.hpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,21 +64,23 @@ struct mask<LCAOQNS> {
6464
};
6565

6666
/// quantum numbers tags related to tensor factorization basis traits
67-
/// \note TensorFactorization basis traits use 5th rightmost bits
67+
/// \note TensorFactorization basis traits use 5th and 6th rightmost bits
6868
enum class TensorFactorizationQNS : bitset_t {
69-
df = 0b010000,
69+
df = 0b010000, // density fitting
70+
thc = 0b100000 // tensor hypercontraction
7071
};
7172

7273
template <>
7374
struct mask<TensorFactorizationQNS> {
7475
using type = std::underlying_type_t<TensorFactorizationQNS>;
75-
static constexpr type value = static_cast<type>(TensorFactorizationQNS::df);
76+
static constexpr type value = static_cast<type>(TensorFactorizationQNS::df) |
77+
static_cast<type>(TensorFactorizationQNS::thc);
7678
};
7779

7880
/// tags related to batching
79-
/// \note BatchingQNS uses the 6th rightmost bit
81+
/// \note BatchingQNS uses the 7th rightmost bit
8082
enum class BatchingQNS : bitset_t {
81-
batch = 0b100000, // for batching tensors
83+
batch = 0b1000000, // for batching tensors
8284
};
8385

8486
template <>

tests/unit/test_mbpt.cpp

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <SeQuant/domain/mbpt/convention.hpp>
1616
#include <SeQuant/domain/mbpt/op.hpp>
1717
#include <SeQuant/domain/mbpt/rules/df.hpp>
18+
#include <SeQuant/domain/mbpt/rules/thc.hpp>
1819
#include <SeQuant/domain/mbpt/utils.hpp>
1920

2021
#include <catch2/catch_test_macros.hpp>
@@ -794,5 +795,37 @@ SECTION("rules") {
794795
REQUIRE_THAT(actual, EquivalentTo(expected.at(i)));
795796
}
796797
}
797-
}
798+
799+
SECTION("tensor-hypercontract") {
800+
const std::vector<std::wstring> inputs = {
801+
L"t{a1,a2;i1,i2} t{a3;i3}",
802+
L"t{a1,a2;i1,i2} g{i1,i2;a1,a2}",
803+
L"t{a1,a2;i1,i2} g{i1,i2;a1,a2}:A",
804+
};
805+
const std::vector<std::wstring> expected = {
806+
L"t{a1,a2;i1,i2} t{a3;i3}",
807+
L"t{a1,a2;i1,i2} B{i1;;x_1} B{;a1;x_1} C{;;x_1,x_2} B{i2;;x_2} "
808+
L"B{;a2;x_2}",
809+
L"t{a1,a2;i1,i2} (B{i1;;x_1} B{;a1;x_1} C{;;x_1,x_2} B{i2;;x_2} "
810+
L"B{;a2;x_2}"
811+
" - B{i2;;x_1} B{;a1;x_1} C{;;x_1,x_2} B{i1;;x_2} B{;a2;x_2})",
812+
};
813+
814+
REQUIRE(inputs.size() == expected.size());
815+
816+
for (std::size_t i = 0; i < inputs.size(); ++i) {
817+
CAPTURE(inputs.at(i));
818+
819+
ExprPtr input_expr = parse_expr(inputs.at(i));
820+
821+
const IndexSpace aux_space =
822+
get_default_context().index_space_registry()->retrieve(L"x");
823+
824+
ExprPtr actual =
825+
mbpt::tensor_hypercontract(input_expr, aux_space, L"g", L"B", L"C");
826+
827+
REQUIRE_THAT(actual, EquivalentTo(expected.at(i)));
828+
}
829+
}
830+
} // SECTION("rules")
798831
}

0 commit comments

Comments
 (0)