Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Demo code #124

Open
yaoliUoA opened this issue Apr 13, 2020 · 8 comments
Open

Demo code #124

yaoliUoA opened this issue Apr 13, 2020 · 8 comments

Comments

@yaoliUoA
Copy link

I think it is better to have some demo code in this package to visualize the segmentation output from hrnet.

@jaintarun
Copy link

Hi Yao,
Were you able to write some demo code that we can use to visualize the predictions?

I am trying to do the same but the code is quite dense and I have made some progress but it would help if you already have something working.

If you are still working on this, would you like to collaborate and creating some demo code, which we can then submit as a PR?

@umairanis03
Copy link

Hi @yaoliUoA @jaintarun Were you able to write the inference code?

@Linda-L
Copy link

Linda-L commented Dec 16, 2020

I think it is better to have some demo code in this package to visualize the segmentation output from hrnet,too

@MyHubTo
Copy link

MyHubTo commented Apr 23, 2021

I think it is necessary too.

@dreamlychina
Copy link

import argparse

from lib.config import config
from lib.config import update_config_demo
import lib.models.seg_hrnet as seg_models

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import cv2
from PIL import Image
import numpy as np
from torch.nn import functional as F

mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]
@torch.no_grad()

class FaceSeg():
def init(self,cfg_file='./experiments/cityscapes/seg_hrnet_w48_train_512x1024_sgd_lr1e-2_wd5e-4_bs_12_epoch484.yaml',weights='best.pt',device = 'cpu',imgsz=[700,700], num_classes=4):#include background

cudnn related setting

update_config_demo(config)

cudnn.benchmark = config.CUDNN.BENCHMARK
cudnn.deterministic = config.CUDNN.DETERMINISTIC
cudnn.enabled = config.CUDNN.ENABLED

# build model
if torch.__version__.startswith('1'):
    module = seg_models
    module.BatchNorm2d_class = module.BatchNorm2d = torch.nn.BatchNorm2d
model = module.get_seg_model(config)

dump_input = torch.rand(
    (1, 3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0])
)

if config.TEST.MODEL_FILE:
    model_state_file = config.TEST.MODEL_FILE
else:
    print("cant find model_file: ",config.TEST.MODEL_FILE)
    exit()
    
pretrained_dict = torch.load(model_state_file)
if 'state_dict' in pretrained_dict:
    pretrained_dict = pretrained_dict['state_dict']
model_dict = model.state_dict()
pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items()
                    if k[6:] in model_dict.keys()}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
if device != 'cpu':
    gpus = list(config.GPUS)
    model = nn.DataParallel(model, device_ids=gpus).cuda()
else:
    print("use cpu seg")
    
model.eval()
self.model=model
self.crop_size=imgsz
self.num_classes=num_classes
self.label_mapping={-1: ignore_label, 0: ignore_label, 
                      1: ignore_label, 2: ignore_label, 
                      3: ignore_label, 4: ignore_label, 
                      5: ignore_label, 6: ignore_label, 
                      7: 0, 8: 1, 9: ignore_label, 
                      10: ignore_label, 11: 2, 12: 3, 
                      13: 4, 14: ignore_label, 15: ignore_label, 
                      16: ignore_label, 17: 5, 18: ignore_label, 
                      19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11,
                      25: 12, 26: 13, 27: 14, 28: 15, 
                      29: ignore_label, 30: ignore_label, 
                      31: 16, 32: 17, 33: 18}

def run(self,img0):
confusion_matrix = np.zeros((config.DATASET.NUM_CLASSES, config.DATASET.NUM_CLASSES))
#要经过转换为tensor
image_nor = img0.astype(np.float32)[:, :, ::-1]
image_nor = image_nor / 255.0
image_nor -= mean
image_nor /= std
print(img0.shape)
ori_height, ori_width, _ = img0.shape

image = image_nor.copy()
stride_h = np.int(self.crop_size[0] * 1.0)
stride_w = np.int(self.crop_size[1] * 1.0)

final_pred = torch.zeros([1, self.num_classes,ori_height,ori_width])
new_img=cv2.resize(image, (self.crop_size[0],self.crop_size[1]),interpolation=cv2.INTER_LINEAR)
height, width = new_img.shape[:-1]
    
new_img = new_img.transpose((2, 0, 1))
new_img = np.expand_dims(new_img, axis=0)
new_img = torch.from_numpy(new_img)

preds = self.model(new_img)
new_size = new_img.size()
print("new size",new_size)
preds = F.interpolate(
    input=preds, size=new_size[-2:],
    mode='bilinear', align_corners=config.MODEL.ALIGN_CORNERS
)
preds=preds.exp()

preds = preds[:, :, 0:height, 0:width]

preds = F.interpolate(
    preds, (ori_height, ori_width), 
    mode='bilinear', align_corners=config.MODEL.ALIGN_CORNERS
)            
final_pred += preds
                           
visual=False
if visual:
    palette = self.get_palette(256)
    preds = np.asarray(np.argmax(preds.detach().cpu(), axis=1), dtype=np.uint8)
    for i in range(preds.shape[0]):
        pred = self.convert_label(preds[i], inverse=True)
        save_img = Image.fromarray(pred)
        save_img.putpalette(palette)
        save_img.save('test.png')

def get_palette(self, n):
palette = [0] * (n * 3)
for j in range(0, n):
lab = j
palette[j * 3 + 0] = 0
palette[j * 3 + 1] = 0
palette[j * 3 + 2] = 0
i = 0
while lab:
palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
i += 1
lab >>= 3
return palette

def convert_label(self, label, inverse=False):
temp = label.copy()
if inverse:
for v, k in self.label_mapping.items():
label[temp == k] = v
else:
for k, v in self.label_mapping.items():
label[temp == k] = v
return label
if name == "main":
face_segt=FaceSeg(weights="your_path/cityscapes/seg_hrnet_w48_train_512x1024_sgd_lr1e-2_wd5e-4_bs_12_epoch484/300_checkpoint.pth.tar")
img=cv2.imread("yours.png")
face_segt.run(img)

@alexanderuo
Copy link

import argparse

from lib.config import config from lib.config import update_config_demo import lib.models.seg_hrnet as seg_models

import torch import torch.nn as nn import torch.backends.cudnn as cudnn import cv2 from PIL import Image import numpy as np from torch.nn import functional as F

mean=[0.485, 0.456, 0.406] std=[0.229, 0.224, 0.225] @torch.no_grad()

class FaceSeg(): def init(self,cfg_file='./experiments/cityscapes/seg_hrnet_w48_train_512x1024_sgd_lr1e-2_wd5e-4_bs_12_epoch484.yaml',weights='best.pt',device = 'cpu',imgsz=[700,700], num_classes=4):#include background

cudnn related setting

update_config_demo(config)

cudnn.benchmark = config.CUDNN.BENCHMARK
cudnn.deterministic = config.CUDNN.DETERMINISTIC
cudnn.enabled = config.CUDNN.ENABLED

# build model
if torch.__version__.startswith('1'):
    module = seg_models
    module.BatchNorm2d_class = module.BatchNorm2d = torch.nn.BatchNorm2d
model = module.get_seg_model(config)

dump_input = torch.rand(
    (1, 3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0])
)

if config.TEST.MODEL_FILE:
    model_state_file = config.TEST.MODEL_FILE
else:
    print("cant find model_file: ",config.TEST.MODEL_FILE)
    exit()
    
pretrained_dict = torch.load(model_state_file)
if 'state_dict' in pretrained_dict:
    pretrained_dict = pretrained_dict['state_dict']
model_dict = model.state_dict()
pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items()
                    if k[6:] in model_dict.keys()}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
if device != 'cpu':
    gpus = list(config.GPUS)
    model = nn.DataParallel(model, device_ids=gpus).cuda()
else:
    print("use cpu seg")
    
model.eval()
self.model=model
self.crop_size=imgsz
self.num_classes=num_classes
self.label_mapping={-1: ignore_label, 0: ignore_label, 
                      1: ignore_label, 2: ignore_label, 
                      3: ignore_label, 4: ignore_label, 
                      5: ignore_label, 6: ignore_label, 
                      7: 0, 8: 1, 9: ignore_label, 
                      10: ignore_label, 11: 2, 12: 3, 
                      13: 4, 14: ignore_label, 15: ignore_label, 
                      16: ignore_label, 17: 5, 18: ignore_label, 
                      19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11,
                      25: 12, 26: 13, 27: 14, 28: 15, 
                      29: ignore_label, 30: ignore_label, 
                      31: 16, 32: 17, 33: 18}

def run(self,img0): confusion_matrix = np.zeros((config.DATASET.NUM_CLASSES, config.DATASET.NUM_CLASSES)) #要经过转换为tensor image_nor = img0.astype(np.float32)[:, :, ::-1] image_nor = image_nor / 255.0 image_nor -= mean image_nor /= std print(img0.shape) ori_height, ori_width, _ = img0.shape

image = image_nor.copy()
stride_h = np.int(self.crop_size[0] * 1.0)
stride_w = np.int(self.crop_size[1] * 1.0)

final_pred = torch.zeros([1, self.num_classes,ori_height,ori_width])
new_img=cv2.resize(image, (self.crop_size[0],self.crop_size[1]),interpolation=cv2.INTER_LINEAR)
height, width = new_img.shape[:-1]
    
new_img = new_img.transpose((2, 0, 1))
new_img = np.expand_dims(new_img, axis=0)
new_img = torch.from_numpy(new_img)

preds = self.model(new_img)
new_size = new_img.size()
print("new size",new_size)
preds = F.interpolate(
    input=preds, size=new_size[-2:],
    mode='bilinear', align_corners=config.MODEL.ALIGN_CORNERS
)
preds=preds.exp()

preds = preds[:, :, 0:height, 0:width]

preds = F.interpolate(
    preds, (ori_height, ori_width), 
    mode='bilinear', align_corners=config.MODEL.ALIGN_CORNERS
)            
final_pred += preds
                           
visual=False
if visual:
    palette = self.get_palette(256)
    preds = np.asarray(np.argmax(preds.detach().cpu(), axis=1), dtype=np.uint8)
    for i in range(preds.shape[0]):
        pred = self.convert_label(preds[i], inverse=True)
        save_img = Image.fromarray(pred)
        save_img.putpalette(palette)
        save_img.save('test.png')

def get_palette(self, n): palette = [0] * (n * 3) for j in range(0, n): lab = j palette[j * 3 + 0] = 0 palette[j * 3 + 1] = 0 palette[j * 3 + 2] = 0 i = 0 while lab: palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i)) palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i)) palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i)) i += 1 lab >>= 3 return palette

def convert_label(self, label, inverse=False): temp = label.copy() if inverse: for v, k in self.label_mapping.items(): label[temp == k] = v else: for k, v in self.label_mapping.items(): label[temp == k] = v return label if name == "main": face_segt=FaceSeg(weights="your_path/cityscapes/seg_hrnet_w48_train_512x1024_sgd_lr1e-2_wd5e-4_bs_12_epoch484/300_checkpoint.pth.tar") img=cv2.imread("yours.png") face_segt.run(img)

My friend, thank you for the code you provided, but there seems to be some problems with the format, could you please provide the correct format of the code

@dreamlychina
Copy link

你对齐下就可以了,我粘贴过来就变这样了

@alexanderuo
Copy link

你对齐下就可以了,我粘贴过来就变这样了

好的,谢谢兄弟,被你发现我是个中国人了哈哈

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants