Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPU Inference with Layout Predictor Fails - Device Mismatch Error #55

Open
higuhigu-lb opened this issue Nov 25, 2024 · 3 comments
Open

Comments

@higuhigu-lb
Copy link

higuhigu-lb commented Nov 25, 2024

I'm attempting to run inference on a GPU using the layout predictor.
I have made the following changes to the code.

class LayoutPredictor:
    # ... (previous code) ...
    def __init__(self, ...):
        # ... (previous code) ...

        # Device selection
        self.device = torch.device("cuda" if torch.cuda.is_available() and not self._use_cpu_only else "cpu")
        print(f"Using device: {self.device}")  # Added for clarity
        self.model.to(self.device) #Move model to selected device

    def predict(self, orig_img: Union[Image.Image, np.ndarray]) -> Iterable[dict]:
        # ... (previous code) ...

        img = transforms(page_img)[None].to(self.device) #Move image to device
        orig_size = torch.tensor([w, h])[None].to(self.device) #Move size to device

        # Predict
        with torch.no_grad():
            labels, boxes, scores = self.model(img, orig_size)

        #Move tensors back to CPU for further processing
        labels = labels.cpu()
        boxes = boxes.cpu()
        scores = scores.cpu()

        # ... (rest of the predict function) ...

After making these changes, I encounter a device mismatch error.
The full error message is as follows:

Could you please help me resolve this device mismatch error?

``` Traceback (most recent call last):

File "/usr/local/lib/python3.10/dist-packages/docling/pipeline/base_pipeline.py", line 149, in _build_document
for p in pipeline_pages: # Must exhaust!

File "/usr/local/lib/python3.10/dist-packages/docling/pipeline/base_pipeline.py", line 116, in _apply_on_pages
yield from page_batch

File "/usr/local/lib/python3.10/dist-packages/docling/models/page_assemble_model.py", line 59, in call
for page in page_batch:

File "/usr/local/lib/python3.10/dist-packages/docling/models/table_structure_model.py", line 93, in call
for page in page_batch:

File "/usr/local/lib/python3.10/dist-packages/docling/models/layout_model.py", line 290, in call
for ix, pred_item in enumerate(

File "/usr/local/lib/python3.10/dist-packages/docling_ibm_models/layoutmodel/layout_predictor.py", line 153, in predict
labels, boxes, scores = self.model(img, orig_size)

File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)

File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)

RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript, serialized code (most recent call last):
File "code/torch.py", line 13, in forward
postprocessor = self.postprocessor
model = self.model
_0, _1, = (model).forward(images, )
~~~~~~~~~~~~~~ <--- HERE
_2 = (postprocessor).forward(_0, orig_target_sizes, _1, )
_3, _4, _5, = _2
File "code/torch/src/zoo/rtdetr/rtdetr.py", line 15, in forward
backbone = self.backbone
_0, _1, _2, = (backbone).forward(images, )
_3, _4, _5, = (encoder).forward(_0, _1, _2, )
~~~~~~~~~~~~~~~~ <--- HERE
_6, _7, = (decoder).forward(_3, _4, _5, )
return (_6, _7)
File "code/torch/src/zoo/rtdetr/hybrid_encoder.py", line 49, in forward
tensor = torch.permute(torch.flatten(_6, 2), [0, 2, 1])
pos_embed = torch.to(CONSTANTS.c9, dtype=6, layout=0, device=torch.device("cpu"))
_8 = torch.permute((_03).forward(tensor, pos_embed, ), [0, 2, 1])
~~~~~~~~~~~~ <--- HERE
_9 = torch.reshape(_8, [-1, 256, _3, _7])
input = torch.contiguous(_9)
File "code/torch/src/zoo/rtdetr/hybrid_encoder.py", line 74, in forward
layers = self.layers
_0 = getattr(layers, "0")
return (_0).forward(tensor, pos_embed, )
~~~~~~~~~~~ <--- HERE
class TransformerEncoderLayer(Module):
parameters = []
File "code/torch/src/zoo/rtdetr/hybrid_encoder.py", line 101, in forward
dropout1 = self.dropout1
self_attn = self.self_attn
x = torch.add(tensor, pos_embed)
~~~~~~~~~ <--- HERE
_11 = (dropout1).forward((self_attn).forward(x, tensor, ), )
input = torch.add(tensor, _11)

Traceback of TorchScript, original code (most recent call last):
/Users/ahn/gits/RT-DETR/rtdetr_pytorch/tools/../src/zoo/rtdetr/hybrid_encoder.py(141): with_pos_embed
/Users/ahn/gits/RT-DETR/rtdetr_pytorch/tools/../src/zoo/rtdetr/hybrid_encoder.py(147): forward
/Users/ahn/miniconda3/envs/docfm/lib/python3.9/site-packages/torch/nn/modules/module.py(1508): _slow_forward
/Users/ahn/miniconda3/envs/docfm/lib/python3.9/site-packages/torch/nn/modules/module.py(1527): _call_impl
/Users/ahn/miniconda3/envs/docfm/lib/python3.9/site-packages/torch/nn/modules/module.py(1518): _wrapped_call_impl
/Users/ahn/gits/RT-DETR/rtdetr_pytorch/tools/../src/zoo/rtdetr/hybrid_encoder.py(174): forward
/Users/ahn/miniconda3/envs/docfm/lib/python3.9/site-packages/torch/nn/modules/module.py(1508): _slow_forward
/Users/ahn/miniconda3/envs/docfm/lib/python3.9/site-packages/torch/nn/modules/module.py(1527): _call_impl
/Users/ahn/miniconda3/envs/docfm/lib/python3.9/site-packages/torch/nn/modules/module.py(1518): _wrapped_call_impl
/Users/ahn/gits/RT-DETR/rtdetr_pytorch/tools/../src/zoo/rtdetr/hybrid_encoder.py(299): forward
/Users/ahn/miniconda3/envs/docfm/lib/python3.9/site-packages/torch/nn/modules/module.py(1508): _slow_forward
/Users/ahn/miniconda3/envs/docfm/lib/python3.9/site-packages/torch/nn/modules/module.py(1527): _call_impl
/Users/ahn/miniconda3/envs/docfm/lib/python3.9/site-packages/torch/nn/modules/module.py(1518): _wrapped_call_impl
/Users/ahn/gits/RT-DETR/rtdetr_pytorch/tools/../src/zoo/rtdetr/rtdetr.py(34): forward
/Users/ahn/miniconda3/envs/docfm/lib/python3.9/site-packages/torch/nn/modules/module.py(1508): _slow_forward
/Users/ahn/miniconda3/envs/docfm/lib/python3.9/site-packages/torch/nn/modules/module.py(1527): _call_impl
/Users/ahn/miniconda3/envs/docfm/lib/python3.9/site-packages/torch/nn/modules/module.py(1518): _wrapped_call_impl
/Users/ahn/gits/RT-DETR/rtdetr_pytorch/tools/export_jit.py(42): forward
/Users/ahn/miniconda3/envs/docfm/lib/python3.9/site-packages/torch/nn/modules/module.py(1508): _slow_forward
/Users/ahn/miniconda3/envs/docfm/lib/python3.9/site-packages/torch/nn/modules/module.py(1527): _call_impl
/Users/ahn/miniconda3/envs/docfm/lib/python3.9/site-packages/torch/nn/modules/module.py(1518): _wrapped_call_impl
/Users/ahn/miniconda3/envs/docfm/lib/python3.9/site-packages/torch/jit/_trace.py(1065): trace_module
/Users/ahn/miniconda3/envs/docfm/lib/python3.9/site-packages/torch/jit/_trace.py(798): trace
/Users/ahn/gits/RT-DETR/rtdetr_pytorch/tools/export_jit.py(49): main
/Users/ahn/gits/RT-DETR/rtdetr_pytorch/tools/export_jit.py(65):
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!


RuntimeError Traceback (most recent call last)
in <cell line: 5>()
6 filename = os.path.basename(pdffile)
7 savename = filename.replace(".pdf", ".md")
----> 8 result = converter.convert(pdffile)
9 with open(os.path.join(savedir, savename), 'w') as f:
10 f.write(result.document.export_to_markdown())

17 frames
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1748
1749 result = None

RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript, serialized code (most recent call last):
File "code/torch.py", line 13, in forward
postprocessor = self.postprocessor
model = self.model
_0, _1, = (model).forward(images, )
~~~~~~~~~~~~~~ <--- HERE
_2 = (postprocessor).forward(_0, orig_target_sizes, _1, )
_3, _4, _5, = _2
File "code/torch/src/zoo/rtdetr/rtdetr.py", line 15, in forward
backbone = self.backbone
_0, _1, _2, = (backbone).forward(images, )
_3, _4, _5, = (encoder).forward(_0, _1, _2, )
~~~~~~~~~~~~~~~~ <--- HERE
_6, _7, = (decoder).forward(_3, _4, _5, )
return (_6, _7)
File "code/torch/src/zoo/rtdetr/hybrid_encoder.py", line 49, in forward
tensor = torch.permute(torch.flatten(_6, 2), [0, 2, 1])
pos_embed = torch.to(CONSTANTS.c9, dtype=6, layout=0, device=torch.device("cpu"))
_8 = torch.permute((_03).forward(tensor, pos_embed, ), [0, 2, 1])
~~~~~~~~~~~~ <--- HERE
_9 = torch.reshape(_8, [-1, 256, _3, _7])
input = torch.contiguous(_9)
File "code/torch/src/zoo/rtdetr/hybrid_encoder.py", line 74, in forward
layers = self.layers
_0 = getattr(layers, "0")
return (_0).forward(tensor, pos_embed, )
~~~~~~~~~~~ <--- HERE
class TransformerEncoderLayer(Module):
parameters = []
File "code/torch/src/zoo/rtdetr/hybrid_encoder.py", line 101, in forward
dropout1 = self.dropout1
self_attn = self.self_attn
x = torch.add(tensor, pos_embed)
~~~~~~~~~ <--- HERE
_11 = (dropout1).forward((self_attn).forward(x, tensor, ), )
input = torch.add(tensor, _11)

Traceback of TorchScript, original code (most recent call last):
/Users/ahn/gits/RT-DETR/rtdetr_pytorch/tools/../src/zoo/rtdetr/hybrid_encoder.py(141): with_pos_embed
/Users/ahn/gits/RT-DETR/rtdetr_pytorch/tools/../src/zoo/rtdetr/hybrid_encoder.py(147): forward
/Users/ahn/miniconda3/envs/docfm/lib/python3.9/site-packages/torch/nn/modules/module.py(1508): _slow_forward
/Users/ahn/miniconda3/envs/docfm/lib/python3.9/site-packages/torch/nn/modules/module.py(1527): _call_impl
/Users/ahn/miniconda3/envs/docfm/lib/python3.9/site-packages/torch/nn/modules/module.py(1518): _wrapped_call_impl
/Users/ahn/gits/RT-DETR/rtdetr_pytorch/tools/../src/zoo/rtdetr/hybrid_encoder.py(174): forward
/Users/ahn/miniconda3/envs/docfm/lib/python3.9/site-packages/torch/nn/modules/module.py(1508): _slow_forward
/Users/ahn/miniconda3/envs/docfm/lib/python3.9/site-packages/torch/nn/modules/module.py(1527): _call_impl
/Users/ahn/miniconda3/envs/docfm/lib/python3.9/site-packages/torch/nn/modules/module.py(1518): _wrapped_call_impl
/Users/ahn/gits/RT-DETR/rtdetr_pytorch/tools/../src/zoo/rtdetr/hybrid_encoder.py(299): forward
/Users/ahn/miniconda3/envs/docfm/lib/python3.9/site-packages/torch/nn/modules/module.py(1508): _slow_forward
/Users/ahn/miniconda3/envs/docfm/lib/python3.9/site-packages/torch/nn/modules/module.py(1527): _call_impl
/Users/ahn/miniconda3/envs/docfm/lib/python3.9/site-packages/torch/nn/modules/module.py(1518): _wrapped_call_impl
/Users/ahn/gits/RT-DETR/rtdetr_pytorch/tools/../src/zoo/rtdetr/rtdetr.py(34): forward
/Users/ahn/miniconda3/envs/docfm/lib/python3.9/site-packages/torch/nn/modules/module.py(1508): _slow_forward
/Users/ahn/miniconda3/envs/docfm/lib/python3.9/site-packages/torch/nn/modules/module.py(1527): _call_impl
/Users/ahn/miniconda3/envs/docfm/lib/python3.9/site-packages/torch/nn/modules/module.py(1518): _wrapped_call_impl
/Users/ahn/gits/RT-DETR/rtdetr_pytorch/tools/export_jit.py(42): forward
/Users/ahn/miniconda3/envs/docfm/lib/python3.9/site-packages/torch/nn/modules/module.py(1508): _slow_forward
/Users/ahn/miniconda3/envs/docfm/lib/python3.9/site-packages/torch/nn/modules/module.py(1527): _call_impl
/Users/ahn/miniconda3/envs/docfm/lib/python3.9/site-packages/torch/nn/modules/module.py(1518): _wrapped_call_impl
/Users/ahn/miniconda3/envs/docfm/lib/python3.9/site-packages/torch/jit/_trace.py(1065): trace_module
/Users/ahn/miniconda3/envs/docfm/lib/python3.9/site-packages/torch/jit/_trace.py(798): trace
/Users/ahn/gits/RT-DETR/rtdetr_pytorch/tools/export_jit.py(49): main
/Users/ahn/gits/RT-DETR/rtdetr_pytorch/tools/export_jit.py(65):
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

</details> 
@cau-git
Copy link
Contributor

cau-git commented Nov 25, 2024

We are actively working on a resolution for this, which will actually come with new features and we get rid of torch.jit model dumps, which causes these issues.

@higuhigu-lb
Copy link
Author

Thank you for providing valuable information.
It seems that the resolution is progressing with PR #50 .
I am looking forward to the new release.

@higuhigu-lb
Copy link
Author

This isn't critical.

I have noticed that the encoder and decoder blocks contain CPU-bound constant layers.

pos_embed = torch.to(CONSTANTS.c0, dtype=6, layout=0, device=torch.device("cpu"))

This suggestes it might be possible to optimize performance by placeing only the backbone layers on the GPU...

        with torch.no_grad():
            i0, i1, i2 = self.model.model.backbone(img)
            i3, i4, i5 = self.model.model.encoder(i0.to("cpu"), i1.to("cpu"), i2.to("cpu"))
            i6, i7 = self.model.model.decoder(i3, i4, i5)
            labels, boxes, scores = self.model.postprocessor(i6, orig_size, i7)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants