Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to use image size other than 224 #2

Open
alqurri77 opened this issue Jan 24, 2024 · 4 comments
Open

How to use image size other than 224 #2

alqurri77 opened this issue Jan 24, 2024 · 4 comments

Comments

@alqurri77
Copy link

Hi
How to use image size other than 224

thanks

@BinYCn
Copy link
Owner

BinYCn commented Jan 25, 2024

Thank you for your question.

To use an image size other than 224, you can adjust the WINDOW_SIZE parameter in the configuration file located at
./configs/*.yaml. For example, set WINDOW_SIZE to 16 when the image size is 512.

@alqurri77
Copy link
Author

I did that, but the issue I'm getting this error:

my_weights torch.Size([256, 3, 256, 256])
error new_shape torch.Size([256, 256, 768])
  0%|                                         | 0/150 [00:05<?, ?it/s]
---------------------------------------------------------------------------
UnboundLocalError                         Traceback (most recent call last)
Input In [17], in <cell line: 116>()
    148 #net.load_from(config)
    151 trainer = {'Synapse': trainer_synapse,}
--> 152 trainer[dataset_name](args, net, args.output_dir)

Input In [16], in trainer_synapse(args, model, snapshot_path)
     54 image_batch, label_batch = image_batch.to(device), label_batch.to(device)
     55 #d1, d2, d3 = model(image_batch)
     56 
     57 
   (...)
     60 
     61 #loss_ce =  muti_bc_loss_fusion(d1, d2, d3,  label_batch ,ce_loss)
---> 63 ds1,ds2,ds3,d1, d2, d3,d4,d5,d6= model(image_batch)
     65 loss_dice =muti_dice_loss_fusion(ds1,ds2,ds3,d1, d2, d3,d4,d5,d6, label_batch,dice_loss)
     67 loss_ce =  muti_bc_loss_fusion  (ds1,ds2,ds3,d1, d2, d3,d4, d5,d6, label_batch ,ce_loss) #ce_loss(d4, y[:].long())

File /scratch/ahmed/lib/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /scratch/ahmed/lib/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

Input In [10], in SwinUnet.forward(self, x)
   1411 if x.size()[1] == 1:
   1412     x = x.repeat(1, 3, 1, 1)
-> 1413 logits = self.swin_unet(x)
   1414 return logits

File /scratch/ahmed/lib/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /scratch/ahmed/lib/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

Input In [10], in SwinTransformerSys.forward(self, x)
   1168 output4 = self.final4(x0_4_unet2)
   1170 #--------------------------------------------------------------------------------
-> 1174 x, x_downsample,x_downsample_my_weights = self.forward_features(x)
   1175 U_decoder_thing = []
   1177 #print("encoders_results[0] ",encoders_results[0].shape)
   1178 #print("encoders_results[1] ",encoders_results[1].shape)
   1179 #print("encoders_results[2] ",encoders_results[2].shape)
   1180 #print("encoders_results[3] ",encoders_results[3].shape)

Input In [10], in SwinTransformerSys.forward_features(self, x)
   1036 for layer in self.layers:
   1037     x_downsample.append(x)
-> 1039     x,my_weights = layer(x)
   1040     x_downsample_my_weights.append(my_weights)
   1043 x = self.norm(x)  # B L C

File /scratch/ahmed/lib/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /scratch/ahmed/lib/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

Input In [10], in BasicLayer.forward(self, x)
    612         x = checkpoint.checkpoint(blk, x)
    613     else:
--> 614         x,my_weights = blk(x)
    615 if self.downsample is not None:
    616     x = self.downsample(x)

File /scratch/ahmed/lib/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /scratch/ahmed/lib/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

Input In [10], in SwinTransformerBlock.forward(self, x)
    401 #elif new_shape2.shape[2]==576:
    402 #  new_shape3=self.chaneshape5(new_shape2)
    403 #elif new_shape2.shape[2]==1024:
    404 #  new_shape3=self.chaneshape6(new_shape2)
    405 else:
    406     print("error new_shape",new_shape2.shape)
--> 407 print("my_weights3:",new_shape3.shape)
    408 my_weights=new_shape3
    409 # merge windows

UnboundLocalError: local variable 'new_shape3' referenced before assignment

I think the problem with this code:

if new_shape2.shape[2]==147:
          new_shape3=self.chaneshape(new_shape2)
        elif new_shape2.shape[2]==294:
          new_shape3=self.chaneshape2(new_shape2)
        elif new_shape2.shape[2]==588:
          new_shape3=self.chaneshape3(new_shape2)
        elif new_shape2.shape[2]==1176:
          new_shape3=self.chaneshape4(new_shape2)
        
        else:
            print("error new_shape",new_shape2.shape)

for 512, the size will be 1024. but there is no 1024 option there.

@BinYCn
Copy link
Owner

BinYCn commented Jan 26, 2024

You can set WINDOW_SIZE to 32 when the image size is 1024.

@alqurri77
Copy link
Author

Thank you. I meant the code at "SwinTransformerBlock" class, its not working even if you set the size of the window correctly. It stuck on this code fragment:

new_shape2=new_shape.reshape(new_shape.shape[0], new_shape.shape[1], -1)
        
        if new_shape2.shape[2]==147:
          new_shape3=self.chaneshape(new_shape2)
        elif new_shape2.shape[2]==294:
          new_shape3=self.chaneshape2(new_shape2)
        elif new_shape2.shape[2]==588:
          new_shape3=self.chaneshape3(new_shape2)
        elif new_shape2.shape[2]==1176:
          new_shape3=self.chaneshape4(new_shape2)
        
        else:
            print("error new_shape",new_shape2.shape)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants