1
+ function net = fmriDenoiserNetwork(modelPath , pretrainedNetwork )
2
+ % fmriDenoiserNetwork - generate MATLAB Version of the DeepInterpolation fMRI-Denoiser
3
+
4
+ lgraph = layerGraph();
5
+
6
+ tempLayers = [
7
+ image3dInputLayer([7 7 7 5 ]," Name" ," image3dinput" ,' Normalization' ,' none' );
8
+ convolution3dLayer([3 3 3 ],8 ," Name" ," conv3d" ," Padding" ," same" ," Weights" ,mywini(1 )," Bias" ,mybini(1 ));
9
+ reluLayer(" Name" ," relu1" )];
10
+ lgraph = addLayers(lgraph ,tempLayers );
11
+
12
+ tempLayers = [
13
+ maxPooling3dLayer([3 3 3 ]," Name" ," pool1" ," Padding" ," same" )
14
+ convolution3dLayer([3 3 3 ],16 ," Name" ," conv3d_1" ," Padding" ," same" ," Weights" ,mywini(2 )," Bias" ,mybini(2 ))
15
+ reluLayer(" Name" ," relu2" )];
16
+ lgraph = addLayers(lgraph ,tempLayers );
17
+
18
+ tempLayers = [
19
+ maxPooling3dLayer([3 3 3 ]," Name" ," pool2" ," Padding" ," same" )
20
+ convolution3dLayer([3 3 3 ],32 ," Name" ," conv3d_2" ," Padding" ," same" ," Weights" ,mywini(3 )," Bias" ,mybini(3 ))
21
+ reluLayer(" Name" ," relu3" )];
22
+ lgraph = addLayers(lgraph ,tempLayers );
23
+
24
+ tempLayers = [
25
+ depthConcatenationLayer(2 ," Name" ," conc_up_1" )
26
+ convolution3dLayer([3 3 3 ],16 ," Name" ," conv3d_3" ," Padding" ," same" ," Weights" ,mywini(4 )," Bias" ,mybini(4 ))
27
+ reluLayer(" Name" ," relu4" )];
28
+ lgraph = addLayers(lgraph ,tempLayers );
29
+
30
+ tempLayers = [
31
+ depthConcatenationLayer(2 ," Name" ," conc_up_2" )
32
+ convolution3dLayer([3 3 3 ],8 ," Name" ," conv3d_4" ," Padding" ," same" ," Weights" ,mywini(5 )," Bias" ,mybini(5 ))
33
+ reluLayer(" Name" ," relu5" )
34
+ convolution3dLayer([1 1 1 ],1 ," Name" ," conv3d_5" ," Padding" ," same" ," Weights" ,mywini(6 )," Bias" ,mybini(6 ))
35
+ regressionLayer(" Name" ," out_r" )];
36
+ lgraph = addLayers(lgraph ,tempLayers );
37
+
38
+ clear tempLayers ;
39
+
40
+ lgraph = connectLayers(lgraph ," relu1" ," pool1" );
41
+ lgraph = connectLayers(lgraph ," relu1" ," conc_up_2/in2" );
42
+ lgraph = connectLayers(lgraph ," relu2" ," pool2" );
43
+ lgraph = connectLayers(lgraph ," relu2" ," conc_up_1/in2" );
44
+ lgraph = connectLayers(lgraph ," relu3" ," conc_up_1/in1" );
45
+ lgraph = connectLayers(lgraph ," relu4" ," conc_up_2/in1" );
46
+
47
+ net = assembleNetwork(lgraph );
48
+ save(modelPath ," net" )
49
+
50
+ function w = mywini(ilayer )
51
+ lwlnames = {' conv3d' ,' conv3d_1' ,' conv3d_2' ,' conv3d_3' ,' conv3d_4' ,' conv3d_5' }; % layers_with_learnables
52
+ lname = lwlnames{ilayer };
53
+ thisweights = h5read(pretrainedNetwork ,strcat(' /model_weights/' ,lname ,' /' ,lname ,' /kernel:0' ));
54
+ w = permute(thisweights ,[5 ,4 ,3 ,2 ,1 ]);
55
+ end
56
+
57
+ function b = mybini(ilayer )
58
+ lwlnames = {' conv3d' ,' conv3d_1' ,' conv3d_2' ,' conv3d_3' ,' conv3d_4' ,' conv3d_5' }; % layers_with_learnables
59
+ lwldims = [8 16 32 16 8 1 ];
60
+ lname = lwlnames{ilayer };
61
+ b = h5read(pretrainedNetwork ,strcat(' /model_weights/' ,lname ,' /' ,lname ,' /bias:0' ));
62
+ b = reshape(b ,[1 1 1 lwldims(ilayer )]);
63
+ end
64
+
65
+ end
0 commit comments