Skip to content

Commit

Permalink
Bugfix setting of device by defaulting to "cpu" (#1182)
Browse files Browse the repository at this point in the history
* Defaulting the device to cpu in case gpu is not available and use_gpu is set to True

Co-authored-by: C V Goudar <[email protected]>
  • Loading branch information
cvgoudar and C V Goudar authored Jun 16, 2021
1 parent 6cd4910 commit f9c4083
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 17 deletions.
11 changes: 3 additions & 8 deletions haystack/generator/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from haystack import Document
from haystack.generator.base import BaseGenerator
from haystack.retriever.dense import DensePassageRetriever
from haystack.utils import get_device

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -118,10 +119,7 @@ def __init__(

self.top_k = top_k

if use_gpu and torch.cuda.is_available():
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")
self.device = get_device(use_gpu)

self.tokenizer = RagTokenizer.from_pretrained(model_name_or_path)

Expand Down Expand Up @@ -379,10 +377,7 @@ def __init__(

self.top_k = top_k

if use_gpu and torch.cuda.is_available():
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")
self.device = get_device(use_gpu)

Seq2SeqGenerator._register_converters(model_name_or_path, input_converter)

Expand Down
11 changes: 3 additions & 8 deletions haystack/retriever/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from haystack.document_store.base import BaseDocumentStore
from haystack import Document
from haystack.retriever.base import BaseRetriever
from haystack.utils import get_device

from farm.infer import Inferencer
from farm.modeling.tokenization import Tokenizer
Expand Down Expand Up @@ -120,10 +121,7 @@ def __init__(self,
"We recommend you use dot_product instead. "
"This can be set when initializing the DocumentStore")

if use_gpu and torch.cuda.is_available():
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")
self.device = get_device(use_gpu)

self.infer_tokenizer_classes = infer_tokenizer_classes
tokenizers_default_classes = {
Expand Down Expand Up @@ -630,10 +628,7 @@ def __init__(
"For details see https://github.com/UKPLab/sentence-transformers ")
# pretrained embedding models coming from: https://github.com/UKPLab/sentence-transformers#pretrained-models
# e.g. 'roberta-base-nli-stsb-mean-tokens'
if retriever.use_gpu:
device = "cuda"
else:
device = "cpu"
device = get_device(retriever.use_gpu)
self.embedding_model = SentenceTransformer(retriever.embedding_model, device=device)
self.show_progress_bar = retriever.progress_bar
document_store = retriever.document_store
Expand Down
9 changes: 8 additions & 1 deletion haystack/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from haystack.document_store.sql import DocumentORM
import subprocess
import time
import torch


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -198,4 +199,10 @@ def get_batches_from_generator(iterable, n):
x = tuple(islice(it, n))
while x:
yield x
x = tuple(islice(it, n))
x = tuple(islice(it, n))


def get_device(use_gpu: bool = True) -> str:
if use_gpu and torch.cuda.is_available():
return "cuda"
return "cpu"

0 comments on commit f9c4083

Please sign in to comment.