Skip to content

Commit

Permalink
tensor-based Kfold CV
Browse files Browse the repository at this point in the history
  • Loading branch information
MegaJoctan committed Mar 3, 2024
1 parent c306690 commit a82bf95
Showing 1 changed file with 28 additions and 104 deletions.
132 changes: 28 additions & 104 deletions cross_validation.mqh
Original file line number Diff line number Diff line change
Expand Up @@ -9,134 +9,58 @@
//| defines |
//+------------------------------------------------------------------+
#include <MALE5\Tensors.mqh>
#include <MALE5\MatrixExtend.mqh>

class CCrossValidation_kfold
class CCrossValidation
{
CTensors *folds_tensor;
void XandYSplitMatrices(const matrix &matrix_,matrix &xmatrix,vector &y_vector,int y_column=-1);
void RemoveCol(matrix &mat, ulong col);
uint k_folds;
CTensors *tensors[]; //Keep track of all the tensors in memory

public:
CCrossValidation_kfold(matrix &data_matrix, uint k_folds=5);
~CCrossValidation_kfold(void);
CCrossValidation();
~CCrossValidation(void);

matrix fold(uint index);
uint fold_size;

matrix fold_x(uint index);
vector fold_y(uint index);
CTensors *KFoldCV(matrix &data, uint n_spilts=5);
};
//+------------------------------------------------------------------+
//| |
//+------------------------------------------------------------------+
CCrossValidation_kfold::CCrossValidation_kfold(matrix &data_matrix, uint k_folds_=5)
CCrossValidation::CCrossValidation()
{
this.k_folds = k_folds_;

folds_tensor = new CTensors(k_folds);

ulong rows = data_matrix.Rows();
fold_size = (int)MathCeil(rows/k_folds);

matrix temp_tensor(fold_size, data_matrix.Cols());

int start=0;
for (ulong i=0; i<k_folds; i++)
{
for (ulong j=start, count=0; j<fold_size+start; j++, count++)
{
temp_tensor.Row(data_matrix.Row(j), count);
}

folds_tensor.TensorAdd(temp_tensor, i); //Obtained size=k data matrix


start += (int)fold_size;
}


//#ifdef DEBUG_MODE
// Print("total ",rows," fold_size ",fold_size," k_folds ",k_folds);
// folds_tensor.TensorPrint();
//#endif

}
//+------------------------------------------------------------------+
//| |
//+------------------------------------------------------------------+
CCrossValidation_kfold::~CCrossValidation_kfold(void)
CCrossValidation::~CCrossValidation(void)
{
delete(folds_tensor);
}
//+------------------------------------------------------------------+
//| |
//+------------------------------------------------------------------+
matrix CCrossValidation_kfold::fold(uint index)
{
if (index+1 > this.k_folds)
{
matrix ret={};
Print("k-fold index out of range");
return (ret);
}

return folds_tensor.Tensor(index);
}
//+------------------------------------------------------------------+
//| |
//+------------------------------------------------------------------+
matrix CCrossValidation_kfold::fold_x(uint index)
{
matrix x; vector y;
matrix fold_matrix = this.fold(index);

this.XandYSplitMatrices(fold_matrix, x, y);

return (x);
for (uint i=0; i<tensors.Size(); i++)
if (CheckPointer(tensors[i]) != POINTER_INVALID)
delete (tensors[i]);
}
//+------------------------------------------------------------------+
//| |
//+------------------------------------------------------------------+
vector CCrossValidation_kfold::fold_y(uint index)
CTensors *CCrossValidation::KFoldCV(matrix &data, uint n_spilts=5)
{
matrix x; vector y;
matrix fold_matrix = this.fold(index);
ArrayResize(tensors, tensors.Size()+1);
tensors[tensors.Size()-1] = new CTensors(n_spilts);

this.XandYSplitMatrices(fold_matrix, x, y);
int size = (int)MathFloor(data.Rows() / (double)n_spilts);

return (y);
}
//+------------------------------------------------------------------+
//| |
//+------------------------------------------------------------------+
void CCrossValidation_kfold::XandYSplitMatrices(const matrix &matrix_,matrix &xmatrix,vector &y_vector,int y_column=-1)
{
y_column = int( y_column==-1 ? matrix_.Cols()-1 : y_column);

y_vector = matrix_.Col(y_column);
xmatrix.Copy(matrix_);

RemoveCol(xmatrix, y_column); //Remove the y column
matrix split_data = {};

for (uint k=0, start = 0; k<n_spilts; k++)
{
split_data = MatrixExtend::Get(data, start, (start+size)-1);

tensors[tensors.Size()-1].Add(split_data, k);

start += size;
}
return tensors[tensors.Size()-1];
}
//+------------------------------------------------------------------+
//| |
//+------------------------------------------------------------------+
void CCrossValidation_kfold::RemoveCol(matrix &mat, ulong col)
{
matrix new_matrix(mat.Rows(),mat.Cols()-1); //Remove the one Column

for (ulong i=0, new_col=0; i<mat.Cols(); i++)
{
if (i == col)
continue;
else
{
new_matrix.Col(mat.Col(i),new_col);
new_col++;
}
}
mat.Copy(new_matrix);
}
//+------------------------------------------------------------------+
//| |
//+------------------------------------------------------------------+

0 comments on commit a82bf95

Please sign in to comment.