diff --git a/.gitignore b/.gitignore
index 13234a1e..59bbf43b 100755
--- a/.gitignore
+++ b/.gitignore
@@ -6,4 +6,5 @@ python/build
 python/dist
 *.egg-info
 *.pkl
-log*/
\ No newline at end of file
+log*/
+*.bat
\ No newline at end of file
diff --git a/README.md b/README.md
index 3abd63e5..72d7a93d 100755
--- a/README.md
+++ b/README.md
@@ -13,7 +13,7 @@ JNeRF is an NeRF benchmark based on [Jittor](https://github.com/Jittor/jittor).
 ## Install
 JNeRF environment requirements:
 
-* System: **Linux**(e.g. Ubuntu/CentOS/Arch), **macOS**, or **Windows Subsystem of Linux (WSL)**
+* System: **Linux**(e.g. Ubuntu/CentOS/Arch), **macOS**, **Windows**, or **Windows Subsystem of Linux (WSL)**
 * Python version >= 3.7
 * CPU compiler (require at least one of the following)
     * g++ (>=5.4.0)
diff --git a/contrib/mipnerf/python/jnerf/models/position_encoders/hash_encoder/grid_encode.py b/contrib/mipnerf/python/jnerf/models/position_encoders/hash_encoder/grid_encode.py
index 2d6369d6..4688bfa3 100644
--- a/contrib/mipnerf/python/jnerf/models/position_encoders/hash_encoder/grid_encode.py
+++ b/contrib/mipnerf/python/jnerf/models/position_encoders/hash_encoder/grid_encode.py
@@ -61,7 +61,10 @@ def __init__(self, hash_func_header, aabb_scale=1, n_pos_dims=3, n_features_per_
         self.m_grid_gradient = jt.empty([m_n_params], self.grad_type)
         self.m_stochastic_interpolation = 0
         header_path = os.path.join(os.path.dirname(__file__), 'op_header')
-        proj_options[f"FLAGS: -I{header_path}"]=1
+        if sys.platform == "linux":
+            proj_options[f"FLAGS: -I{header_path}"]=1
+        else:
+            proj_options[f'FLAGS: -I"{header_path}"']=1
 
     def execute(self, x,m_grid):
         self.num_elements=x.shape[0]
@@ -87,7 +90,7 @@ def execute(self, x,m_grid):
 		const uint32_t blocks = div_round_up(num_elements, threads.x);
         extract_position<float,N_POS_DIMS><<<blocks, threads, 0, stream>>>(
 			num_elements,
-		{{in1_p,in1_shape1}},
+		{{in1_p,(size_t)in1_shape1}},
 			m_positions_p
 		);
         static constexpr uint32_t N_THREADS_HASHGRID = 512;
@@ -115,7 +118,7 @@ def execute(self, x,m_grid):
         const dim3 threads_transpose = {{ {self.m_n_levels}, 8, 1 }};
 		const uint32_t blocks_transpose = div_round_up(num_elements, threads_transpose.y);
        
-        PitchedPtr<grad_t> outputs{{ out0_p,out0_shape1 }};
+        PitchedPtr<grad_t> outputs{{ out0_p,(size_t)out0_shape1 }};
            
 	    transpose_encoded_position<vector_t<grad_t,N_FEATURES_PER_LEVEL>><<<blocks_transpose, threads_transpose, 0, stream>>>(
 			num_elements,
@@ -149,7 +152,7 @@ def grad(self, grad_x):
         const unsigned int N_FEATURES_PER_LEVEL={self.N_FEATURES_PER_LEVEL};                 
         cudaStream_t stream=0;
 	    const dim3 threads_transpose ={{  {self.m_n_levels} , 8, 1}};
-        PitchedPtr<grad_t> dL_dy{{ in2_p,in2_shape1 }};
+        PitchedPtr<grad_t> dL_dy{{ in2_p,(size_t)in2_shape1 }};
         cudaMemsetAsync(out0_p, 0, out0->size);    
         const uint32_t blocks_transpose = div_round_up(num_elements, threads_transpose.y);  
         transpose_gradients<vector_t<grad_t, N_FEATURES_PER_LEVEL>><<<blocks_transpose, threads_transpose, 0, stream>>>(
diff --git a/contrib/mipnerf/python/jnerf/models/position_encoders/sh_encoder/sh_encoder.py b/contrib/mipnerf/python/jnerf/models/position_encoders/sh_encoder/sh_encoder.py
index fd46f9d1..50b8fd2c 100644
--- a/contrib/mipnerf/python/jnerf/models/position_encoders/sh_encoder/sh_encoder.py
+++ b/contrib/mipnerf/python/jnerf/models/position_encoders/sh_encoder/sh_encoder.py
@@ -2,6 +2,7 @@
 import jittor as jt
 from jittor import Function
 import numpy as np
+import sys
 from jnerf.ops.code_ops.global_vars import global_headers,proj_options
 from jnerf.utils.config import get_cfg
 from jnerf.utils.registry import ENCODERS
@@ -20,7 +21,10 @@ def __init__(self) :
         else:
             self.grad_type='float32'
         header_path = os.path.join(os.path.dirname(__file__), 'op_header')
-        proj_options[f"FLAGS: -I{header_path}"]=1
+        if sys.platform == "linux":
+            proj_options[f"FLAGS: -I{header_path}"]=1
+        else:
+            proj_options[f'FLAGS: -I"{header_path}"']=1
         self.out_dim=self.m_n_padded_output_dims
     
     def execute(self,x) :
@@ -37,8 +41,8 @@ def execute(self,x) :
        
         cudaStream_t stream=0;
     
-        PitchedPtr<const float> inputs={{in0_p,in0_shape1}};
-		PitchedPtr<grad_t> outputs={{out_p,out_shape1}};
+        PitchedPtr<const float> inputs={{in0_p,(size_t)in0_shape1}};
+		PitchedPtr<grad_t> outputs={{out_p,(size_t)out_shape1}};
 		float* dy_dx = nullptr;
         linear_kernel(kernel_sh<grad_t>, 0, stream,
 			num_elements,
diff --git a/contrib/mipnerf/python/jnerf/models/samplers/density_grid_sampler/density_grid_sampler.py b/contrib/mipnerf/python/jnerf/models/samplers/density_grid_sampler/density_grid_sampler.py
index 607df9aa..215c963f 100644
--- a/contrib/mipnerf/python/jnerf/models/samplers/density_grid_sampler/density_grid_sampler.py
+++ b/contrib/mipnerf/python/jnerf/models/samplers/density_grid_sampler/density_grid_sampler.py
@@ -91,7 +91,10 @@ def __init__(self, update_den_freq=16, update_block_size=5000000):
         self.dataset_ray_data = False  # 数据集是否包含光线信息
         
         header_path = os.path.join(os.path.dirname(__file__), 'op_header')
-        proj_options[f"FLAGS: -I{header_path}"]=1
+        if sys.platform == "linux":
+            proj_options[f"FLAGS: -I{header_path}"]=1
+        else:
+            proj_options[f'FLAGS: -I"{header_path}"']=1
 
         self.density_grad_header = f"""
         inline constexpr __device__ uint32_t NERF_GRIDSIZE() {{ return {self.NERF_GRIDSIZE}; }} // size of the density/occupancy grid.
diff --git a/contrib/mipnerf/python/jnerf/ops/code_ops/global_vars.py b/contrib/mipnerf/python/jnerf/ops/code_ops/global_vars.py
index 9e1accb4..3b835247 100644
--- a/contrib/mipnerf/python/jnerf/ops/code_ops/global_vars.py
+++ b/contrib/mipnerf/python/jnerf/ops/code_ops/global_vars.py
@@ -1,27 +1,53 @@
 import os
+import sys
 import jittor as jt
 jt.flags.use_cuda = 1
 
-global_headers = """
+export_gloabl = ''
+import_global = 'extern'
+if sys.platform == "win32":
+    export_gloabl = '__declspec(dllexport)'
+    import_global = '__declspec(dllimport) extern'
+
+global_headers = f"""
 #include "pcg32.h"
-namespace jittor {
-extern int global_var1;
-extern pcg32 rng;
-}
+namespace jittor {{
+EXTERN_LIB int global_var1;
+EXTERN_LIB pcg32 rng;
+}}
 """
 
-global_src = """
-namespace jittor {
-int global_var1 = 123;
-pcg32 rng{1337};
-}
+global_decl_headers = f"""
+#include "pcg32.h"
+namespace jittor {{
+{export_gloabl} int global_var1;
+{export_gloabl} pcg32 rng;
+}}
+"""
+
+global_src = f"""
+#include "pcg32.h"
+namespace jittor {{
+{export_gloabl} int global_var1 = 123;
+{export_gloabl} pcg32 rng{{1337}};
+}}
 """
 
 proj_path = os.path.join(os.path.dirname(__file__), '..', 'op_include')
-proj_options = { f"FLAGS: -I{proj_path}/eigen -I{proj_path}/include -I{proj_path}/pcg32 -I{proj_path}/../op_header -DGLOBAL_VAR --extended-lambda --expt-relaxed-constexpr": 1 }
+if sys.platform == "linux":
+    proj_options = { f"FLAGS: -I{proj_path}/eigen -I{proj_path}/include -I{proj_path}/pcg32 -I{proj_path}/../op_header -DGLOBAL_VAR --extended-lambda --expt-relaxed-constexpr": 1 }
+else:
+    proj_options = { f'FLAGS: -I"{proj_path}/eigen" -I"{proj_path}/include" -I"{proj_path}/pcg32" -I"{proj_path}/../op_header" -DGLOBAL_VAR --extended-lambda --expt-relaxed-constexpr': 1 }
+
+jt.profiler.start()
 gv = jt.code([1], int, 
-    cuda_header=global_headers+global_src, 
+    cuda_header=global_src, 
     cuda_src="""
 """)
 gv.compile_options = proj_options
 gv.sync()
+jt.profiler.stop()
+
+if os.name == "nt":
+    dll_name = jt.profiler.report()[-1][-10].replace(".cc", "")
+    proj_options[f'FLAGS: -l{dll_name} '] = 1
\ No newline at end of file
diff --git a/python/jnerf/__init__.py b/python/jnerf/__init__.py
index e36a6547..5f5bc00a 100644
--- a/python/jnerf/__init__.py
+++ b/python/jnerf/__init__.py
@@ -5,7 +5,7 @@
 dirname = os.path.dirname(__file__)
 LOG.i(f"JNeRF({__version__}) at {dirname}")
 import sys
-assert sys.platform == "linux", "Windows/MacOS is not supported yet, everyone is welcome to contribute to this"
+assert sys.platform == "linux" or sys.platform == "win32" # "MacOS is not supported yet, everyone is welcome to contribute to this"
 
 sp_char = ' "\''
 for char in sp_char:
diff --git a/python/jnerf/dataset/dataset.py b/python/jnerf/dataset/dataset.py
index 9ae9960e..06b1a79d 100755
--- a/python/jnerf/dataset/dataset.py
+++ b/python/jnerf/dataset/dataset.py
@@ -113,7 +113,7 @@ def load_data(self,root_dir=None):
             matrix=np.array(frame['transform_matrix'],np.float32)[:-1, :]
             self.transforms_gpu.append(
                             self.matrix_nerf2ngp(matrix, self.scale, self.offset))
-                           
+            
         self.resolution=[self.W,self.H]
         self.resolution_gpu=jt.array(self.resolution)
         metadata=np.empty([11],np.float32)
diff --git a/python/jnerf/models/position_encoders/hash_encoder/grid_encode.py b/python/jnerf/models/position_encoders/hash_encoder/grid_encode.py
index 2d6369d6..4688bfa3 100755
--- a/python/jnerf/models/position_encoders/hash_encoder/grid_encode.py
+++ b/python/jnerf/models/position_encoders/hash_encoder/grid_encode.py
@@ -61,7 +61,10 @@ def __init__(self, hash_func_header, aabb_scale=1, n_pos_dims=3, n_features_per_
         self.m_grid_gradient = jt.empty([m_n_params], self.grad_type)
         self.m_stochastic_interpolation = 0
         header_path = os.path.join(os.path.dirname(__file__), 'op_header')
-        proj_options[f"FLAGS: -I{header_path}"]=1
+        if sys.platform == "linux":
+            proj_options[f"FLAGS: -I{header_path}"]=1
+        else:
+            proj_options[f'FLAGS: -I"{header_path}"']=1
 
     def execute(self, x,m_grid):
         self.num_elements=x.shape[0]
@@ -87,7 +90,7 @@ def execute(self, x,m_grid):
 		const uint32_t blocks = div_round_up(num_elements, threads.x);
         extract_position<float,N_POS_DIMS><<<blocks, threads, 0, stream>>>(
 			num_elements,
-		{{in1_p,in1_shape1}},
+		{{in1_p,(size_t)in1_shape1}},
 			m_positions_p
 		);
         static constexpr uint32_t N_THREADS_HASHGRID = 512;
@@ -115,7 +118,7 @@ def execute(self, x,m_grid):
         const dim3 threads_transpose = {{ {self.m_n_levels}, 8, 1 }};
 		const uint32_t blocks_transpose = div_round_up(num_elements, threads_transpose.y);
        
-        PitchedPtr<grad_t> outputs{{ out0_p,out0_shape1 }};
+        PitchedPtr<grad_t> outputs{{ out0_p,(size_t)out0_shape1 }};
            
 	    transpose_encoded_position<vector_t<grad_t,N_FEATURES_PER_LEVEL>><<<blocks_transpose, threads_transpose, 0, stream>>>(
 			num_elements,
@@ -149,7 +152,7 @@ def grad(self, grad_x):
         const unsigned int N_FEATURES_PER_LEVEL={self.N_FEATURES_PER_LEVEL};                 
         cudaStream_t stream=0;
 	    const dim3 threads_transpose ={{  {self.m_n_levels} , 8, 1}};
-        PitchedPtr<grad_t> dL_dy{{ in2_p,in2_shape1 }};
+        PitchedPtr<grad_t> dL_dy{{ in2_p,(size_t)in2_shape1 }};
         cudaMemsetAsync(out0_p, 0, out0->size);    
         const uint32_t blocks_transpose = div_round_up(num_elements, threads_transpose.y);  
         transpose_gradients<vector_t<grad_t, N_FEATURES_PER_LEVEL>><<<blocks_transpose, threads_transpose, 0, stream>>>(
diff --git a/python/jnerf/models/position_encoders/sh_encoder/sh_encoder.py b/python/jnerf/models/position_encoders/sh_encoder/sh_encoder.py
index fd46f9d1..50b8fd2c 100755
--- a/python/jnerf/models/position_encoders/sh_encoder/sh_encoder.py
+++ b/python/jnerf/models/position_encoders/sh_encoder/sh_encoder.py
@@ -2,6 +2,7 @@
 import jittor as jt
 from jittor import Function
 import numpy as np
+import sys
 from jnerf.ops.code_ops.global_vars import global_headers,proj_options
 from jnerf.utils.config import get_cfg
 from jnerf.utils.registry import ENCODERS
@@ -20,7 +21,10 @@ def __init__(self) :
         else:
             self.grad_type='float32'
         header_path = os.path.join(os.path.dirname(__file__), 'op_header')
-        proj_options[f"FLAGS: -I{header_path}"]=1
+        if sys.platform == "linux":
+            proj_options[f"FLAGS: -I{header_path}"]=1
+        else:
+            proj_options[f'FLAGS: -I"{header_path}"']=1
         self.out_dim=self.m_n_padded_output_dims
     
     def execute(self,x) :
@@ -37,8 +41,8 @@ def execute(self,x) :
        
         cudaStream_t stream=0;
     
-        PitchedPtr<const float> inputs={{in0_p,in0_shape1}};
-		PitchedPtr<grad_t> outputs={{out_p,out_shape1}};
+        PitchedPtr<const float> inputs={{in0_p,(size_t)in0_shape1}};
+		PitchedPtr<grad_t> outputs={{out_p,(size_t)out_shape1}};
 		float* dy_dx = nullptr;
         linear_kernel(kernel_sh<grad_t>, 0, stream,
 			num_elements,
diff --git a/python/jnerf/models/samplers/density_grid_sampler/density_grid_sampler.py b/python/jnerf/models/samplers/density_grid_sampler/density_grid_sampler.py
index 607df9aa..7a7f2ad0 100644
--- a/python/jnerf/models/samplers/density_grid_sampler/density_grid_sampler.py
+++ b/python/jnerf/models/samplers/density_grid_sampler/density_grid_sampler.py
@@ -1,5 +1,6 @@
 import os
 import jittor as jt
+import sys
 from jittor import nn
 from .ema_grid_samples_nerf import ema_grid_samples_nerf
 from .generate_grid_samples_nerf_nonuniform import generate_grid_samples_nerf_nonuniform
@@ -91,7 +92,10 @@ def __init__(self, update_den_freq=16, update_block_size=5000000):
         self.dataset_ray_data = False  # 数据集是否包含光线信息
         
         header_path = os.path.join(os.path.dirname(__file__), 'op_header')
-        proj_options[f"FLAGS: -I{header_path}"]=1
+        if sys.platform == "linux":
+            proj_options[f"FLAGS: -I{header_path}"]=1
+        else:
+            proj_options[f'FLAGS: -I"{header_path}"']=1
 
         self.density_grad_header = f"""
         inline constexpr __device__ uint32_t NERF_GRIDSIZE() {{ return {self.NERF_GRIDSIZE}; }} // size of the density/occupancy grid.
diff --git a/python/jnerf/models/samplers/density_grid_sampler/generate_grid_samples_nerf_nonuniform.py b/python/jnerf/models/samplers/density_grid_sampler/generate_grid_samples_nerf_nonuniform.py
index 92a694ab..2b3f7d40 100755
--- a/python/jnerf/models/samplers/density_grid_sampler/generate_grid_samples_nerf_nonuniform.py
+++ b/python/jnerf/models/samplers/density_grid_sampler/generate_grid_samples_nerf_nonuniform.py
@@ -46,11 +46,8 @@ def execute(self, density_grid, n_elements, density_grid_ema_step, max_cascade,
 
         
         """)
+        # print(proj_options)
         output[0].compile_options = proj_options
         output[0].sync()
         output[1].sync()
         return output
-
-
-
-
diff --git a/python/jnerf/ops/code_ops/fully_fused_mlp.py b/python/jnerf/ops/code_ops/fully_fused_mlp.py
index 5540de1f..15d18fbc 100644
--- a/python/jnerf/ops/code_ops/fully_fused_mlp.py
+++ b/python/jnerf/ops/code_ops/fully_fused_mlp.py
@@ -21,8 +21,8 @@ def __init__(self, weights, check_mid="0", output_activation="Activation::None")
         self.width = 0
         self.output_intermediate = None
         con_weights = []
-        self.code_path = pathlib.Path(__file__+"/../op_header").resolve()
-        self.so_name = os.path.join(pathlib.Path(__file__+"/../op_header").resolve(), "fully_fused_mlp_function.o")
+        self.code_path = pathlib.Path(__file__+"/../").resolve()
+        self.so_name = os.path.join(pathlib.Path(__file__+"/../op_header").resolve(), "fully_fused_mlp_function.cc")
         for i in range(len(weights)):
             if i == 0:
                 self.weight_shape0 = weights[0].shape[0]
@@ -81,7 +81,7 @@ def execute(self, a, con_weights):
         else:
             self.padded_input = self.input
         self.outputs, self.output_intermediate = jt.code([(self.padded_input.shape[0], 16), (self.padded_input.shape[0] * (len(self.weights) - 1), self.width)], [a.dtype, a.dtype], [self.padded_input, con_weights], cuda_header=cuda_header, cuda_src=cuda_src)
-        self.outputs.compile_options = {f"FLAGS: -I{self.code_path} -Xlinker {self.so_name} ":1}
+        self.outputs.compile_options = {f"FLAGS: -I{self.code_path}":1}
         self.con_weights = con_weights
         return self.outputs[:self.input.shape[0]]
 
@@ -115,7 +115,7 @@ def grad(self, grads):
         );
         '''
         output, grad_temps = jt.code([(self.padded_input.shape[0], self.input.shape[1]), ((len(self.weights)-1) * self.padded_input.shape[0],  self.width)], [self.input.dtype, self.input.dtype], [grads.transpose(), self.con_weights, self.output_intermediate], cuda_header=cuda_header, cuda_src=cuda_src)
-        output.compile_options = {f"FLAGS: -I{self.code_path} -Xlinker {self.so_name} ":1}
+        output.compile_options = {f"FLAGS: -I{self.code_path}":1}
         if self.check_mid == "1":
             self.grad_temps = grad_temps
         if not need_last:
diff --git a/python/jnerf/ops/code_ops/fully_fused_mlp_function.cc b/python/jnerf/ops/code_ops/fully_fused_mlp_function.cc
new file mode 100644
index 00000000..c61c226b
--- /dev/null
+++ b/python/jnerf/ops/code_ops/fully_fused_mlp_function.cc
@@ -0,0 +1,580 @@
+#pragma once
+#undef out
+#include "fully_fused_mlp_header.h"
+// implement temp GPUMatrix and GPUDynamicMatrix here.
+#define RM MatrixLayout::kRowMajor
+#define CM MatrixLayout::kColumnMajor
+
+template <typename T>
+__device__ __host__ T div_round_up(T val, T divisor) {
+    return (val + divisor - 1) / divisor;
+}
+// warp activation defined here
+template <typename T, typename fragment_t>
+__host__ __device__ void warp_activation(Activation activation, const fragment_t& frag, fragment_t& result) {
+    switch (activation) {
+        case Activation::ReLU:
+            #pragma unroll
+            for (int t=0; t < result.num_elements; t++) {
+                result.x[t] = frag.x[t] * (T)((T)frag.x[t] > (T)0.0f);
+            }
+            return;
+        case Activation::None: result = frag; return;
+        default:
+            // Unsupported activation
+            // assert(false); // Commented out due to isolated strange side-effects on Windows
+            return;
+    }
+}
+template <typename T, typename fragment_t>
+__host__ __device__ fragment_t warp_activation(Activation activation, const fragment_t& frag) {
+    fragment_t result;
+    warp_activation<T>(activation, frag, result);
+    return result;
+}
+template <typename T, typename fragment_t, typename forward_fragment_t>
+__host__ __device__ void warp_activation_backward(Activation activation, const fragment_t& frag, const forward_fragment_t& forward_frag, fragment_t& result) {
+    switch (activation) {
+        case Activation::ReLU:
+            #pragma unroll
+            for (int t=0; t < result.num_elements; t++) {
+                result.x[t] = frag.x[t] * (T)(forward_frag.x[t] > (T)0.0f);
+            }
+            return;
+        case Activation::None: result = frag; return;
+        default:
+            // Unsupported activation
+            // assert(false); // Commented out due to isolated strange side-effects on Windows
+            return;
+    }
+}
+template <typename T, typename fragment_t, typename forward_fragment_t>
+__host__ __device__ fragment_t warp_activation_backward(Activation activation, const fragment_t& frag, const forward_fragment_t& forward_frag) {
+    fragment_t result;
+    warp_activation_backward<T>(activation, frag, forward_frag, result);
+    return result;
+}
+void check_shmem_error(cudaError_t error) {
+    if (error != cudaSuccess) {
+        throw std::runtime_error{"FullyFusedMLP: insufficient shared memory available on the GPU. Reduce `n_neurons` or use `CutlassMLP` (better compatibility but slower) instead."};
+    }
+}
+template <int WIDTH, int N_ITERS, typename OUT_T, bool BACKWARD=false>
+__device__ void threadblock_layer(Activation activation, __half* __restrict__ act_shmem, const __half* __restrict__ weights_this_layer, OUT_T* __restrict__ out_intermediate_threadblock_this_layer, const OUT_T* __restrict__ activation_aux = nullptr) {
+    // act_shmem contains the intermediate activations (shared memory) of the thread block's chunk of the batch.
+    //           Can be forward activations or backward activations, depending on caller.
+    // weights_this_layer points to the weight matrix of the current layer.
+    // out_intermediate_threadblock_this_layer points to the location where intermediate activations produced by the thread block should be written to.
+    //                  Can be nullptr if nothing should be written.
+    // activation_aux points to additional arguments that the activation function may depend on. Points to the hidden forward activations when computing backward activations.
+    constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0;
+    constexpr uint32_t N_BLOCKS = WIDTH / 16;
+    using namespace nvcuda;
+    // If we're performing the backward pass, weights must be loaded in transposed form, which
+    // is achieved by interpreting the memory in row_major instead of col_major order.
+    using weights_layout_t = std::conditional_t<BACKWARD, wmma::row_major, wmma::col_major>;
+    // Fragments
+    wmma::fragment<wmma::matrix_a, 16, 16, 16, __half, wmma::row_major> act_frag;
+    wmma::fragment<wmma::matrix_b, 16, 16, 16, __half, weights_layout_t> weights_frag[N_BLOCKS];
+    wmma::fragment<wmma::accumulator, 16, 16, 16, OUT_T> result_frag[N_ITERS];
+    // Indices
+    const uint32_t li = threadIdx.x; // index in warp ("lane index")
+    const uint32_t wi = threadIdx.y; // index in block ("warp index")
+    const uint32_t lane_offset = (8 * li) % WIDTH;
+    const uint32_t row = (8 * li + wi * 8 * 32) / WIDTH;
+    const uint32_t weights_col = 16 * wi;
+    __syncthreads();
+    // Load N_BLOCKS chunks of weights from global memory into registers.
+    #pragma unroll
+    for (uint32_t i = 0; i < N_BLOCKS; ++i) {
+        if (BACKWARD) {
+            // If we're performing the backward pass, additional index swizzling is needed to
+            // load the weights in transposed form.
+            wmma::load_matrix_sync(weights_frag[i], weights_this_layer + 16 * i * WIDTH + weights_col, WIDTH);
+        } else {
+            wmma::load_matrix_sync(weights_frag[i], weights_this_layer + 16 * i + weights_col * WIDTH, WIDTH);
+        }
+    }
+    #pragma unroll
+    for (int l = 0; l < N_ITERS; ++l) {
+        wmma::fill_fragment(result_frag[l], 0.0f);
+        #pragma unroll
+        for (uint32_t i = 0; i < N_BLOCKS; ++i) {
+            // Load a chunk of intermediate activations from shared memory and multiply with chunk of weights
+            wmma::load_matrix_sync(act_frag, act_shmem + 16 * i + (16 * l) * (WIDTH + SKEW), WIDTH + SKEW);
+            wmma::mma_sync(result_frag[l], act_frag, weights_frag[i], result_frag[l]);
+        }
+        // Activation
+        if (BACKWARD) {
+            // Load the temporary forward matrix for the relu transfer
+            wmma::load_matrix_sync(act_frag, activation_aux + weights_col + l * 16 * WIDTH, WIDTH);
+            warp_activation_backward<__half>(activation, result_frag[l], act_frag, result_frag[l]);
+        } else {
+            warp_activation<__half>(activation, result_frag[l], result_frag[l]);
+        }
+    }
+    __syncthreads();
+    #pragma unroll
+    for (int l = 0; l < N_ITERS; ++l) {
+        wmma::store_matrix_sync(act_shmem + weights_col + l * 16 * (WIDTH + SKEW), result_frag[l], WIDTH + SKEW, wmma::mem_row_major);
+    }
+    if (out_intermediate_threadblock_this_layer != nullptr) {
+        __syncthreads();
+        #pragma unroll
+        for (int l = 0; l < N_ITERS; ++l) {
+            *(int4*)&out_intermediate_threadblock_this_layer[lane_offset + (row + 16 * l) * WIDTH] = *(int4*)&act_shmem[lane_offset + (row + 16 * l) * (WIDTH + SKEW)];
+        }
+    }
+}
+template <int WIDTH, int N_ITERS>
+__device__ void threadblock_load_input_static(__half* __restrict__ act_shmem, const __half* __restrict__ input_threadblock) {
+    // act_shmem will be filled by the thread block's chunk of input_threadblock
+    constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0;
+    // Indices
+    const uint32_t li = threadIdx.x; // index in warp ("lane index")
+    const uint32_t wi = threadIdx.y; // index in block ("warp index")
+    const uint32_t lane_offset = (8 * li) % WIDTH;
+    const uint32_t row = (8 * li + wi * 8 * 32) / WIDTH;
+    #pragma unroll
+    for (int i = 0; i < N_ITERS; ++i) {
+        *(int4*)&act_shmem[lane_offset + (row + 16 * i) * (WIDTH + SKEW)] = *(int4*)&input_threadblock[lane_offset + (row + 16 * i) * WIDTH];
+    }
+}
+template <int WIDTH, int N_ITERS, Activation ACTIVATION, typename OUTPUT_LAYOUT>
+__global__ void kernel_mlp_fused_backward(const __half* __restrict__ dL_doutput, const __half* __restrict__ weights, __half* __restrict__ out_intermediate, const __half* __restrict__ forward, __half* __restrict__ dL_dinput, const __half* __restrict__ weights_first_layer, const uint32_t batch_size, const uint32_t out_width, const uint32_t n_hidden_matmuls, int need_last) {
+    // `dL_doutput` points to the input matrix of the backward pass, i.e. the loss gradients. Assumed to be 16 neurons wide.
+    // `weights` points to the weight matrices (contiguous in memory).
+    // `out_intermediate` points to the memory where backpropagated activation gradients should be written.
+    // `forward` points to the memory where the intermediate activations of the forward pass are located. (needed for activation backprop)
+    constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0;
+    // Indices
+    const uint32_t li = threadIdx.x; // index in warp ("lane index")
+    const uint32_t wi = threadIdx.y; // index in block ("warp index")
+    const uint32_t bi = blockIdx.x;	 // block index
+    // Shared memory contains the intermediate activations of blockDim.y*16 elements.
+    // A skew is applied to the matrix storage to avoid bank conflicts.
+    extern __shared__ __half shmem[];
+    __half* act_shmem = shmem;
+    const uint32_t lane_offset = (8 * li) % WIDTH;
+    const uint32_t row = (8 * li + wi * 8 * 32) / WIDTH;
+    // Multipying one 16-row chunk of intermediate activations with the weight matrix requires all warps of the block.
+    // Thus, each block computes exactly one 16-row chunk of the next layer's intermediate activations.
+    const uint32_t elem_idx_base = 16 * bi * N_ITERS;
+    const uint32_t elem_idx = elem_idx_base;
+    const uint32_t layer_stride = WIDTH * WIDTH;
+    const uint32_t output_stride = WIDTH * batch_size;
+    // Backprop through last layer
+    if (out_width <= 16) {
+        using namespace nvcuda;
+        // Fragments in registers
+        wmma::fragment<wmma::matrix_a, 16, 16, 16, __half, OUTPUT_LAYOUT> act_frag;
+        wmma::fragment<wmma::matrix_b, 16, 16, 16, __half, wmma::row_major> weights_frag;
+        wmma::fragment<wmma::accumulator, 16, 16, 16, __half> result_frag[N_ITERS];
+        // Load the relevant chunk of the last layer's weight matrix from global memory into registers
+        const uint32_t weights_col = 16 * wi;
+        wmma::load_matrix_sync(weights_frag, weights + layer_stride * n_hidden_matmuls + weights_col, WIDTH);
+        #pragma unroll
+        for (int l = 0; l < N_ITERS; ++l) {
+            wmma::fill_fragment(result_frag[l], 0.0f);
+            // Load a chunk of output gradients from shared memory and multiply with previously loaded weights
+            if (std::is_same<OUTPUT_LAYOUT, wmma::row_major>::value) {
+                wmma::load_matrix_sync(act_frag, dL_doutput + (elem_idx + 16 * l) * 16, 16);
+            } else {
+                wmma::load_matrix_sync(act_frag, dL_doutput + (elem_idx + 16 * l), batch_size);
+            }
+            // NOTE: activation transfer of the _output_ activation is expected to be done _prior_ to calling this kernel
+            //       in a separate pass, because the tranfered activation gradient is also needed to compute the weight
+            //       gradient of the last weight matrix (see backward()).
+            wmma::mma_sync(result_frag[l], act_frag, weights_frag, result_frag[l]);
+            // Load the temporary forward matrix for the relu transfer
+            wmma::fragment<wmma::matrix_a, 16, 16, 16, __half, wmma::row_major> forward_frag;
+            wmma::load_matrix_sync(forward_frag, forward + output_stride * n_hidden_matmuls + weights_col + (elem_idx + l * 16) * WIDTH, WIDTH);
+            warp_activation_backward<__half>(ACTIVATION, result_frag[l], forward_frag, result_frag[l]);
+        }
+        __syncthreads();
+        #pragma unroll
+        for (int l = 0; l < N_ITERS; ++l) {
+            wmma::store_matrix_sync(act_shmem + weights_col + (16 * l) * (WIDTH + SKEW), result_frag[l], WIDTH + SKEW, wmma::mem_row_major);
+        }
+        __syncthreads();
+        #pragma unroll
+        for (int i = 0; i < N_ITERS; ++i) {
+            *(int4*)&out_intermediate[lane_offset + (row + elem_idx + i * 16) * WIDTH] = *(int4*)&act_shmem[lane_offset + (row + 16 * i) * (WIDTH + SKEW)];
+        }
+    } 
+    // else {
+    // 	// If the output width is larger than 16, we will have used CUTLASS for backpropping through the last layer.
+    // 	// Load the resulting gradients.
+    // 	threadblock_load_input_static<WIDTH, N_ITERS>(act_shmem, out_intermediate + elem_idx * WIDTH);
+    // }
+    // Backprop through hidden layers
+    for (uint32_t k = 0; k < n_hidden_matmuls; ++k) {
+        threadblock_layer<WIDTH, N_ITERS, __half, true>(ACTIVATION, act_shmem, weights + layer_stride * (n_hidden_matmuls - k - 1), out_intermediate + output_stride * (k + 1) + elem_idx_base * WIDTH, forward + output_stride * (n_hidden_matmuls - k - 1) + elem_idx_base * WIDTH);
+    }
+    // Compute loss gradients w.r.t. input if desired.
+    // THIS CODE ASSUMES THAT THE INPUT WIDTH IS THE SAME AS THE NETWORK WIDTH
+    // AND THAT THE INPUT LAYOUT IS THE SAME AS THE HIDDEN LAYOUT.
+    // DON'T PASS A NON-NULL dL_dinput IF THIS REQUIREMENT IS NOT MET.
+    if (dL_dinput != nullptr && need_last) {
+        threadblock_layer<WIDTH, N_ITERS, __half, true>(Activation::None, act_shmem, weights_first_layer, dL_dinput + elem_idx_base * WIDTH);
+    }
+}
+template <int WIDTH, typename T, Activation ACTIVATION>
+void mlp_fused_backward(
+    cudaStream_t stream,
+    T* weights_first_layer,
+    T* weights,
+    T* dL_doutput,
+    T* temps,
+    T* forward,
+    T* dL_dinput,
+    const uint32_t n_hidden_matmuls,
+    int grad_shape0,
+    int grad_shape1,
+    int need_last
+) {
+    const uint32_t batch_size = grad_shape0;
+    const uint32_t out_width = grad_shape1;
+    constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0;
+    constexpr uint32_t N_BLOCKS = WIDTH / 16;
+    // if (forward.cols() != batch_size) {
+    // 	throw std::runtime_error{"Batch size of matrices dL_doutput and temporaries doesn't match."};
+    // }
+    const int N_ITERS = WIDTH >= 256 ? 2 : 8;
+    // if (batch_size % (16 * N_ITERS) != 0) {
+    // 	throw std::runtime_error{"Batch size must be a multiple of " + std::to_string(16 * N_ITERS) + "."};
+    // }
+    const dim3 threads = { 32u, N_BLOCKS, 1 }; // 32 threads = 1 warp, 8 warps per block for 16 rows, up to 2x 8 warps can share input (does not help vs. 1)
+    uint32_t n_elems_per_block = 16 * N_ITERS;
+    uint32_t n_blocks = div_round_up(batch_size, n_elems_per_block);
+    int shmem_size = sizeof(__half) * ((16 * N_ITERS) * (WIDTH + SKEW)); // WIDTH rows of input and 16 * threads.z rows of weights
+    const dim3 blocks = { n_blocks, 1u, 1u };
+    // The kernels operate with transposed layouts compared with the MLP code
+    // if (dL_doutput.layout() == RM) {
+    // 	check_shmem_error(cudaFuncSetAttribute(kernel_mlp_fused_backward<WIDTH, N_ITERS, ACTIVATION, nvcuda::wmma::col_major>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size));
+    // 	kernel_mlp_fused_backward<WIDTH, N_ITERS, ACTIVATION, nvcuda::wmma::col_major><<<blocks, threads, shmem_size, stream>>>(dL_doutput.data(), weights.data(), temporaries.data(), forward.data(), dL_dinput ? dL_dinput->data() : nullptr, weights_first_layer.data(), dL_doutput.stride(), batch_size, out_width, n_hidden_matmuls);
+    // } else {
+    // 	check_shmem_error(cudaFuncSetAttribute(kernel_mlp_fused_backward<WIDTH, N_ITERS, ACTIVATION, nvcuda::wmma::row_major>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size));
+    // 	kernel_mlp_fused_backward<WIDTH, N_ITERS, ACTIVATION, nvcuda::wmma::row_major><<<blocks, threads, shmem_size, stream>>>(dL_doutput.data(), weights.data(), temporaries.data(), forward.data(), dL_dinput ? dL_dinput->data() : nullptr, weights_first_layer.data(), dL_doutput.stride(), batch_size, out_width, n_hidden_matmuls);
+    // }
+    check_shmem_error(cudaFuncSetAttribute(kernel_mlp_fused_backward<WIDTH, N_ITERS, ACTIVATION, nvcuda::wmma::col_major>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size));
+    kernel_mlp_fused_backward<WIDTH, N_ITERS, ACTIVATION, nvcuda::wmma::col_major><<<blocks, threads, shmem_size, stream>>>(dL_doutput, weights, temps, forward, dL_dinput, weights_first_layer, batch_size, out_width, n_hidden_matmuls, need_last);
+}
+template <int WIDTH, int N_ITERS, typename OUT_T, typename INPUT_LAYOUT>
+__device__ void threadblock_input_layer_forward_dynamic(Activation activation, __half* __restrict__ act_shmem, const __half* __restrict__ input_threadblock, const __half* __restrict__ weights_this_layer, OUT_T* __restrict__ out_intermediate_threadblock_this_layer, const uint32_t in_width, const uint32_t batch_size) {
+    // act_shmem contains the intermediate activations (shared memory) of the thread block's chunk of the batch
+    // input_threadblock points to the thread block's chunk of the input batch in global memory
+    // weights_this_layer points to the weight matrix of the current layer
+    // out_intermediate_threadblock_this_layer points to the location where intermediate activations produced by the thread block should be written to.
+    //                  Can be nullptr if nothing should be written.
+    // in_width is the dynamic width of the input layer
+    constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0;
+    constexpr uint32_t INPUT_SKEW = 8;
+    constexpr uint32_t N_BLOCKS = WIDTH / 16;
+    using namespace nvcuda;
+    // Fragments
+    wmma::fragment<wmma::matrix_a, 16, 16, 16, __half, INPUT_LAYOUT> act_frag;
+    wmma::fragment<wmma::matrix_b, 16, 16, 16, __half, wmma::col_major> weights_frag;
+    wmma::fragment<wmma::accumulator, 16, 16, 16, OUT_T> result_frag[N_ITERS];
+    // Indices
+    const uint32_t li = threadIdx.x; // index in warp ("lane index")
+    const uint32_t wi = threadIdx.y; // index in block ("warp index")
+    const uint32_t lane_offset = (8 * li) % WIDTH;
+    const uint32_t row = (8 * li + wi * 8 * 32) / WIDTH;
+    const uint32_t weights_col = 16 * wi;
+    __half* __restrict__ weights_shmem = act_shmem + 16 * (in_width + INPUT_SKEW);
+    // Load input weight matrix (fits completely into shared memory)
+    // Each thread can load 8 fp16 elements (16 bytes) at once; we have N_BLOCKS warps
+    const uint32_t n_elems_per_load = N_BLOCKS * 32 * 8;
+    const uint32_t thread_elem_idx = (li + wi * 32) * 8;
+    const uint32_t n_elems_b = WIDTH * in_width;
+    #pragma unroll
+    for (uint32_t idx = thread_elem_idx; idx < n_elems_b; idx += n_elems_per_load) {
+        const uint32_t idx_skewed = idx + idx / in_width * INPUT_SKEW;
+        *(int4*)&weights_shmem[idx_skewed] = *(int4*)&weights_this_layer[idx];
+    }
+    const uint32_t n_tensor_ops = in_width / 16;
+    if (std::is_same<INPUT_LAYOUT, wmma::col_major>::value) {
+        __syncthreads();
+    }
+    #pragma unroll
+    for (int l = 0; l < N_ITERS; ++l) {
+        if (std::is_same<INPUT_LAYOUT, wmma::row_major>::value) {
+            // Load chunk of inputs into shmem.
+            // This is faster than loading it from gmem directly, even though it is only used once.
+            // (Possibly due to latency hiding through staging.)
+            const uint32_t n_elems_a = 16 * in_width;
+            #pragma unroll
+            for (uint32_t idx = thread_elem_idx; idx < n_elems_a; idx += n_elems_per_load) {
+                const uint32_t idx_skewed = idx + idx / in_width * INPUT_SKEW;
+                *(int4*)&act_shmem[idx_skewed] = *(int4*)&input_threadblock[l * n_elems_a + idx];
+            }
+            __syncthreads();
+        }
+        wmma::fill_fragment(result_frag[l], 0.0f);
+        #pragma unroll
+        for (uint32_t i = 0; i < n_tensor_ops; ++i) {
+            // Load chunk of inputs and weights from shared memory and multiply them
+            if (std::is_same<INPUT_LAYOUT, wmma::row_major>::value) {
+                wmma::load_matrix_sync(act_frag, act_shmem + 16 * i, in_width + INPUT_SKEW);
+            } else {
+                wmma::load_matrix_sync(act_frag, input_threadblock + 16 * i * batch_size + 16 * l, batch_size);
+            }
+            wmma::load_matrix_sync(weights_frag, weights_shmem + 16 * i + weights_col * (in_width + INPUT_SKEW), in_width + INPUT_SKEW);
+            wmma::mma_sync(result_frag[l], act_frag, weights_frag, result_frag[l]);
+        }
+        if (std::is_same<INPUT_LAYOUT, wmma::row_major>::value) {
+            __syncthreads();
+        }
+        warp_activation<__half>(activation, result_frag[l], result_frag[l]);
+    }
+    if (std::is_same<INPUT_LAYOUT, wmma::col_major>::value) {
+        __syncthreads();
+    }
+    #pragma unroll
+    for (int l = 0; l < N_ITERS; ++l) {
+        wmma::store_matrix_sync(act_shmem + weights_col + (16 * l) * (WIDTH + SKEW), result_frag[l], WIDTH + SKEW, wmma::mem_row_major);
+    }
+    if (out_intermediate_threadblock_this_layer != nullptr) {
+        __syncthreads();
+        #pragma unroll
+        for (int i = 0; i < N_ITERS; ++i) {
+            *(int4*)&out_intermediate_threadblock_this_layer[lane_offset + (row + 16 * i) * WIDTH] = *(int4*)&act_shmem[lane_offset + (row + 16 * i) * (WIDTH + SKEW)];
+        }
+    }
+}
+
+template <int WIDTH, int N_ITERS, typename OUT_T>
+__device__ void threadblock_last_layer_forward(Activation activation, __half* __restrict__ act_shmem, const __half* __restrict__ weights_this_layer, OUT_T* __restrict__ out, const uint32_t batch_size, const nvcuda::wmma::layout_t output_layout, const uint32_t output_stride) {
+    // act_shmem contains the intermediate activations (shared memory) of the thread block's chunk of the batch
+    // weights_this_layer points to the weight matrix of the current layer
+    // out points to the location where the result produced by the thread block should be written to.
+    //   Can be nullptr if nothing should be written.
+    constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0;
+    constexpr uint32_t N_BLOCKS = WIDTH / 16;
+    using namespace nvcuda;
+    // Fragments
+    wmma::fragment<wmma::matrix_a, 16, 16, 16, __half, wmma::row_major> act_frag;
+    wmma::fragment<wmma::matrix_b, 16, 16, 16, __half, wmma::col_major> weights_frag[N_BLOCKS];
+    wmma::fragment<wmma::accumulator, 16, 16, 16, OUT_T> result_frag;
+    // Indices
+    const uint32_t li = threadIdx.x; // index in warp ("lane index")
+    const uint32_t wi = threadIdx.y; // index in block ("warp index")
+    __half* __restrict__ weights_shmem = act_shmem + N_ITERS * 16 * (WIDTH + SKEW);
+    const uint32_t weights_row = (8 * li) % WIDTH;
+    const uint32_t weights_col = (8 * li + 8 * 32 * wi) / WIDTH;
+    // Load weight matrix into shared memory for the last multiplication.
+    // Loading into shared memory as opposed to directly into registers is faster
+    // because unlike in the previous layers, each warp uses the same entries of the weight matrix.
+    *(int4*)&weights_shmem[weights_row + weights_col * (WIDTH + SKEW)] = *(int4*)&weights_this_layer[weights_row + weights_col * WIDTH];
+    __syncthreads();
+    #pragma unroll
+    for (uint32_t i = 0; i < N_BLOCKS; ++i)
+        wmma::load_matrix_sync(weights_frag[i], weights_shmem + 16 * i, WIDTH + SKEW);
+    // Perform last layer by parallelizing over iters
+    for (uint32_t idx = wi; idx < N_ITERS; idx += N_BLOCKS) {
+        wmma::fill_fragment(result_frag, 0.0f);
+        #pragma unroll
+        for (uint32_t i = 0; i < N_BLOCKS; ++i) {
+            // Load a chunk of intermediate activations from shared memory and multiply with chunk of the weight matrix
+            wmma::load_matrix_sync(act_frag, act_shmem + 16 * i + (16 * idx) * (WIDTH + SKEW), WIDTH + SKEW);
+            wmma::mma_sync(result_frag, act_frag, weights_frag[i], result_frag);
+        }
+
+        warp_activation<__half>(activation, result_frag, result_frag);
+        if (output_layout == wmma::mem_row_major) {
+            wmma::store_matrix_sync(out + idx * 16 * output_stride, result_frag, output_stride, output_layout);
+        } else {
+            wmma::store_matrix_sync(out + idx * 16, result_frag, batch_size, output_layout);
+        }
+    }
+}
+template <int WIDTH, int N_ITERS, typename OUT_T, Activation ACTIVATION, bool INFERENCE>
+__global__ void kernel_mlp_fused(const Activation output_activation, const __half* __restrict__ input, const __half* __restrict__ weights, OUT_T* __restrict__ out_intermediate, OUT_T* __restrict__ out, const uint32_t batch_size, const uint32_t in_width, const uint32_t out_width, const uint32_t n_hidden_matmuls, const nvcuda::wmma::layout_t input_layout, const nvcuda::wmma::layout_t output_layout, const int output_stride) {
+    // `input` points to the input matrix. Can be any width.
+    // `weights` points to the weight matrices (contiguous in memory).
+    // `out_intermediate` points to the memory where intermediate activations should be written. When performing inference, a value of nullptr is expected (intermediate results are not written).
+    // `out` points to the memory where the network output should be written. (Output width is assumed to be 16 neurons.)
+    // Commented out due to isolated strange side-effects on Windows
+    // if (INFERENCE) {
+    // 	assert(out_intermediate == nullptr);
+    // } else {
+    // 	assert(out_intermediate);
+    // }
+    // Shared memory contains the intermediate activations of blockDim.y*16 elements.
+    // In some cases, it also contains the weight matrix for the first and last layer.
+    extern __shared__ __half shmem[];
+    __half* act_shmem = shmem;
+    // Each block computes exactly one 16-element chunk of the batch.
+    const uint32_t elem_idx = 16 * blockIdx.x * N_ITERS;
+    // First layer
+    if (input_layout == nvcuda::wmma::mem_col_major || in_width != WIDTH) {
+        if (input_layout == nvcuda::wmma::mem_row_major) {
+            threadblock_input_layer_forward_dynamic<WIDTH, N_ITERS, OUT_T, nvcuda::wmma::row_major>(ACTIVATION, act_shmem, input + elem_idx * in_width, weights, !INFERENCE ? (out_intermediate + elem_idx * WIDTH) : nullptr, in_width, batch_size);
+        } else {
+            threadblock_input_layer_forward_dynamic<WIDTH, N_ITERS, OUT_T, nvcuda::wmma::col_major>(ACTIVATION, act_shmem, input + elem_idx, weights, !INFERENCE ? (out_intermediate + elem_idx * WIDTH) : nullptr, in_width, batch_size);
+        }
+    } else {
+        // If the input has the same width & layout as the hidden layers, we can simply use the network's regular layer routine (with static size)
+        // instead of using the slower dynamic input layer routine.
+        threadblock_load_input_static<WIDTH, N_ITERS>(act_shmem, input + elem_idx * WIDTH);
+        threadblock_layer<WIDTH, N_ITERS, OUT_T>(ACTIVATION, act_shmem, weights, !INFERENCE ? (out_intermediate + elem_idx * WIDTH) : nullptr);
+    }
+    const uint32_t first_layer_size = WIDTH * in_width;
+    const uint32_t weights_stride = WIDTH * WIDTH;
+    const uint32_t layer_stride = WIDTH * batch_size;
+    // Hidden layers
+    for (uint32_t k = 0; k < n_hidden_matmuls; ++k) {
+        threadblock_layer<WIDTH, N_ITERS, OUT_T>(ACTIVATION, act_shmem, weights + first_layer_size + weights_stride * k, !INFERENCE ? (out_intermediate + layer_stride * (k + 1) + elem_idx * WIDTH) : nullptr);
+    }
+    // Last layer
+    if (output_layout == nvcuda::wmma::mem_row_major) {                
+        threadblock_last_layer_forward<WIDTH, N_ITERS, OUT_T>(output_activation, act_shmem, weights + first_layer_size + weights_stride * n_hidden_matmuls, out + elem_idx * output_stride, output_stride, output_layout, output_stride);
+    } else {
+        threadblock_last_layer_forward<WIDTH, N_ITERS, OUT_T>(output_activation, act_shmem, weights + first_layer_size + weights_stride * n_hidden_matmuls, out + elem_idx, batch_size, output_layout, output_stride);
+    }
+    
+}
+template <int WIDTH, typename T, Activation ACTIVATION, bool INFERENCE>
+void mlp_fused_forward(
+    cudaStream_t stream,
+    Activation output_activation,
+    T* weights,
+    T* input,
+    T* output_intermediate,
+    T* output,
+    const uint32_t n_hidden_layers,
+    int input_shape0,
+    int input_shape1,
+    int weights_shape0,
+    int weights_shape1,
+    int output_shape0,
+    int output_shape1
+) {
+    const uint32_t batch_size = input_shape0;
+    const uint32_t in_width = input_shape1;
+    constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0; // <- always going to be 8 as we only support multiple-of-16 widths
+    constexpr uint32_t INPUT_SKEW = 8; // <- likewise with inputs
+    constexpr uint32_t N_BLOCK_ROWS = WIDTH / 16;
+    // LOGir << batch_size << " " << in_width << " " << WIDTH << " " << weights_shape0 << " " <<weights_shape1;
+    // static_assert(WIDTH % 16 == 0, "Width must be a multiply of 16.");
+    // if (in_width % 16 != 0) {
+    // 	throw std::runtime_error{"Inputs must have a multiple-of-16 elements."};
+    // }
+    // if (weights.rows() != WIDTH) {
+    // 	throw std::runtime_error{"The fully fused forward pass only works with WIDTH-sized matrices."};
+    // }
+    // if (weights.cols() % 16 != 0) {
+    // 	throw std::runtime_error{std::string("weights must have a multiple-of-16 number of columns. ") + std::to_string(weights.cols())};
+    // }
+    // if (output_intermediate.cols() != batch_size) {
+    // 	throw std::runtime_error{"Batch size of inputs and output_intermediate doesn't match."};
+    // }
+    // if (output && output->cols() != batch_size) {
+    // 	throw std::runtime_error{"Batch size of inputs and outputs doesn't match."};
+    // }
+    const int N_ITERS = WIDTH >= 256 ? 2 : 8;
+    if (batch_size % (16 * N_ITERS) != 0) {
+        throw std::runtime_error{"Batch size must be a multiple of " + std::to_string(16 * N_ITERS) + "."};
+    }
+    const dim3 threads = { 32u, N_BLOCK_ROWS, 1 }; // 32 threads = 1 warp, N_BLOCK_ROWS warps per block for 16 rows, up to 2x 8 warps can share input (does not help vs. 1)
+    uint32_t n_elems_per_block = 16 * N_ITERS;
+    uint32_t n_blocks = div_round_up(batch_size, n_elems_per_block);
+    size_t shmem_size = sizeof(__half) * (16 + 16 * N_ITERS) * (WIDTH + SKEW); // 16*WIDTH rows of weights (for the last layer; others are in registers only) + 16*WIDTH*N_ITERS rows of intermediate activations
+    shmem_size = std::max(shmem_size, sizeof(__half) * (WIDTH + 16) * (in_width + INPUT_SKEW));
+    const dim3 blocks = { n_blocks, 1u, 1u };
+    check_shmem_error(cudaFuncSetAttribute(kernel_mlp_fused<WIDTH, N_ITERS, __half, ACTIVATION, INFERENCE>, cudaFuncAttributeMaxDynamicSharedMemorySize, (int)shmem_size));
+    kernel_mlp_fused<WIDTH, N_ITERS, __half, ACTIVATION, INFERENCE><<<blocks, threads, shmem_size, stream>>>(
+        output_activation,
+        input,
+        weights,
+        output_intermediate,
+        output ? output : nullptr,
+        batch_size,
+        in_width,
+        output ? output_shape1 : 0,
+        n_hidden_layers,
+        // The kernels operate with transposed layouts compared with the MLP code
+        nvcuda::wmma::mem_row_major,
+        nvcuda::wmma::mem_row_major,
+        output_shape1
+    );
+}
+
+
+void mlp_fused_forward_func(
+    int WIDTH,
+    Activation ACTIVATION,
+    bool INFERENCE,
+    cudaStream_t stream,
+    Activation output_activation,
+    __half* weights,
+    __half* input,
+    __half* output_intermediate,
+    __half* output,
+    const uint32_t n_hidden_layers,
+    int input_shape0,
+    int input_shape1,
+    int weights_shape0,
+    int weights_shape1,
+    int output_shape0,
+    int output_shape1
+) {
+
+    #define FORWARD_ARGS \
+     stream,\
+     output_activation,\
+     weights,\
+     input,\
+     output_intermediate,\
+     output,\
+     n_hidden_layers,\
+     input_shape0,\
+     input_shape1,\
+     weights_shape0,\
+     weights_shape1,\
+     output_shape0,\
+     output_shape1
+    if (WIDTH == 64 && ACTIVATION==Activation::ReLU && INFERENCE==false) {
+        mlp_fused_forward<64, __half, Activation::ReLU, false>(FORWARD_ARGS);
+    } 
+    else {
+        std::string msg = "mlp_fused_forward_func error not supported WIDTH=" + std::string("WIDTH") + " ACTIVATION=" + std::string("ACTIVATION") + " INFERENCE=" + std::string("INFERENCE") + ", please contact us to add this support.";
+    }
+}
+
+
+void mlp_fused_backward_func(
+    int WIDTH, 
+    Activation ACTIVATION,
+    cudaStream_t stream,
+    __half* weights_first_layer,
+    __half* weights,
+    __half* dL_doutput,
+    __half* temps,
+    __half* forward,
+    __half* dL_dinput,
+    const uint32_t n_hidden_matmuls,
+    int grad_shape0,
+    int grad_shape1,
+    int need_last
+) {
+    #define BACKWARD_ARGS \
+    stream, \
+    weights_first_layer, \
+    weights, \
+    dL_doutput, \
+    temps, \
+    forward, \
+    dL_dinput, \
+    n_hidden_matmuls, \
+    grad_shape0, \
+    grad_shape1, \
+    need_last
+    if (WIDTH == 64 && ACTIVATION==Activation::ReLU) {
+        mlp_fused_backward<64, __half, Activation::ReLU>(BACKWARD_ARGS);
+    } 
+    else {
+        std::string msg = "mlp_fused_forward_func error not supported WIDTH=" + std::string("WIDTH") + " ACTIVATION=" + std::string("ACTIVATION") + " INFERENCE=" + std::string("INFERENCE") + ", please contact us to add this support.";
+    }
+}
\ No newline at end of file
diff --git a/python/jnerf/ops/code_ops/fully_fused_mlp_header.h b/python/jnerf/ops/code_ops/fully_fused_mlp_header.h
new file mode 100644
index 00000000..ca5bfd7c
--- /dev/null
+++ b/python/jnerf/ops/code_ops/fully_fused_mlp_header.h
@@ -0,0 +1,637 @@
+#pragma once
+#undef out
+#include <map>
+#include <type_traits>
+#include <stdint.h>
+#include <algorithm>
+#include <stdexcept>
+#include <mma.h>
+#define RM MatrixLayout::kRowMajor
+#define CM MatrixLayout::kColumnMajor
+typedef enum MatrixLayout {
+kColumnMajor,
+kRowMajor
+} MatrixLayout;
+
+typedef enum Activation {
+    ReLU,
+    Exponential,
+    Sine,
+    Sigmoid,
+    Squareplus,
+    Softplus,
+    None,
+} Activation;
+
+void mlp_fused_backward_func(
+    int WIDTH, 
+    Activation ACTIVATION,
+    cudaStream_t stream,
+    __half* weights_first_layer,
+    __half* weights,
+    __half* dL_doutput,
+    __half* temps,
+    __half* forward,
+    __half* dL_dinput,
+    const uint32_t n_hidden_matmuls,
+    int grad_shape0,
+    int grad_shape1,
+    int need_last
+);
+
+
+void mlp_fused_forward_func(
+    int WIDTH,
+    Activation ACTIVATION,
+    bool INFERENCE,
+    cudaStream_t stream,
+    Activation output_activation,
+    __half* weights,
+    __half* input,
+    __half* output_intermediate,
+    __half* output,
+    const uint32_t n_hidden_layers,
+    int input_shape0,
+    int input_shape1,
+    int weights_shape0,
+    int weights_shape1,
+    int output_shape0,
+    int output_shape1
+);
+
+#define RM MatrixLayout::kRowMajor
+#define CM MatrixLayout::kColumnMajor
+
+template <typename T>
+__device__ __host__ T div_round_up(T val, T divisor) {
+    return (val + divisor - 1) / divisor;
+}
+// warp activation defined here
+template <typename T, typename fragment_t>
+__host__ __device__ void warp_activation(Activation activation, const fragment_t& frag, fragment_t& result) {
+    switch (activation) {
+        case Activation::ReLU:
+            #pragma unroll
+            for (int t=0; t < result.num_elements; t++) {
+                result.x[t] = frag.x[t] * (T)((T)frag.x[t] > (T)0.0f);
+            }
+            return;
+        case Activation::None: result = frag; return;
+        default:
+            // Unsupported activation
+            // assert(false); // Commented out due to isolated strange side-effects on Windows
+            return;
+    }
+}
+template <typename T, typename fragment_t>
+__host__ __device__ fragment_t warp_activation(Activation activation, const fragment_t& frag) {
+    fragment_t result;
+    warp_activation<T>(activation, frag, result);
+    return result;
+}
+template <typename T, typename fragment_t, typename forward_fragment_t>
+__host__ __device__ void warp_activation_backward(Activation activation, const fragment_t& frag, const forward_fragment_t& forward_frag, fragment_t& result) {
+    switch (activation) {
+        case Activation::ReLU:
+            #pragma unroll
+            for (int t=0; t < result.num_elements; t++) {
+                result.x[t] = frag.x[t] * (T)(forward_frag.x[t] > (T)0.0f);
+            }
+            return;
+        case Activation::None: result = frag; return;
+        default:
+            // Unsupported activation
+            // assert(false); // Commented out due to isolated strange side-effects on Windows
+            return;
+    }
+}
+template <typename T, typename fragment_t, typename forward_fragment_t>
+__host__ __device__ fragment_t warp_activation_backward(Activation activation, const fragment_t& frag, const forward_fragment_t& forward_frag) {
+    fragment_t result;
+    warp_activation_backward<T>(activation, frag, forward_frag, result);
+    return result;
+}
+void check_shmem_error(cudaError_t error) {
+    if (error != cudaSuccess) {
+        throw std::runtime_error{"FullyFusedMLP: insufficient shared memory available on the GPU. Reduce `n_neurons` or use `CutlassMLP` (better compatibility but slower) instead."};
+    }
+}
+template <int WIDTH, int N_ITERS, typename OUT_T, bool BACKWARD=false>
+__device__ void threadblock_layer(Activation activation, __half* __restrict__ act_shmem, const __half* __restrict__ weights_this_layer, OUT_T* __restrict__ out_intermediate_threadblock_this_layer, const OUT_T* __restrict__ activation_aux = nullptr) {
+    // act_shmem contains the intermediate activations (shared memory) of the thread block's chunk of the batch.
+    //           Can be forward activations or backward activations, depending on caller.
+    // weights_this_layer points to the weight matrix of the current layer.
+    // out_intermediate_threadblock_this_layer points to the location where intermediate activations produced by the thread block should be written to.
+    //                  Can be nullptr if nothing should be written.
+    // activation_aux points to additional arguments that the activation function may depend on. Points to the hidden forward activations when computing backward activations.
+    constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0;
+    constexpr uint32_t N_BLOCKS = WIDTH / 16;
+    using namespace nvcuda;
+    // If we're performing the backward pass, weights must be loaded in transposed form, which
+    // is achieved by interpreting the memory in row_major instead of col_major order.
+    using weights_layout_t = std::conditional_t<BACKWARD, wmma::row_major, wmma::col_major>;
+    // Fragments
+    wmma::fragment<wmma::matrix_a, 16, 16, 16, __half, wmma::row_major> act_frag;
+    wmma::fragment<wmma::matrix_b, 16, 16, 16, __half, weights_layout_t> weights_frag[N_BLOCKS];
+    wmma::fragment<wmma::accumulator, 16, 16, 16, OUT_T> result_frag[N_ITERS];
+    // Indices
+    const uint32_t li = threadIdx.x; // index in warp ("lane index")
+    const uint32_t wi = threadIdx.y; // index in block ("warp index")
+    const uint32_t lane_offset = (8 * li) % WIDTH;
+    const uint32_t row = (8 * li + wi * 8 * 32) / WIDTH;
+    const uint32_t weights_col = 16 * wi;
+    __syncthreads();
+    // Load N_BLOCKS chunks of weights from global memory into registers.
+    #pragma unroll
+    for (uint32_t i = 0; i < N_BLOCKS; ++i) {
+        if (BACKWARD) {
+            // If we're performing the backward pass, additional index swizzling is needed to
+            // load the weights in transposed form.
+            wmma::load_matrix_sync(weights_frag[i], weights_this_layer + 16 * i * WIDTH + weights_col, WIDTH);
+        } else {
+            wmma::load_matrix_sync(weights_frag[i], weights_this_layer + 16 * i + weights_col * WIDTH, WIDTH);
+        }
+    }
+    #pragma unroll
+    for (int l = 0; l < N_ITERS; ++l) {
+        wmma::fill_fragment(result_frag[l], 0.0f);
+        #pragma unroll
+        for (uint32_t i = 0; i < N_BLOCKS; ++i) {
+            // Load a chunk of intermediate activations from shared memory and multiply with chunk of weights
+            wmma::load_matrix_sync(act_frag, act_shmem + 16 * i + (16 * l) * (WIDTH + SKEW), WIDTH + SKEW);
+            wmma::mma_sync(result_frag[l], act_frag, weights_frag[i], result_frag[l]);
+        }
+        // Activation
+        if (BACKWARD) {
+            // Load the temporary forward matrix for the relu transfer
+            wmma::load_matrix_sync(act_frag, activation_aux + weights_col + l * 16 * WIDTH, WIDTH);
+            warp_activation_backward<__half>(activation, result_frag[l], act_frag, result_frag[l]);
+        } else {
+            warp_activation<__half>(activation, result_frag[l], result_frag[l]);
+        }
+    }
+    __syncthreads();
+    #pragma unroll
+    for (int l = 0; l < N_ITERS; ++l) {
+        wmma::store_matrix_sync(act_shmem + weights_col + l * 16 * (WIDTH + SKEW), result_frag[l], WIDTH + SKEW, wmma::mem_row_major);
+    }
+    if (out_intermediate_threadblock_this_layer != nullptr) {
+        __syncthreads();
+        #pragma unroll
+        for (int l = 0; l < N_ITERS; ++l) {
+            *(int4*)&out_intermediate_threadblock_this_layer[lane_offset + (row + 16 * l) * WIDTH] = *(int4*)&act_shmem[lane_offset + (row + 16 * l) * (WIDTH + SKEW)];
+        }
+    }
+}
+template <int WIDTH, int N_ITERS>
+__device__ void threadblock_load_input_static(__half* __restrict__ act_shmem, const __half* __restrict__ input_threadblock) {
+    // act_shmem will be filled by the thread block's chunk of input_threadblock
+    constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0;
+    // Indices
+    const uint32_t li = threadIdx.x; // index in warp ("lane index")
+    const uint32_t wi = threadIdx.y; // index in block ("warp index")
+    const uint32_t lane_offset = (8 * li) % WIDTH;
+    const uint32_t row = (8 * li + wi * 8 * 32) / WIDTH;
+    #pragma unroll
+    for (int i = 0; i < N_ITERS; ++i) {
+        *(int4*)&act_shmem[lane_offset + (row + 16 * i) * (WIDTH + SKEW)] = *(int4*)&input_threadblock[lane_offset + (row + 16 * i) * WIDTH];
+    }
+}
+template <int WIDTH, int N_ITERS, Activation ACTIVATION, typename OUTPUT_LAYOUT>
+__global__ void kernel_mlp_fused_backward(const __half* __restrict__ dL_doutput, const __half* __restrict__ weights, __half* __restrict__ out_intermediate, const __half* __restrict__ forward, __half* __restrict__ dL_dinput, const __half* __restrict__ weights_first_layer, const uint32_t batch_size, const uint32_t out_width, const uint32_t n_hidden_matmuls, int need_last) {
+    // `dL_doutput` points to the input matrix of the backward pass, i.e. the loss gradients. Assumed to be 16 neurons wide.
+    // `weights` points to the weight matrices (contiguous in memory).
+    // `out_intermediate` points to the memory where backpropagated activation gradients should be written.
+    // `forward` points to the memory where the intermediate activations of the forward pass are located. (needed for activation backprop)
+    constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0;
+    // Indices
+    const uint32_t li = threadIdx.x; // index in warp ("lane index")
+    const uint32_t wi = threadIdx.y; // index in block ("warp index")
+    const uint32_t bi = blockIdx.x;	 // block index
+    // Shared memory contains the intermediate activations of blockDim.y*16 elements.
+    // A skew is applied to the matrix storage to avoid bank conflicts.
+    extern __shared__ __half shmem[];
+    __half* act_shmem = shmem;
+    const uint32_t lane_offset = (8 * li) % WIDTH;
+    const uint32_t row = (8 * li + wi * 8 * 32) / WIDTH;
+    // Multipying one 16-row chunk of intermediate activations with the weight matrix requires all warps of the block.
+    // Thus, each block computes exactly one 16-row chunk of the next layer's intermediate activations.
+    const uint32_t elem_idx_base = 16 * bi * N_ITERS;
+    const uint32_t elem_idx = elem_idx_base;
+    const uint32_t layer_stride = WIDTH * WIDTH;
+    const uint32_t output_stride = WIDTH * batch_size;
+    // Backprop through last layer
+    if (out_width <= 16) {
+        using namespace nvcuda;
+        // Fragments in registers
+        wmma::fragment<wmma::matrix_a, 16, 16, 16, __half, OUTPUT_LAYOUT> act_frag;
+        wmma::fragment<wmma::matrix_b, 16, 16, 16, __half, wmma::row_major> weights_frag;
+        wmma::fragment<wmma::accumulator, 16, 16, 16, __half> result_frag[N_ITERS];
+        // Load the relevant chunk of the last layer's weight matrix from global memory into registers
+        const uint32_t weights_col = 16 * wi;
+        wmma::load_matrix_sync(weights_frag, weights + layer_stride * n_hidden_matmuls + weights_col, WIDTH);
+        #pragma unroll
+        for (int l = 0; l < N_ITERS; ++l) {
+            wmma::fill_fragment(result_frag[l], 0.0f);
+            // Load a chunk of output gradients from shared memory and multiply with previously loaded weights
+            if (std::is_same<OUTPUT_LAYOUT, wmma::row_major>::value) {
+                wmma::load_matrix_sync(act_frag, dL_doutput + (elem_idx + 16 * l) * 16, 16);
+            } else {
+                wmma::load_matrix_sync(act_frag, dL_doutput + (elem_idx + 16 * l), batch_size);
+            }
+            // NOTE: activation transfer of the _output_ activation is expected to be done _prior_ to calling this kernel
+            //       in a separate pass, because the tranfered activation gradient is also needed to compute the weight
+            //       gradient of the last weight matrix (see backward()).
+            wmma::mma_sync(result_frag[l], act_frag, weights_frag, result_frag[l]);
+            // Load the temporary forward matrix for the relu transfer
+            wmma::fragment<wmma::matrix_a, 16, 16, 16, __half, wmma::row_major> forward_frag;
+            wmma::load_matrix_sync(forward_frag, forward + output_stride * n_hidden_matmuls + weights_col + (elem_idx + l * 16) * WIDTH, WIDTH);
+            warp_activation_backward<__half>(ACTIVATION, result_frag[l], forward_frag, result_frag[l]);
+        }
+        __syncthreads();
+        #pragma unroll
+        for (int l = 0; l < N_ITERS; ++l) {
+            wmma::store_matrix_sync(act_shmem + weights_col + (16 * l) * (WIDTH + SKEW), result_frag[l], WIDTH + SKEW, wmma::mem_row_major);
+        }
+        __syncthreads();
+        #pragma unroll
+        for (int i = 0; i < N_ITERS; ++i) {
+            *(int4*)&out_intermediate[lane_offset + (row + elem_idx + i * 16) * WIDTH] = *(int4*)&act_shmem[lane_offset + (row + 16 * i) * (WIDTH + SKEW)];
+        }
+    } 
+    // else {
+    // 	// If the output width is larger than 16, we will have used CUTLASS for backpropping through the last layer.
+    // 	// Load the resulting gradients.
+    // 	threadblock_load_input_static<WIDTH, N_ITERS>(act_shmem, out_intermediate + elem_idx * WIDTH);
+    // }
+    // Backprop through hidden layers
+    for (uint32_t k = 0; k < n_hidden_matmuls; ++k) {
+        threadblock_layer<WIDTH, N_ITERS, __half, true>(ACTIVATION, act_shmem, weights + layer_stride * (n_hidden_matmuls - k - 1), out_intermediate + output_stride * (k + 1) + elem_idx_base * WIDTH, forward + output_stride * (n_hidden_matmuls - k - 1) + elem_idx_base * WIDTH);
+    }
+    // Compute loss gradients w.r.t. input if desired.
+    // THIS CODE ASSUMES THAT THE INPUT WIDTH IS THE SAME AS THE NETWORK WIDTH
+    // AND THAT THE INPUT LAYOUT IS THE SAME AS THE HIDDEN LAYOUT.
+    // DON'T PASS A NON-NULL dL_dinput IF THIS REQUIREMENT IS NOT MET.
+    if (dL_dinput != nullptr && need_last) {
+        threadblock_layer<WIDTH, N_ITERS, __half, true>(Activation::None, act_shmem, weights_first_layer, dL_dinput + elem_idx_base * WIDTH);
+    }
+}
+template <int WIDTH, typename T, Activation ACTIVATION>
+void mlp_fused_backward(
+    cudaStream_t stream,
+    T* weights_first_layer,
+    T* weights,
+    T* dL_doutput,
+    T* temps,
+    T* forward,
+    T* dL_dinput,
+    const uint32_t n_hidden_matmuls,
+    int grad_shape0,
+    int grad_shape1,
+    int need_last
+) {
+    const uint32_t batch_size = grad_shape0;
+    const uint32_t out_width = grad_shape1;
+    constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0;
+    constexpr uint32_t N_BLOCKS = WIDTH / 16;
+    // if (forward.cols() != batch_size) {
+    // 	throw std::runtime_error{"Batch size of matrices dL_doutput and temporaries doesn't match."};
+    // }
+    const int N_ITERS = WIDTH >= 256 ? 2 : 8;
+    // if (batch_size % (16 * N_ITERS) != 0) {
+    // 	throw std::runtime_error{"Batch size must be a multiple of " + std::to_string(16 * N_ITERS) + "."};
+    // }
+    const dim3 threads = { 32u, N_BLOCKS, 1 }; // 32 threads = 1 warp, 8 warps per block for 16 rows, up to 2x 8 warps can share input (does not help vs. 1)
+    uint32_t n_elems_per_block = 16 * N_ITERS;
+    uint32_t n_blocks = div_round_up(batch_size, n_elems_per_block);
+    int shmem_size = sizeof(__half) * ((16 * N_ITERS) * (WIDTH + SKEW)); // WIDTH rows of input and 16 * threads.z rows of weights
+    const dim3 blocks = { n_blocks, 1u, 1u };
+    // The kernels operate with transposed layouts compared with the MLP code
+    // if (dL_doutput.layout() == RM) {
+    // 	check_shmem_error(cudaFuncSetAttribute(kernel_mlp_fused_backward<WIDTH, N_ITERS, ACTIVATION, nvcuda::wmma::col_major>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size));
+    // 	kernel_mlp_fused_backward<WIDTH, N_ITERS, ACTIVATION, nvcuda::wmma::col_major><<<blocks, threads, shmem_size, stream>>>(dL_doutput.data(), weights.data(), temporaries.data(), forward.data(), dL_dinput ? dL_dinput->data() : nullptr, weights_first_layer.data(), dL_doutput.stride(), batch_size, out_width, n_hidden_matmuls);
+    // } else {
+    // 	check_shmem_error(cudaFuncSetAttribute(kernel_mlp_fused_backward<WIDTH, N_ITERS, ACTIVATION, nvcuda::wmma::row_major>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size));
+    // 	kernel_mlp_fused_backward<WIDTH, N_ITERS, ACTIVATION, nvcuda::wmma::row_major><<<blocks, threads, shmem_size, stream>>>(dL_doutput.data(), weights.data(), temporaries.data(), forward.data(), dL_dinput ? dL_dinput->data() : nullptr, weights_first_layer.data(), dL_doutput.stride(), batch_size, out_width, n_hidden_matmuls);
+    // }
+    check_shmem_error(cudaFuncSetAttribute(kernel_mlp_fused_backward<WIDTH, N_ITERS, ACTIVATION, nvcuda::wmma::col_major>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size));
+    kernel_mlp_fused_backward<WIDTH, N_ITERS, ACTIVATION, nvcuda::wmma::col_major><<<blocks, threads, shmem_size, stream>>>(dL_doutput, weights, temps, forward, dL_dinput, weights_first_layer, batch_size, out_width, n_hidden_matmuls, need_last);
+}
+template <int WIDTH, int N_ITERS, typename OUT_T, typename INPUT_LAYOUT>
+__device__ void threadblock_input_layer_forward_dynamic(Activation activation, __half* __restrict__ act_shmem, const __half* __restrict__ input_threadblock, const __half* __restrict__ weights_this_layer, OUT_T* __restrict__ out_intermediate_threadblock_this_layer, const uint32_t in_width, const uint32_t batch_size) {
+    // act_shmem contains the intermediate activations (shared memory) of the thread block's chunk of the batch
+    // input_threadblock points to the thread block's chunk of the input batch in global memory
+    // weights_this_layer points to the weight matrix of the current layer
+    // out_intermediate_threadblock_this_layer points to the location where intermediate activations produced by the thread block should be written to.
+    //                  Can be nullptr if nothing should be written.
+    // in_width is the dynamic width of the input layer
+    constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0;
+    constexpr uint32_t INPUT_SKEW = 8;
+    constexpr uint32_t N_BLOCKS = WIDTH / 16;
+    using namespace nvcuda;
+    // Fragments
+    wmma::fragment<wmma::matrix_a, 16, 16, 16, __half, INPUT_LAYOUT> act_frag;
+    wmma::fragment<wmma::matrix_b, 16, 16, 16, __half, wmma::col_major> weights_frag;
+    wmma::fragment<wmma::accumulator, 16, 16, 16, OUT_T> result_frag[N_ITERS];
+    // Indices
+    const uint32_t li = threadIdx.x; // index in warp ("lane index")
+    const uint32_t wi = threadIdx.y; // index in block ("warp index")
+    const uint32_t lane_offset = (8 * li) % WIDTH;
+    const uint32_t row = (8 * li + wi * 8 * 32) / WIDTH;
+    const uint32_t weights_col = 16 * wi;
+    __half* __restrict__ weights_shmem = act_shmem + 16 * (in_width + INPUT_SKEW);
+    // Load input weight matrix (fits completely into shared memory)
+    // Each thread can load 8 fp16 elements (16 bytes) at once; we have N_BLOCKS warps
+    const uint32_t n_elems_per_load = N_BLOCKS * 32 * 8;
+    const uint32_t thread_elem_idx = (li + wi * 32) * 8;
+    const uint32_t n_elems_b = WIDTH * in_width;
+    #pragma unroll
+    for (uint32_t idx = thread_elem_idx; idx < n_elems_b; idx += n_elems_per_load) {
+        const uint32_t idx_skewed = idx + idx / in_width * INPUT_SKEW;
+        *(int4*)&weights_shmem[idx_skewed] = *(int4*)&weights_this_layer[idx];
+    }
+    const uint32_t n_tensor_ops = in_width / 16;
+    if (std::is_same<INPUT_LAYOUT, wmma::col_major>::value) {
+        __syncthreads();
+    }
+    #pragma unroll
+    for (int l = 0; l < N_ITERS; ++l) {
+        if (std::is_same<INPUT_LAYOUT, wmma::row_major>::value) {
+            // Load chunk of inputs into shmem.
+            // This is faster than loading it from gmem directly, even though it is only used once.
+            // (Possibly due to latency hiding through staging.)
+            const uint32_t n_elems_a = 16 * in_width;
+            #pragma unroll
+            for (uint32_t idx = thread_elem_idx; idx < n_elems_a; idx += n_elems_per_load) {
+                const uint32_t idx_skewed = idx + idx / in_width * INPUT_SKEW;
+                *(int4*)&act_shmem[idx_skewed] = *(int4*)&input_threadblock[l * n_elems_a + idx];
+            }
+            __syncthreads();
+        }
+        wmma::fill_fragment(result_frag[l], 0.0f);
+        #pragma unroll
+        for (uint32_t i = 0; i < n_tensor_ops; ++i) {
+            // Load chunk of inputs and weights from shared memory and multiply them
+            if (std::is_same<INPUT_LAYOUT, wmma::row_major>::value) {
+                wmma::load_matrix_sync(act_frag, act_shmem + 16 * i, in_width + INPUT_SKEW);
+            } else {
+                wmma::load_matrix_sync(act_frag, input_threadblock + 16 * i * batch_size + 16 * l, batch_size);
+            }
+            wmma::load_matrix_sync(weights_frag, weights_shmem + 16 * i + weights_col * (in_width + INPUT_SKEW), in_width + INPUT_SKEW);
+            wmma::mma_sync(result_frag[l], act_frag, weights_frag, result_frag[l]);
+        }
+        if (std::is_same<INPUT_LAYOUT, wmma::row_major>::value) {
+            __syncthreads();
+        }
+        warp_activation<__half>(activation, result_frag[l], result_frag[l]);
+    }
+    if (std::is_same<INPUT_LAYOUT, wmma::col_major>::value) {
+        __syncthreads();
+    }
+    #pragma unroll
+    for (int l = 0; l < N_ITERS; ++l) {
+        wmma::store_matrix_sync(act_shmem + weights_col + (16 * l) * (WIDTH + SKEW), result_frag[l], WIDTH + SKEW, wmma::mem_row_major);
+    }
+    if (out_intermediate_threadblock_this_layer != nullptr) {
+        __syncthreads();
+        #pragma unroll
+        for (int i = 0; i < N_ITERS; ++i) {
+            *(int4*)&out_intermediate_threadblock_this_layer[lane_offset + (row + 16 * i) * WIDTH] = *(int4*)&act_shmem[lane_offset + (row + 16 * i) * (WIDTH + SKEW)];
+        }
+    }
+}
+
+template <int WIDTH, int N_ITERS, typename OUT_T>
+__device__ void threadblock_last_layer_forward(Activation activation, __half* __restrict__ act_shmem, const __half* __restrict__ weights_this_layer, OUT_T* __restrict__ out, const uint32_t batch_size, const nvcuda::wmma::layout_t output_layout, const uint32_t output_stride) {
+    // act_shmem contains the intermediate activations (shared memory) of the thread block's chunk of the batch
+    // weights_this_layer points to the weight matrix of the current layer
+    // out points to the location where the result produced by the thread block should be written to.
+    //   Can be nullptr if nothing should be written.
+    constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0;
+    constexpr uint32_t N_BLOCKS = WIDTH / 16;
+    using namespace nvcuda;
+    // Fragments
+    wmma::fragment<wmma::matrix_a, 16, 16, 16, __half, wmma::row_major> act_frag;
+    wmma::fragment<wmma::matrix_b, 16, 16, 16, __half, wmma::col_major> weights_frag[N_BLOCKS];
+    wmma::fragment<wmma::accumulator, 16, 16, 16, OUT_T> result_frag;
+    // Indices
+    const uint32_t li = threadIdx.x; // index in warp ("lane index")
+    const uint32_t wi = threadIdx.y; // index in block ("warp index")
+    __half* __restrict__ weights_shmem = act_shmem + N_ITERS * 16 * (WIDTH + SKEW);
+    const uint32_t weights_row = (8 * li) % WIDTH;
+    const uint32_t weights_col = (8 * li + 8 * 32 * wi) / WIDTH;
+    // Load weight matrix into shared memory for the last multiplication.
+    // Loading into shared memory as opposed to directly into registers is faster
+    // because unlike in the previous layers, each warp uses the same entries of the weight matrix.
+    *(int4*)&weights_shmem[weights_row + weights_col * (WIDTH + SKEW)] = *(int4*)&weights_this_layer[weights_row + weights_col * WIDTH];
+    __syncthreads();
+    #pragma unroll
+    for (uint32_t i = 0; i < N_BLOCKS; ++i)
+        wmma::load_matrix_sync(weights_frag[i], weights_shmem + 16 * i, WIDTH + SKEW);
+    // Perform last layer by parallelizing over iters
+    for (uint32_t idx = wi; idx < N_ITERS; idx += N_BLOCKS) {
+        wmma::fill_fragment(result_frag, 0.0f);
+        #pragma unroll
+        for (uint32_t i = 0; i < N_BLOCKS; ++i) {
+            // Load a chunk of intermediate activations from shared memory and multiply with chunk of the weight matrix
+            wmma::load_matrix_sync(act_frag, act_shmem + 16 * i + (16 * idx) * (WIDTH + SKEW), WIDTH + SKEW);
+            wmma::mma_sync(result_frag, act_frag, weights_frag[i], result_frag);
+        }
+
+        warp_activation<__half>(activation, result_frag, result_frag);
+        if (output_layout == wmma::mem_row_major) {
+            wmma::store_matrix_sync(out + idx * 16 * output_stride, result_frag, output_stride, output_layout);
+        } else {
+            wmma::store_matrix_sync(out + idx * 16, result_frag, batch_size, output_layout);
+        }
+    }
+}
+template <int WIDTH, int N_ITERS, typename OUT_T, Activation ACTIVATION, bool INFERENCE>
+__global__ void kernel_mlp_fused(const Activation output_activation, const __half* __restrict__ input, const __half* __restrict__ weights, OUT_T* __restrict__ out_intermediate, OUT_T* __restrict__ out, const uint32_t batch_size, const uint32_t in_width, const uint32_t out_width, const uint32_t n_hidden_matmuls, const nvcuda::wmma::layout_t input_layout, const nvcuda::wmma::layout_t output_layout, const int output_stride) {
+    // `input` points to the input matrix. Can be any width.
+    // `weights` points to the weight matrices (contiguous in memory).
+    // `out_intermediate` points to the memory where intermediate activations should be written. When performing inference, a value of nullptr is expected (intermediate results are not written).
+    // `out` points to the memory where the network output should be written. (Output width is assumed to be 16 neurons.)
+    // Commented out due to isolated strange side-effects on Windows
+    // if (INFERENCE) {
+    // 	assert(out_intermediate == nullptr);
+    // } else {
+    // 	assert(out_intermediate);
+    // }
+    // Shared memory contains the intermediate activations of blockDim.y*16 elements.
+    // In some cases, it also contains the weight matrix for the first and last layer.
+    extern __shared__ __half shmem[];
+    __half* act_shmem = shmem;
+    // Each block computes exactly one 16-element chunk of the batch.
+    const uint32_t elem_idx = 16 * blockIdx.x * N_ITERS;
+    // First layer
+    if (input_layout == nvcuda::wmma::mem_col_major || in_width != WIDTH) {
+        if (input_layout == nvcuda::wmma::mem_row_major) {
+            threadblock_input_layer_forward_dynamic<WIDTH, N_ITERS, OUT_T, nvcuda::wmma::row_major>(ACTIVATION, act_shmem, input + elem_idx * in_width, weights, !INFERENCE ? (out_intermediate + elem_idx * WIDTH) : nullptr, in_width, batch_size);
+        } else {
+            threadblock_input_layer_forward_dynamic<WIDTH, N_ITERS, OUT_T, nvcuda::wmma::col_major>(ACTIVATION, act_shmem, input + elem_idx, weights, !INFERENCE ? (out_intermediate + elem_idx * WIDTH) : nullptr, in_width, batch_size);
+        }
+    } else {
+        // If the input has the same width & layout as the hidden layers, we can simply use the network's regular layer routine (with static size)
+        // instead of using the slower dynamic input layer routine.
+        threadblock_load_input_static<WIDTH, N_ITERS>(act_shmem, input + elem_idx * WIDTH);
+        threadblock_layer<WIDTH, N_ITERS, OUT_T>(ACTIVATION, act_shmem, weights, !INFERENCE ? (out_intermediate + elem_idx * WIDTH) : nullptr);
+    }
+    const uint32_t first_layer_size = WIDTH * in_width;
+    const uint32_t weights_stride = WIDTH * WIDTH;
+    const uint32_t layer_stride = WIDTH * batch_size;
+    // Hidden layers
+    for (uint32_t k = 0; k < n_hidden_matmuls; ++k) {
+        threadblock_layer<WIDTH, N_ITERS, OUT_T>(ACTIVATION, act_shmem, weights + first_layer_size + weights_stride * k, !INFERENCE ? (out_intermediate + layer_stride * (k + 1) + elem_idx * WIDTH) : nullptr);
+    }
+    // Last layer
+    if (output_layout == nvcuda::wmma::mem_row_major) {                
+        threadblock_last_layer_forward<WIDTH, N_ITERS, OUT_T>(output_activation, act_shmem, weights + first_layer_size + weights_stride * n_hidden_matmuls, out + elem_idx * output_stride, output_stride, output_layout, output_stride);
+    } else {
+        threadblock_last_layer_forward<WIDTH, N_ITERS, OUT_T>(output_activation, act_shmem, weights + first_layer_size + weights_stride * n_hidden_matmuls, out + elem_idx, batch_size, output_layout, output_stride);
+    }
+    
+}
+template <int WIDTH, typename T, Activation ACTIVATION, bool INFERENCE>
+void mlp_fused_forward(
+    cudaStream_t stream,
+    Activation output_activation,
+    T* weights,
+    T* input,
+    T* output_intermediate,
+    T* output,
+    const uint32_t n_hidden_layers,
+    int input_shape0,
+    int input_shape1,
+    int weights_shape0,
+    int weights_shape1,
+    int output_shape0,
+    int output_shape1
+) {
+    const uint32_t batch_size = input_shape0;
+    const uint32_t in_width = input_shape1;
+    constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0; // <- always going to be 8 as we only support multiple-of-16 widths
+    constexpr uint32_t INPUT_SKEW = 8; // <- likewise with inputs
+    constexpr uint32_t N_BLOCK_ROWS = WIDTH / 16;
+    // LOGir << batch_size << " " << in_width << " " << WIDTH << " " << weights_shape0 << " " <<weights_shape1;
+    // static_assert(WIDTH % 16 == 0, "Width must be a multiply of 16.");
+    // if (in_width % 16 != 0) {
+    // 	throw std::runtime_error{"Inputs must have a multiple-of-16 elements."};
+    // }
+    // if (weights.rows() != WIDTH) {
+    // 	throw std::runtime_error{"The fully fused forward pass only works with WIDTH-sized matrices."};
+    // }
+    // if (weights.cols() % 16 != 0) {
+    // 	throw std::runtime_error{std::string("weights must have a multiple-of-16 number of columns. ") + std::to_string(weights.cols())};
+    // }
+    // if (output_intermediate.cols() != batch_size) {
+    // 	throw std::runtime_error{"Batch size of inputs and output_intermediate doesn't match."};
+    // }
+    // if (output && output->cols() != batch_size) {
+    // 	throw std::runtime_error{"Batch size of inputs and outputs doesn't match."};
+    // }
+    const int N_ITERS = WIDTH >= 256 ? 2 : 8;
+    if (batch_size % (16 * N_ITERS) != 0) {
+        throw std::runtime_error{"Batch size must be a multiple of " + std::to_string(16 * N_ITERS) + "."};
+    }
+    const dim3 threads = { 32u, N_BLOCK_ROWS, 1 }; // 32 threads = 1 warp, N_BLOCK_ROWS warps per block for 16 rows, up to 2x 8 warps can share input (does not help vs. 1)
+    uint32_t n_elems_per_block = 16 * N_ITERS;
+    uint32_t n_blocks = div_round_up(batch_size, n_elems_per_block);
+    size_t shmem_size = sizeof(__half) * (16 + 16 * N_ITERS) * (WIDTH + SKEW); // 16*WIDTH rows of weights (for the last layer; others are in registers only) + 16*WIDTH*N_ITERS rows of intermediate activations
+    shmem_size = std::max(shmem_size, sizeof(__half) * (WIDTH + 16) * (in_width + INPUT_SKEW));
+    const dim3 blocks = { n_blocks, 1u, 1u };
+    check_shmem_error(cudaFuncSetAttribute(kernel_mlp_fused<WIDTH, N_ITERS, __half, ACTIVATION, INFERENCE>, cudaFuncAttributeMaxDynamicSharedMemorySize, (int)shmem_size));
+    kernel_mlp_fused<WIDTH, N_ITERS, __half, ACTIVATION, INFERENCE><<<blocks, threads, shmem_size, stream>>>(
+        output_activation,
+        input,
+        weights,
+        output_intermediate,
+        output ? output : nullptr,
+        batch_size,
+        in_width,
+        output ? output_shape1 : 0,
+        n_hidden_layers,
+        // The kernels operate with transposed layouts compared with the MLP code
+        nvcuda::wmma::mem_row_major,
+        nvcuda::wmma::mem_row_major,
+        output_shape1
+    );
+}
+
+
+void mlp_fused_forward_func(
+    int WIDTH,
+    Activation ACTIVATION,
+    bool INFERENCE,
+    cudaStream_t stream,
+    Activation output_activation,
+    __half* weights,
+    __half* input,
+    __half* output_intermediate,
+    __half* output,
+    const uint32_t n_hidden_layers,
+    int input_shape0,
+    int input_shape1,
+    int weights_shape0,
+    int weights_shape1,
+    int output_shape0,
+    int output_shape1
+) {
+
+    #define FORWARD_ARGS \
+     stream,\
+     output_activation,\
+     weights,\
+     input,\
+     output_intermediate,\
+     output,\
+     n_hidden_layers,\
+     input_shape0,\
+     input_shape1,\
+     weights_shape0,\
+     weights_shape1,\
+     output_shape0,\
+     output_shape1
+    if (WIDTH == 64 && ACTIVATION==Activation::ReLU && INFERENCE==false) {
+        mlp_fused_forward<64, __half, Activation::ReLU, false>(FORWARD_ARGS);
+    } 
+    else {
+        std::string msg = "mlp_fused_forward_func error not supported WIDTH=" + std::string("WIDTH") + " ACTIVATION=" + std::string("ACTIVATION") + " INFERENCE=" + std::string("INFERENCE") + ", please contact us to add this support.";
+    }
+}
+
+
+void mlp_fused_backward_func(
+    int WIDTH, 
+    Activation ACTIVATION,
+    cudaStream_t stream,
+    __half* weights_first_layer,
+    __half* weights,
+    __half* dL_doutput,
+    __half* temps,
+    __half* forward,
+    __half* dL_dinput,
+    const uint32_t n_hidden_matmuls,
+    int grad_shape0,
+    int grad_shape1,
+    int need_last
+) {
+    #define BACKWARD_ARGS \
+    stream, \
+    weights_first_layer, \
+    weights, \
+    dL_doutput, \
+    temps, \
+    forward, \
+    dL_dinput, \
+    n_hidden_matmuls, \
+    grad_shape0, \
+    grad_shape1, \
+    need_last
+    if (WIDTH == 64 && ACTIVATION==Activation::ReLU) {
+        mlp_fused_backward<64, __half, Activation::ReLU>(BACKWARD_ARGS);
+    } 
+    else {
+        std::string msg = "mlp_fused_forward_func error not supported WIDTH=" + std::string("WIDTH") + " ACTIVATION=" + std::string("ACTIVATION") + " INFERENCE=" + std::string("INFERENCE") + ", please contact us to add this support.";
+    }
+}
\ No newline at end of file
diff --git a/python/jnerf/ops/code_ops/global_vars.py b/python/jnerf/ops/code_ops/global_vars.py
index 9e1accb4..3b835247 100644
--- a/python/jnerf/ops/code_ops/global_vars.py
+++ b/python/jnerf/ops/code_ops/global_vars.py
@@ -1,27 +1,53 @@
 import os
+import sys
 import jittor as jt
 jt.flags.use_cuda = 1
 
-global_headers = """
+export_gloabl = ''
+import_global = 'extern'
+if sys.platform == "win32":
+    export_gloabl = '__declspec(dllexport)'
+    import_global = '__declspec(dllimport) extern'
+
+global_headers = f"""
 #include "pcg32.h"
-namespace jittor {
-extern int global_var1;
-extern pcg32 rng;
-}
+namespace jittor {{
+EXTERN_LIB int global_var1;
+EXTERN_LIB pcg32 rng;
+}}
 """
 
-global_src = """
-namespace jittor {
-int global_var1 = 123;
-pcg32 rng{1337};
-}
+global_decl_headers = f"""
+#include "pcg32.h"
+namespace jittor {{
+{export_gloabl} int global_var1;
+{export_gloabl} pcg32 rng;
+}}
+"""
+
+global_src = f"""
+#include "pcg32.h"
+namespace jittor {{
+{export_gloabl} int global_var1 = 123;
+{export_gloabl} pcg32 rng{{1337}};
+}}
 """
 
 proj_path = os.path.join(os.path.dirname(__file__), '..', 'op_include')
-proj_options = { f"FLAGS: -I{proj_path}/eigen -I{proj_path}/include -I{proj_path}/pcg32 -I{proj_path}/../op_header -DGLOBAL_VAR --extended-lambda --expt-relaxed-constexpr": 1 }
+if sys.platform == "linux":
+    proj_options = { f"FLAGS: -I{proj_path}/eigen -I{proj_path}/include -I{proj_path}/pcg32 -I{proj_path}/../op_header -DGLOBAL_VAR --extended-lambda --expt-relaxed-constexpr": 1 }
+else:
+    proj_options = { f'FLAGS: -I"{proj_path}/eigen" -I"{proj_path}/include" -I"{proj_path}/pcg32" -I"{proj_path}/../op_header" -DGLOBAL_VAR --extended-lambda --expt-relaxed-constexpr': 1 }
+
+jt.profiler.start()
 gv = jt.code([1], int, 
-    cuda_header=global_headers+global_src, 
+    cuda_header=global_src, 
     cuda_src="""
 """)
 gv.compile_options = proj_options
 gv.sync()
+jt.profiler.stop()
+
+if os.name == "nt":
+    dll_name = jt.profiler.report()[-1][-10].replace(".cc", "")
+    proj_options[f'FLAGS: -l{dll_name} '] = 1
\ No newline at end of file
diff --git a/python/jnerf/ops/code_ops/op_header/calc_rgb.o b/python/jnerf/ops/code_ops/op_header/calc_rgb.o
deleted file mode 100644
index f268a288..00000000
Binary files a/python/jnerf/ops/code_ops/op_header/calc_rgb.o and /dev/null differ
diff --git a/python/jnerf/ops/code_ops/op_header/fully_fused_mlp_function.o b/python/jnerf/ops/code_ops/op_header/fully_fused_mlp_function.o
deleted file mode 100644
index e02f6ff4..00000000
Binary files a/python/jnerf/ops/code_ops/op_header/fully_fused_mlp_function.o and /dev/null differ
diff --git a/requirements.txt b/requirements.txt
index b76a21fa..3fe959e5 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,4 +1,3 @@
-jittor>=1.3.4.13
 numpy
 tqdm
 opencv-python
@@ -9,3 +8,5 @@ PyMCubes
 trimesh
 plyfile
 open3d
+pyhocon
+icecream
\ No newline at end of file