Skip to content

Commit d1e82eb

Browse files
authored
Added example with redoing the factorization through KLU (#77)
* example with KLU redoing factorization * fixing memory leaks in LinSolverDirectRocSolverRf
1 parent fa10d84 commit d1e82eb

5 files changed

+301
-15
lines changed

examples/CMakeLists.txt

+6-1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ if(RESOLVE_USE_HIP)
4848
# Build example with KLU factorization, rocsolver Rf refactorization, and FGMRES iterative refinement
4949
add_executable(klu_rocsolverrf_fgmres.exe r_KLU_rocSolverRf_FGMRES.cpp)
5050
target_link_libraries(klu_rocsolverrf_fgmres.exe PRIVATE ReSolve)
51+
52+
# Example in which factorization is redone if solution is bad
53+
add_executable(klu_rocsolverrf_check_redo.exe r_KLU_rocsolverrf_redo_factorization.cpp)
54+
target_link_libraries(klu_rocsolverrf_check_redo.exe PRIVATE ReSolve)
55+
5156
endif(RESOLVE_USE_HIP)
5257

5358
# Install all examples in bin directory
@@ -58,7 +63,7 @@ if(RESOLVE_USE_CUDA)
5863
endif(RESOLVE_USE_CUDA)
5964

6065
if(RESOLVE_USE_HIP)
61-
set(installable_executables ${installable_executables} klu_rocsolverrf.exe)
66+
set(installable_executables ${installable_executables} klu_rocsolverrf.exe klu_rocsolverrf_fgmres.exe klu_rocsolverrf_check_redo.exe)
6267
endif(RESOLVE_USE_HIP)
6368

6469
install(TARGETS ${installable_executables}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
#include <string>
2+
#include <iostream>
3+
#include <iomanip>
4+
5+
#include <resolve/matrix/Coo.hpp>
6+
#include <resolve/matrix/Csr.hpp>
7+
#include <resolve/matrix/Csc.hpp>
8+
#include <resolve/vector/Vector.hpp>
9+
#include <resolve/matrix/io.hpp>
10+
#include <resolve/matrix/MatrixHandler.hpp>
11+
#include <resolve/vector/VectorHandler.hpp>
12+
#include <resolve/LinSolverDirectKLU.hpp>
13+
#include <resolve/LinSolverDirectRocSolverRf.hpp>
14+
#include <resolve/workspace/LinAlgWorkspace.hpp>
15+
16+
using namespace ReSolve::constants;
17+
18+
int main(int argc, char *argv[] )
19+
{
20+
// Use the same data types as those you specified in ReSolve build.
21+
using index_type = ReSolve::index_type;
22+
using real_type = ReSolve::real_type;
23+
using vector_type = ReSolve::vector::Vector;
24+
25+
(void) argc; // TODO: Check if the number of input parameters is correct.
26+
std::string matrixFileName = argv[1];
27+
std::string rhsFileName = argv[2];
28+
29+
index_type numSystems = atoi(argv[3]);
30+
std::cout<<"Family mtx file name: "<< matrixFileName << ", total number of matrices: "<<numSystems<<std::endl;
31+
std::cout<<"Family rhs file name: "<< rhsFileName << ", total number of RHSes: " << numSystems<<std::endl;
32+
33+
std::string fileId;
34+
std::string rhsId;
35+
std::string matrixFileNameFull;
36+
std::string rhsFileNameFull;
37+
38+
ReSolve::matrix::Coo* A_coo;
39+
ReSolve::matrix::Csr* A;
40+
41+
ReSolve::LinAlgWorkspaceHIP* workspace_HIP = new ReSolve::LinAlgWorkspaceHIP;
42+
workspace_HIP->initializeHandles();
43+
ReSolve::MatrixHandler* matrix_handler = new ReSolve::MatrixHandler(workspace_HIP);
44+
ReSolve::VectorHandler* vector_handler = new ReSolve::VectorHandler(workspace_HIP);
45+
real_type* rhs = nullptr;
46+
real_type* x = nullptr;
47+
48+
vector_type* vec_rhs;
49+
vector_type* vec_x;
50+
vector_type* vec_r;
51+
52+
ReSolve::LinSolverDirectKLU* KLU = new ReSolve::LinSolverDirectKLU;
53+
ReSolve::LinSolverDirectRocSolverRf* Rf = new ReSolve::LinSolverDirectRocSolverRf(workspace_HIP);
54+
55+
real_type res_nrm;
56+
real_type b_nrm;
57+
58+
for (int i = 0; i < numSystems; ++i)
59+
{
60+
index_type j = 4 + i * 2;
61+
fileId = argv[j];
62+
rhsId = argv[j + 1];
63+
64+
matrixFileNameFull = "";
65+
rhsFileNameFull = "";
66+
67+
// Read matrix first
68+
matrixFileNameFull = matrixFileName + fileId + ".mtx";
69+
rhsFileNameFull = rhsFileName + rhsId + ".mtx";
70+
std::cout << std::endl << std::endl << std::endl;
71+
std::cout << "========================================================================================================================"<<std::endl;
72+
std::cout << "Reading: " << matrixFileNameFull << std::endl;
73+
std::cout << "========================================================================================================================"<<std::endl;
74+
std::cout << std::endl;
75+
// Read first matrix
76+
std::ifstream mat_file(matrixFileNameFull);
77+
if(!mat_file.is_open())
78+
{
79+
std::cout << "Failed to open file " << matrixFileNameFull << "\n";
80+
return -1;
81+
}
82+
std::ifstream rhs_file(rhsFileNameFull);
83+
if(!rhs_file.is_open())
84+
{
85+
std::cout << "Failed to open file " << rhsFileNameFull << "\n";
86+
return -1;
87+
}
88+
if (i == 0) {
89+
A_coo = ReSolve::io::readMatrixFromFile(mat_file);
90+
A = new ReSolve::matrix::Csr(A_coo->getNumRows(),
91+
A_coo->getNumColumns(),
92+
A_coo->getNnz(),
93+
A_coo->symmetric(),
94+
A_coo->expanded());
95+
96+
rhs = ReSolve::io::readRhsFromFile(rhs_file);
97+
x = new real_type[A->getNumRows()];
98+
vec_rhs = new vector_type(A->getNumRows());
99+
vec_x = new vector_type(A->getNumRows());
100+
vec_r = new vector_type(A->getNumRows());
101+
}
102+
else {
103+
ReSolve::io::readAndUpdateMatrix(mat_file, A_coo);
104+
ReSolve::io::readAndUpdateRhs(rhs_file, &rhs);
105+
}
106+
std::cout<<"Finished reading the matrix and rhs, size: "<<A->getNumRows()<<" x "<<A->getNumColumns()<< ", nnz: "<< A->getNnz()<< ", symmetric? "<<A->symmetric()<< ", Expanded? "<<A->expanded()<<std::endl;
107+
mat_file.close();
108+
rhs_file.close();
109+
110+
//Now convert to CSR.
111+
if (i < 2) {
112+
matrix_handler->coo2csr(A_coo, A, "cpu");
113+
vec_rhs->update(rhs, ReSolve::memory::HOST, ReSolve::memory::HOST);
114+
vec_rhs->setDataUpdated(ReSolve::memory::HOST);
115+
} else {
116+
matrix_handler->coo2csr(A_coo, A, "hip");
117+
vec_rhs->update(rhs, ReSolve::memory::HOST, ReSolve::memory::DEVICE);
118+
}
119+
std::cout<<"COO to CSR completed. Expanded NNZ: "<< A->getNnzExpanded()<<std::endl;
120+
//Now call direct solver
121+
if (i == 0) {
122+
KLU->setupParameters(1, 0.1, false);
123+
}
124+
int status;
125+
if (i < 2){
126+
KLU->setup(A);
127+
status = KLU->analyze();
128+
std::cout<<"KLU analysis status: "<<status<<std::endl;
129+
status = KLU->factorize();
130+
std::cout<<"KLU factorization status: "<<status<<std::endl;
131+
status = KLU->solve(vec_rhs, vec_x);
132+
std::cout<<"KLU solve status: "<<status<<std::endl;
133+
if (i == 1) {
134+
ReSolve::matrix::Csc* L = (ReSolve::matrix::Csc*) KLU->getLFactor();
135+
ReSolve::matrix::Csc* U = (ReSolve::matrix::Csc*) KLU->getUFactor();
136+
index_type* P = KLU->getPOrdering();
137+
index_type* Q = KLU->getQOrdering();
138+
vec_rhs->update(rhs, ReSolve::memory::HOST, ReSolve::memory::DEVICE);
139+
Rf->setup(A, L, U, P, Q, vec_rhs);
140+
Rf->refactorize();
141+
}
142+
} else {
143+
std::cout<<"Using rocsolver rf"<<std::endl;
144+
status = Rf->refactorize();
145+
std::cout<<"rocsolver rf refactorization status: "<<status<<std::endl;
146+
status = Rf->solve(vec_rhs, vec_x);
147+
std::cout<<"rocsolver rf solve status: "<<status<<std::endl;
148+
}
149+
vec_r->update(rhs, ReSolve::memory::HOST, ReSolve::memory::DEVICE);
150+
151+
matrix_handler->setValuesChanged(true, "hip");
152+
153+
matrix_handler->matvec(A, vec_x, vec_r, &ONE, &MINUSONE,"csr", "hip");
154+
res_nrm = sqrt(vector_handler->dot(vec_r, vec_r, "hip"));
155+
b_nrm = sqrt(vector_handler->dot(vec_rhs, vec_rhs, "hip"));
156+
std::cout << "\t 2-Norm of the residual: "
157+
<< std::scientific << std::setprecision(16)
158+
<< res_nrm/b_nrm << "\n";
159+
if (!isnan(res_nrm)) {
160+
if (res_nrm/b_nrm > 1e-7 ) {
161+
std::cout << "\n \t !!! ALERT !!! Residual norm is too large; redoing KLU symbolic and numeric factorization. !!! ALERT !!! \n \n";
162+
163+
KLU->setup(A);
164+
status = KLU->analyze();
165+
std::cout<<"KLU analysis status: "<<status<<std::endl;
166+
status = KLU->factorize();
167+
std::cout<<"KLU factorization status: "<<status<<std::endl;
168+
status = KLU->solve(vec_rhs, vec_x);
169+
std::cout<<"KLU solve status: "<<status<<std::endl;
170+
171+
vec_rhs->update(rhs, ReSolve::memory::HOST, ReSolve::memory::DEVICE);
172+
vec_r->update(rhs, ReSolve::memory::HOST, ReSolve::memory::DEVICE);
173+
174+
matrix_handler->setValuesChanged(true, "hip");
175+
176+
matrix_handler->matvec(A, vec_x, vec_r, &ONE, &MINUSONE,"csr", "hip");
177+
res_nrm = sqrt(vector_handler->dot(vec_r, vec_r, "hip"));
178+
179+
std::cout<<"\t New residual norm: "
180+
<< std::scientific << std::setprecision(16)
181+
<< res_nrm/b_nrm << "\n";
182+
183+
184+
ReSolve::matrix::Csc* L = (ReSolve::matrix::Csc*) KLU->getLFactor();
185+
ReSolve::matrix::Csc* U = (ReSolve::matrix::Csc*) KLU->getUFactor();
186+
187+
index_type* P = KLU->getPOrdering();
188+
index_type* Q = KLU->getQOrdering();
189+
190+
Rf->setup(A, L, U, P, Q, vec_rhs);
191+
}
192+
}
193+
194+
195+
} // for (int i = 0; i < numSystems; ++i)
196+
197+
//now DELETE
198+
delete A;
199+
delete A_coo;
200+
delete KLU;
201+
delete Rf;
202+
delete [] x;
203+
delete [] rhs;
204+
delete vec_r;
205+
delete vec_x;
206+
delete workspace_HIP;
207+
delete matrix_handler;
208+
delete vector_handler;
209+
return 0;
210+
}

resolve/LinSolverDirectKLU.cpp

+48
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ namespace ReSolve
99
{
1010
Symbolic_ = nullptr;
1111
Numeric_ = nullptr;
12+
13+
L_ = nullptr;
14+
U_ = nullptr;
15+
1216
klu_defaults(&Common_) ;
1317
}
1418

@@ -40,8 +44,24 @@ namespace ReSolve
4044

4145
int LinSolverDirectKLU::analyze()
4246
{
47+
// in case we called this function AGAIN
48+
if (Symbolic_ != nullptr) {
49+
klu_free_symbolic(&Symbolic_, &Common_);
50+
}
51+
4352
Symbolic_ = klu_analyze(A_->getNumRows(), A_->getRowData(memory::HOST), A_->getColData(memory::HOST), &Common_) ;
4453

54+
factors_extracted_ = false;
55+
if (L_ != nullptr) {
56+
delete L_;
57+
L_ = nullptr;
58+
}
59+
60+
if (U_ != nullptr) {
61+
delete U_;
62+
U_ = nullptr;
63+
}
64+
4565
if (Symbolic_ == nullptr){
4666
printf("Symbolic_ factorization crashed withCommon_.status = %d \n", Common_.status);
4767
return 1;
@@ -51,8 +71,24 @@ namespace ReSolve
5171

5272
int LinSolverDirectKLU::factorize()
5373
{
74+
if (Numeric_ != nullptr) {
75+
klu_free_numeric(&Numeric_, &Common_);
76+
}
77+
5478
Numeric_ = klu_factor(A_->getRowData(memory::HOST), A_->getColData(memory::HOST), A_->getValues(memory::HOST), Symbolic_, &Common_);
5579

80+
factors_extracted_ = false;
81+
82+
if (L_ != nullptr) {
83+
delete L_;
84+
L_ = nullptr;
85+
}
86+
87+
if (U_ != nullptr) {
88+
delete U_;
89+
U_ = nullptr;
90+
}
91+
5692
if (Numeric_ == nullptr){
5793
return 1;
5894
}
@@ -63,6 +99,18 @@ namespace ReSolve
6399
{
64100
int kluStatus = klu_refactor (A_->getRowData(memory::HOST), A_->getColData(memory::HOST), A_->getValues(memory::HOST), Symbolic_, Numeric_, &Common_);
65101

102+
factors_extracted_ = false;
103+
104+
if (L_ != nullptr) {
105+
delete L_;
106+
L_ = nullptr;
107+
}
108+
109+
if (U_ != nullptr) {
110+
delete U_;
111+
U_ = nullptr;
112+
}
113+
66114
if (!kluStatus){
67115
//display error
68116
return 1;

resolve/LinSolverDirectRocSolverRf.cpp

+32-9
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,19 @@ namespace ReSolve
3838
//set matrix info
3939
rocsolver_create_rfinfo(&infoM_, workspace_->getRocblasHandle());
4040
//create combined factor
41-
addFactors(L,U);
41+
42+
addFactors(L, U);
43+
4244
M_->setUpdated(ReSolve::memory::HOST);
4345
M_->copyData(ReSolve::memory::DEVICE);
44-
mem_.allocateArrayOnDevice(&d_P_, n);
45-
mem_.allocateArrayOnDevice(&d_Q_, n);
4646

47+
if (d_P_ == nullptr) {
48+
mem_.allocateArrayOnDevice(&d_P_, n);
49+
}
50+
51+
if (d_Q_ == nullptr) {
52+
mem_.allocateArrayOnDevice(&d_Q_, n);
53+
}
4754
mem_.copyArrayHostToDevice(d_P_, P, n);
4855
mem_.copyArrayHostToDevice(d_Q_, Q, n);
4956

@@ -70,12 +77,22 @@ namespace ReSolve
7077

7178
// tri solve setup
7279
if (solve_mode_ == 1) { // fast mode
73-
L_csr_ = new ReSolve::matrix::Csr(L->getNumRows(), L->getNumColumns(), L->getNnz());
74-
U_csr_ = new ReSolve::matrix::Csr(U->getNumRows(), U->getNumColumns(), U->getNnz());
7580

81+
if (L_csr_ != nullptr) {
82+
delete L_csr_;
83+
}
84+
85+
L_csr_ = new ReSolve::matrix::Csr(L->getNumRows(), L->getNumColumns(), L->getNnz());
7686
L_csr_->allocateMatrixData(ReSolve::memory::DEVICE);
87+
88+
if (U_csr_ != nullptr) {
89+
delete U_csr_;
90+
}
91+
92+
U_csr_ = new ReSolve::matrix::Csr(U->getNumRows(), U->getNumColumns(), U->getNnz());
7793
U_csr_->allocateMatrixData(ReSolve::memory::DEVICE);
7894

95+
7996
rocsparse_create_mat_descr(&(descr_L_));
8097
rocsparse_set_mat_fill_mode(descr_L_, rocsparse_fill_mode_lower);
8198
rocsparse_set_mat_index_base(descr_L_, rocsparse_index_base_zero);
@@ -161,9 +178,12 @@ namespace ReSolve
161178
error_sum += status_rocsparse_;
162179
if (status_rocsparse_!=0)printf("status after analysis 2 %d \n", status_rocsparse_);
163180
//allocate aux data
164-
165-
mem_.allocateArrayOnDevice(&d_aux1_,n);
166-
mem_.allocateArrayOnDevice(&d_aux2_,n);
181+
if (d_aux1_ == nullptr) {
182+
mem_.allocateArrayOnDevice(&d_aux1_,n);
183+
}
184+
if (d_aux2_ == nullptr) {
185+
mem_.allocateArrayOnDevice(&d_aux2_,n);
186+
}
167187

168188
}
169189
return error_sum;
@@ -193,7 +213,7 @@ namespace ReSolve
193213

194214
if (solve_mode_ == 1) {
195215
//split M, fill L and U with correct values
196-
printf("solve mode 1, splitting the factors again \n");
216+
printf("solve mode 1, splitting the factors again \n");
197217
status_rocblas_ = rocsolver_dcsrrf_splitlu(workspace_->getRocblasHandle(),
198218
A_->getNumRows(),
199219
M_->getNnzExpanded(),
@@ -360,6 +380,9 @@ printf("solve mode 1, splitting the factors again \n");
360380
index_type* Li = L->getRowData(ReSolve::memory::HOST);
361381
index_type* Up = U->getColData(ReSolve::memory::HOST);
362382
index_type* Ui = U->getRowData(ReSolve::memory::HOST);
383+
if (M_ != nullptr) {
384+
delete M_;
385+
}
363386

364387
index_type nnzM = ( L->getNnz() + U->getNnz() - n );
365388
M_ = new matrix::Csr(n, n, nnzM);

0 commit comments

Comments
 (0)