Skip to content

Commit

Permalink
Updated train/val split method to RandomSampler based method
Browse files Browse the repository at this point in the history
  • Loading branch information
Sulabh Kumra committed Jul 6, 2020
1 parent fa1d7da commit 6abee0d
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 92 deletions.
128 changes: 65 additions & 63 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import sys

import cv2
import numpy as np
import tensorboardX
import torch
import torch.optim as optim
Expand Down Expand Up @@ -37,26 +38,27 @@ def parse_args():
parser.add_argument('--channel-size', type=int, default=32,
help='Internal channel size for the network')

# Dataset & Data & Training
# Datasets
parser.add_argument('--dataset', type=str,
help='Dataset Name ("cornell" or "jaquard")')
parser.add_argument('--dataset-path', type=str,
help='Path to dataset')
parser.add_argument('--split', type=float, default=0.9,
help='Fraction of data for training (remainder is validation)')
parser.add_argument('--ds-shuffle', action='store_true', default=True,
help='Shuffle the dataset')
parser.add_argument('--ds-rotate', type=float, default=0.0,
help='Shift the start point of the dataset to use a different test/train split'
'for cross validation.')
help='Shift the start point of the dataset to use a different test/train split')
parser.add_argument('--num-workers', type=int, default=8,
help='Dataset workers')

# Training
parser.add_argument('--batch-size', type=int, default=8,
help='Batch size')
parser.add_argument('--epochs', type=int, default=30,
parser.add_argument('--epochs', type=int, default=50,
help='Training epochs')
parser.add_argument('--batches-per-epoch', type=int, default=1000,
help='Batches per Epoch')
parser.add_argument('--val-batches', type=int, default=250,
help='Validation Batches')
parser.add_argument('--optim', type=str, default='adam',
help='Optmizer for the training. (adam or SGD)')

Expand All @@ -69,18 +71,19 @@ def parse_args():
help='Visualise the training process')
parser.add_argument('--cpu', dest='force_cpu', action='store_true', default=False,
help='Force code to run in CPU mode')
parser.add_argument('--random-seed', type=int, default=123,
help='Random seed for numpy')

args = parser.parse_args()
return args


def validate(net, device, val_data, batches_per_epoch):
def validate(net, device, val_data):
"""
Run validation.
:param net: Network
:param device: Torch device
:param val_data: Validation Dataset
:param batches_per_epoch: Number of batches to run
:return: Successes, Failures and Losses
"""
net.eval()
Expand All @@ -97,38 +100,32 @@ def validate(net, device, val_data, batches_per_epoch):
ld = len(val_data)

with torch.no_grad():
batch_idx = 0
while batch_idx < batches_per_epoch:
for x, y, didx, rot, zoom_factor in val_data:
batch_idx += 1
if batches_per_epoch is not None and batch_idx >= batches_per_epoch:
break

xc = x.to(device)
yc = [yy.to(device) for yy in y]
lossd = net.compute_loss(xc, yc)

loss = lossd['loss']

results['loss'] += loss.item() / ld
for ln, l in lossd['losses'].items():
if ln not in results['losses']:
results['losses'][ln] = 0
results['losses'][ln] += l.item() / ld

q_out, ang_out, w_out = post_process_output(lossd['pred']['pos'], lossd['pred']['cos'],
lossd['pred']['sin'], lossd['pred']['width'])

s = evaluation.calculate_iou_match(q_out, ang_out,
val_data.dataset.get_gtbb(didx, rot, zoom_factor),
no_grasps=1,
grasp_width=w_out,
)

if s:
results['correct'] += 1
else:
results['failed'] += 1
for x, y, didx, rot, zoom_factor in val_data:
xc = x.to(device)
yc = [yy.to(device) for yy in y]
lossd = net.compute_loss(xc, yc)

loss = lossd['loss']

results['loss'] += loss.item() / ld
for ln, l in lossd['losses'].items():
if ln not in results['losses']:
results['losses'][ln] = 0
results['losses'][ln] += l.item() / ld

q_out, ang_out, w_out = post_process_output(lossd['pred']['pos'], lossd['pred']['cos'],
lossd['pred']['sin'], lossd['pred']['width'])

s = evaluation.calculate_iou_match(q_out, ang_out,
val_data.dataset.get_gtbb(didx, rot, zoom_factor),
no_grasps=1,
grasp_width=w_out,
)

if s:
results['correct'] += 1
else:
results['failed'] += 1

return results

Expand Down Expand Up @@ -242,34 +239,39 @@ def run():
# Load Dataset
logging.info('Loading {} Dataset...'.format(args.dataset.title()))
Dataset = get_dataset(args.dataset)
dataset = Dataset(args.dataset_path,
ds_rotate=args.ds_rotate,
random_rotate=True,
random_zoom=True,
include_depth=args.use_depth,
include_rgb=args.use_rgb)
logging.info('Dataset size is {}'.format(dataset.length))

# Creating data indices for training and validation splits
indices = list(range(dataset.length))
split = int(np.floor(args.split * dataset.length))
if args.ds_shuffle:
np.random.seed(args.random_seed)
np.random.shuffle(indices)
train_indices, val_indices = indices[:split], indices[split:]
logging.info('Training size: {}'.format(len(train_indices)))
logging.info('Validation size: {}'.format(len(val_indices)))

# Creating data samplers and loaders
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indices)
val_sampler = torch.utils.data.sampler.SubsetRandomSampler(val_indices)

train_dataset = Dataset(args.dataset_path,
start=0.0,
end=args.split,
ds_rotate=args.ds_rotate,
random_rotate=True,
random_zoom=True,
include_depth=args.use_depth,
include_rgb=args.use_rgb)
train_data = torch.utils.data.DataLoader(
train_dataset,
dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers
num_workers=args.num_workers,
sampler=train_sampler
)
val_dataset = Dataset(args.dataset_path,
start=args.split,
end=1.0,
ds_rotate=args.ds_rotate,
random_rotate=True,
random_zoom=True,
include_depth=args.use_depth,
include_rgb=args.use_rgb)
val_data = torch.utils.data.DataLoader(
val_dataset,
dataset,
batch_size=1,
shuffle=False,
num_workers=args.num_workers
num_workers=args.num_workers,
sampler=val_sampler
)
logging.info('Done')

Expand Down Expand Up @@ -314,7 +316,7 @@ def run():

# Run Validation
logging.info('Validating...')
test_results = validate(net, device, val_data, args.val_batches)
test_results = validate(net, device, val_data)
logging.info('%d/%d = %f' % (test_results['correct'], test_results['correct'] + test_results['failed'],
test_results['correct'] / (test_results['correct'] + test_results['failed'])))

Expand Down
2 changes: 1 addition & 1 deletion utils/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ def get_dataset(dataset_name):
from .jacquard_data import JacquardDataset
return JacquardDataset
else:
raise NotImplementedError('Dataset Type {} is Not implemented'.format(dataset_name))
raise NotImplementedError('Dataset Type {} is Not implemented'.format(dataset_name))
24 changes: 10 additions & 14 deletions utils/data/cornell_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,31 +10,27 @@ class CornellDataset(GraspDatasetBase):
Dataset wrapper for the Cornell dataset.
"""

def __init__(self, file_path, start=0.0, end=1.0, ds_rotate=0, **kwargs):
def __init__(self, file_path, ds_rotate=0, **kwargs):
"""
:param file_path: Cornell Dataset directory.
:param start: If splitting the dataset, start at this fraction [0,1]
:param end: If splitting the dataset, finish at this fraction
:param ds_rotate: If splitting the dataset, rotate the list of items by this fraction first
:param kwargs: kwargs for GraspDatasetBase
"""
super(CornellDataset, self).__init__(**kwargs)

graspf = glob.glob(os.path.join(file_path, '*', 'pcd*cpos.txt'))
graspf.sort()
l = len(graspf)
if l == 0:
self.grasp_files = glob.glob(os.path.join(file_path, '*', 'pcd*cpos.txt'))
self.grasp_files.sort()
self.length = len(self.grasp_files)

if self.length == 0:
raise FileNotFoundError('No dataset files found. Check path: {}'.format(file_path))

if ds_rotate:
graspf = graspf[int(l * ds_rotate):] + graspf[:int(l * ds_rotate)]

depthf = [f.replace('cpos.txt', 'd.tiff') for f in graspf]
rgbf = [f.replace('d.tiff', 'r.png') for f in depthf]
self.grasp_files = self.grasp_files[int(self.length * ds_rotate):] + self.grasp_files[
:int(self.length * ds_rotate)]

self.grasp_files = graspf[int(l * start):int(l * end)]
self.depth_files = depthf[int(l * start):int(l * end)]
self.rgb_files = rgbf[int(l * start):int(l * end)]
self.depth_files = [f.replace('cpos.txt', 'd.tiff') for f in self.grasp_files]
self.rgb_files = [f.replace('d.tiff', 'r.png') for f in self.depth_files]

def _get_crop_attrs(self, idx):
gtbbs = grasp.GraspRectangles.load_from_cornell_file(self.grasp_files[idx])
Expand Down
23 changes: 9 additions & 14 deletions utils/data/jacquard_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,32 +10,27 @@ class JacquardDataset(GraspDatasetBase):
Dataset wrapper for the Jacquard dataset.
"""

def __init__(self, file_path, start=0.0, end=1.0, ds_rotate=0, **kwargs):
def __init__(self, file_path, ds_rotate=0, **kwargs):
"""
:param file_path: Jacquard Dataset directory.
:param start: If splitting the dataset, start at this fraction [0,1]
:param end: If splitting the dataset, finish at this fraction
:param ds_rotate: If splitting the dataset, rotate the list of items by this fraction first
:param kwargs: kwargs for GraspDatasetBase
"""
super(JacquardDataset, self).__init__(**kwargs)

graspf = glob.glob(os.path.join(file_path, '*', '*_grasps.txt'))
graspf.sort()
l = len(graspf)
self.grasp_files = glob.glob(os.path.join(file_path, '*', '*_grasps.txt'))
self.grasp_files.sort()
self.length = len(self.grasp_files)

if l == 0:
if self.length == 0:
raise FileNotFoundError('No dataset files found. Check path: {}'.format(file_path))

if ds_rotate:
graspf = graspf[int(l * ds_rotate):] + graspf[:int(l * ds_rotate)]
self.grasp_files = self.grasp_files[int(self.length * ds_rotate):] + self.grasp_files[
:int(self.length * ds_rotate)]

depthf = [f.replace('grasps.txt', 'perfect_depth.tiff') for f in graspf]
rgbf = [f.replace('perfect_depth.tiff', 'RGB.png') for f in depthf]

self.grasp_files = graspf[int(l * start):int(l * end)]
self.depth_files = depthf[int(l * start):int(l * end)]
self.rgb_files = rgbf[int(l * start):int(l * end)]
self.depth_files = [f.replace('grasps.txt', 'perfect_depth.tiff') for f in self.grasp_files]
self.rgb_files = [f.replace('perfect_depth.tiff', 'RGB.png') for f in self.depth_files]

def get_gtbb(self, idx, rot=0, zoom=1.0):
gtbbs = grasp.GraspRectangles.load_from_jacquard_file(self.grasp_files[idx], scale=self.output_size / 1024.0)
Expand Down

0 comments on commit 6abee0d

Please sign in to comment.