diff --git a/data/__init__.py b/data/__init__.py index 56fe212..a86a981 100644 --- a/data/__init__.py +++ b/data/__init__.py @@ -59,7 +59,8 @@ def create_dataset(opt, rank=0): dataset = data_loader.load_data() return dataset -class CustomDatasetDataLoader(): + +class CustomDatasetDataLoader: """Wrapper class of Dataset class that performs multi-threaded data loading""" def __init__(self, opt, rank=0): diff --git a/data/base_dataset.py b/data/base_dataset.py index 1275d60..9fd296f 100644 --- a/data/base_dataset.py +++ b/data/base_dataset.py @@ -74,6 +74,7 @@ def get_transform(grayscale=False): transform_list += [transforms.ToTensor()] return transforms.Compose(transform_list) + def get_affine_mat(opt, size): shift_x, shift_y, scale, rot_angle, flip = 0., 0., 1., 0., False w, h = size @@ -101,9 +102,11 @@ def get_affine_mat(opt, size): affine_inv = np.linalg.inv(affine) return affine, affine_inv, flip + def apply_img_affine(img, affine_inv, method=RESAMPLING_METHOD): return img.transform(img.size, Image.AFFINE, data=affine_inv.flatten()[:6], resample=RESAMPLING_METHOD) + def apply_lm_affine(landmark, affine, flip, size): _, h = size lm = landmark.copy() diff --git a/data/flist_dataset.py b/data/flist_dataset.py index c0b6945..c676135 100644 --- a/data/flist_dataset.py +++ b/data/flist_dataset.py @@ -28,11 +28,13 @@ def default_flist_reader(flist): return imlist + def jason_flist_reader(flist): with open(flist, 'r') as fp: info = json.load(fp) return info + def parse_label(label): return torch.tensor(np.array(label).astype(np.float32)) @@ -62,7 +64,6 @@ def __init__(self, opt): self.name = 'train' if opt.isTrain else 'val' if '_' in opt.flist: self.name += '_' + opt.flist.split(os.sep)[-1].split('_')[0] - def __getitem__(self, index): """Return a data point and its metadata information. @@ -99,7 +100,6 @@ def __getitem__(self, index): lm_tensor = parse_label(lm) M_tensor = parse_label(M) - return {'imgs': img_tensor, 'lms': lm_tensor, 'msks': msk_tensor, @@ -115,9 +115,6 @@ def _augmentation(self, img, lm, opt, msk=None): if msk is not None: msk = apply_img_affine(msk, affine_inv, method=Image.BILINEAR) return img, lm, msk - - - def __len__(self): """Return the total number of images in the dataset. diff --git a/data_preparation.py b/data_preparation.py index 6ffc79d..b127d9e 100644 --- a/data_preparation.py +++ b/data_preparation.py @@ -1,14 +1,14 @@ """This script is the data preparation script for Deep3DFaceRecon_pytorch """ -import os +import os import numpy as np import argparse -from util.detect_lm68 import detect_68p,load_lm_graph +from util.detect_lm68 import detect_68p, load_lm_graph from util.skin_mask import get_skin_mask from util.generate_list import check_list, write_list import warnings -warnings.filterwarnings("ignore") +warnings.filterwarnings("ignore") parser = argparse.ArgumentParser() parser.add_argument('--data_root', type=str, default='datasets', help='root directory for training data') @@ -18,28 +18,30 @@ os.environ['CUDA_VISIBLE_DEVICES'] = '0' -def data_prepare(folder_list,mode): - lm_sess,input_op,output_op = load_lm_graph('./checkpoints/lm_model/68lm_detector.pb') # load a tensorflow version 68-landmark detector +def data_prepare(folder_list, mode): + + lm_sess, input_op, output_op = load_lm_graph('./checkpoints/lm_model/68lm_detector.pb') # load a tensorflow version 68-landmark detector for img_folder in folder_list: - detect_68p(img_folder,lm_sess,input_op,output_op) # detect landmarks for images - get_skin_mask(img_folder) # generate skin attention mask for images + detect_68p(img_folder, lm_sess, input_op, output_op) # detect landmarks for images + get_skin_mask(img_folder) # generate skin attention mask for images # create files that record path to all training data msks_list = [] for img_folder in folder_list: path = os.path.join(img_folder, 'mask') - msks_list += ['/'.join([img_folder, 'mask', i]) for i in sorted(os.listdir(path)) if 'jpg' in i or + msks_list += ['/'.join([img_folder, 'mask', i]) for i in sorted(os.listdir(path)) if 'jpg' in i or 'png' in i or 'jpeg' in i or 'PNG' in i] imgs_list = [i.replace('mask/', '') for i in msks_list] lms_list = [i.replace('mask', 'landmarks') for i in msks_list] lms_list = ['.'.join(i.split('.')[:-1]) + '.txt' for i in lms_list] - lms_list_final, imgs_list_final, msks_list_final = check_list(lms_list, imgs_list, msks_list) # check if the path is valid - write_list(lms_list_final, imgs_list_final, msks_list_final, mode=mode) # save files + lms_list_final, imgs_list_final, msks_list_final = check_list(lms_list, imgs_list, msks_list) # check if the path is valid + write_list(lms_list_final, imgs_list_final, msks_list_final, mode=mode) # save files + if __name__ == '__main__': - print('Datasets:',opt.img_folder) - data_prepare([os.path.join(opt.data_root,folder) for folder in opt.img_folder],opt.mode) + print('Datasets:', opt.img_folder) + data_prepare([os.path.join(opt.data_root, folder) for folder in opt.img_folder], opt.mode) diff --git a/models/base_model.py b/models/base_model.py index 2a05d3a..4915338 100644 --- a/models/base_model.py +++ b/models/base_model.py @@ -100,8 +100,7 @@ def setup(self, opt): if not self.isTrain or opt.continue_train: load_suffix = opt.epoch self.load_networks(load_suffix) - - + # self.print_networks(opt.verbose) def parallelize(self, convert_sync_batchnorm=True): @@ -117,8 +116,8 @@ def parallelize(self, convert_sync_batchnorm=True): if convert_sync_batchnorm: module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module) setattr(self, name, torch.nn.parallel.DistributedDataParallel(module.to(self.device), - device_ids=[self.device.index], - find_unused_parameters=True, broadcast_buffers=True)) + device_ids=[self.device.index], + find_unused_parameters=True, broadcast_buffers=True)) # DistributedDataParallel is not needed when a module doesn't have any parameter that requires a gradient. for name in self.parallel_names: @@ -168,7 +167,7 @@ def compute_visuals(self): def get_image_paths(self, name='A'): """ Return image paths that are used to load current data""" - return self.image_paths if name =='A' else self.image_paths_B + return self.image_paths if name == 'A' else self.image_paths_B def update_learning_rate(self): """Update learning rates for all the networks; called at the end of every epoch""" @@ -213,17 +212,16 @@ def save_networks(self, epoch): for name in self.model_names: if isinstance(name, str): net = getattr(self, name) - if isinstance(net, torch.nn.DataParallel) or isinstance(net, - torch.nn.parallel.DistributedDataParallel): + if (isinstance(net, torch.nn.DataParallel) + or isinstance(net, torch.nn.parallel.DistributedDataParallel)): net = net.module save_dict[name] = net.state_dict() - for i, optim in enumerate(self.optimizers): - save_dict['opt_%02d'%i] = optim.state_dict() + save_dict['opt_%02d' % i] = optim.state_dict() for i, sched in enumerate(self.schedulers): - save_dict['sched_%02d'%i] = sched.state_dict() + save_dict['sched_%02d' % i] = sched.state_dict() torch.save(save_dict, save_path) @@ -267,19 +265,16 @@ def load_networks(self, epoch): if self.opt.continue_train: print('loading the optim from %s' % load_path) for i, optim in enumerate(self.optimizers): - optim.load_state_dict(state_dict['opt_%02d'%i]) + optim.load_state_dict(state_dict['opt_%02d' % i]) try: print('loading the sched from %s' % load_path) for i, sched in enumerate(self.schedulers): - sched.load_state_dict(state_dict['sched_%02d'%i]) + sched.load_state_dict(state_dict['sched_%02d' % i]) except: print('Failed to load schedulers, set schedulers according to epoch count manually') for i, sched in enumerate(self.schedulers): sched.last_epoch = self.opt.epoch_count - 1 - - - def print_networks(self, verbose): """Print the total number of parameters in the network and (if verbose) network architecture diff --git a/models/bfm.py b/models/bfm.py index 33a27e7..d581c40 100644 --- a/models/bfm.py +++ b/models/bfm.py @@ -2,12 +2,13 @@ """ import numpy as np -import torch +import torch import torch.nn.functional as F from scipy.io import loadmat from util.load_mats import transferBFM09 import os + def perspective_projection(focal, center): # return p.T (N, 3) @ (3, 3) return np.array([ @@ -16,25 +17,25 @@ def perspective_projection(focal, center): 0, 0, 1 ]).reshape([3, 3]).astype(np.float32).transpose() + class SH: def __init__(self): self.a = [np.pi, 2 * np.pi / np.sqrt(3.), 2 * np.pi / np.sqrt(8.)] self.c = [1/np.sqrt(4 * np.pi), np.sqrt(3.) / np.sqrt(4 * np.pi), 3 * np.sqrt(5.) / np.sqrt(12 * np.pi)] - class ParametricFaceModel: def __init__(self, - bfm_folder='./BFM', - recenter=True, - camera_distance=10., - init_lit=np.array([ + bfm_folder='./BFM', + recenter=True, + camera_distance=10., + init_lit=np.array([ 0.8, 0, 0, 0, 0, 0, 0, 0, 0 ]), - focal=1015., - center=112., - is_train=True, - default_name='BFM_model_front.mat'): + focal=1015., + center=112., + is_train=True, + default_name='BFM_model_front.mat'): if not os.path.isfile(os.path.join(bfm_folder, default_name)): transferBFM09(bfm_folder) @@ -74,7 +75,6 @@ def __init__(self, self.camera_distance = camera_distance self.SH = SH() self.init_lit = init_lit.reshape([1, 1, -1]).astype(np.float32) - def to(self, device): self.device = device @@ -82,7 +82,6 @@ def to(self, device): if type(value).__module__ == np.__name__: setattr(self, key, torch.tensor(value).to(device)) - def compute_shape(self, id_coeff, exp_coeff): """ Return: @@ -97,7 +96,6 @@ def compute_shape(self, id_coeff, exp_coeff): exp_part = torch.einsum('ij,aj->ai', self.exp_base, exp_coeff) face_shape = id_part + exp_part + self.mean_shape.reshape([1, -1]) return face_shape.reshape([batch_size, -1, 3]) - def compute_texture(self, tex_coeff, normalize=True): """ @@ -113,7 +111,6 @@ def compute_texture(self, tex_coeff, normalize=True): face_texture = face_texture / 255. return face_texture.reshape([batch_size, -1, 3]) - def compute_norm(self, face_shape): """ Return: @@ -136,7 +133,6 @@ def compute_norm(self, face_shape): vertex_norm = F.normalize(vertex_norm, dim=-1, p=2) return vertex_norm - def compute_color(self, face_texture, face_norm, gamma): """ Return: @@ -170,7 +166,6 @@ def compute_color(self, face_texture, face_norm, gamma): face_color = torch.cat([r, g, b], dim=-1) * face_texture return face_color - def compute_rotation(self, angles): """ Return: @@ -206,7 +201,6 @@ def compute_rotation(self, angles): rot = rot_z @ rot_y @ rot_x return rot.permute(0, 2, 1) - def to_camera(self, face_shape): face_shape[..., -1] = self.camera_distance - face_shape[..., -1] return face_shape @@ -225,7 +219,6 @@ def to_image(self, face_shape): return face_proj - def transform(self, face_shape, rot, trans): """ Return: @@ -238,7 +231,6 @@ def transform(self, face_shape, rot, trans): """ return face_shape @ rot + trans.unsqueeze(1) - def get_landmarks(self, face_proj): """ Return: @@ -271,6 +263,7 @@ def split_coeff(self, coeffs): 'gamma': gammas, 'trans': translations } + def compute_for_render(self, coeffs): """ Return: @@ -284,7 +277,6 @@ def compute_for_render(self, coeffs): face_shape = self.compute_shape(coef_dict['id'], coef_dict['exp']) rotation = self.compute_rotation(coef_dict['angle']) - face_shape_transformed = self.transform(face_shape, rotation, coef_dict['trans']) face_vertex = self.to_camera(face_shape_transformed) diff --git a/models/facerecon_model.py b/models/facerecon_model.py index dfaaea9..1ebc269 100644 --- a/models/facerecon_model.py +++ b/models/facerecon_model.py @@ -7,13 +7,14 @@ from . import networks from .bfm import ParametricFaceModel from .losses import perceptual_loss, photo_loss, reg_loss, reflectance_loss, landmark_loss -from util import util +from util import util from util.nvdiffrast import MeshRenderer from util.preprocess import estimate_norm_torch import trimesh from scipy.io import savemat + class FaceReconModel(BaseModel): @staticmethod @@ -42,7 +43,6 @@ def modify_commandline_options(parser, is_train=True): parser.add_argument('--use_crop_face', type=util.str2bool, nargs='?', const=True, default=False, help='use crop mask for photo loss') parser.add_argument('--use_predef_M', type=util.str2bool, nargs='?', const=True, default=False, help='use predefined M for predicted face') - # augmentation parameters parser.add_argument('--shift_pixs', type=float, default=10., help='shift pixels') parser.add_argument('--scale_delta', type=float, default=0.1, help='delta scale factor') @@ -59,8 +59,6 @@ def modify_commandline_options(parser, is_train=True): parser.add_argument('--w_lm', type=float, default=1.6e-3, help='weight for lm loss') parser.add_argument('--w_reflc', type=float, default=5.0, help='weight for reflc loss') - - opt, _ = parser.parse_known_args() parser.set_defaults( focal=1015., center=112., camera_d=10., use_last_fc=False, z_near=5., z_far=15. @@ -125,9 +123,9 @@ def set_input(self, input): Parameters: input: a dictionary that contains the data itself and its metadata information. """ - self.input_img = input['imgs'].to(self.device) + self.input_img = input['imgs'].to(self.device) self.atten_mask = input['msks'].to(self.device) if 'msks' in input else None - self.gt_lm = input['lms'].to(self.device) if 'lms' in input else None + self.gt_lm = input['lms'].to(self.device) if 'lms' in input else None self.trans_m = input['M'].to(self.device) if 'M' in input else None self.image_paths = input['im_paths'] if 'im_paths' in input else None @@ -141,11 +139,10 @@ def forward(self): self.pred_coeffs_dict = self.facemodel.split_coeff(output_coeff) - def compute_losses(self): """Calculate losses, gradients, and update network weights; called in every training iteration""" - assert self.net_recog.training == False + assert not self.net_recog.training trans_m = self.trans_m if not self.opt.use_predef_M: trans_m = estimate_norm_torch(self.pred_lm, self.input_img.shape[-2]) @@ -172,16 +169,15 @@ def compute_losses(self): self.loss_all = self.loss_feat + self.loss_color + self.loss_reg + self.loss_gamma \ + self.loss_lm + self.loss_reflc - def optimize_parameters(self, isTrain=True): - self.forward() + self.forward() self.compute_losses() """Update network weights; it will be called in every training iteration.""" if isTrain: - self.optimizer.zero_grad() - self.loss_all.backward() - self.optimizer.step() + self.optimizer.zero_grad() + self.loss_all.backward() + self.optimizer.step() def compute_visuals(self): with torch.no_grad(): @@ -195,10 +191,10 @@ def compute_visuals(self): output_vis_numpy = util.draw_landmarks(output_vis_numpy_raw, gt_lm_numpy, 'b') output_vis_numpy = util.draw_landmarks(output_vis_numpy, pred_lm_numpy, 'r') - output_vis_numpy = np.concatenate((input_img_numpy, + output_vis_numpy = np.concatenate((input_img_numpy, output_vis_numpy_raw, output_vis_numpy), axis=-2) else: - output_vis_numpy = np.concatenate((input_img_numpy, + output_vis_numpy = np.concatenate((input_img_numpy, output_vis_numpy_raw), axis=-2) self.output_vis = torch.tensor( @@ -208,7 +204,7 @@ def compute_visuals(self): def save_mesh(self, name): recon_shape = self.pred_vertex # get reconstructed shape - recon_shape[..., -1] = 10 - recon_shape[..., -1] # from camera space to world space + recon_shape[..., -1] = 10 - recon_shape[..., -1] # from camera space to world space recon_shape = recon_shape.cpu().numpy()[0] recon_color = self.pred_color recon_color = recon_color.cpu().numpy()[0] @@ -216,13 +212,10 @@ def save_mesh(self, name): mesh = trimesh.Trimesh(vertices=recon_shape, faces=tri, vertex_colors=np.clip(255. * recon_color, 0, 255).astype(np.uint8), process=False) mesh.export(name) - def save_coeff(self,name): + def save_coeff(self, name): - pred_coeffs = {key:self.pred_coeffs_dict[key].cpu().numpy() for key in self.pred_coeffs_dict} + pred_coeffs = {key: self.pred_coeffs_dict[key].cpu().numpy() for key in self.pred_coeffs_dict} pred_lm = self.pred_lm.cpu().numpy() - pred_lm = np.stack([pred_lm[:,:,0],self.input_img.shape[2]-1-pred_lm[:,:,1]],axis=2) # transfer to image coordinate + pred_lm = np.stack([pred_lm[:, :, 0], self.input_img.shape[2]-1-pred_lm[:, :, 1]], axis=2) # transfer to image coordinate pred_coeffs['lm68'] = pred_lm - savemat(name,pred_coeffs) - - - + savemat(name, pred_coeffs) diff --git a/models/losses.py b/models/losses.py index fbacb63..f707919 100644 --- a/models/losses.py +++ b/models/losses.py @@ -4,18 +4,21 @@ from kornia.geometry import warp_affine import torch.nn.functional as F + def resize_n_crop(image, M, dsize=112): # image: (b, c, h, w) # M : (b, 2, 3) return warp_affine(image, M, dsize=(dsize, dsize)) -### perceptual level loss + +# perceptual level loss class PerceptualLoss(nn.Module): def __init__(self, recog_net, input_size=112): super(PerceptualLoss, self).__init__() self.recog_net = recog_net self.preprocess = lambda x: 2 * x - 1 - self.input_size=input_size + self.input_size = input_size + def forward(imageA, imageB, M): """ 1 - cosine distance @@ -36,12 +39,14 @@ def forward(imageA, imageB, M): # assert torch.sum((cosine_d > 1).float()) == 0 return torch.sum(1 - cosine_d) / cosine_d.shape[0] + def perceptual_loss(id_featureA, id_featureB): cosine_d = torch.sum(id_featureA * id_featureB, dim=-1) - # assert torch.sum((cosine_d > 1).float()) == 0 + # assert torch.sum((cosine_d > 1).float()) == 0 return torch.sum(1 - cosine_d) / cosine_d.shape[0] -### image level loss + +# image level loss def photo_loss(imageA, imageB, mask, eps=1e-6): """ l2 norm (with sqrt, to ensure backward stabililty, use eps, otherwise Nan may occur) @@ -53,6 +58,7 @@ def photo_loss(imageA, imageB, mask, eps=1e-6): loss = torch.sum(loss) / torch.max(torch.sum(mask), torch.tensor(1.0).to(mask.device)) return loss + def landmark_loss(predict_lm, gt_lm, weight=None): """ weighted mse loss @@ -72,7 +78,7 @@ def landmark_loss(predict_lm, gt_lm, weight=None): return loss -### regulization +# regulization def reg_loss(coeffs_dict, opt=None): """ l2 norm without the sqrt, from yu's implementation (mse) @@ -98,6 +104,7 @@ def reg_loss(coeffs_dict, opt=None): return creg_loss, gamma_loss + def reflectance_loss(texture, mask): """ minimize texture variance (mse), albedo regularization to ensure an uniform skin albedo @@ -110,4 +117,3 @@ def reflectance_loss(texture, mask): texture_mean = torch.sum(mask * texture, dim=1, keepdims=True) / torch.sum(mask) loss = torch.sum(((texture - texture_mean) * mask)**2) / (texture.shape[0] * torch.sum(mask)) return loss - diff --git a/models/networks.py b/models/networks.py index 40ce9f9..6328b57 100644 --- a/models/networks.py +++ b/models/networks.py @@ -18,11 +18,13 @@ from .arcface_torch.backbones import get_model from kornia.geometry import warp_affine + def resize_n_crop(image, M, dsize=112): # image: (b, c, h, w) # M : (b, 2, 3) return warp_affine(image, M, dsize=(dsize, dsize)) + def filter_state_dict(state_dict, remove_name='fc'): new_state_dict = {} for key in state_dict: @@ -31,6 +33,7 @@ def filter_state_dict(state_dict, remove_name='fc'): new_state_dict[key] = state_dict[key] return new_state_dict + def get_scheduler(optimizer, opt): """Return a learning rate scheduler @@ -61,34 +64,37 @@ def lambda_rule(epoch): def define_net_recon(net_recon, use_last_fc=False, init_path=None): return ReconNetWrapper(net_recon, use_last_fc=use_last_fc, init_path=init_path) + def define_net_recog(net_recog, pretrained_path=None): net = RecogNetWrapper(net_recog=net_recog, pretrained_path=pretrained_path) net.eval() return net + class ReconNetWrapper(nn.Module): - fc_dim=257 + fc_dim = 257 + def __init__(self, net_recon, use_last_fc=False, init_path=None): super(ReconNetWrapper, self).__init__() self.use_last_fc = use_last_fc if net_recon not in func_dict: - return NotImplementedError('network [%s] is not implemented', net_recon) + return NotImplementedError('network [%s] is not implemented', net_recon) func, last_dim = func_dict[net_recon] backbone = func(use_last_fc=use_last_fc, num_classes=self.fc_dim) if init_path and os.path.isfile(init_path): state_dict = filter_state_dict(torch.load(init_path, map_location='cpu')) backbone.load_state_dict(state_dict) - print("loading init net_recon %s from %s" %(net_recon, init_path)) + print("loading init net_recon %s from %s" % (net_recon, init_path)) self.backbone = backbone if not use_last_fc: self.final_layers = nn.ModuleList([ - conv1x1(last_dim, 80, bias=True), # id layer - conv1x1(last_dim, 64, bias=True), # exp layer - conv1x1(last_dim, 80, bias=True), # tex layer - conv1x1(last_dim, 3, bias=True), # angle layer - conv1x1(last_dim, 27, bias=True), # gamma layer - conv1x1(last_dim, 2, bias=True), # tx, ty - conv1x1(last_dim, 1, bias=True) # tz + conv1x1(last_dim, 80, bias=True), # id layer + conv1x1(last_dim, 64, bias=True), # exp layer + conv1x1(last_dim, 80, bias=True), # tex layer + conv1x1(last_dim, 3, bias=True), # angle layer + conv1x1(last_dim, 27, bias=True), # gamma layer + conv1x1(last_dim, 2, bias=True), # tx, ty + conv1x1(last_dim, 1, bias=True) # tz ]) for m in self.final_layers: nn.init.constant_(m.weight, 0.) @@ -111,12 +117,12 @@ def __init__(self, net_recog, pretrained_path=None, input_size=112): if pretrained_path: state_dict = torch.load(pretrained_path, map_location='cpu') net.load_state_dict(state_dict) - print("loading pretrained net_recog %s from %s" %(net_recog, pretrained_path)) + print("loading pretrained net_recog %s from %s" % (net_recog, pretrained_path)) for param in net.parameters(): param.requires_grad = False self.net = net self.preprocess = lambda x: 2 * x - 1 - self.input_size=input_size + self.input_size = input_size def forward(self, image, M): image = self.preprocess(resize_n_crop(image, M, self.input_size)) @@ -316,8 +322,6 @@ def __init__( nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) - - # Zero-initialize the last BN in each residual branch, # so that the residual branch starts with zeros, and each residual block behaves like an identity. # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 diff --git a/options/base_options.py b/options/base_options.py index 67375d0..5172b34 100644 --- a/options/base_options.py +++ b/options/base_options.py @@ -129,7 +129,6 @@ def parse(self): suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' opt.name = opt.name + suffix - # set gpu ids str_ids = opt.gpu_ids.split(',') gpu_ids = [] @@ -152,7 +151,7 @@ def parse(self): if os.path.isdir(model_dir): model_pths = [i for i in os.listdir(model_dir) if i.endswith('pth')] if os.path.isdir(model_dir) and len(model_pths) != 0: - opt.continue_train= True + opt.continue_train = True # update the latest epoch count if opt.continue_train: @@ -162,7 +161,6 @@ def parse(self): opt.epoch_count = max(epoch_counts) + 1 else: opt.epoch_count = int(opt.epoch) + 1 - self.print_options(opt) self.opt = opt diff --git a/options/train_options.py b/options/train_options.py index 1337bfd..8b6c896 100644 --- a/options/train_options.py +++ b/options/train_options.py @@ -4,6 +4,7 @@ from .base_options import BaseOptions from util import util + class TrainOptions(BaseOptions): """This class includes training options. @@ -28,7 +29,6 @@ def initialize(self, parser): parser.add_argument('--flist_val', type=str, default='datalist/val/masks.txt', help='list of mask names of val set') parser.add_argument('--batch_size_val', type=int, default=32) - # visualization parameters parser.add_argument('--display_freq', type=int, default=1000, help='frequency of showing training results on screen') parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') diff --git a/test.py b/test.py index 13e1a7d..cd925b7 100644 --- a/test.py +++ b/test.py @@ -14,18 +14,20 @@ from data.flist_dataset import default_flist_reader from scipy.io import loadmat, savemat + def get_data_path(root='examples'): im_path = [os.path.join(root, i) for i in sorted(os.listdir(root)) if i.endswith('png') or i.endswith('jpg')] lm_path = [i.replace('png', 'txt').replace('jpg', 'txt') for i in im_path] - lm_path = [os.path.join(i.replace(i.split(os.path.sep)[-1],''),'detections',i.split(os.path.sep)[-1]) for i in lm_path] + lm_path = [os.path.join(i.replace(i.split(os.path.sep)[-1], ''), 'detections', i.split(os.path.sep)[-1]) for i in lm_path] return im_path, lm_path + def read_data(im_path, lm_path, lm3d_std, to_tensor=True): # to RGB im = Image.open(im_path).convert('RGB') - W,H = im.size + W, H = im.size lm = np.loadtxt(lm_path).astype(np.float32) lm = lm.reshape([-1, 2]) lm[:, -1] = H - 1 - lm[:, -1] @@ -35,6 +37,7 @@ def read_data(im_path, lm_path, lm3d_std, to_tensor=True): lm = torch.tensor(lm).unsqueeze(0) return im, lm + def main(rank, opt, name='examples'): device = torch.device(rank) torch.cuda.set_device(device) @@ -50,9 +53,9 @@ def main(rank, opt, name='examples'): for i in range(len(im_path)): print(i, im_path[i]) - img_name = im_path[i].split(os.path.sep)[-1].replace('.png','').replace('.jpg','') + img_name = im_path[i].split(os.path.sep)[-1].replace('.png', '').replace('.jpg', '') if not os.path.isfile(lm_path[i]): - print("%s is not found !!!"%lm_path[i]) + print("%s is not found !!!" % lm_path[i]) continue im_tensor, lm_tensor = read_data(im_path[i], lm_path[i], lm3d_std) data = { @@ -65,10 +68,10 @@ def main(rank, opt, name='examples'): visualizer.display_current_results(visuals, 0, opt.epoch, dataset=name.split(os.path.sep)[-1], save_results=True, count=i, name=img_name, add_image=False) - model.save_mesh(os.path.join(visualizer.img_dir, name.split(os.path.sep)[-1], 'epoch_%s_%06d'%(opt.epoch, 0),img_name+'.obj')) # save reconstruction meshes - model.save_coeff(os.path.join(visualizer.img_dir, name.split(os.path.sep)[-1], 'epoch_%s_%06d'%(opt.epoch, 0),img_name+'.mat')) # save predicted coefficients + model.save_mesh(os.path.join(visualizer.img_dir, name.split(os.path.sep)[-1], 'epoch_%s_%06d' % (opt.epoch, 0), img_name+'.obj')) # save reconstruction meshes + model.save_coeff(os.path.join(visualizer.img_dir, name.split(os.path.sep)[-1], 'epoch_%s_%06d' % (opt.epoch, 0), img_name+'.mat')) # save predicted coefficients + if __name__ == '__main__': opt = TestOptions().parse() # get test options - main(0, opt,opt.img_folder) - + main(0, opt, opt.img_folder) diff --git a/train.py b/train.py index 26e856f..e40fe6c 100644 --- a/train.py +++ b/train.py @@ -21,9 +21,11 @@ def setup(rank, world_size, port): # initialize the process group dist.init_process_group("gloo", rank=rank, world_size=world_size) + def cleanup(): dist.destroy_process_group() + def main(rank, world_size, train_opt): val_opt = genvalconf(train_opt, isTrain=False) @@ -45,7 +47,7 @@ def main(rank, world_size, train_opt): if rank == 0: print('The batch number of training images = %d\n, \ - the batch number of validation images = %d'% (train_dataset_batches, val_dataset_batches)) + the batch number of validation images = %d' % (train_dataset_batches, val_dataset_batches)) model.print_networks(train_opt.verbose) visualizer = MyVisualizer(train_opt) # create a visualizer that display/save images and plots @@ -123,7 +125,7 @@ def main(rank, world_size, train_opt): eval_time = time.time() - val_start_time if rank == 0: - visualizer.print_current_losses(epoch, epoch_iter, losses_avg, eval_time, t_data, dataset='val') # visualize training results + visualizer.print_current_losses(epoch, epoch_iter, losses_avg, eval_time, t_data, dataset='val') # visualize training results visualizer.plot_current_losses(total_iters, losses_avg, dataset='val') model.train() @@ -152,6 +154,7 @@ def main(rank, world_size, train_opt): if use_ddp: dist.barrier() + if __name__ == '__main__': import warnings diff --git a/util/detect_lm68.py b/util/detect_lm68.py index b7e4099..9dbf602 100644 --- a/util/detect_lm68.py +++ b/util/detect_lm68.py @@ -9,9 +9,11 @@ mean_face = np.loadtxt('util/test_mean_face.txt') mean_face = mean_face.reshape([68, 2]) + def save_label(labels, save_path): np.savetxt(save_path, labels) + def draw_landmarks(img, landmark, save_name): landmark = landmark lm_img = np.zeros([img.shape[0], img.shape[1], 3]) @@ -21,10 +23,8 @@ def draw_landmarks(img, landmark, save_name): for i in range(len(landmark)): for j in range(-1, 1): for k in range(-1, 1): - if img.shape[0] - 1 - landmark[i, 1]+j > 0 and \ - img.shape[0] - 1 - landmark[i, 1]+j < img.shape[0] and \ - landmark[i, 0]+k > 0 and \ - landmark[i, 0]+k < img.shape[1]: + if (0 < img.shape[0] - 1 - landmark[i, 1] + j < img.shape[0] + and 0 < landmark[i, 0] + k < img.shape[1]): lm_img[img.shape[0] - 1 - landmark[i, 1]+j, landmark[i, 0]+k, :] = np.array([0, 0, 255]) lm_img = lm_img.astype(np.uint8) @@ -35,6 +35,7 @@ def draw_landmarks(img, landmark, save_name): def load_data(img_name, txt_name): return cv2.imread(img_name), np.loadtxt(txt_name) + # create tensorflow graph for landmark detector def load_lm_graph(graph_filename): with tf.gfile.GFile(graph_filename, 'rb') as f: @@ -47,10 +48,11 @@ def load_lm_graph(graph_filename): output_lm = graph.get_tensor_by_name('net/lm:0') lm_sess = tf.Session(graph=graph) - return lm_sess,img_224,output_lm + return lm_sess, img_224, output_lm + # landmark detection -def detect_68p(img_path,sess,input_op,output_op): +def detect_68p(img_path, sess, input_op, output_op): print('detecting landmarks......') names = [i for i in sorted(os.listdir( img_path)) if 'jpg' in i or 'png' in i or 'jpeg' in i or 'PNG' in i] @@ -66,10 +68,10 @@ def detect_68p(img_path,sess,input_op,output_op): for i in range(0, len(names)): name = names[i] - print('%05d' % (i), ' ', name) + print('%05d' % i, ' ', name) full_image_name = os.path.join(img_path, name) txt_name = '.'.join(name.split('.')[:-1]) + '.txt' - full_txt_name = os.path.join(img_path, 'detections', txt_name) # 5 facial landmark path for each image + full_txt_name = os.path.join(img_path, 'detections', txt_name) # 5 facial landmark path for each image # if an image does not have detected 5 facial landmarks, remove it from the training list if not os.path.isfile(full_txt_name): @@ -78,7 +80,7 @@ def detect_68p(img_path,sess,input_op,output_op): # load data img, five_points = load_data(full_image_name, full_txt_name) - input_img, scale, bbox = align_for_lm(img, five_points) # align for 68 landmark detection + input_img, scale, bbox = align_for_lm(img, five_points) # align for 68 landmark detection # if the alignment fails, remove corresponding image from the training list if scale == 0: diff --git a/util/generate_list.py b/util/generate_list.py index 943d906..1086a3d 100644 --- a/util/generate_list.py +++ b/util/generate_list.py @@ -3,8 +3,9 @@ import os + # save path to training data -def write_list(lms_list, imgs_list, msks_list, mode='train',save_folder='datalist', save_name=''): +def write_list(lms_list, imgs_list, msks_list, mode='train', save_folder='datalist', save_name=''): save_path = os.path.join(save_folder, mode) if not os.path.isdir(save_path): os.makedirs(save_path) @@ -17,6 +18,7 @@ def write_list(lms_list, imgs_list, msks_list, mode='train',save_folder='datalis with open(os.path.join(save_path, save_name + 'masks.txt'), 'w') as fd: fd.writelines([i + '\n' for i in msks_list]) + # check if the path is valid def check_list(rlms_list, rimgs_list, rmsks_list): lms_list, imgs_list, msks_list = [], [], [] diff --git a/util/load_mats.py b/util/load_mats.py index 5b1f4a7..d22c690 100644 --- a/util/load_mats.py +++ b/util/load_mats.py @@ -7,6 +7,7 @@ from array import array import os.path as osp + # load expression basis def LoadExpBasis(bfm_folder='BFM'): n_vertex = 53215 @@ -114,4 +115,3 @@ def load_lm3d(bfm_folder): Lm3D = Lm3D[[1, 2, 0, 3, 4], :] return Lm3D - diff --git a/util/nvdiffrast.py b/util/nvdiffrast.py index f634d66..37e2896 100644 --- a/util/nvdiffrast.py +++ b/util/nvdiffrast.py @@ -12,19 +12,21 @@ from scipy.io import loadmat from torch import nn + def ndc_projection(x=0.1, n=1.0, f=50.0): return np.array([[n/x, 0, 0, 0], [ 0, n/-x, 0, 0], [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)], [ 0, 0, -1, 0]]).astype(np.float32) + class MeshRenderer(nn.Module): def __init__(self, - rasterize_fov, - znear=0.1, - zfar=10, - rasterize_size=224, - use_opengl=True): + rasterize_fov, + znear=0.1, + zfar=10, + rasterize_size=224, + use_opengl=True): super(MeshRenderer, self).__init__() x = np.tan(np.deg2rad(rasterize_fov * 0.5)) * znear @@ -54,7 +56,6 @@ def forward(self, vertex, tri, feat=None): vertex = torch.cat([vertex, torch.ones([*vertex.shape[:2], 1]).to(device)], dim=-1) vertex[..., 1] = -vertex[..., 1] - vertex_ndc = vertex @ ndc_proj.t() if self.ctx is None: if self.use_opengl: @@ -63,7 +64,7 @@ def forward(self, vertex, tri, feat=None): else: self.ctx = dr.RasterizeCudaContext(device=device) ctx_str = "cuda" - print("create %s ctx on device cuda:%d"%(ctx_str, device.index)) + print("create %s ctx on device cuda:%d" % (ctx_str, device.index)) ranges = None if isinstance(tri, List) or len(tri.shape) == 3: @@ -80,11 +81,10 @@ def forward(self, vertex, tri, feat=None): tri = tri.type(torch.int32).contiguous() rast_out, _ = dr.rasterize(self.ctx, vertex_ndc.contiguous(), tri, resolution=[rsize, rsize], ranges=ranges) - depth, _ = dr.interpolate(vertex.reshape([-1,4])[...,2].unsqueeze(1).contiguous(), rast_out, tri) + depth, _ = dr.interpolate(vertex.reshape([-1, 4])[..., 2].unsqueeze(1).contiguous(), rast_out, tri) depth = depth.permute(0, 3, 1, 2) mask = (rast_out[..., 3] > 0).float().unsqueeze(1) depth = mask * depth - image = None if feat is not None: @@ -93,4 +93,3 @@ def forward(self, vertex, tri, feat=None): image = mask * image return mask, depth, image - diff --git a/util/preprocess.py b/util/preprocess.py index c516f45..d07eecf 100644 --- a/util/preprocess.py +++ b/util/preprocess.py @@ -45,6 +45,7 @@ def POS(xp, x): return t, s + # bounding box for 68 landmark detection def BBRegression(points, params): @@ -70,10 +71,11 @@ def BBRegression(points, params): inputs = np.transpose(inputs) x = inputs[:, 0] * rms + x_mean y = inputs[:, 1] * rms + y_mean - w = 224/inputs[:, 2] * rms + w = 224 / inputs[:, 2] * rms rects = [x, y, w, w] return np.array(rects).reshape([4]) + # utils for landmark detection def img_padding(img, box): success = True @@ -88,19 +90,21 @@ def img_padding(img, box): success = False return res, bbox, success + # utils for landmark detection def crop(img, bbox): padded_img, padded_bbox, flag = img_padding(img, bbox) if flag: crop_img = padded_img[padded_bbox[1]: padded_bbox[1] + - padded_bbox[3], padded_bbox[0]: padded_bbox[0] + padded_bbox[2]] + padded_bbox[3], padded_bbox[0]: padded_bbox[0] + padded_bbox[2]] crop_img = cv2.resize(crop_img.astype(np.uint8), - (224, 224), interpolation=cv2.INTER_CUBIC) + (224, 224), interpolation=cv2.INTER_CUBIC) scale = 224 / padded_bbox[3] return crop_img, scale else: return padded_img, 0 + # utils for landmark detection def scale_trans(img, lm, t, s): imgw = img.shape[1] @@ -118,7 +122,7 @@ def scale_trans(img, lm, t, s): up = h//2 - 112 bbox = [left, up, 224, 224] cropped_img, scale2 = crop(img, bbox) - assert(scale2!=0) + assert(scale2 != 0) t1 = np.array([bbox[0], bbox[1]]) # back to raw img s * crop + s * t1 + t2 @@ -128,6 +132,7 @@ def scale_trans(img, lm, t, s): inv = (scale/scale2, scale * t1 + t2.reshape([2])) return cropped_img, inv + # utils for landmark detection def align_for_lm(img, five_points): five_points = np.array(five_points).reshape([1, 10]) @@ -163,6 +168,7 @@ def resize_n_crop_img(img, lm, t, s, target_size=224., mask=None): return img, lm, mask + # utils for face reconstruction def extract_5p(lm): lm_idx = np.array([31, 37, 40, 43, 46, 49, 55]) - 1 @@ -171,6 +177,7 @@ def extract_5p(lm): lm5p = lm5p[[1, 2, 0, 3, 4], :] return lm5p + # utils for face reconstruction def align_img(img, lm, lm3D, mask=None, target_size=224., rescale_factor=102.): """ @@ -203,6 +210,7 @@ def align_img(img, lm, lm3D, mask=None, target_size=224., rescale_factor=102.): return trans_params, img_new, lm_new, mask_new + # utils for face recognition model def estimate_norm(lm_68p, H): # from https://github.com/deepinsight/insightface/blob/c61d3cd208a603dfa4a338bd743b320ce3e94730/recognition/common/face_align.py#L68 @@ -227,6 +235,7 @@ def estimate_norm(lm_68p, H): return M[0:2, :] + def estimate_norm_torch(lm_68p, H): lm_68p_ = lm_68p.detach().cpu().numpy() M = [] diff --git a/util/skin_mask.py b/util/skin_mask.py index a8a74e4..b31f9f0 100644 --- a/util/skin_mask.py +++ b/util/skin_mask.py @@ -6,15 +6,16 @@ import os import cv2 + class GMM: def __init__(self, dim, num, w, mu, cov, cov_det, cov_inv): - self.dim = dim # feature dimension - self.num = num # number of Gaussian components - self.w = w # weights of Gaussian components (a list of scalars) - self.mu= mu # mean of Gaussian components (a list of 1xdim vectors) - self.cov = cov # covariance matrix of Gaussian components (a list of dimxdim matrices) - self.cov_det = cov_det # pre-computed determinet of covariance matrices (a list of scalars) - self.cov_inv = cov_inv # pre-computed inverse covariance matrices (a list of dimxdim matrices) + self.dim = dim # feature dimension + self.num = num # number of Gaussian components + self.w = w # weights of Gaussian components (a list of scalars) + self.mu = mu # mean of Gaussian components (a list of 1xdim vectors) + self.cov = cov # covariance matrix of Gaussian components (a list of dimxdim matrices) + self.cov_det = cov_det # pre-computed determinet of covariance matrices (a list of scalars) + self.cov_inv = cov_inv # pre-computed inverse covariance matrices (a list of dimxdim matrices) self.factor = [0]*num for i in range(self.num): @@ -28,13 +29,13 @@ def likelihood(self, data): for i in range(self.num): data_ = data - self.mu[i] - tmp = np.matmul(data_,self.cov_inv[i]) * data_ - tmp = np.sum(tmp,axis=1) + tmp = np.matmul(data_, self.cov_inv[i]) * data_ + tmp = np.sum(tmp, axis=1) power = -0.5 * tmp p = np.array([math.exp(power[j]) for j in range(N)]) - p = p/self.factor[i] - lh += p*self.w[i] + p = p / self.factor[i] + lh += p * self.w[i] return lh @@ -58,27 +59,27 @@ def _bgr2ycbcr(bgr): gmm_skin_w = [0.24063933, 0.16365987, 0.26034665, 0.33535415] gmm_skin_mu = [np.array([113.71862, 103.39613, 164.08226]), - np.array([150.19858, 105.18467, 155.51428]), - np.array([183.92976, 107.62468, 152.71820]), - np.array([114.90524, 113.59782, 151.38217])] + np.array([150.19858, 105.18467, 155.51428]), + np.array([183.92976, 107.62468, 152.71820]), + np.array([114.90524, 113.59782, 151.38217])] gmm_skin_cov_det = [5692842.5, 5851930.5, 2329131., 1585971.] -gmm_skin_cov_inv = [np.array([[0.0019472069, 0.0020450759, -0.00060243998],[0.0020450759, 0.017700525, 0.0051420014],[-0.00060243998, 0.0051420014, 0.0081308950]]), - np.array([[0.0027110141, 0.0011036990, 0.0023122299],[0.0011036990, 0.010707724, 0.010742856],[0.0023122299, 0.010742856, 0.017481629]]), - np.array([[0.0048026871, 0.00022935172, 0.0077668377],[0.00022935172, 0.011729696, 0.0081661865],[0.0077668377, 0.0081661865, 0.025374353]]), - np.array([[0.0011989699, 0.0022453172, -0.0010748957],[0.0022453172, 0.047758564, 0.020332102],[-0.0010748957, 0.020332102, 0.024502251]])] +gmm_skin_cov_inv = [np.array([[0.0019472069, 0.0020450759, -0.00060243998], [0.0020450759, 0.017700525, 0.0051420014], [-0.00060243998, 0.0051420014, 0.0081308950]]), + np.array([[0.0027110141, 0.0011036990, 0.0023122299], [0.0011036990, 0.010707724, 0.010742856], [0.0023122299, 0.010742856, 0.017481629]]), + np.array([[0.0048026871, 0.00022935172, 0.0077668377], [0.00022935172, 0.011729696, 0.0081661865], [0.0077668377, 0.0081661865, 0.025374353]]), + np.array([[0.0011989699, 0.0022453172, -0.0010748957], [0.0022453172, 0.047758564, 0.020332102], [-0.0010748957, 0.020332102, 0.024502251]])] gmm_skin = GMM(3, 4, gmm_skin_w, gmm_skin_mu, [], gmm_skin_cov_det, gmm_skin_cov_inv) gmm_nonskin_w = [0.12791070, 0.31130761, 0.34245777, 0.21832393] gmm_nonskin_mu = [np.array([99.200851, 112.07533, 140.20602]), - np.array([110.91392, 125.52969, 130.19237]), - np.array([129.75864, 129.96107, 126.96808]), - np.array([112.29587, 128.85121, 129.05431])] + np.array([110.91392, 125.52969, 130.19237]), + np.array([129.75864, 129.96107, 126.96808]), + np.array([112.29587, 128.85121, 129.05431])] gmm_nonskin_cov_det = [458703648., 6466488., 90611376., 133097.63] -gmm_nonskin_cov_inv = [np.array([[0.00085371657, 0.00071197288, 0.00023958916],[0.00071197288, 0.0025935620, 0.00076557708],[0.00023958916, 0.00076557708, 0.0015042332]]), - np.array([[0.00024650150, 0.00045542428, 0.00015019422],[0.00045542428, 0.026412144, 0.018419769],[0.00015019422, 0.018419769, 0.037497383]]), - np.array([[0.00037054974, 0.00038146760, 0.00040408765],[0.00038146760, 0.0085505722, 0.0079136286],[0.00040408765, 0.0079136286, 0.010982352]]), - np.array([[0.00013709733, 0.00051228428, 0.00012777430],[0.00051228428, 0.28237113, 0.10528370],[0.00012777430, 0.10528370, 0.23468947]])] +gmm_nonskin_cov_inv = [np.array([[0.00085371657, 0.00071197288, 0.00023958916], [0.00071197288, 0.0025935620, 0.00076557708], [0.00023958916, 0.00076557708, 0.0015042332]]), + np.array([[0.00024650150, 0.00045542428, 0.00015019422], [0.00045542428, 0.026412144, 0.018419769], [0.00015019422, 0.018419769, 0.037497383]]), + np.array([[0.00037054974, 0.00038146760, 0.00040408765], [0.00038146760, 0.0085505722, 0.0079136286], [0.00040408765, 0.0079136286, 0.010982352]]), + np.array([[0.00013709733, 0.00051228428, 0.00012777430], [0.00051228428, 0.28237113, 0.10528370], [0.00012777430, 0.10528370, 0.23468947]])] gmm_nonskin = GMM(3, 4, gmm_nonskin_w, gmm_nonskin_mu, [], gmm_nonskin_cov_det, gmm_nonskin_cov_inv) @@ -90,20 +91,20 @@ def _bgr2ycbcr(bgr): def skinmask(imbgr): im = _bgr2ycbcr(imbgr) - data = im.reshape((-1,3)) + data = im.reshape((-1, 3)) lh_skin = gmm_skin.likelihood(data) lh_nonskin = gmm_nonskin.likelihood(data) tmp1 = prior_skin * lh_skin tmp2 = prior_nonskin * lh_nonskin - post_skin = tmp1 / (tmp1+tmp2) # posterior probability + post_skin = tmp1 / (tmp1+tmp2) # posterior probability - post_skin = post_skin.reshape((im.shape[0],im.shape[1])) + post_skin = post_skin.reshape((im.shape[0], im.shape[1])) post_skin = np.round(post_skin*255) post_skin = post_skin.astype(np.uint8) - post_skin = np.tile(np.expand_dims(post_skin,2),[1,1,3]) # reshape to H*W*3 + post_skin = np.tile(np.expand_dims(post_skin, 2), [1, 1, 3]) # reshape to H*W*3 return post_skin @@ -118,7 +119,7 @@ def get_skin_mask(img_path): for i in range(0, len(names)): name = names[i] - print('%05d' % (i), ' ', name) + print('%05d' % i, ' ', name) full_image_name = os.path.join(img_path, name) img = cv2.imread(full_image_name).astype(np.float32) skin_img = skinmask(img) diff --git a/util/util.py b/util/util.py index 0db5ec9..e716437 100644 --- a/util/util.py +++ b/util/util.py @@ -34,6 +34,7 @@ def copyconf(default_opt, **kwargs): setattr(conf, key, kwargs[key]) return conf + def genvalconf(train_opt, **kwargs): conf = Namespace(**vars(train_opt)) attr_dict = train_opt.__dict__ @@ -45,7 +46,8 @@ def genvalconf(train_opt, **kwargs): setattr(conf, key, kwargs[key]) return conf - + + def find_class_in_module(target_cls_name, module): target_cls_name = target_cls_name.replace('_', '').lower() clslib = importlib.import_module(module) @@ -183,6 +185,7 @@ def correct_resize(t, size, mode=RESAMPLING_METHOD): resized.append(resized_t) return torch.stack(resized, dim=0).to(device) + def draw_landmarks(img, landmark, color='r', step=2): """ Return: @@ -194,7 +197,7 @@ def draw_landmarks(img, landmark, color='r', step=2): landmark -- numpy.array, (B, 68, 2), y direction is opposite to v direction color -- str, 'r' or 'b' (red or blue) """ - if color =='r': + if color == 'r': c = np.array([255., 0, 0]) else: c = np.array([0, 0, 255.]) diff --git a/util/visualizer.py b/util/visualizer.py index 4023a6d..c3dff2d 100644 --- a/util/visualizer.py +++ b/util/visualizer.py @@ -10,6 +10,7 @@ from subprocess import Popen, PIPE from torch.utils.tensorboard import SummaryWriter + def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256): """Save images to the disk. @@ -41,7 +42,7 @@ def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256): webpage.add_images(ims, txts, links, width=width) -class Visualizer(): +class Visualizer: """This class includes several functions that can display/save images and print/save logging information. It uses a Python library tensprboardX for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images. @@ -78,7 +79,6 @@ def reset(self): """Reset the self.saved status""" self.saved = False - def display_current_results(self, visuals, total_iters, epoch, save_result): """Display current results on tensorboad; save current results to an HTML file. @@ -170,9 +170,8 @@ def __init__(self, opt): now = time.strftime("%c") log_file.write('================ Training Loss (%s) ================\n' % now) - def display_current_results(self, visuals, total_iters, epoch, dataset='train', save_results=False, count=0, name=None, - add_image=True): + add_image=True): """Display current results on tensorboad; save current results to an HTML file. Parameters: @@ -187,11 +186,11 @@ def display_current_results(self, visuals, total_iters, epoch, dataset='train', for i in range(image.shape[0]): image_numpy = util.tensor2im(image[i]) if add_image: - self.writer.add_image(label + '%s_%02d'%(dataset, i + count), - image_numpy, total_iters, dataformats='HWC') + self.writer.add_image(label + '%s_%02d' % (dataset, i + count), + image_numpy, total_iters, dataformats='HWC') if save_results: - save_path = os.path.join(self.img_dir, dataset, 'epoch_%s_%06d'%(epoch, total_iters)) + save_path = os.path.join(self.img_dir, dataset, 'epoch_%s_%06d' % (epoch, total_iters)) if not os.path.isdir(save_path): os.makedirs(save_path) @@ -201,10 +200,9 @@ def display_current_results(self, visuals, total_iters, epoch, dataset='train', img_path = os.path.join(save_path, '%s_%03d.png' % (label, i + count)) util.save_image(image_numpy, img_path) - def plot_current_losses(self, total_iters, losses, dataset='train'): for name, value in losses.items(): - self.writer.add_scalar(name + '/%s'%dataset, value, total_iters) + self.writer.add_scalar(name + '/%s' % dataset, value, total_iters) # losses: same format as |losses| of plot_current_losses def print_current_losses(self, epoch, iters, losses, t_comp, t_data, dataset='train'):