Skip to content

Commit

Permalink
fixed some bugs in workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
kirtyvedula committed Apr 9, 2020
1 parent 1a5f3db commit c7b3ed4
Show file tree
Hide file tree
Showing 14 changed files with 78 additions and 196 deletions.
Binary file removed ae_mfbank_AWGN_BPSK_74.mat
Binary file not shown.
32 changes: 26 additions & 6 deletions awgn_train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,32 @@

from models import FC_Autoencoder
from tools import EarlyStopping
from utils import generate_encoded_sym_dict, get_args
from utils import generate_encoded_sym_dict
from datasets import prepare_data
from trainer import train, validate, test
from get_args import get_args

def awgn_train(trainloader, valloader, val_set_size, epochs, net, optimizer, early_stopping, loss_func, device, loss_vec, batch_size, EbN0_dB_train, args, log_writer_train,log_writer_val):
def awgn_train(trainloader, valloader, val_set_size, device, args):

# Define loggers
log_writer_train = SummaryWriter('logs/train/')
log_writer_val = SummaryWriter('logs/val/')

# Setup the model and move it to GPU
net = FC_Autoencoder(args.k, args.n_channel)
net = net.to(device)

optimizer = torch.optim.Adam(net.parameters(), lr=args.learning_rate) # optimize all network parameters
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.01) # Decay LR by a factor of 0.1 every 7 epochs
loss_func = nn.CrossEntropyLoss() # the target label is not one-hotted
patience = 10 # early stopping patience; how long to wait after last time validation loss improved.
early_stopping = EarlyStopping(patience=patience, verbose=True) # initialize the early_stopping object
loss_vec = []

start = time.time()
for epoch in range(epochs):
train_epoch_loss, train_epoch_acc = train(trainloader, net, optimizer, loss_func, device, loss_vec, batch_size, EbN0_dB_train, args)
val_loss, val_accuracy = validate(net,valloader,loss_func, val_set_size, device, EbN0_dB_train, args)
for epoch in range(args.epochs):
train_epoch_loss, train_epoch_acc = train(trainloader, net, optimizer, loss_func, device, loss_vec, args.batch_size, args.EbN0_dB_train, args)
val_loss, val_accuracy = validate(net,valloader,loss_func, val_set_size, device, args.EbN0_dB_train, args)
print('Epoch: ', epoch + 1, '| train loss: %.4f' % train_epoch_loss, '| train acc: %4f' % (train_epoch_acc*100),'%','| val loss: %.4f' % val_loss, '| val acc: %4f' % (val_accuracy*100),'%')
log_writer_train.add_scalar('Train/Loss', train_epoch_loss, epoch)
log_writer_train.add_scalar('Train/Accuracy', train_epoch_acc, epoch)
Expand All @@ -36,7 +52,11 @@ def awgn_train(trainloader, valloader, val_set_size, epochs, net, optimizer, ear
torch.save(net.state_dict(), 'trained_net_74AE.ckpt') # Save trained net
generate_encoded_sym_dict(args.n_channel, args.k, net, device) # Generate encoded symbols

def awgn_test(testloader, net, device, EbN0_test, test_BLER, args):
return net

def awgn_test(testloader, net, device, args):
EbN0_test = torch.arange(args.EbN0dB_test_start, args.EbN0dB_test_end, args.EbN0dB_precision) # Test parameters
test_BLER = torch.zeros((len(EbN0_test), 1))
for p in range(len(EbN0_test)):
test_BLER[p] = test(net, args, testloader, device, EbN0_test[p])
print('Eb/N0:', EbN0_test[p].numpy(), '| test BLER: %.4f' % test_BLER[p])
33 changes: 16 additions & 17 deletions channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,19 @@ def bgin(x, EbN0_dB_1, EbN0_dB_2, R, prob, device):
x += noise
return x

# def interference( x, noise_shape, amp, omega, phase, type):
# interference = torch.zeros(np.shape(noise_shape))
# indices = torch.transpose(np.tile(np.arange(np.size(noise_shape, 0)), (np.size(noise_shape, 1), 1)))
# if type == 'sin':
# for i in range(np.size(indices, 1)):
# interference[:, i] = amp * np.sin(omega * indices[:, i] + phase[:, i])
# elif type == 'bpsk':
# random_seq = np.random.randint(low=0, high=1, size=np.shape(noise_shape))
# constellation = amp * 2 * (random_seq - 0.5) # bpsk
#
# indices = np.transpose(np.tile(np.arange(np.size(noise_shape, 0)), (np.size(noise_shape, 1), 1)))
# interference = constellation * (np.exp(1j * omega * indices) + phase)
# else:
# print('Type not specified.')
# x += interference
# return x
#
def interference( x, noise_shape, amp, omega, phase, type):
interference = torch.zeros(np.shape(noise_shape))
indices = torch.transpose(np.tile(np.arange(np.size(noise_shape, 0)), (np.size(noise_shape, 1), 1)))
if type == 'sin':
for i in range(np.size(indices, 1)):
interference[:, i] = amp * np.sin(omega * indices[:, i] + phase[:, i])
elif type == 'bpsk':
random_seq = np.random.randint(low=0, high=1, size=np.shape(noise_shape))
constellation = amp * 2 * (random_seq - 0.5) # bpsk
indices = np.transpose(np.tile(np.arange(np.size(noise_shape, 0)), (np.size(noise_shape, 1), 1)))
interference = constellation * (np.exp(1j * omega * indices) + phase)
else:
print('Type not specified.')
x += interference
return x

Binary file modified checkpoint.pt
Binary file not shown.
16 changes: 9 additions & 7 deletions get_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

def get_args():
################################
# Setup Parameters and get args
# Setup all user parameters and get args
################################
parser = argparse.ArgumentParser()
parser.add_argument('-n_channel', type=float, default = 7)
Expand All @@ -13,12 +13,14 @@ def get_args():
parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training')
parser.add_argument('-modulation', choices = ['BPSK','QPSK','8PSK','16QAM','64QAM','256QAM'], default='BPSK')
parser.add_argument('-coding', choices=['SingleParity_4_3','Hamming_7_4','Hamming_15_11','Polar_16_4', 'EGolay_24_12'],default='Hamming_7_4')
# parser.add_argument('-dropout',type=float, default=0.0)
# parser.add_argument('-EbN0dB_test_start', type=float, default=-0.0)
# parser.add_argument('-EbN0dB_test_end', type=float, default=11.0)
# parser.add_argument('-EbN0dB_points', type=int, default=23)
# parser.add_argument('-batch_size', type=int, default=100)
# parser.add_argument('-num_epoch', type=int, default=1)
parser.add_argument('-learning_rate', type=float, default=0.001)
parser.add_argument('-batch_size', type=int, default=64)
parser.add_argument('-dropout',type=float, default=0.0)
parser.add_argument('-EbN0_dB_train', type=float, default=3.0)
parser.add_argument('-EbN0dB_test_start', type=float, default=-0.0)
parser.add_argument('-EbN0dB_test_end', type=float, default=11.5)
parser.add_argument('-EbN0dB_precision', type=int, default=0.5)
parser.add_argument('-epochs', type=int, default=10)
args = parser.parse_args()

return args
104 changes: 0 additions & 104 deletions main.py

This file was deleted.

38 changes: 6 additions & 32 deletions main1.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from models import FC_Autoencoder
from tools import EarlyStopping
from utils import generate_encoded_sym_dict, get_args
from utils import generate_encoded_sym_dict
from datasets import prepare_data
from trainer import train, validate, test
from get_args import get_args
Expand All @@ -18,16 +18,6 @@
import numpy as np

# User parameters
EbN0_dB_train = 3.0
epochs = 100 # train the training data 'epoch' times
batch_size = 64
learning_rate = 0.001 # learning rate
EbN0_test = torch.arange(0, 11.5, 0.5) # Test parameters

patience = 10 # early stopping patience; how long to wait after last time validation loss improved.
early_stopping = EarlyStopping(patience=patience, verbose=True) # initialize the early_stopping object

# Set sizes
train_set_size = 10 ** 5
val_set_size = 10 ** 5
test_set_size = 10 ** 5
Expand All @@ -36,38 +26,22 @@
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

# Define loggers
log_writer_train = SummaryWriter('logs/train/')
log_writer_val = SummaryWriter('logs/val/')


def run():
torch.backends.cudnn.benchmark = True # Make sure torch is accessing CuDNN libraries
args = get_args() # Get arguments - go with default (Hamming (7,4) BPSK) if not provided
args = get_args() # Get arguments - go with default (Hamming (7,4) BPSK) if not provided

R = args.k / args.n_channel
class_num = 2 ** args.k # (n=7,k=4) m=16

# Setup the model and move it to GPU
net = FC_Autoencoder(args.k, args.n_channel)
net = net.to(device)

# Prepare data
traindataset, trainloader, train_labels = prepare_data(train_set_size, class_num, batch_size)
traindataset, trainloader, train_labels = prepare_data(train_set_size, class_num, args.batch_size)
valdataset, valloader, val_labels = prepare_data(val_set_size, class_num, val_set_size) # Validation data
testdataset, testloader, test_labels = prepare_data(test_set_size, class_num, test_set_size)
test_BLER = torch.zeros((len(EbN0_test), 1))

optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate) # optimize all network parameters
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.01) # Decay LR by a factor of 0.1 every 7 epochs
loss_func = nn.CrossEntropyLoss() # the target label is not one-hotted
loss_vec = []

# Training
awgn_train(trainloader, valloader, val_set_size, epochs, net,
optimizer, early_stopping, loss_func, device, loss_vec,
batch_size, EbN0_dB_train, args, log_writer_train, log_writer_val)
trained_net = awgn_train(trainloader, valloader, val_set_size, device, args)

# TESTING
awgn_test(testloader, net, device, EbN0_test, test_BLER, args)
awgn_test(testloader, trained_net, device, args)
if __name__ == '__main__':
run()
Binary file added mfbanks/ae_mfbank_AWGN_BPSK_74.mat
Binary file not shown.
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from tensorboardX import SummaryWriter
from torch.optim import lr_scheduler
from tools import EarlyStopping
from utils import generate_encoded_sym_dict, get_args
from utils import generate_encoded_sym_dict
from datasets import prepare_data
from trainer import train, validate, test
import numpy as np
Expand Down
1 change: 0 additions & 1 deletion next_steps.md

This file was deleted.

Binary file modified trained_net_74AE.ckpt
Binary file not shown.
48 changes: 20 additions & 28 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,6 @@
import pandas as pd
import matplotlib.pyplot as plt

def get_args():
parser = argparse.ArgumentParser()

args = parser.parse_args()
return args



def d2b(d, n):
d = np.array(d)
d = np.reshape(d, (1, -1))
Expand All @@ -32,26 +24,26 @@ def generate_encoded_sym_dict(n_channel,k,net, device):
S_encoded_syms = (enc_output.cpu()).detach().numpy()

dict1 = {'S_encoded_syms': S_encoded_syms, 'bit_dict': bit_dict.astype(np.int8)}
savemat('ae_mfbank_AWGN_BPSK_'+str(n_channel)+str(k)+'.mat', dict1)
savemat('mfbanks/ae_mfbank_AWGN_BPSK_'+str(n_channel)+str(k)+'.mat', dict1)
print('Generated dictionaries and encoded symbols')


def get_plots():
# Plot 1 -
plt.plot(train_acc_store,'r-o')
plt.plot(test_acc_store,'b-o')
plt.xlabel('number of epochs')
plt.ylabel('accuracy')
plt.ylim(0.85,1)
plt.legend(('training','validation'),loc='upper left')
plt.title('train and test accuracy w.r.t epochs')
plt.show()

# Plot 2 -
plt.plot(train_loss_store,'r-o')
plt.plot(test_loss_store,'b-o')
plt.xlabel('number of epochs')
plt.ylabel('loss')
plt.legend(('training','validation'),loc='upper right')
plt.title('train and test loss w.r.t epochs')
plt.show()
# def get_plots():
# # Plot 1 -
# plt.plot(train_acc_store,'r-o')
# plt.plot(test_acc_store,'b-o')
# plt.xlabel('number of epochs')
# plt.ylabel('accuracy')
# plt.ylim(0.85,1)
# plt.legend(('training','validation'),loc='upper left')
# plt.title('train and test accuracy w.r.t epochs')
# plt.show()
#
# # Plot 2 -
# plt.plot(train_loss_store,'r-o')
# plt.plot(test_loss_store,'b-o')
# plt.xlabel('number of epochs')
# plt.ylabel('loss')
# plt.legend(('training','validation'),loc='upper right')
# plt.title('train and test loss w.r.t epochs')
# plt.show()

0 comments on commit c7b3ed4

Please sign in to comment.