You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In other programs, the data used is 3D image data. If I want to use class MambaLayer(nn.Module): part of the code, is it necessary to divide the original data into patches first
#52
In other programs, the data used is 3D image data. If I want to use class MambaLayer(nn.Module): part of the code, is it necessary to divide the original data into patches first
class MambaLayer(nn.Module):
def init(self, dim, d_state = 16, d_conv = 4, expand = 2):
super().init()
self.dim = dim
self.norm = nn.LayerNorm(dim)
self.mamba = Mamba(
d_model=dim, # Model dimension d_model
d_state=d_state, # SSM state expansion factor
d_conv=d_conv, # Local convolution width
expand=expand, # Block expansion factor
)
@autocast(enabled=False)
def forward(self, x):
if x.dtype == torch.float16:
x = x.type(torch.float32)
B, C = x.shape[:2]
assert C == self.dim
n_tokens = x.shape[2:].numel()
img_dims = x.shape[2:]
x_flat = x.reshape(B, C, n_tokens).transpose(-1, -2)
x_norm = self.norm(x_flat)
x_mamba = self.mamba(x_norm)
out = x_mamba.transpose(-1, -2).reshape(B, C, *img_dims)
return out
The text was updated successfully, but these errors were encountered:
In other programs, the data used is 3D image data. If I want to use class MambaLayer(nn.Module): part of the code, is it necessary to divide the original data into patches first
class MambaLayer(nn.Module):
def init(self, dim, d_state = 16, d_conv = 4, expand = 2):
super().init()
self.dim = dim
self.norm = nn.LayerNorm(dim)
self.mamba = Mamba(
d_model=dim, # Model dimension d_model
d_state=d_state, # SSM state expansion factor
d_conv=d_conv, # Local convolution width
expand=expand, # Block expansion factor
)
The text was updated successfully, but these errors were encountered: