Skip to content

Commit

Permalink
Support for cpu as compute device
Browse files Browse the repository at this point in the history
  • Loading branch information
Sulabh Kumra committed Jun 8, 2020
1 parent 3c75ebf commit c2db619
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 7 deletions.
6 changes: 5 additions & 1 deletion evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch.utils.data

from hardware.device import get_device
from inference.post_process import post_process_output
from utils.data import get_dataset
from utils.dataset_processing import evaluation, grasp
Expand Down Expand Up @@ -32,6 +33,7 @@ def parse_args():
parser.add_argument('--iou-eval', action='store_true', help='Compute success based on IoU metric.')
parser.add_argument('--jacquard-output', action='store_true', help='Jacquard-dataset style output')
parser.add_argument('--vis', action='store_true', help='Visualise the network output')
parser.add_argument('--cpu', dest='force_cpu', action='store_true', default=False, help='force code to run in CPU mode')

args = parser.parse_args()

Expand All @@ -46,9 +48,11 @@ def parse_args():
if __name__ == '__main__':
args = parse_args()

# Get the compute device
device = get_device(args.force_cpu)

# Load Network
net = torch.load(args.network)
device = torch.device("cuda:0")

# Load Dataset
logging.info('Loading {} Dataset...'.format(args.dataset.title()))
Expand Down
19 changes: 19 additions & 0 deletions hardware/device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import logging

import torch

logging.basicConfig(level=logging.INFO)


def get_device(force_cpu):
# Check if CUDA can be used
if torch.cuda.is_available() and not force_cpu:
logging.info("CUDA detected. Running with GPU acceleration.")
device = torch.device("cuda")
elif force_cpu:
logging.info("CUDA detected, but overriding with option '--cpu'. Running with only CPU.")
device = torch.device("cpu")
else:
logging.info("CUDA is *NOT* detected. Running with only CPU.")
device = torch.device("cpu")
return device
4 changes: 3 additions & 1 deletion inference/grasp_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch

from hardware.camera import RealSenseCamera
from hardware.device import get_device
from inference.post_process import post_process_output
from utils.data.camera_data import CameraData
from utils.dataset_processing.grasp import detect_grasps
Expand Down Expand Up @@ -43,7 +44,8 @@ def __init__(self, saved_model_path, cam_id, visualize=False):
def load_model(self):
print('Loading model... ')
self.model = torch.load(self.saved_model_path)
self.device = torch.device("cuda:0")
# Get the compute device
self.device = get_device(force_cpu=False)

def generate(self):
# Get RGB-D image from camera
Expand Down
9 changes: 7 additions & 2 deletions run_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
import torch.utils.data
from PIL import Image

from hardware.device import get_device
from inference.post_process import post_process_output
from utils.data.camera_data import CameraData
from utils.visualisation.plot import plot_results, save_results

logging.basicConfig(level=logging.INFO)


Expand All @@ -17,10 +19,11 @@ def parse_args():
parser.add_argument('--network', type=str, default='cornell_rgbd_iou_0.95', help='Path to saved network to evaluate')
parser.add_argument('--rgb_path', type=str, default='cornell/08/pcd0845r.png', help='RGB Image path')
parser.add_argument('--depth_path', type=str, default='cornell/08/pcd0845d.tiff', help='Depth Image path')
parser.add_argument('--use-depth', type=int, default=1, help='Use Depth image for evaluation (0/1)')
parser.add_argument('--use-depth', type=int, default=1, help='Use Depth image for evaluation (1/0)')
parser.add_argument('--use-rgb', type=int, default=1, help='Use RGB image for evaluation (1/0)')
parser.add_argument('--n-grasps', type=int, default=1, help='Number of grasps to consider per image')
parser.add_argument('--save', type=int, default=0, help='Save the results')
parser.add_argument('--cpu', dest='force_cpu', action='store_true', default=False, help='force code to run in CPU mode')

args = parser.parse_args()
return args
Expand All @@ -39,9 +42,11 @@ def parse_args():
# Load Network
logging.info('Loading model...')
net = torch.load(args.network)
device = torch.device("cuda:0")
logging.info('Done')

# Get the compute device
device = get_device(args.force_cpu)

img_data = CameraData(include_depth=args.use_depth, include_rgb=args.use_rgb)

x, depth_img, rgb_img = img_data.get_data(rgb=rgb, depth=depth)
Expand Down
7 changes: 5 additions & 2 deletions run_realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch.utils.data

from hardware.camera import RealSenseCamera
from hardware.device import get_device
from inference.post_process import post_process_output
from utils.data.camera_data import CameraData
from utils.visualisation.plot import save_results, plot_results
Expand All @@ -19,6 +20,7 @@ def parse_args():
parser.add_argument('--use-depth', type=int, default=1, help='Use Depth image for evaluation (1/0)')
parser.add_argument('--use-rgb', type=int, default=1, help='Use RGB image for evaluation (1/0)')
parser.add_argument('--n-grasps', type=int, default=1, help='Number of grasps to consider per image')
parser.add_argument('--cpu', dest='force_cpu', action='store_true', default=False, help='force code to run in CPU mode')

args = parser.parse_args()
return args
Expand All @@ -36,10 +38,11 @@ def parse_args():
# Load Network
logging.info('Loading model...')
net = torch.load(args.network)
device = torch.device("cuda:0")

logging.info('Done')

# Get the compute device
device = get_device(args.force_cpu)

try:
fig = plt.figure(figsize=(10, 10))
while True:
Expand Down
5 changes: 4 additions & 1 deletion train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torchsummary import summary

from inference.model import GenerativeResnet
from hardware.device import get_device
from inference.post_process import post_process_output
from utils.data import get_dataset
from utils.dataset_processing import evaluation
Expand Down Expand Up @@ -43,6 +44,7 @@ def parse_args():
parser.add_argument('--outdir', type=str, default='output/models/', help='Training Output Directory')
parser.add_argument('--logdir', type=str, default='tensorboard/', help='Log directory')
parser.add_argument('--vis', action='store_true', help='Visualise the training process')
parser.add_argument('--cpu', dest='force_cpu', action='store_true', default=False, help='force code to run in CPU mode')

args = parser.parse_args()
return args
Expand Down Expand Up @@ -184,6 +186,8 @@ def run():
if not os.path.exists(save_folder):
os.makedirs(save_folder)
tb = tensorboardX.SummaryWriter(os.path.join(args.logdir, net_desc))
# Get the compute device
device = get_device(args.force_cpu)

# Load Dataset
logging.info('Loading {} Dataset...'.format(args.dataset.title()))
Expand Down Expand Up @@ -224,7 +228,6 @@ def run():
input_channels = 1*args.use_depth + 3*args.use_rgb

net = GenerativeResnet(input_channels=input_channels)
device = torch.device("cuda:0")
net = net.to(device)
optimizer = optim.Adam(net.parameters())
logging.info('Done')
Expand Down

0 comments on commit c2db619

Please sign in to comment.