Skip to content
Open

pep8 #153

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ def create_dataset(opt, rank=0):
dataset = data_loader.load_data()
return dataset

class CustomDatasetDataLoader():

class CustomDatasetDataLoader:
"""Wrapper class of Dataset class that performs multi-threaded data loading"""

def __init__(self, opt, rank=0):
Expand Down
3 changes: 3 additions & 0 deletions data/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def get_transform(grayscale=False):
transform_list += [transforms.ToTensor()]
return transforms.Compose(transform_list)


def get_affine_mat(opt, size):
shift_x, shift_y, scale, rot_angle, flip = 0., 0., 1., 0., False
w, h = size
Expand Down Expand Up @@ -101,9 +102,11 @@ def get_affine_mat(opt, size):
affine_inv = np.linalg.inv(affine)
return affine, affine_inv, flip


def apply_img_affine(img, affine_inv, method=RESAMPLING_METHOD):
return img.transform(img.size, Image.AFFINE, data=affine_inv.flatten()[:6], resample=RESAMPLING_METHOD)


def apply_lm_affine(landmark, affine, flip, size):
_, h = size
lm = landmark.copy()
Expand Down
7 changes: 2 additions & 5 deletions data/flist_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,13 @@ def default_flist_reader(flist):

return imlist


def jason_flist_reader(flist):
with open(flist, 'r') as fp:
info = json.load(fp)
return info


def parse_label(label):
return torch.tensor(np.array(label).astype(np.float32))

Expand Down Expand Up @@ -62,7 +64,6 @@ def __init__(self, opt):
self.name = 'train' if opt.isTrain else 'val'
if '_' in opt.flist:
self.name += '_' + opt.flist.split(os.sep)[-1].split('_')[0]


def __getitem__(self, index):
"""Return a data point and its metadata information.
Expand Down Expand Up @@ -99,7 +100,6 @@ def __getitem__(self, index):
lm_tensor = parse_label(lm)
M_tensor = parse_label(M)


return {'imgs': img_tensor,
'lms': lm_tensor,
'msks': msk_tensor,
Expand All @@ -115,9 +115,6 @@ def _augmentation(self, img, lm, opt, msk=None):
if msk is not None:
msk = apply_img_affine(msk, affine_inv, method=Image.BILINEAR)
return img, lm, msk




def __len__(self):
"""Return the total number of images in the dataset.
Expand Down
26 changes: 14 additions & 12 deletions data_preparation.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"""This script is the data preparation script for Deep3DFaceRecon_pytorch
"""

import os
import os
import numpy as np
import argparse
from util.detect_lm68 import detect_68p,load_lm_graph
from util.detect_lm68 import detect_68p, load_lm_graph
from util.skin_mask import get_skin_mask
from util.generate_list import check_list, write_list
import warnings
warnings.filterwarnings("ignore")
warnings.filterwarnings("ignore")

parser = argparse.ArgumentParser()
parser.add_argument('--data_root', type=str, default='datasets', help='root directory for training data')
Expand All @@ -18,28 +18,30 @@

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

def data_prepare(folder_list,mode):

lm_sess,input_op,output_op = load_lm_graph('./checkpoints/lm_model/68lm_detector.pb') # load a tensorflow version 68-landmark detector
def data_prepare(folder_list, mode):

lm_sess, input_op, output_op = load_lm_graph('./checkpoints/lm_model/68lm_detector.pb') # load a tensorflow version 68-landmark detector

for img_folder in folder_list:
detect_68p(img_folder,lm_sess,input_op,output_op) # detect landmarks for images
get_skin_mask(img_folder) # generate skin attention mask for images
detect_68p(img_folder, lm_sess, input_op, output_op) # detect landmarks for images
get_skin_mask(img_folder) # generate skin attention mask for images

# create files that record path to all training data
msks_list = []
for img_folder in folder_list:
path = os.path.join(img_folder, 'mask')
msks_list += ['/'.join([img_folder, 'mask', i]) for i in sorted(os.listdir(path)) if 'jpg' in i or
msks_list += ['/'.join([img_folder, 'mask', i]) for i in sorted(os.listdir(path)) if 'jpg' in i or
'png' in i or 'jpeg' in i or 'PNG' in i]

imgs_list = [i.replace('mask/', '') for i in msks_list]
lms_list = [i.replace('mask', 'landmarks') for i in msks_list]
lms_list = ['.'.join(i.split('.')[:-1]) + '.txt' for i in lms_list]

lms_list_final, imgs_list_final, msks_list_final = check_list(lms_list, imgs_list, msks_list) # check if the path is valid
write_list(lms_list_final, imgs_list_final, msks_list_final, mode=mode) # save files
lms_list_final, imgs_list_final, msks_list_final = check_list(lms_list, imgs_list, msks_list) # check if the path is valid
write_list(lms_list_final, imgs_list_final, msks_list_final, mode=mode) # save files


if __name__ == '__main__':
print('Datasets:',opt.img_folder)
data_prepare([os.path.join(opt.data_root,folder) for folder in opt.img_folder],opt.mode)
print('Datasets:', opt.img_folder)
data_prepare([os.path.join(opt.data_root, folder) for folder in opt.img_folder], opt.mode)
25 changes: 10 additions & 15 deletions models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,7 @@ def setup(self, opt):
if not self.isTrain or opt.continue_train:
load_suffix = opt.epoch
self.load_networks(load_suffix)



# self.print_networks(opt.verbose)

def parallelize(self, convert_sync_batchnorm=True):
Expand All @@ -117,8 +116,8 @@ def parallelize(self, convert_sync_batchnorm=True):
if convert_sync_batchnorm:
module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module)
setattr(self, name, torch.nn.parallel.DistributedDataParallel(module.to(self.device),
device_ids=[self.device.index],
find_unused_parameters=True, broadcast_buffers=True))
device_ids=[self.device.index],
find_unused_parameters=True, broadcast_buffers=True))

# DistributedDataParallel is not needed when a module doesn't have any parameter that requires a gradient.
for name in self.parallel_names:
Expand Down Expand Up @@ -168,7 +167,7 @@ def compute_visuals(self):

def get_image_paths(self, name='A'):
""" Return image paths that are used to load current data"""
return self.image_paths if name =='A' else self.image_paths_B
return self.image_paths if name == 'A' else self.image_paths_B

def update_learning_rate(self):
"""Update learning rates for all the networks; called at the end of every epoch"""
Expand Down Expand Up @@ -213,17 +212,16 @@ def save_networks(self, epoch):
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, name)
if isinstance(net, torch.nn.DataParallel) or isinstance(net,
torch.nn.parallel.DistributedDataParallel):
if (isinstance(net, torch.nn.DataParallel)
or isinstance(net, torch.nn.parallel.DistributedDataParallel)):
net = net.module
save_dict[name] = net.state_dict()


for i, optim in enumerate(self.optimizers):
save_dict['opt_%02d'%i] = optim.state_dict()
save_dict['opt_%02d' % i] = optim.state_dict()

for i, sched in enumerate(self.schedulers):
save_dict['sched_%02d'%i] = sched.state_dict()
save_dict['sched_%02d' % i] = sched.state_dict()

torch.save(save_dict, save_path)

Expand Down Expand Up @@ -267,19 +265,16 @@ def load_networks(self, epoch):
if self.opt.continue_train:
print('loading the optim from %s' % load_path)
for i, optim in enumerate(self.optimizers):
optim.load_state_dict(state_dict['opt_%02d'%i])
optim.load_state_dict(state_dict['opt_%02d' % i])

try:
print('loading the sched from %s' % load_path)
for i, sched in enumerate(self.schedulers):
sched.load_state_dict(state_dict['sched_%02d'%i])
sched.load_state_dict(state_dict['sched_%02d' % i])
except:
print('Failed to load schedulers, set schedulers according to epoch count manually')
for i, sched in enumerate(self.schedulers):
sched.last_epoch = self.opt.epoch_count - 1




def print_networks(self, verbose):
"""Print the total number of parameters in the network and (if verbose) network architecture
Expand Down
32 changes: 12 additions & 20 deletions models/bfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
"""

import numpy as np
import torch
import torch
import torch.nn.functional as F
from scipy.io import loadmat
from util.load_mats import transferBFM09
import os


def perspective_projection(focal, center):
# return p.T (N, 3) @ (3, 3)
return np.array([
Expand All @@ -16,25 +17,25 @@ def perspective_projection(focal, center):
0, 0, 1
]).reshape([3, 3]).astype(np.float32).transpose()


class SH:
def __init__(self):
self.a = [np.pi, 2 * np.pi / np.sqrt(3.), 2 * np.pi / np.sqrt(8.)]
self.c = [1/np.sqrt(4 * np.pi), np.sqrt(3.) / np.sqrt(4 * np.pi), 3 * np.sqrt(5.) / np.sqrt(12 * np.pi)]



class ParametricFaceModel:
def __init__(self,
bfm_folder='./BFM',
recenter=True,
camera_distance=10.,
init_lit=np.array([
bfm_folder='./BFM',
recenter=True,
camera_distance=10.,
init_lit=np.array([
0.8, 0, 0, 0, 0, 0, 0, 0, 0
]),
focal=1015.,
center=112.,
is_train=True,
default_name='BFM_model_front.mat'):
focal=1015.,
center=112.,
is_train=True,
default_name='BFM_model_front.mat'):

if not os.path.isfile(os.path.join(bfm_folder, default_name)):
transferBFM09(bfm_folder)
Expand Down Expand Up @@ -74,15 +75,13 @@ def __init__(self,
self.camera_distance = camera_distance
self.SH = SH()
self.init_lit = init_lit.reshape([1, 1, -1]).astype(np.float32)


def to(self, device):
self.device = device
for key, value in self.__dict__.items():
if type(value).__module__ == np.__name__:
setattr(self, key, torch.tensor(value).to(device))


def compute_shape(self, id_coeff, exp_coeff):
"""
Return:
Expand All @@ -97,7 +96,6 @@ def compute_shape(self, id_coeff, exp_coeff):
exp_part = torch.einsum('ij,aj->ai', self.exp_base, exp_coeff)
face_shape = id_part + exp_part + self.mean_shape.reshape([1, -1])
return face_shape.reshape([batch_size, -1, 3])


def compute_texture(self, tex_coeff, normalize=True):
"""
Expand All @@ -113,7 +111,6 @@ def compute_texture(self, tex_coeff, normalize=True):
face_texture = face_texture / 255.
return face_texture.reshape([batch_size, -1, 3])


def compute_norm(self, face_shape):
"""
Return:
Expand All @@ -136,7 +133,6 @@ def compute_norm(self, face_shape):
vertex_norm = F.normalize(vertex_norm, dim=-1, p=2)
return vertex_norm


def compute_color(self, face_texture, face_norm, gamma):
"""
Return:
Expand Down Expand Up @@ -170,7 +166,6 @@ def compute_color(self, face_texture, face_norm, gamma):
face_color = torch.cat([r, g, b], dim=-1) * face_texture
return face_color


def compute_rotation(self, angles):
"""
Return:
Expand Down Expand Up @@ -206,7 +201,6 @@ def compute_rotation(self, angles):
rot = rot_z @ rot_y @ rot_x
return rot.permute(0, 2, 1)


def to_camera(self, face_shape):
face_shape[..., -1] = self.camera_distance - face_shape[..., -1]
return face_shape
Expand All @@ -225,7 +219,6 @@ def to_image(self, face_shape):

return face_proj


def transform(self, face_shape, rot, trans):
"""
Return:
Expand All @@ -238,7 +231,6 @@ def transform(self, face_shape, rot, trans):
"""
return face_shape @ rot + trans.unsqueeze(1)


def get_landmarks(self, face_proj):
"""
Return:
Expand Down Expand Up @@ -271,6 +263,7 @@ def split_coeff(self, coeffs):
'gamma': gammas,
'trans': translations
}

def compute_for_render(self, coeffs):
"""
Return:
Expand All @@ -284,7 +277,6 @@ def compute_for_render(self, coeffs):
face_shape = self.compute_shape(coef_dict['id'], coef_dict['exp'])
rotation = self.compute_rotation(coef_dict['angle'])


face_shape_transformed = self.transform(face_shape, rotation, coef_dict['trans'])
face_vertex = self.to_camera(face_shape_transformed)

Expand Down
Loading