Skip to content

Commit

Permalink
clean stuff + remove redundant arg batch_size (issue #9)
Browse files Browse the repository at this point in the history
  • Loading branch information
aRI0U committed Jul 27, 2020
1 parent 3eae0d5 commit 057035a
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 36 deletions.
18 changes: 14 additions & 4 deletions data.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,10 @@ def __len__(self):

class ActiveLearningSampler(IterableDataset):

def __init__(self, dataset, split='training'):
def __init__(self, dataset, batch_size=6, split='training'):
self.dataset = dataset
self.split = split
self.batch_size = batch_size
self.possibility = {}
self.min_possibility = {}

Expand All @@ -163,7 +164,7 @@ def __len__(self):
def spatially_regular_gen(self):
# Choosing the least known point as center of a new cloud each time.

for i in range(self.n_samples * cfg.batch_size): # num_per_epoch
for i in range(self.n_samples * self.batch_size): # num_per_epoch
# t0 = time.time()
if cfg.sampling_type=='active_learning':
# Generator loop
Expand Down Expand Up @@ -229,8 +230,17 @@ def spatially_regular_gen(self):
def data_loaders(dir, sampling_method='active_learning', **kwargs):
if sampling_method == 'active_learning':
dataset = CloudsDataset(dir / 'train')
val_sampler = ActiveLearningSampler(dataset, split='validation')
train_sampler = ActiveLearningSampler(dataset, split='training')
batch_size = kwargs.get('batch_size', 6)
val_sampler = ActiveLearningSampler(
dataset,
batch_size=batch_size,
split='validation'
)
train_sampler = ActiveLearningSampler(
dataset,
batch_size=batch_size,
split='training'
)
return DataLoader(train_sampler, **kwargs), DataLoader(val_sampler, **kwargs)

if sampling_method == 'naive':
Expand Down
18 changes: 1 addition & 17 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,18 +68,10 @@ def train(args):

print('Computing weights...', end='\t')
samples_per_class = np.array(cfg.class_weights)
# weight = samples_per_class / float(sum(samples_per_class))
# class_weights = 1 / (weight + 0.02)
# effective = 1.0 - np.power(0.99, samples_per_class)
# class_weights = (1 - 0.99) / effective
# class_weights = class_weights / (np.sum(class_weights) * num_classes)
# class_weights = class_weights / float(sum(class_weights))
# weights = torch.tensor(class_weights).float().to(args.gpu)

n_samples = torch.tensor(cfg.class_weights, dtype=torch.float, device=args.gpu)
ratio_samples = n_samples / n_samples.sum()
weights = 1 / (ratio_samples + 0.02)
#weights = F.softmin(n_samples)
# weights = (1/ratio_samples) / (1/ratio_samples).sum()

print('Done.')
print('Weights:', weights)
Expand Down Expand Up @@ -160,14 +152,6 @@ def train(args):
} for iou, val_iou in zip(ious, val_ious)
]

# acc_dicts = [
# {
# f'{i:02d}_train_acc': acc,
# f'{}': val_acc
# }
# for i, (acc, val_accs) in enumerate(zip(accs, val_accs))
# ]

t1 = time.time()
d = t1 - t0
# Display results
Expand Down
16 changes: 1 addition & 15 deletions utils/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,11 @@

class Config:
num_points = 40960 # Number of input points
# num_classes = 14 # Number of valid classes
sub_grid_size = 0.04 # preprocess_parameter

batch_size = 6
train_steps = 200 # Number of steps per epochs
val_steps = 100 # Number of validation steps per epoch
#
# sub_sampling_ratio = [4, 4, 4, 4, 2] # sampling ratio of random sampling at each layer
# d_out = [16, 64, 128, 256, 512] # feature dimension
#
# noise_init = 3.5 # noise initial parameter
# max_epoch = 100 # maximum epoch during training
# learning_rate = 1e-2 # initial learning rate
# lr_decays = {i: 0.95 for i in range(0, 500)} # decay rate of learning rate
#
# train_sum_dir = 'train_log'
# saving = True
# saving_path = None
#

sampling_type = 'active_learning'
class_weights = [1938651, 1242339, 608870, 1699694, 2794560, 195000, 115990, 549838, 531470, 292971, 196633, 59032, 209046, 39321]

Expand Down

0 comments on commit 057035a

Please sign in to comment.