diff --git a/rope/Coordinator.py b/rope/Coordinator.py new file mode 100644 index 0000000..3bdd8fe --- /dev/null +++ b/rope/Coordinator.py @@ -0,0 +1,248 @@ +# #!/usr/bin/env python3 +import os +import time + +import rope.GUI as GUI +import rope.VideoManager as VM + +from insightface.app import FaceAnalysis +import onnxruntime +import onnx + +import torch +from rope.external.clipseg import CLIPDensePredT + +import segmentation_models_pytorch as smp +from collections import OrderedDict +from torchvision import transforms + +from rope.external.model import BiSeNet + + + +def coordinator(): + global gui, vm, action, frame, r_frame + start = time.time() + + if gui.get_action_length() > 0: + action.append(gui.get_action()) + if vm.get_action_length() > 0: + action.append(vm.get_action()) +################## + if vm.get_frame_length() > 0: + frame.append(vm.get_frame()) + + if len(frame) > 0: + gui.set_image(frame[0], False) + gui.display_image_in_video_frame() + frame.pop(0) + #################### + if vm.get_requested_frame_length() > 0: + r_frame.append(vm.get_requested_frame()) + if len(r_frame) > 0: + # print ("1:", time.time()) + gui.set_image(r_frame[0], True) + gui.display_image_in_video_frame() + r_frame=[] + #################### + if len(action) > 0: + if action[0][0] == "load_target_video": + vm.load_target_video(action[0][1]) + #gui.set_slider_position(0) + action.pop(0) + elif action[0][0] == "play_video": + vm.play_video(action[0][1]) + action.pop(0) + elif action[0][0] == "set_video_position": + vm.get_requested_video_frame(action[0][1]) + action.pop(0) + elif action[0][0] == "find_faces": + gui.find_faces(action[0][1]) + action.pop(0) + elif action[0][0] == "clear_faces": + gui.clear_faces() + action.pop(0) + elif action[0][0] == "swap": + if not vm.swapper_model: + gui.set_status("loading Swapper") + swapper, emap = load_swapper_model() + vm.set_swapper_model(swapper, emap) + gui.set_status("Swapper loaded!") + vm.swap_set(action[0][1]) + action.pop(0) + elif action[0][0] == "source_embeddings": + vm.load_source_embeddings(action[0][1]) + action.pop(0) + elif action[0][0] == "target_faces": + vm.found_faces_assignments = action[0][1] + action.pop(0) + + + + + elif action [0][0] == "num_threads": + vm.num_threads = action[0][1] + action.pop(0) + + + + elif action [0][0] == "pos_thresh": + vm.pos_thresh = action[0][1] + action.pop(0) + elif action [0][0] == "neg_thresh": + vm.neg_thresh = action[0][1] + action.pop(0) + elif action [0][0] == "saved_video_path": + vm.saved_video_path = action[0][1] + action.pop(0) + + + + elif action [0][0] == "vid_qual": + vm.vid_qual = int(action[0][1]) + action.pop(0) + + elif action [0][0] == "parameters": + if action[0][1]["GFPGANState"]: + if not vm.GFPGAN_model: + gui.set_status("loading GFPGAN...") + vm.GFPGAN_model = load_GFPGAN_model() + gui.set_status("GFPGAN loaded!") + if action[0][1]["CLIPState"]: + if not vm.clip_session: + gui.set_status("loading CLIP..") + vm.clip_session, vm.cuda_device = load_clip_model() + gui.set_status("CLIP loaded!") + if action[0][1]["OccluderState"]: + if not vm.occluder_model: + gui.set_status("loading Occluder.") + vm.occluder_model, vm.occluder_tensor = load_occluder_model() + gui.set_status("Occluder loaded!") + if action[0][1]["FaceParserState"]: + if not vm.face_parsing_model: + gui.set_status("loading FaceParser") + vm.face_parsing_model, vm.face_parsing_tensor = load_face_parser_model() + gui.set_status("FaceParser loaded!") + + + vm.parameters = action[0][1] + action.pop(0) + + elif action [0][0] == "load_models": + gui.set_status("loading Faceapp...") + faceapp = load_faceapp_model() + gui.set_faceapp_model(faceapp) + vm.set_faceapp_model(faceapp) + gui.set_status("loading Target Videos...") + gui.populate_target_videos() + gui.set_status("loading Source Faces...") + gui.load_source_faces() + gui.set_status("Done...") + action.pop(0) + + + # From VM + elif action[0][0] == "stop_play": + gui.toggle_play_video() + action.pop(0) + + elif action[0][0] == "set_slider_length": + gui.set_video_slider_length(action[0][1]) + action.pop(0) + + elif action[0][0] == "send_msg": + gui.set_status(action[0][1]) + action.pop(0) + + else: + print("Action not found: "+action[0][0]+" "+str(action[0][1])) + action.pop(0) + + # start = time.time() + + + gui.check_for_video_resize() + vm.process() + gui.after(1, coordinator) + # print(time.time() - start) + +def load_faceapp_model(): + app = FaceAnalysis(name='buffalo_l') + app.prepare(ctx_id=0, det_size=(640, 640)) + return app + +def load_swapper_model(): + # Load Swapper model and get graph param + model = onnx.load("./models/inswapper_128.fp16.onnx") + graph = model.graph + + emap = onnx.numpy_helper.to_array(graph.initializer[-1]) + + # Create Swapper model session + opts = onnxruntime.SessionOptions() + # opts.enable_profiling = True + opts.enable_cpu_mem_arena = False + return onnxruntime.InferenceSession( "./models/inswapper_128.fp16.onnx", opts, providers=["CUDAExecutionProvider"]), emap + +def load_clip_model(): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + clip_session = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, complex_trans_conv=True) + clip_session.eval(); + clip_session.load_state_dict(torch.load('./models/rd64-uni-refined.pth', map_location=torch.device('cuda')), strict=False) + clip_session.to(device) + return clip_session, device + +def load_GFPGAN_model(): + + opts = onnxruntime.SessionOptions() + # opts.enable_profiling = True + opts.enable_cpu_mem_arena = False + GFPGAN_session = onnxruntime.InferenceSession( "./models/GFPGANv1.4.onnx", providers=["CUDAExecutionProvider"]) + return GFPGAN_session + +def load_occluder_model(): + to_tensor = transforms.ToTensor() + model = smp.Unet(encoder_name='resnet18', encoder_weights='imagenet', classes=1, activation=None) + + weights = torch.load('./models/occluder.ckpt') + new_weights = OrderedDict() + for key in weights.keys(): + new_key = '.'.join(key.split('.')[1:]) + new_weights[new_key] = weights[key] + + model.load_state_dict(new_weights) + model.to('cuda') + model.eval() + return model, to_tensor + +def load_face_parser_model(): + n_classes = 19 + model = BiSeNet(n_classes=n_classes) + model.cuda() + model.load_state_dict(torch.load("./models/79999_iter.pth")) + model.eval() + + to_tensor = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + ]) + + return model, to_tensor + +def run(): + global gui, vm, action, frame, r_frame + gui = GUI.GUI() + vm = VM.VideoManager() + + action = [] + frame = [] + r_frame = [] + + + gui.initialize_gui() + coordinator() + + gui.mainloop() + + diff --git a/rope/GUI.py b/rope/GUI.py new file mode 100644 index 0000000..9894b0f --- /dev/null +++ b/rope/GUI.py @@ -0,0 +1,1489 @@ +import os +import cv2 +import tkinter as tk +from tkinter import filedialog +import numpy as np +from PIL import Image, ImageTk +import json + + + +class GUI(tk.Tk): + def __init__( self ): + super().__init__() + # Adding a title to the self + # self.call('tk', 'scaling', 0.5) + self.title("Test Application") + self.pixel = [] + self.target_face = { + "TKButton": [], + "ButtonState": "off", + "Image": [], + "Embedding": [], + "SourceFaceAssignments": [], + "EmbeddingNumber": 0 + } + self.target_faces = [] + + self.source_face = { + "TKButton": [], + "ButtonState": "off", + "Image": [], + "Embedding": [] + } + self.source_faces = [] + + self.parameters = { + "GFPGANState": False, + "GFPGANAmount": 100, + "DiffState": False, + "DiffAmount": 4, + "ThreshholdState": False, + "Threshhold": 0.85, + "MaskTop": 20, + "MaskSide": 30, + "MaskBlur": 15, + "OccluderState": False, + "CLIPState": False, + "CLIPText": tk.StringVar(value=""), + "CLIPAmount": 0.5, + "FaceParserState": False, + "BlurAmount": 5 + } + + + + + self.num_threads = 1 + self.video_quality = 18 + self.target_videos = [] + self.target_video_file = [] + self.action_q = [] + self.video_image = [] + self.x1 = [] + self.y1 = [] + self.found_faces_assignments = [] + self.play_video = False + self.rec_video = False + self.swap = False + self.faceapp_model = [] + # self.GFPGAN_int = tk.IntVar() + # self.fake_diff_int = tk.IntVar() + # self.CLIP_int = tk.IntVar() + self.video_loaded = False + # self.occluder_int = tk.IntVar() + self.dock = True + self.undock = [] + + self.save_file = [] + self.json_dict = {"source videos":None, "source faces":None, "saved videos":None, "threads":1} + + self.new_int = tk.IntVar() + + self.button1 = "gray25" + self.button_1_text = "light goldenrod" + self.button1_active = "gray50" + + self.button_highlight_style = { + 'bg': 'light goldenrod', + 'fg': 'gray20', + 'activebackground': 'gray75', + 'activeforeground': 'light goldenrod', + 'relief': 'flat', + 'border': '0', + 'font': ("Arial", 9) + } + self.inactive_button_style = { + 'bg': 'gray20', + 'fg': 'white', + 'activebackground': 'gray10', + 'activeforeground': 'white', + 'relief': 'flat', + 'border': '0', + 'font': ("Arial", 9) + } + self.active_button_style = { + 'bg': 'black', + 'fg': 'white', + 'activebackground': 'gray10', + 'activeforeground': 'white', + 'relief': 'flat', + 'border': '0', + 'font': ("Arial", 9) + } + self.need_button_style = { + 'bg': 'gray30', + 'fg': 'white', + 'relief': 'flat', + 'border': '0', + 'font': ("Arial italic", 9) + } + + self.canvas_label_style = { + 'bg': 'gray20', + 'relief': 'flat', + 'bd': '0', + 'highlightthickness': '0' + } + + self.canvas_style1 = { + 'bg': 'gray20', + 'relief': 'flat', + 'bd': '0', + 'highlightthickness': '0' + } + self.button_style2 = { + 'bg': 'light goldenrod', + } + self.button_style3 = { + 'bg': 'gray25', + 'relief': 'flat', + 'border': '1' + } + + self.spinbox_style = { + 'width': '5', + 'bg': 'gray40', + 'fg': 'white', + 'relief': 'flat', + 'width': '5' + } + + + self.frame_style = { + 'bg': 'gray20', + 'relief': 'flat', + 'bd': '0' + } + + self.checkbox_style = { + 'bg': 'gray40', + 'fg': 'white', + 'relief': 'flat', + 'bd': '0', + 'anchor': 'w', + 'selectcolor': 'gray40' + } + self.label_style = { + 'bg': 'gray20', + 'fg': 'white', + 'relief': 'flat', + 'bd': '0', + 'anchor': 'w' + } + self.slider_style = { + 'bg': 'gray20', + 'fg': 'white', + 'activebackground': 'black', + 'highlightthickness':'0', + 'relief': 'flat', + 'sliderrelief': 'flat', + 'border': '0', + 'width': '10', + 'troughcolor': 'gray40', + 'font': ("Arial", 9) + } + + # Video frame + self.video_frame = tk.Frame( self, self.frame_style) + self.video_frame.grid( row = 0, column = 0, sticky='NEWS', pady = 2 ) + + self.video_frame.grid_columnconfigure(0, minsize = 10) + self.video_frame.grid_columnconfigure(1, weight = 10) + self.video_frame.grid_rowconfigure(0, weight = 1) + self.video_frame.grid_rowconfigure(1, weight = 0) + + # Video [0,0] + self.video = tk.Label( self.video_frame, self.label_style, bg='black') + self.video.grid( row = 0, column = 0, columnspan = 3, sticky='NEWS', pady =0 ) + self.video.bind("", self.iterate_through_merged_embeddings) + + # Video button canvas + self.video_button_canvas = tk.Canvas( self.video_frame, self.canvas_style1, width = 112, height = 40) + self.video_button_canvas.grid( row = 1, column = 0, sticky='NEWS', pady = 0) + + # Dock + self.video_dock = tk.Button( self.video_button_canvas, self.inactive_button_style, text="^^", wraplength=1, command=lambda: self.toggle_dock()) + self.video_dock.place(x=8, y=2, width = 15, height = 36) + + # Video Play + img = Image.open('./rope/media/play.png') + resized_image= img.resize((30,30), Image.ANTIALIAS) + self.play_icon = ImageTk.PhotoImage(resized_image) + self.video_play = tk.Button( self.video_button_canvas, self.inactive_button_style, image=self.play_icon, command=lambda: self.toggle_play_video()) + self.video_play.place(x=31, y=2, width = 36, height = 36) + + # Video Record + img = Image.open('./rope/media/rec.png') + resized_image= img.resize((30,30), Image.ANTIALIAS) + self.rec_icon = ImageTk.PhotoImage(resized_image) + + self.video_record = tk.Button( self.video_button_canvas, self.inactive_button_style, image=self.rec_icon, command=lambda: self.toggle_rec_video()) + self.video_record.place(x=69, y=2, width = 36, height = 36) + + # Video Slider + self.video_slider = tk.Scale( self.video_frame, self.slider_style, orient='horizontal') + self.video_slider.bind("", lambda event:self.add_action_and_update_frame("set_video_position", self.video_slider.get(), False)) + self.video_slider.bind("", lambda event:self.add_action_and_update_frame("set_video_position", self.video_slider.get(), False)) + self.video_slider.bind("", lambda event:self.add_action_and_update_frame("set_video_position", self.video_slider.get(), False)) + self.video_slider.bind("", self.mouse_wheel) + self.video_slider.grid( row = 1, column = 1, sticky='NEWS', pady = 2 ) + + ######### Options + x_space = 40 + self.options_frame = tk.Frame( self, self.frame_style, height = 71) + self.options_frame.grid( row = 1, column = 0, sticky='NEWS', pady = 2 ) + + self.options_frame.grid_rowconfigure( 0, weight = 100 ) + self.options_frame.grid_columnconfigure( 0, weight = 1 ) + + # Left Canvas + self.options_frame_canvas1 = tk.Canvas( self.options_frame, self.canvas_style1, height = 71) + self.options_frame_canvas1.grid( row = 0, column = 0, sticky='NEWS', pady = 0 ) + + + # Label Frame 1 + self.label_frame1 = tk.LabelFrame( self.options_frame_canvas1, self.frame_style, height = 71, width = 800 ) + self.label_frame1.place(x=0, y=0) + + column1=8 + # GFPGAN + # GFPGAN-checkbox + img = Image.open('./rope/media/gfpgan_logo.png') + resized_image= img.resize((45,20), Image.ANTIALIAS) + self.GFPGAN_icon = ImageTk.PhotoImage(resized_image) + temp = ' ' + str(int(self.parameters["GFPGANAmount"])) + '%' + self.GFPGAN_button = tk.Button(self.label_frame1, self.inactive_button_style, compound='left', image=self.GFPGAN_icon, text=temp, anchor='w', command=lambda: self.toggle_GFPGAN()) + self.GFPGAN_button.place(x=column1, y=8, width = 125, height = 26) + self.GFPGAN_button.bind("", self.change_GFPGAN_amount) + + # Fake_diff + # Fake_diff-checkbox + img = Image.open('./rope/media/diff.png') + resized_image= img.resize((20,20), Image.ANTIALIAS) + self.diff_icon = ImageTk.PhotoImage(resized_image) + temp = ' Differ ' + str(int(self.parameters["DiffAmount"]*10)) + '%' + self.differ_button = tk.Button(self.label_frame1, self.inactive_button_style, compound='left', image=self.diff_icon, text=temp, anchor='w', command=lambda: self.toggle_differ()) + self.differ_button.place(x=column1, y=37, width = 125, height = 26) + self.differ_button.bind("", self.change_differ_amount) + + column2=column1+125+x_space + # Mask top + # Mask top-label + img = Image.open('./rope/media/maskup.png') + resized_image= img.resize((20,20), Image.ANTIALIAS) + self.masktop_icon = ImageTk.PhotoImage(resized_image) + temp = ' Top Mask ' + str(int(self.parameters["MaskTop"]*100.0/64.0)) + '%' + + self.top_blend_id = tk.Label(self.label_frame1, self.label_style, compound='left', image=self.masktop_icon, text=temp, anchor='w') + self.top_blend_id.place(x=column2, y=8, width = 125, height = 26) + self.top_blend_id.bind("", self.change_mask_top_amount) + + # # Mask sides + # # Mask sides-label + # img = Image.open('./rope/media/maskside.png') + # resized_image= img.resize((20,20), Image.ANTIALIAS) + # self.maskside_icon = ImageTk.PhotoImage(resized_image) + # temp = ' Side Mask ' + str(int(self.parameters["MaskSide"]*100.0/64.0)) + '%' + # self.side_blend_id = tk.Label(self.label_frame1, self.label_style, compound='left', image=self.maskside_icon, text=temp, anchor='w') + # self.side_blend_id.place(x=column2, y=37, width = 125, height = 26) + # self.side_blend_id.bind("", self.change_mask_side_amount) + + # # Mask blur + # # Mask blur-label + img = Image.open('./rope/media/maskblur.png') + resized_image= img.resize((20,20), Image.ANTIALIAS) + self.maskblur_icon = ImageTk.PhotoImage(resized_image) + temp = ' Mask Blur ' + str(int(self.parameters["MaskBlur"]*100.0/64.0)) + '%' + self.mask_blur_id = tk.Label(self.label_frame1, self.label_style, compound='left', image=self.maskblur_icon, text=temp, anchor='w') + self.mask_blur_id.place(x=column2, y=37, width = 125, height = 26) + self.mask_blur_id.bind("", self.change_mask_blur_amount) + + column3=column2+125+x_space + # CLIP + # CLIP-checkbox + img = Image.open('./rope/media/CLIP.png') + resized_image= img.resize((20,20), Image.ANTIALIAS) + self.CLIP_icon = ImageTk.PhotoImage(resized_image) + temp = ' CLIP ' + str(int(self.parameters["CLIPAmount"]*100)) + '%' + self.CLIP_button = tk.Button(self.label_frame1, self.inactive_button_style, compound='left', image=self.CLIP_icon, text=temp, anchor='w', command=lambda: self.toggle_CLIP()) + self.CLIP_button.place(x=column3, y=8, width=125, height=26) + self.CLIP_button.bind("", self.change_CLIP_amount) + + # CLIP-entry + self.CLIP_text = tk.Entry(self.label_frame1, relief='flat', bd=0, textvariable=self.parameters["CLIPText"]) + self.CLIP_text.place(x=column3, y=40, width = 125, height=20) + self.CLIP_text.bind("", lambda event: self.add_action_and_update_frame("parameters", self.parameters)) + + column4=column3+125+x_space + # # Occluder + # # Occluder-checkbox + img = Image.open('./rope/media/occluder.png') + resized_image= img.resize((20,20), Image.ANTIALIAS) + self.occluder_icon = ImageTk.PhotoImage(resized_image) + temp = ' Occluder' + self.occluder_button = tk.Button(self.label_frame1, self.inactive_button_style, compound='left', image=self.occluder_icon, text=temp, anchor='w', command=lambda: self.toggle_occluder()) + self.occluder_button.place(x=column4, y=8, width=125, height=26) + + # # Face Parser + # # Face Parser-checkbox + img = Image.open('./rope/media/parse.png') + resized_image= img.resize((20,20), Image.ANTIALIAS) + self.parser_icon = ImageTk.PhotoImage(resized_image) + temp = ' Mouth Parser' + self.parser_button = tk.Button(self.label_frame1, self.inactive_button_style, compound='left', image=self.parser_icon, text=temp, anchor='w', command=lambda: self.toggle_parser()) + self.parser_button.place(x=column4, y=37, width=125, height=26) + + column5=column4+125+x_space + # # Blur + # # Blur-label + img = Image.open('./rope/media/blur.png') + resized_image= img.resize((20,20), Image.ANTIALIAS) + self.blur_icon = ImageTk.PhotoImage(resized_image) + temp = ' Blur ' + str(int(self.parameters["BlurAmount"]*100.0/64.0)) + '%' + self.blur_id = tk.Label(self.label_frame1, self.label_style, compound='left', image=self.blur_icon, text=temp, anchor='w') + self.blur_id.place(x=column5, y=8, width = 125, height = 26) + self.blur_id.bind("", self.change_blur_amount) + + # # Face Threshhold + # # Face Threshhold-label + img = Image.open('./rope/media/thresh.png') + resized_image= img.resize((20,20), Image.ANTIALIAS) + self.threshhold_icon = ImageTk.PhotoImage(resized_image) + temp = ' Threshhold ' + str(int(self.parameters["Threshhold"]*100)) + '%' + self.threshhold_button = tk.Button(self.label_frame1, self.inactive_button_style, compound='left', image=self.threshhold_icon, text=temp, anchor='w', command=lambda: self.toggle_threshhold()) + self.threshhold_button.place(x=column5, y=37, width=125, height=26) + self.threshhold_button.bind("", self.change_threshhold_amount) + + ######## Target Faces + # Found Faces frame [1,0] + self.found_faces_frame = tk.Frame( self, self.frame_style) + self.found_faces_frame.grid( row = 2, column = 0, sticky='NEWS', pady = 2 ) + + self.found_faces_frame.grid_columnconfigure( 0, minsize = 10 ) + self.found_faces_frame.grid_columnconfigure( 1, weight = 1 ) + self.found_faces_frame.grid_rowconfigure( 0, weight = 0 ) + + # Button Canvas [0,0] + self.found_faces_buttons_canvas = tk.Canvas( self.found_faces_frame, self.canvas_style1, height = 100, width = 112) + self.found_faces_buttons_canvas.grid( row = 0, column = 0, ) + + # Faces Load + img = Image.open('./rope/media/tarface.png') + resized_image= img.resize((20,20), Image.ANTIALIAS) + self.target_faces_load_icon = ImageTk.PhotoImage(resized_image) + self.found_faces_load_button = tk.Button(self.found_faces_buttons_canvas, self.inactive_button_style, image=self.target_faces_load_icon, compound='left', anchor='w', text=" Find", command=lambda: self.add_action_and_update_frame("find_faces", "current", False)) + self.found_faces_load_button.place(x=8, y=8, width = 96, height = 26) + + # Faces Clear + img = Image.open('./rope/media/tarfacedel.png') + resized_image= img.resize((20,20), Image.ANTIALIAS) + self.target_faces_del_icon = ImageTk.PhotoImage(resized_image) + self.found_faces_clear_button = tk.Button(self.found_faces_buttons_canvas, self.inactive_button_style, image=self.target_faces_del_icon, compound='left', anchor='w', text=" Clear", command=lambda: self.add_action_and_update_frame("clear_faces", "current", False)) + self.found_faces_clear_button.place(x=8, y=37, width = 96, height = 26) + + # Video Swap + img = Image.open('./rope/media/swap.png') + resized_image= img.resize((20,20), Image.ANTIALIAS) + self.swap_icon = ImageTk.PhotoImage(resized_image) + self.video_swap = tk.Button( self.found_faces_buttons_canvas, self.inactive_button_style, image=self.swap_icon, compound='left', anchor='w', text=" Swap", command=lambda: self.toggle_swapper()) + self.video_swap.place(x=8, y=66, width = 96, height = 26) + + # Faces Canvas [0,1] + self.found_faces_canvas = tk.Canvas( self.found_faces_frame, self.canvas_style1, height = 100 ) + self.found_faces_canvas.grid( row = 0, column = 1, sticky='NEWS') + self.found_faces_canvas.bind("", self.target_faces_mouse_wheel) + self.found_faces_canvas.create_text(8, 45, anchor='w', fill='grey25', font=("Arial italic", 50), text=" Target Faces") + + # # Label + # self.target_faces_id = tk.Canvas(self.found_faces_buttons_canvas, self.canvas_label_style) + # self.target_faces_id.place(x=8, y=8, width = 20, height = 84) + # self.target_faces_id.create_text(8, 45, justify='center', fill='white', font=("Arial italic", 9), text="Target Faces", angle=90 ) + + + ######## Source Faces + # Source Faces frame [2,0] + self.source_faces_frame = tk.Frame( self, self.frame_style) + self.source_faces_frame.grid( row = 3, column = 0, sticky='NEWS', pady = 2 ) + + self.source_faces_frame.grid_columnconfigure( 0, minsize = 10 ) + self.source_faces_frame.grid_columnconfigure( 1, weight = 1 ) + self.source_faces_frame.grid_rowconfigure( 0, weight = 0 ) + + # Button Canvas [0,0] + self.source_faces_buttons = [] + self.source_button_canvas = tk.Canvas( self.source_faces_frame, self.canvas_style1, height = 100, width = 112) + self.source_button_canvas.grid( row = 0, column = 0, sticky='NEWS') + + # Load Source Faces + img = Image.open('./rope/media/save.png') + resized_image= img.resize((20,20), Image.ANTIALIAS) + self.save_icon = ImageTk.PhotoImage(resized_image) + self.faces_filepath_button = tk.Button(self.source_button_canvas, self.need_button_style, image=self.save_icon, compound='left', anchor='w', text="Source Faces", wraplength=120, command=lambda: self.select_faces_path()) + self.faces_filepath_button.place(x=8, y=8, width = 96, height = 26) + + # Merged Embeddings Text + self.merged_embedding_name = tk.StringVar() + self.merged_embeddings_text = tk.Entry(self.source_button_canvas, relief='flat', bd=0, textvariable=self.merged_embedding_name) + self.merged_embeddings_text.place(x=8, y=37, width = 96, height=20) + self.merged_embeddings_text.bind("", lambda event: self.save_selected_source_faces(self.merged_embedding_name)) + + # Embedding remove + img = Image.open('./rope/media/delemb.png') + resized_image= img.resize((20,20), Image.ANTIALIAS) + self.delemb_icon = ImageTk.PhotoImage(resized_image) + self.merged_embedding_remove_button = tk.Button(self.source_button_canvas, self.inactive_button_style, image=self.delemb_icon, compound='left', anchor='w', text=" Delete", command=lambda: self.delete_merged_embedding()) + self.merged_embedding_remove_button.place(x=8, y=66, width = 96, height = 26) + + # Faces Canvas [0,1] + self.source_faces_canvas = tk.Canvas( self.source_faces_frame, self.canvas_style1, height = 100) + self.source_faces_canvas.grid( row = 0, column = 1, sticky='NEWS') + self.source_faces_canvas.bind("", self.source_faces_mouse_wheel) + self.source_faces_canvas.create_text(8, 45, anchor='w', fill='grey25', font=("Arial italic", 50), text=' Source Faces') + + +######### Target Videos + # Target Video frame [3,0] + self.target_videos_frame = tk.Frame( self, self.frame_style) + self.target_videos_frame.grid( row = 4, column = 0, sticky='NEWS', pady = 2 ) + + self.target_videos_frame.grid_columnconfigure( 0, minsize = 10 ) + self.target_videos_frame.grid_columnconfigure( 1, weight = 1 ) + self.target_videos_frame.grid_rowconfigure( 0, weight = 0 ) + + # Button Canvas [0,0] + self.target_videos_buttons = [] + self.target_button_canvas = tk.Canvas(self.target_videos_frame, self.canvas_style1, height = 100, width = 112) + self.target_button_canvas.grid( row = 0, column = 0, sticky='NEWS') + + # # Videos Load + # self.target_video_load_button = tk.Button(self.target_button_canvas, self.button_style1, text="Reload videos", command=lambda: self.populate_target_videos()) + # self.target_video_load_button.place(x=8, y=8, width = 84, height = 20) + + # Target Videos Filepath + self.video_filepath_button = tk.Button(self.target_button_canvas, self.need_button_style, image=self.save_icon, compound='left', anchor='w', text="Target Videos", wraplength=115, command=lambda: self.select_video_path()) + self.video_filepath_button.place(x=8, y=8, width = 96, height = 26) + + # Video Canvas [0,1] + self.target_video_canvas = tk.Canvas( self.target_videos_frame, self.canvas_style1, height = 100) + self.target_video_canvas.grid( row = 0, column = 1, sticky='NEWS') + self.target_video_canvas.bind("", self.target_videos_mouse_wheel) + self.target_video_canvas.create_text(8, 45, anchor='w', fill='grey25', font=("Arial italic", 50), text=' Target Videos') + + column = 8 + + ######### Options + x_space = 40 + self.program_options_frame = tk.Frame( self, self.frame_style, height = 42) + self.program_options_frame.grid( row = 5, column = 0, sticky='NEWS', pady = 2 ) + + self.program_options_frame.grid_rowconfigure( 0, weight = 100 ) + self.program_options_frame.grid_columnconfigure( 0, weight = 1 ) + + # Left Canvas + self.program_options_frame_canvas = tk.Canvas( self.program_options_frame, self.canvas_style1, height = 42) + self.program_options_frame_canvas.grid( row = 0, column = 0, sticky='NEWS', pady = 0 ) + + # Label Frame 1 + self.program_options_label = tk.LabelFrame( self.program_options_frame_canvas, self.frame_style, height = 42, width = 800 ) + self.program_options_label.place(x=0, y=0) + + # Load Folders + img = Image.open('./rope/media/save.png') + resized_image= img.resize((20,20), Image.ANTIALIAS) + self.load_folders_icon = ImageTk.PhotoImage(resized_image) + self.load_folders_button = tk.Button(self.program_options_label, self.need_button_style, compound='left', image=self.load_folders_icon, text=" Load Folders", anchor='w', command=lambda: self.load_all()) + self.load_folders_button.place(x=column, y=8, width = 125, height = 26) + + column=column+125+x_space + # Save Videos Filepath + self.save_video_filepath_button = tk.Button(self.program_options_label, self.need_button_style, image=self.save_icon, compound='left', anchor='w', text="Saved Videos", wraplength=115, command=lambda: self.select_save_video_path()) + self.save_video_filepath_button.place(x=column, y=8, width = 96, height = 26) + + column=column+125+x_space + # Threads + img = Image.open('./rope/media/threads.png') + resized_image= img.resize((20,20), Image.ANTIALIAS) + self.threads_icon = ImageTk.PhotoImage(resized_image) + temp = ' Threads ' + str(self.num_threads) + self.num_threads_id = tk.Label(self.program_options_label, self.label_style, compound='left', image=self.threads_icon, text=temp, anchor='w') + self.num_threads_id.place(x=column, y=8, width = 125, height = 26) + self.num_threads_id.bind("", self.change_threads_amount) + + column=column+125+x_space + # Video Quality + img = Image.open('./rope/media/maskside.png') + resized_image= img.resize((20,20), Image.ANTIALIAS) + self.video_quality_icon = ImageTk.PhotoImage(resized_image) + temp = ' Video Quality ' + str(self.video_quality) + self.vid_qual_button = tk.Label(self.program_options_label, self.label_style, compound='left', image=self.video_quality_icon, text=temp, anchor='w') + self.vid_qual_button.place(x=column, y=8, width = 125, height=26) + self.vid_qual_button.bind("", self.change_video_quality) + + + + # Status + self.status_frame = tk.Frame( self, bg='grey20', height = 15) + self.status_frame.grid( row = 6, column = 0, sticky='NEWS', pady = 2 ) + + self.status_label = tk.Label(self.status_frame, fg="white", bg='grey20') + self.status_label.pack() + # self.status_label_text = tk.Label(self.status_frame, anchor="w", bg='grey75', text="Threads:") + # self.status_label_text.place(x=100, y=8, width = 50, height=17) + + def target_faces_mouse_wheel(self, event): + self.found_faces_canvas.xview_scroll(1*int(event.delta/120.0), "units") + + + def source_faces_mouse_wheel(self, event): + self.source_faces_canvas.xview_scroll(1*int(event.delta/120.0), "units") + + + def target_videos_mouse_wheel(self, event): + self.target_video_canvas.xview_scroll(1*int(event.delta/120.0), "units") + + + def initialize_gui( self ): + + self.title("Rope - Crystal") + # self.overrideredirect(True) + self.configure(bg='grey10') + self.resizable(width=True, height=True) + + self.geometry('%dx%d+%d+%d' % (800, 1020, self.winfo_screenwidth()/2-400, self.winfo_screenheight()/2-510)) + + self.grid_columnconfigure(0, weight = 1) + + self.grid_rowconfigure(0, weight = 10) + self.grid_rowconfigure(1, weight = 0) + self.grid_rowconfigure(2, weight = 0) + self.grid_rowconfigure(3, weight = 0) + self.grid_rowconfigure(4, weight = 0) + self.grid_rowconfigure(5, weight = 0) + self.grid_rowconfigure(6, weight = 0) + + + self.add_action_and_update_frame("vid_qual",int(self.video_quality), False) + self.add_action_and_update_frame("num_threads",int(self.num_threads), False) + self.add_action_and_update_frame("parameters", self.parameters, False) + + + try: + self.save_file = open("data.json", "r") + except: + with open("data.json", "w") as outfile: + json.dump(self.json_dict, outfile) + else: + jason_object = [] + with open('data.json', 'r') as openfile: + json_object = json.load(openfile) + + self.json_dict["source videos"] = json_object["source videos"] + if self.json_dict["source videos"]: + temp = self.json_dict["source videos"] + temp_len = len(temp) + temp = '...'+temp[temp_len-10:] + + self.video_filepath_button.configure(self.inactive_button_style, text=temp) + + self.json_dict["source faces"] = json_object["source faces"] + if self.json_dict["source faces"]: + temp = self.json_dict["source faces"] + temp_len = len(temp) + temp = '...'+temp[temp_len-10:] + + self.faces_filepath_button.configure(self.inactive_button_style, text=temp) + + self.json_dict["saved videos"] = json_object["saved videos"] + if self.json_dict["saved videos"]: + temp = self.json_dict["saved videos"] + temp_len = len(temp) + temp = '...'+temp[temp_len-10:] + + self.save_video_filepath_button.configure(self.inactive_button_style, text=temp) + self.add_action_and_update_frame("saved_video_path",self.json_dict["saved videos"], False) + + self.json_dict["threads"] = json_object["threads"] + if self.json_dict["threads"]: + temp = self.json_dict["threads"] + self.num_threads = int(temp) + + temp = ' Threads ' + str(self.num_threads) + self.num_threads_id.config(text=temp) + + self.add_action_and_update_frame("num_threads",int(self.num_threads), False) + + + + def load_all(self): + if not self.json_dict["source videos"] or not self.json_dict["source faces"]: + print("Please set faces and videos folders first!") + return + + self.add_action_and_update_frame("load_models", True, False) + self.load_folders_button.configure(self.inactive_button_style, text=" Folders loaded!") + + + def select_video_path(self): + + temp = self.json_dict["source videos"] + + self.json_dict["source videos"] = filedialog.askdirectory(title="Select Target Videos Folder", initialdir=temp) + + temp = self.json_dict["source videos"] + temp_len = len(temp) + temp = '...'+temp[temp_len-10:] + + self.video_filepath_button.configure(self.inactive_button_style, text=temp) + + with open("data.json", "w") as outfile: + json.dump(self.json_dict, outfile) + + self.populate_target_videos() + + def select_save_video_path(self): + temp = self.json_dict["saved videos"] + + self.json_dict["saved videos"] = filedialog.askdirectory(title="Select Save Video Folder", initialdir=temp) + + temp = self.json_dict["saved videos"] + temp_len = len(temp) + temp = '...'+temp[temp_len-10:] + + self.save_video_filepath_button.configure(self.inactive_button_style, text=temp) + + self.add_action_and_update_frame("saved_video_path",self.json_dict["saved videos"], False) + + with open("data.json", "w") as outfile: + json.dump(self.json_dict, outfile) + + def select_faces_path(self): + temp = self.json_dict["source faces"] + + self.json_dict["source faces"] = filedialog.askdirectory(title="Select Source Faces Folder", initialdir=temp) + + temp = self.json_dict["source faces"] + temp_len = len(temp) + temp = '...'+temp[temp_len-10:] + + self.faces_filepath_button.configure(self.inactive_button_style, text=temp) + + with open("data.json", "w") as outfile: + json.dump(self.json_dict, outfile) + + self.load_source_faces() + + def load_source_faces(self): + if not self.faceapp_model: + print("Load model first") + else: + + + self.source_faces = [] + self.source_faces_canvas.delete("all") + + + + + # First load merged embeddings + if os.path.exists("merged_embeddings.txt"): + + temp0 = [] + with open("merged_embeddings.txt", "r") as embedfile: + temp = embedfile.read().splitlines() + + for i in range(0, len(temp), 513): + to = [temp[i][6:], np.array(temp[i+1:i+513], dtype='float32')] + temp0.append(to) + + self.pixel = tk.PhotoImage(height=0, width=0) + + for j in range(len(temp0)): + + new_source_face = self.source_face.copy() + self.source_faces.append(new_source_face) + + self.source_faces[j]["ButtonState"] = False + self.source_faces[j]["Embedding"] = temp0[j][1] + self.source_faces[j]["TKButton"] = tk.Button(self.source_faces_canvas, self.inactive_button_style, image=self.pixel, text=temp0[j][0], height=14, width=84, compound='left') + + + + self.source_faces[j]["TKButton"].bind("", lambda event, arg=j: self.toggle_source_faces_buttons_state(event, arg)) + self.source_faces[j]["TKButton"].bind("", lambda event, arg=j: self.toggle_source_faces_buttons_state_shift(event, arg)) + self.source_faces[j]["TKButton"].bind("", self.source_faces_mouse_wheel) + + self.source_faces_canvas.create_window((j//4)*92,8+(22*(j%4)), window = self.source_faces[j]["TKButton"],anchor='nw') + # print((j//4)*92,8+(22*(j%4))) + + directory = self.json_dict["source faces"] + + if directory == None: + print("No directory assigned") + else: + + filenames = os.listdir(directory) + + faces = [] + + # Find all faces and ad to faces[] + for name in filenames: #should check if is an image + temp_file = os.path.join(directory, name) + temp_file = cv2.imread(temp_file) + ret = self.faceapp_model.get(temp_file, max_num=1) + if ret: + bbox = ret[0].bbox + y_diff = bbox[3] - bbox[1] + x_diff = bbox[2] - bbox[0] + + crop = temp_file[int(bbox[1]):int(bbox[3]),int(bbox[0]):int(bbox[2])]#y,x + if y_diff > x_diff: + padding = int((y_diff - x_diff) / 2) + crop = cv2.copyMakeBorder( crop, 0, 0, padding, padding, cv2.BORDER_CONSTANT) + else: + padding = int((x_diff - y_diff) / 2) + crop = cv2.copyMakeBorder( crop, padding, padding, 0, 0, cv2.BORDER_CONSTANT ) + + crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB) + crop = cv2.resize( crop, (82, 82)) + temp = [crop, ret[0].embedding] + faces.append(temp) + + shift_i_len = len(self.source_faces) + + # Add faces[] images to buttons + for i in range(len(faces)): + new_source_face = self.source_face.copy() + self.source_faces.append(new_source_face) + + shift_i = i+ shift_i_len + + self.source_faces[shift_i]["Image"] = ImageTk.PhotoImage(image=Image.fromarray(faces[i][0])) + self.source_faces[shift_i]["Embedding"] = faces[i][1] + self.source_faces[shift_i]["TKButton"] = tk.Button(self.source_faces_canvas, self.inactive_button_style, image= self.source_faces[shift_i]["Image"], height = 86, width = 86) + self.source_faces[shift_i]["ButtonState"] = False + + self.source_faces[shift_i]["TKButton"].bind("", lambda event, arg=shift_i: self.toggle_source_faces_buttons_state(event, arg)) + self.source_faces[shift_i]["TKButton"].bind("", lambda event, arg=shift_i: self.toggle_source_faces_buttons_state_shift(event, arg)) + self.source_faces[shift_i]["TKButton"].bind("", self.source_faces_mouse_wheel) + + self.source_faces_canvas.create_window(((shift_i_len//4)+i+1)*92,8, window = self.source_faces[shift_i]["TKButton"],anchor='nw') + + + self.source_faces_canvas.configure(scrollregion = self.source_faces_canvas.bbox("all")) + self.source_faces_canvas.xview_moveto(0) + + # send over source faces embeddings + self.add_action_and_update_frame("source_embeddings", self.source_faces, False) + + def find_faces(self, scope): + try: + ret = self.faceapp_model.get(self.video_image, max_num=10) + except Exception: + print(" No video selected") + else: + # Find all faces and add to faces[] + if ret: + # Loop thgouh all faces in video frame + for i in range(len(ret)): + # Create a frame for each face + bbox = ret[i].bbox + + if bbox[0] < 0: + bbox[0] = 0 + if bbox[1] < 0: + bbox[1] = 0 + if bbox[2]>self.video_image.shape[1]: + bbox[2] = self.video_image.shape[1] + if bbox[3]>self.video_image.shape[0]: + bbox[3] = self.video_image.shape[0] + + + y_diff = bbox[3] - bbox[1] + x_diff = bbox[2] - bbox[0] + + crop = self.video_image[int(bbox[1]):int(bbox[3]),int(bbox[0]):int(bbox[2])]#y,x + + if y_diff > x_diff: + padding = int((y_diff - x_diff) / 2) + crop = cv2.copyMakeBorder( crop, 0, 0, padding, padding, cv2.BORDER_CONSTANT) + else: + padding = int((x_diff - y_diff) / 2) + crop = cv2.copyMakeBorder( crop, padding, padding, 0, 0, cv2.BORDER_CONSTANT ) + crop = cv2.resize( crop, (82, 82)) + + + + found = False + # Test for existing simularities + for j in range(len(self.target_faces)): + sim = self.findCosineDistance(ret[i].embedding, self.target_faces[j]["Embedding"]) + + if sim", self.target_faces_mouse_wheel) + self.target_faces[last_index]["ButtonState"] = False + self.target_faces[last_index]["Image"] = ImageTk.PhotoImage(image=Image.fromarray(crop)) + self.target_faces[last_index]["Embedding"] = ret[i].embedding + self.target_faces[last_index]["EmbeddingNumber"] = 1 + + # Add image to button + self.target_faces[-1]["TKButton"].config( pady = 10, image = self.target_faces[last_index]["Image"], command=lambda k=last_index: self.toggle_found_faces_buttons_state(k)) + + # Add button to canvas + self.found_faces_canvas.create_window((last_index)*92, 8, window=self.target_faces[last_index]["TKButton"], anchor='nw') + + self.found_faces_canvas.configure(scrollregion = self.found_faces_canvas.bbox("all")) + + + def clear_faces(self): + self.target_faces = [] + + + + self.found_faces_canvas.delete("all") + + + # toggle the target faces button and make assignments + def toggle_found_faces_buttons_state(self, button): + # Turn all Target faces off + for i in range(len(self.target_faces)): + self.target_faces[i]["ButtonState"] = False + self.target_faces[i]["TKButton"].config(self.inactive_button_style) + + # Set only the selected target face to on + self.target_faces[button]["ButtonState"] = True + self.target_faces[button]["TKButton"].config(self.button_highlight_style) + + # set all source face buttons to off + for i in range(len(self.source_faces)): + self.source_faces[i]["ButtonState"] = False + self.source_faces[i]["TKButton"].config(self.inactive_button_style) + + # turn back on the ones that are assigned to the curent target face + for i in range(len(self.target_faces[button]["SourceFaceAssignments"])): + self.source_faces[self.target_faces[button]["SourceFaceAssignments"][i]]["ButtonState"] = True + self.source_faces[self.target_faces[button]["SourceFaceAssignments"][i]]["TKButton"].config(self.button_highlight_style) + + + + + def toggle_source_faces_buttons_state(self, event, button): + + # Set all other Source Face buttons to False + for i in range(len(self.source_faces)): + self.source_faces[i]["TKButton"].config(self.inactive_button_style) + if i != button: + self.source_faces[i]["ButtonState"] = False + + # Toggle the selected Source Face + self.source_faces[button]["ButtonState"] = not self.source_faces[button]["ButtonState"] + + # Determine which target face is selected + if self.target_faces: + for i in range(len(self.target_faces)): + if self.target_faces[i]["ButtonState"]: + + # Clear the assignments + self.target_faces[i]["SourceFaceAssignments"] = [] + + # Append new assignment if new state is True + if self.source_faces[button]["ButtonState"]: + self.target_faces[i]["SourceFaceAssignments"].append(button) + self.source_faces[button]["TKButton"].config(self.button_highlight_style) + + break + + + self.add_action_and_update_frame("target_faces", self.target_faces) + + def toggle_source_faces_buttons_state_shift(self, event, button): + + # Toggle the selected Source Face + self.source_faces[button]["ButtonState"] = not self.source_faces[button]["ButtonState"] + + if self.source_faces[button]["ButtonState"]: + self.source_faces[button]["TKButton"].config(self.button_highlight_style) + else: + self.source_faces[button]["TKButton"].config(self.inactive_button_style) + + # If a target face is selected + for i in range(len(self.target_faces)): + if self.target_faces[i]["ButtonState"]: + + # Clear all of the assignments + self.target_faces[i]["SourceFaceAssignments"] = [] + + # Iterate through all Source faces + for j in range(len(self.source_faces)): + + # If the source face is active + if self.source_faces[j]["ButtonState"]: + self.target_faces[i]["SourceFaceAssignments"].append(j) + + break + + self.add_action_and_update_frame("target_faces", self.target_faces) + + def populate_target_videos(self): + + self.target_videos_buttons = [] + self.target_videos = [] + self.target_video_canvas.delete("all") + + directory = self.json_dict["source videos"] + + filenames = os.listdir(directory) + + videos = [] + self.target_videos = [] + self.target_videos_buttons = [] + self.target_video_canvas.delete("all") + + for name in filenames: #should check if is an image + video_file = os.path.join(directory, name) + vidcap = cv2.VideoCapture(video_file) + vidcap.set(cv2.CAP_PROP_POS_FRAMES, int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)/2)) + success, image = vidcap.read() + if success: + crop = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + crop = cv2.resize( crop, (82, 82)) + temp = [crop, video_file] + videos.append(temp) + + for i in range(len(videos)): + self.target_videos_buttons.append(tk.Button(self.target_video_canvas, self.inactive_button_style, height = 86, width = 86)) + + for i in range(len(videos)): + rgb_video = Image.fromarray(videos[i][0]) + self.target_videos.append(ImageTk.PhotoImage(image=rgb_video)) + self.target_videos_buttons[i].config( image = self.target_videos[i], command=lambda i=i: self.load_target_video(i, videos[i][1])) + self.target_videos_buttons[i].bind("", self.target_videos_mouse_wheel) + self.target_video_canvas.create_window(i*92, 8, window = self.target_videos_buttons[i], anchor='nw') + + self.target_video_canvas.configure(scrollregion = self.target_video_canvas.bbox("all")) + + def load_target_video(self, button, video_file): + self.video_loaded = True + self.add_action_and_update_frame("load_target_video", video_file, False) + for i in range(len(self.target_videos_buttons)): + self.target_videos_buttons[i].config(self.inactive_button_style) + self.target_videos_buttons[button].config(self.button_highlight_style) + + if self.swap == True: + self.toggle_swapper() + if self.play_video == True: + self.toggle_play_video() + + self.clear_faces() + + + def set_image(self, image, requested): + self.video_image = image[0] + if not requested: + self.set_slider_position(image[1]) + # @profile + def display_image_in_video_frame(self): + + image = self.video_image + + x1 = float(self.x1) + y1 = float(self.y1 ) + x2 = float(image.shape[1]) + y2 = float(image.shape[0]) + + m1 = x1/y1 + m2 = x2/y2 + + if m2>m1: + x2 = x1 + y2 = x1/m2 + image = cv2.resize(image, (int(x2), int(y2))) + padding = int((y1-y2)/2.0) + image = cv2.copyMakeBorder( image, padding, padding, 0, 0, cv2.BORDER_CONSTANT) + else: + y2=y1 + x2=y2*m2 + image = cv2.resize(image, (int(x2), int(y2))) + padding=int((x1-x2)/2.0) + image = cv2.copyMakeBorder( image, 0, 0, padding, padding, cv2.BORDER_CONSTANT) + + image = Image.fromarray(image) + image = ImageTk.PhotoImage(image) + self.video.configure(image=image) + self.video.image = image + + def check_for_video_resize(self): + if self.x1 != self.video.winfo_width() or self.y1 != self.video.winfo_height(): + self.x1 = self.video.winfo_width() + self.y1 = self.video.winfo_height() + if np.any(self.video_image): + self.display_image_in_video_frame() + + def get_action(self): + action = self.action_q[0] + self.action_q.pop(0) + return action + + def get_action_length(self): + return len(self.action_q) + + + + def set_video_slider_length(self, video_length): + self.video_slider.configure(to=video_length) + + def set_slider_position(self, position): + self.video_slider.set(position) + + def findCosineDistance(self, vector1, vector2): + + vec1 = vector1.flatten() + vec2 = vector2.flatten() + + a = np.dot(vec1.T, vec2) + b = np.dot(vec1.T, vec1) + c = np.dot(vec2.T, vec2) + return 1 - (a/(np.sqrt(b)*np.sqrt(c))) + + def CosineSimilarity(self, test_vec, source_vecs): + + cos_dist = 0 + for source_vec in source_vecs: + cos_dist += self.findCosineDistance(test_vec, source_vec) + return cos_dist/len(source_vecs) + + + + def toggle_play_video(self): + if not self.video_loaded: + print("Please select video first!") + return + self.play_video = not self.play_video + + if self.play_video: + if self.rec_video: + if not self.json_dict["saved videos"]: + print("Set saved video folder first!") + self.play_video = False + self.add_action_and_update_frame("play_video", "stop", False) + self.video_play.config(self.inactive_button_style) + else: + self.add_action_and_update_frame("play_video", "record", False) + self.video_play.config(self.active_button_style) + else: + self.add_action_and_update_frame("play_video", "play", False) + self.video_play.config(self.active_button_style) + + else: + self.add_action_and_update_frame("play_video", "stop", False) + self.video_play.config(self.inactive_button_style) + if self.rec_video: + self.toggle_rec_video() + + + + def toggle_swapper(self): + self.swap = not self.swap + + if not self.swap: + self.video_swap.config(self.inactive_button_style) + else: + self.video_swap.config(self.active_button_style) + + if not self.play_video: + self.add_action_and_update_frame("swap", self.swap) + else: + self.add_action_and_update_frame("swap", self.swap, False) + + def toggle_rec_video(self): + if not self.play_video: + self.rec_video = not self.rec_video + + if self.rec_video == False: + self.video_record.config(self.inactive_button_style) + else: + self.video_record.config(self.active_button_style, bg='red') + + + + + + + def set_faceapp_model(self, faceapp): + self.faceapp_model = faceapp + + + + + def add_action_and_update_frame(self, action, parameter, update_frame=True): + + # Get values for self.parameters + if action == "parameters": + parameter = { + "GFPGANState": parameter["GFPGANState"], + "GFPGANAmount": parameter["GFPGANAmount"], + "DiffState": parameter["DiffState"], + "DiffAmount": parameter["DiffAmount"], + "Threshhold": parameter["Threshhold"], + "ThreshholdState": parameter["ThreshholdState"], + "MaskTop": parameter["MaskTop"], + "MaskSide": parameter["MaskSide"], + "MaskBlur": parameter["MaskBlur"], + "OccluderState": parameter["OccluderState"], + "CLIPState": parameter["CLIPState"], + "CLIPText": parameter["CLIPText"].get(), + "CLIPAmount": parameter["CLIPAmount"], + "FaceParserState": parameter["FaceParserState"], + "BlurAmount": parameter["BlurAmount"] + } + + # Send over action/parmeters tuple + temp = [action, parameter] + self.action_q.append(temp) + + # If the video is not playing and update_frame is true + if not self.play_video and update_frame: + temp = ["set_video_position", self.video_slider.get()] + self.action_q.append(temp) + + + + def toggle_dock(self): + self.dock = False + if not self.dock: + # self.video_frame.winfo_width() + self.grid_rowconfigure(0, weight = 0) + # self.geometry('%dx%d+%d+%d' % (800, 800, self.winfo_screenwidth()/2-400, self.winfo_screenheight()/2-400)) + self.geometry('%dx%d' % (self.winfo_width(), 458)) + self.resizable(width=True, height=False) + + + self.undock = self.wm_manage(self.video_frame) + + self.video_frame.config(width=1024, height=768) + self.video_frame.grid_propagate(0) + def set_status(self, msg): + self.status_label.configure(text=str(msg)) + self.status_label.pack() + + def mouse_wheel(self, event): + if event.delta > 0: + self.video_slider.set(self.video_slider.get()+1) + self.add_action_and_update_frame("set_video_position", self.video_slider.get(), False) + else: + self.video_slider.set(self.video_slider.get()-1) + self.add_action_and_update_frame("set_video_position", self.video_slider.get(), False) + + def save_selected_source_faces(self, text): + + temp = 0 + temp_len = 1 + temp_data = False + if text != "": + for i in range(len(self.source_faces)): + if self.source_faces[i]["ButtonState"]: + temp_data = True + if temp == []: + temp = self.source_faces[i]["Embedding"] + else: + temp += self.source_faces[i]["Embedding"] + temp_len += 1 + + temp /= temp_len + + if temp_data: + with open("merged_embeddings.txt", "a") as embedfile: + identifier = "Name: "+text.get() + embedfile.write("%s\n" % identifier) + for number in temp: + embedfile.write("%s\n" % number) + + self.load_source_faces() + + + def delete_merged_embedding(self): #add multi select + + # get selected button + sel = [] + for j in range(len(self.source_faces)): + if self.source_faces[j]["ButtonState"]: + sel = j + break + + # check if it is a merged embedding + # if so, read txt embedding into list + temp0 = [] + if os.path.exists("merged_embeddings.txt"): + + with open("merged_embeddings.txt", "r") as embedfile: + temp = embedfile.read().splitlines() + + for i in range(0, len(temp), 513): + to = [temp[i], np.array(temp[i+1:i+513], dtype='float32')] + temp0.append(to) + + if j < len(temp0): + temp0.pop(j) + + with open("merged_embeddings.txt", "w") as embedfile: + for line in temp0: + embedfile.write("%s\n" % line[0]) + for i in range(512): + embedfile.write("%s\n" % line[1][i]) + + self.load_source_faces() + + def iterate_through_merged_embeddings(self, event): + if event.delta>0: + for i in range(len(self.source_faces)): + if self.source_faces[i]["ButtonState"] and i0: + self.toggle_source_faces_buttons_state(None, i-1) + break + + def toggle_GFPGAN(self): + self.parameters["GFPGANState"] = not self.parameters["GFPGANState"] + + if self.parameters["GFPGANState"]: + self.GFPGAN_button.config(self.active_button_style) + else: + self.GFPGAN_button.config(self.inactive_button_style) + + self.add_action_and_update_frame("parameters", self.parameters) + + def change_GFPGAN_amount(self, event): + self.parameters["GFPGANAmount"] += (5*int(event.delta/120.0)) + if self.parameters["GFPGANAmount"] > 100: + self.parameters["GFPGANAmount"] = 100 + if self.parameters["GFPGANAmount"] < 0 : + self.parameters["GFPGANAmount"] = 0 + + if self.parameters["GFPGANAmount"] >= 100: + temp = ' ' + str(int(self.parameters["GFPGANAmount"])) + '%' + else: + temp = ' ' + str(int(self.parameters["GFPGANAmount"])) + '%' + + + self.GFPGAN_button.config(text=temp) + + self.add_action_and_update_frame("parameters", self.parameters) + + def toggle_differ(self): + self.parameters["DiffState"] = not self.parameters["DiffState"] + + if self.parameters["DiffState"]: + self.differ_button.config(self.active_button_style) + else: + self.differ_button.config(self.inactive_button_style) + + self.add_action_and_update_frame("parameters", self.parameters) + + def change_differ_amount(self, event): + self.parameters["DiffAmount"] += (0.5*int(event.delta/120.0)) + if self.parameters["DiffAmount"] > 10: + self.parameters["DiffAmount"] = 10 + if self.parameters["DiffAmount"] < 0 : + self.parameters["DiffAmount"] = 0 + + if self.parameters["DiffAmount"] >= 10: + temp = ' Differ ' + str(int(self.parameters["DiffAmount"]*10)) + '%' + else: + temp = ' Differ ' + str(int(self.parameters["DiffAmount"]*10)) + '%' + + self.differ_button.config(text=temp) + + self.add_action_and_update_frame("parameters", self.parameters) + + def toggle_threshhold(self): + self.parameters["ThreshholdState"] = not self.parameters["ThreshholdState"] + + if self.parameters["ThreshholdState"]: + self.threshhold_button.config(self.active_button_style) + else: + self.threshhold_button.config(self.inactive_button_style) + + self.add_action_and_update_frame("parameters", self.parameters) + + def change_threshhold_amount(self, event): + self.parameters["Threshhold"] += (0.01*int(event.delta/120.0)) + if self.parameters["Threshhold"] > 1: + self.parameters["Threshhold"] = 1 + if self.parameters["Threshhold"] < 0 : + self.parameters["Threshhold"] = 0 + + if self.parameters["Threshhold"] >= 1: + temp = ' Threshhold' + str(int(self.parameters["Threshhold"]*100)) + '%' + else: + temp = ' Threshhold ' + str(int(self.parameters["Threshhold"]*100)) + '%' + + self.threshhold_button.config(text=temp) + + self.add_action_and_update_frame("parameters", self.parameters) + + def change_mask_top_amount(self, event): + + self.parameters["MaskTop"] += (1*int(event.delta/120.0)) + if self.parameters["MaskTop"] > 64: + self.parameters["MaskTop"] = 64 + if self.parameters["MaskTop"] < 0 : + self.parameters["MaskTop"] = 0 + + if self.parameters["MaskTop"] >= 64: + temp = ' Top Mask ' + str(int(self.parameters["MaskTop"]*100.0/64.0)) + '%' + else: + temp = ' Top Mask ' + str(int(self.parameters["MaskTop"]*100.0/64.0)) + '%' + + self.top_blend_id.config(text=temp) + + self.add_action_and_update_frame("parameters", self.parameters) + + def change_mask_side_amount(self, event): + + self.parameters["MaskSide"] += (1*int(event.delta/120.0)) + if self.parameters["MaskSide"] > 64: + self.parameters["MaskSide"] = 64 + if self.parameters["MaskSide"] < 0 : + self.parameters["MaskSide"] = 0 + + if self.parameters["MaskSide"] >= 64: + temp = ' Side Mask ' + str(int(self.parameters["MaskSide"]*100.0/64.0)) + '%' + else: + temp = ' Side Mask ' + str(int(self.parameters["MaskSide"]*100.0/64.0)) + '%' + + self.side_blend_id.config(text=temp) + + self.add_action_and_update_frame("parameters", self.parameters) + + def change_mask_blur_amount(self, event): + + self.parameters["MaskBlur"] += (1*int(event.delta/120.0)) + if self.parameters["MaskBlur"] > 30: + self.parameters["MaskBlur"] = 30 + if self.parameters["MaskBlur"] < 0 : + self.parameters["MaskBlur"] = 0 + + temp_num = str(int(self.parameters["MaskBlur"]*100.0/30.0)) + temp_num_len = 4-len(temp_num) + + temp = ' Mask Blur ' + ' '*temp_num_len + temp_num + '%' + + self.mask_blur_id.config(text=temp) + + self.add_action_and_update_frame("parameters", self.parameters) + + def toggle_CLIP(self): + self.parameters["CLIPState"] = not self.parameters["CLIPState"] + + if self.parameters["CLIPState"]: + self.CLIP_button.config(self.active_button_style) + else: + self.CLIP_button.config(self.inactive_button_style) + + self.add_action_and_update_frame("parameters", self.parameters) + + def change_CLIP_amount(self, event): + self.parameters["CLIPAmount"] += (0.01*int(event.delta/120.0)) + if self.parameters["CLIPAmount"] > 1: + self.parameters["CLIPAmount"] = 1 + if self.parameters["CLIPAmount"] < 0 : + self.parameters["CLIPAmount"] = 0 + + if self.parameters["CLIPAmount"] >= 1: + temp = ' CLIP ' + str(int(self.parameters["CLIPAmount"]*100)) + '%' + else: + temp = ' CLIP ' + str(int(self.parameters["CLIPAmount"]*100)) + '%' + + self.CLIP_button.config(text=temp) + + self.add_action_and_update_frame("parameters", self.parameters) + + def toggle_occluder(self): + self.parameters["OccluderState"] = not self.parameters["OccluderState"] + + if self.parameters["OccluderState"]: + self.occluder_button.config(self.active_button_style) + else: + self.occluder_button.config(self.inactive_button_style) + + self.add_action_and_update_frame("parameters", self.parameters) + + def toggle_parser(self): + self.parameters["FaceParserState"] = not self.parameters["FaceParserState"] + + if self.parameters["FaceParserState"]: + self.parser_button.config(self.active_button_style) + else: + self.parser_button.config(self.inactive_button_style) + + self.add_action_and_update_frame("parameters", self.parameters) + + def change_blur_amount(self, event): + + self.parameters["BlurAmount"] += (1*int(event.delta/120.0)) + if self.parameters["BlurAmount"] > 64: + self.parameters["BlurAmount"] = 64 + if self.parameters["BlurAmount"] < 0 : + self.parameters["BlurAmount"] = 0 + + if self.parameters["BlurAmount"] >= 10: + temp = ' Blur ' + str(int(self.parameters["BlurAmount"]*100.0/64.0)) + '%' + else: + temp = ' Blur ' + str(int(self.parameters["BlurAmount"]*100.0/64.0)) + '%' + + self.blur_id.config(text=temp) + + self.add_action_and_update_frame("parameters", self.parameters) + + def change_video_quality(self, event): + self.video_quality += (1*int(event.delta/120.0)) + + if self.video_quality > 50: + self.video_quality = 50 + if self.video_quality < 0 : + self.video_quality = 0 + + temp = ' Video Quality ' + str(self.video_quality) + + self.vid_qual_button.config(text=temp) + + self.add_action_and_update_frame("vid_qual",int(self.video_quality), False) + + def change_threads_amount(self, event): + self.num_threads += (1*int(event.delta/120.0)) + + if self.num_threads > 10: + self.num_threads = 10 + if self.num_threads < 1: + self.num_threads = 1 + + temp = ' Threads ' + str(self.num_threads) + + self.num_threads_id.config(text=temp) + + self.add_action_and_update_frame("num_threads",int(self.num_threads), False) + + self.json_dict["threads"] = self.num_threads + with open("data.json", "w") as outfile: + json.dump(self.json_dict, outfile) + + + # https://discord.gg/EcdVAFJzqp diff --git a/rope/VideoManager.py b/rope/VideoManager.py new file mode 100644 index 0000000..f81fe11 --- /dev/null +++ b/rope/VideoManager.py @@ -0,0 +1,766 @@ +import os +import cv2 +import tkinter as tk +from PIL import Image, ImageTk +import threading +import time +import numpy as np +from numpy.linalg import norm as l2norm +from skimage import transform as trans +from insightface.utils.face_align import norm_crop2 +import subprocess +from math import floor, ceil + +import torch +import requests +from PIL import Image +from torchvision import transforms +import json +import math + + + +# from itertools import combinations + +lock=threading.Lock() + +class VideoManager(): + def __init__( self ): + # Model related + self.swapper_model = [] # insightface swapper model + self.faceapp_model = [] # insight faceapp model + self.input_names = [] # names of the inswapper.onnx inputs + self.input_size = [] # size of the inswapper.onnx inputs + self.emap = [] # comes from loading the inswapper model. not sure of data + self.output_names = [] # names of the inswapper.onnx outputs + self.arcface_dst = np.array( [[38.2946, 51.6963], [73.5318, 51.5014], [56.0252, 71.7366], [41.5493, 92.3655], [70.7299, 92.2041]], dtype=np.float32) + self.GFPGAN_model = [] + self.occluder_model = [] + self.occluder_tensor = [] + self.face_parsing_model = [] + self.face_parsing_tensor = [] + + #Video related + self.capture = [] # cv2 video + self.is_video_loaded = False # flag for video loaded state + self.video_frame_total = None # length of currently loaded video + self.play = False # flag for the play button toggle + self.current_frame = 0 # the current frame of the video + self.create_video = False + self.output_video = [] + self.file_name = [] + self.vid_qual = [] + + # Play related + self.set_read_threads = [] # Name of threaded function + self.frame_timer = time.time() # used to set the framerate during playing + self.play_frame_tracker = -1 # tracks the next frame during playing in case the threads return out of order + + # Queues + self.action_q = [] # queue for sending to the coordinator + self.frame_q = [] # queue for frames that are ready for coordinator + self.frame_q2 = [] # queue for frames created by thread and ready to be added to frame_q + self.r_frame_q = [] # queue for frames that are requested by the GUI + self.read_video_frame_q = [] + + # swapping related + self.source_embedding = [] # array with indexed source embeddings + self.swap = False # flag for the swap enabled toggle + self.found_faces_assignments = [] # array that maps the found faces to source faces + + self.parameters = [] + + self.num_threads = 0 + self.target_video = [] + + self.fps = 1.0 + self.temp_file = [] + + self.i_image = [] + self.io_binding = False + self.video_read_success = False + self.clip_session = [] + self.cuda_device = [] + + self.start_time = [] + self.record = False + self.output = [] + + self.saved_video_path = [] + self.sp = [] + self.timer = [] + + self.clip_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), transforms.Resize((256, 256))]) + + self.arcface_dst_max = [] + self.arcface_dst_max.append( math.sqrt(( self.arcface_dst[0][0]- self.arcface_dst[1][0])*( self.arcface_dst[0][0]- self.arcface_dst[1][0]) + ( self.arcface_dst[0][1]- self.arcface_dst[1][1])*( self.arcface_dst[0][1]- self.arcface_dst[1][1])) ) + self.arcface_dst_max.append( math.sqrt(( self.arcface_dst[1][0]- self.arcface_dst[4][0])*( self.arcface_dst[1][0]- self.arcface_dst[4][0]) + ( self.arcface_dst[1][1]- self.arcface_dst[4][1])*( self.arcface_dst[1][1]- self.arcface_dst[4][1])) ) + self.arcface_dst_max.append( math.sqrt(( self.arcface_dst[3][0]- self.arcface_dst[4][0])*( self.arcface_dst[3][0]- self.arcface_dst[4][0]) + ( self.arcface_dst[3][1]- self.arcface_dst[4][1])*( self.arcface_dst[3][1]- self.arcface_dst[4][1])) ) + self.arcface_dst_max.append( math.sqrt(( self.arcface_dst[0][0]- self.arcface_dst[3][0])*( self.arcface_dst[0][0]- self.arcface_dst[3][0]) + ( self.arcface_dst[0][1]- self.arcface_dst[3][1])*( self.arcface_dst[0][1]- self.arcface_dst[3][1])) ) + self.arcface_dst_max.append( math.sqrt(( self.arcface_dst[0][0]- self.arcface_dst[4][0])*( self.arcface_dst[0][0]- self.arcface_dst[4][0]) + ( self.arcface_dst[0][1]- self.arcface_dst[4][1])*( self.arcface_dst[0][1]- self.arcface_dst[4][1])) ) + self.arcface_dst_max.append( math.sqrt(( self.arcface_dst[1][0]- self.arcface_dst[3][0])*( self.arcface_dst[1][0]- self.arcface_dst[3][0]) + ( self.arcface_dst[1][1]- self.arcface_dst[3][1])*( self.arcface_dst[1][1]- self.arcface_dst[3][1])) ) + + def load_target_video( self, file ): + # If we already have a video loaded, release it + if self.capture: + self.capture.release() + + # Open file + self.capture = cv2.VideoCapture(file) + self.fps = self.capture.get(cv2.CAP_PROP_FPS) + # print(self.fps) + + + if not self.capture.isOpened(): + print("Cannot open file: ", file) + exit() + else: + self.target_video = file + self.is_video_loaded = True + self.video_frame_total = int(self.capture.get(cv2.CAP_PROP_FRAME_COUNT)) + self.play = False + self.current_frame = 0 + + self.set_read_threads = [] + self.frame_timer = time.time() + self.play_frame_tracker = 0 + + self.frame_q = [] + self.frame_q2 = [] + self.r_frame_q = [] + + self.swap = False + self.found_faces_assignments = [] + + self.add_action("set_slider_length",self.video_frame_total-1) + + self.capture.set(cv2.CAP_PROP_POS_FRAMES, self.current_frame) + + success, image = self.capture.read() + if success: + crop = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + temp = [crop, 0] + self.frame_q.append(temp) + self.capture.set(cv2.CAP_PROP_POS_FRAMES, self.current_frame) + + ## Action queue + def add_action(self, action, param): + temp = [action, param] + self.action_q.append(temp) + + def get_action_length(self): + return len(self.action_q) + + def get_action(self): + action = self.action_q[0] + self.action_q.pop(0) + return action + + ## Queues for the Coordinator + def get_frame(self): + frame = self.frame_q[0] + self.frame_q.pop(0) + return frame + + def get_frame_length(self): + return len(self.frame_q) + + def get_requested_frame(self): + frame = self.r_frame_q[0] + self.r_frame_q.pop(0) + return frame + + def get_requested_frame_length(self): + return len(self.r_frame_q) + + + def get_requested_video_frame(self, frame): + if self.is_video_loaded == True: + self.play_video(False) + self.current_frame = int(frame) + self.capture.set(cv2.CAP_PROP_POS_FRAMES, min(self.video_frame_total, self.current_frame)) + success, target_image = self.capture.read() + self.capture.set(cv2.CAP_PROP_POS_FRAMES, min(self.video_frame_total, self.current_frame)) + if success: + # target_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + if not self.swap: + temp = [target_image, self.current_frame] + else: + temp = [self.swap_video(target_image), self.current_frame] + temp[0] = cv2.cvtColor(temp[0], cv2.COLOR_BGR2RGB) + self.r_frame_q.append(temp) + + def play_video(self, command): #"record", "play", "stop" + if command == "play": + self.play = True + self.play_frame_tracker = self.current_frame + # self.capture.set(cv2.CAP_PROP_POS_FRAMES, min(self.video_frame_total, self.current_frame)) + + if command == "stop": + self.play = False + if command == "record": + self.record = True + self.play = True + # Initialize + self.timer = time.time() + frame_width = int(self.capture.get(3)) + frame_height = int(self.capture.get(4)) + frame_size = (frame_width,frame_height) + + + self.play_frame_tracker = self.current_frame + # self.start_time = self.capture.get(cv2.CAP_PROP_POS_MSEC)/1000.0 + self.start_time = float(self.capture.get(cv2.CAP_PROP_POS_FRAMES) / float(self.fps)) + + + self.file_name = os.path.splitext(os.path.basename(self.target_video)) + + base_filename = self.file_name[0]+"_"+str(time.time())[:10] + + self.output = os.path.join(self.saved_video_path, base_filename) + + self.temp_file = self.output+"_temp"+self.file_name[1] + + data = subprocess.run(['ffprobe', '-loglevel', 'error', '-show_streams', '-of', 'json', f'{self.target_video}'], capture_output=True).stdout + d = json.loads(data) + + # if d['streams'][0]['codec_name'] =='vp9': + # args = ["ffmpeg", + # "-an", + # "-r", str(self.fps), + # "-i", "pipe:", + # "-vf", "format="+d['streams'][0]['pix_fmt'], + # "-vcodec", d['streams'][0]['codec_name'], + # "-r", str(self.fps), + # "-s", str(frame_width)+"x"+str(frame_height), + # final_file] + + args = ["ffmpeg", + "-an", + "-r", str(self.fps), + "-i", "pipe:", + "-vf", "format=yuvj420p", + "-c:v", "libx264", + "-crf", str(self.vid_qual), + "-r", str(self.fps), + "-s", str(frame_width)+"x"+str(frame_height), + self.temp_file] + + + + self.sp = subprocess.Popen(args, stdin=subprocess.PIPE) + + + + + + def process(self): + + if len(self.set_read_threads) != self.num_threads: + self.set_read_threads = [[0] * 4 for i in range(self.num_threads)] + + # Add threads to Queue + if self.play == True and self.is_video_loaded == True: + for i in range(self.num_threads): + if self.set_read_threads[i][3] == 0: + self.set_read_threads [i] = [threading.Thread(target=self.thread_video_read, args = [self.current_frame]).start(), 0, self.current_frame, 1] + self.current_frame += 1 + break + + else: + self.play == False + + # Always be emptying the queues + time_diff = time.time() - self.frame_timer + + if not self.record and time_diff >= 1.0/float(self.fps): + for i in range(self.num_threads): + if self.set_read_threads[i][3] == 2 and self.set_read_threads[i][2] == self.play_frame_tracker: + + temp = [self.set_read_threads[i][1], self.set_read_threads[i][2]] + self.frame_q.append(temp) + fps = round(1.0/time_diff, 1) + msg = "Playing at %s fps" %fps + self.add_action("send_msg", msg) + self.play_frame_tracker += 1 + self.set_read_threads[i][3] = 0 + self.frame_timer = time.time() + break + + elif self.record: + empty_count = 0 + for i in range(self.num_threads): + # print(self.set_read_threads[i][3]) + if self.set_read_threads[i][3] == 2 and self.set_read_threads[i][2] == self.play_frame_tracker: + + temp = [self.set_read_threads[i][1], self.set_read_threads[i][2]] + self.frame_q.append(temp) + + image = self.set_read_threads[i][1] + pil_image = Image.fromarray(image) + pil_image.save(self.sp.stdin, 'JPEG') + framen = self.play_frame_tracker + msg = "Rendering frame %s/%s" %(framen, self.video_frame_total-1) + + self.play_frame_tracker += 1 + self.set_read_threads[i][3] = 0 + self.frame_timer = time.time() + break + elif self.set_read_threads[i][3] == 0: + empty_count = empty_count + 1 + + if empty_count == self.num_threads: + # Close video and process + + stop_time = float(self.capture.get(cv2.CAP_PROP_POS_FRAMES) / float(self.fps)) + if stop_time == 0: + stop_time = float(self.video_frame_total) / float(self.fps) + + self.sp.stdin.close() + self.sp.wait() + + orig_file = self.target_video + final_file = self.output+self.file_name[1] + self.add_action("send_msg", "adding audio...") + args = ["ffmpeg", + "-i", self.temp_file, + "-ss", str(self.start_time), "-to", str(stop_time), "-i", orig_file, + "-c", "copy", # may be c:v + "-map", "0:v:0", "-map", "1:a:0?", + "-shortest", + final_file] + + four = subprocess.run(args) + + os.remove(self.temp_file) + + self.record = False + timef= time.time() - self.timer + msg = "done...total rendering time: %s seconds" % round(timef,1) + self.add_action("send_msg", msg) + + def thread_video_read(self, frame_number): + # frame_timer = time.time() + + with lock: + success, target_image = self.capture.read() + + if success: + if not self.swap: + temp = [target_image, frame_number] + else: + temp = [self.swap_video(target_image), frame_number] + temp[0] = cv2.cvtColor(temp[0], cv2.COLOR_BGR2RGB) + for i in range(len(self.set_read_threads)): + if self.set_read_threads[i][2] == frame_number: + self.set_read_threads[i][1] = temp[0] + self.set_read_threads[i][3] = 2 + break + + else: + for i in range(len(self.set_read_threads)): + if self.set_read_threads[i][2] == frame_number: + self.set_read_threads[i][3] = 0 + break + + self.play = False + self.add_action("stop_play", True) + # time_diff = time.time() - frame_timer + # print( time_diff) + + def load_source_embeddings(self, source_embeddings): + self.source_embedding = [] + for i in range(len(source_embeddings)): + self.source_embedding.append(source_embeddings[i]["Embedding"]) + + def swap_set(self, swap): + self.swap = swap + # self.get_video_frame(self.current_frame) + + def set_swapper_model(self, swapper, emap): + self.swapper_model = swapper + self.emap = emap + + # Get in/out size and create some data + inputs = self.swapper_model.get_inputs() + for inp in inputs: + self.input_names.append(inp.name) + input_cfg = inputs[0] + input_shape = input_cfg.shape + self.input_size = tuple(input_shape[2:4][::-1]) + + outputs = self.swapper_model.get_outputs() + for out in outputs: + self.output_names.append(out.name) + + + + def set_faceapp_model(self, faceapp): + self.faceapp_model = faceapp + + def swap_video(self, target_image): + # Find faces, returns all faces + ret = self.faceapp_model.get(target_image, max_num=10) + if ret: + img = target_image + target_face = ret + + # Loop through target faces to see if they match our target embeddings + for i in range(len(target_face)): + for j in range(len(self.found_faces_assignments)): + # sim between face in video and already found face + sim = self.findCosineDistance(target_face[i].embedding, self.found_faces_assignments[j]["Embedding"]) + + # if the face[i] in the frame matches afound face[j] AND the found face is active (not []) + if self.parameters["ThreshholdState"]: + threshhold = 2.0 + else: + threshhold = self.parameters["Threshhold"] + + if simimg.shape[1]: + right=img.shape[1] + + bottom = ceil(bbox[3]) + if bottom>img.shape[0]: + bottom=img.shape[0] + + swapped_face_upscaled = swapped_face_upscaled[top:bottom, left:right, 0:3].astype(np.float32) + img_a = img[top:bottom, left:right, 0:3].astype(np.float32) + + img_mask = cv2.warpAffine(img_mask, IM512, (img.shape[1], img.shape[0]), borderValue=0.0) + img_mask = np.reshape(img_mask, [img_mask.shape[0],img_mask.shape[1],1]) + img_mask = img_mask[top:bottom, left:right, 0:1] + + img_mask = 1.0-img_mask + + swapped_face_upscaled += img_mask*img_a + + img[top:bottom, left:right, 0:3] = swapped_face_upscaled + + + return img.astype(np.uint8) #BGR + + + + + def apply_occlusion(self, img): + data = self.occluder_tensor(img).unsqueeze(0) + data = data.to('cuda') + with lock: + with torch.no_grad(): + pred = self.occluder_model(data) + occlude_mask = (pred > 0).type(torch.float32) + occlude_mask = occlude_mask.squeeze().cpu().numpy()*1.0 + + return occlude_mask + + + def apply_neg_CLIPs(self, img): + clip_mask = np.ones((256, 256)) + CLIPimg = self.clip_transform(img).unsqueeze(0) + + if self.parameters["CLIPText"] != "": + prompts = self.parameters["CLIPText"].split(',') + + with lock: + with torch.no_grad(): + preds = self.clip_session(CLIPimg.repeat(len(prompts),1,1,1), prompts)[0] + + clip_mask = 1 - torch.sigmoid(preds[0][0]) + for i in range(len(prompts)-1): + clip_mask *= 1-torch.sigmoid(preds[i+1][0]) + clip_mask = clip_mask.data.cpu().numpy() + + clip_mask[clip_mask>self.parameters["CLIPAmount"]] = 1.0 + clip_mask[clip_mask<=self.parameters["CLIPAmount"]] = 0.0 + + return clip_mask + + def apply_face_parser(self, img): + + # atts = [1 'skin', 2 'l_brow', 3 'r_brow', 4 'l_eye', 5 'r_eye', 6 'eye_g', 7 'l_ear', 8 'r_ear', 9 'ear_r', 10 'nose', 11 'mouth', 12 'u_lip', 13 'l_lip', 14 'neck', 15 'neck_l', 16 'cloth', 17 'hair', 18 'hat'] + + + with lock: + with torch.no_grad(): + img1 = self.face_parsing_tensor(img.astype(np.uint8)) + img1 = torch.unsqueeze(img1, 0) + img1 = img1.cuda() + out = self.face_parsing_model(img1)[0] + parsing = out.squeeze(0).cpu().numpy().argmax(0) + + vis_parsing_anno = parsing.copy().astype(np.uint8) + vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=1, fy=1, interpolation=cv2.INTER_NEAREST) + vis_parsing_anno_color = np.ones((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1])) + + index = np.where((vis_parsing_anno == 11) | (vis_parsing_anno == 12) | (vis_parsing_anno == 13)) + # index = np.where(vis_parsing_anno == 11) + vis_parsing_anno_color[index[0], index[1]] = 0.0 + # kernel = np.ones((2, 2)) + # vis_parsing_anno_color = cv2.erode(vis_parsing_anno_color, kernel, iterations=10) + return vis_parsing_anno_color + + def apply_face_parser_nose(self, img): + + # atts = [1 'skin', 2 'l_brow', 3 'r_brow', 4 'l_eye', 5 'r_eye', 6 'eye_g', 7 'l_ear', 8 'r_ear', 9 'ear_r', 10 'nose', 11 'mouth', 12 'u_lip', 13 'l_lip', 14 'neck', 15 'neck_l', 16 'cloth', 17 'hair', 18 'hat'] + + + with lock: + with torch.no_grad(): + img1 = self.face_parsing_tensor(img) + img1 = torch.unsqueeze(img1, 0) + img1 = img1.cuda() + out = self.face_parsing_model(img1)[0] + parsing = out.squeeze(0).cpu().numpy().argmax(0) + + vis_parsing_anno = parsing.copy().astype(np.uint8) + vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=1, fy=1, interpolation=cv2.INTER_NEAREST) + vis_parsing_anno_color = np.ones((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1])) + + index = np.where((vis_parsing_anno == 10) | (vis_parsing_anno == 17)) + vis_parsing_anno_color[index[0], index[1]] = 0.0 + + return vis_parsing_anno_color + + def apply_GFPGAN(self, swapped_face_upscaled): + + + temp = swapped_face_upscaled + + # preprocess + # temp = cv2.resize(temp, (512, 512)) + temp = temp / 255.0 + # temp = temp.astype('float32') + temp = cv2.cvtColor(temp, cv2.COLOR_BGR2RGB) + temp[:,:,0] = (temp[:,:,0]-0.5)/0.5 + temp[:,:,1] = (temp[:,:,1]-0.5)/0.5 + temp[:,:,2] = (temp[:,:,2]-0.5)/0.5 + temp = np.float32(temp[np.newaxis,:,:,:]) + temp = temp.transpose(0, 3, 1, 2) + + ort_inputs = {"input": temp} + if self.io_binding: + io_binding = self.GFPGAN_model.io_binding() + io_binding.bind_cpu_input("input", temp) + io_binding.bind_output("1288", "cuda") + + self.GFPGAN_model.run_with_iobinding(io_binding) + ort_outs = io_binding.copy_outputs_to_cpu() + else: + + ort_outs = self.GFPGAN_model.run(None, ort_inputs) + + output = ort_outs[0][0] + + # postprocess + output = output.clip(-1,1) + output = (output + 1) / 2 + output = output.transpose(1, 2, 0) + output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) + output = (output * 255.0).round() + + + + + + temp2 = float(self.parameters["GFPGANAmount"])/100.0 + swapped_face_upscaled = cv2.addWeighted(output, temp2, swapped_face_upscaled, 1.0-temp2,0) + + return swapped_face_upscaled + + def apply_fake_diff(self, swapped_face, original_face): + fake_diff = swapped_face.astype(np.float32) - original_face.astype(np.float32) + fake_diff = np.abs(fake_diff).mean(axis=2) + fake_diff[:2,:] = 0 + fake_diff[-2:,:] = 0 + fake_diff[:,:2] = 0 + fake_diff[:,-2:] = 0 + + fthresh = int(self.parameters["DiffAmount"]) + fake_diff[fake_diff=fthresh] = 255 + + return fake_diff + + + + + # @profile diff --git a/rope/__pycache__/Coordinator.cpython-310.pyc b/rope/__pycache__/Coordinator.cpython-310.pyc new file mode 100644 index 0000000..a37f0b5 Binary files /dev/null and b/rope/__pycache__/Coordinator.cpython-310.pyc differ diff --git a/rope/__pycache__/GUI.cpython-310.pyc b/rope/__pycache__/GUI.cpython-310.pyc new file mode 100644 index 0000000..8fa93e5 Binary files /dev/null and b/rope/__pycache__/GUI.cpython-310.pyc differ diff --git a/rope/__pycache__/VideoManager.cpython-310.pyc b/rope/__pycache__/VideoManager.cpython-310.pyc new file mode 100644 index 0000000..0dff44d Binary files /dev/null and b/rope/__pycache__/VideoManager.cpython-310.pyc differ diff --git a/rope/__pycache__/core_working.cpython-310.pyc b/rope/__pycache__/core_working.cpython-310.pyc new file mode 100644 index 0000000..68cc0e2 Binary files /dev/null and b/rope/__pycache__/core_working.cpython-310.pyc differ diff --git a/rope/external/__pycache__/clipseg.cpython-310.pyc b/rope/external/__pycache__/clipseg.cpython-310.pyc new file mode 100644 index 0000000..5a87675 Binary files /dev/null and b/rope/external/__pycache__/clipseg.cpython-310.pyc differ diff --git a/rope/external/__pycache__/model.cpython-310.pyc b/rope/external/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000..ece0232 Binary files /dev/null and b/rope/external/__pycache__/model.cpython-310.pyc differ diff --git a/rope/external/__pycache__/resnet.cpython-310.pyc b/rope/external/__pycache__/resnet.cpython-310.pyc new file mode 100644 index 0000000..1466d3b Binary files /dev/null and b/rope/external/__pycache__/resnet.cpython-310.pyc differ diff --git a/rope/external/cliplib/__init__.py b/rope/external/cliplib/__init__.py new file mode 100644 index 0000000..dcc5619 --- /dev/null +++ b/rope/external/cliplib/__init__.py @@ -0,0 +1 @@ +from .clip import * diff --git a/rope/external/cliplib/__pycache__/__init__.cpython-310.pyc b/rope/external/cliplib/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..ee34a7e Binary files /dev/null and b/rope/external/cliplib/__pycache__/__init__.cpython-310.pyc differ diff --git a/rope/external/cliplib/__pycache__/clip.cpython-310.pyc b/rope/external/cliplib/__pycache__/clip.cpython-310.pyc new file mode 100644 index 0000000..1ac97de Binary files /dev/null and b/rope/external/cliplib/__pycache__/clip.cpython-310.pyc differ diff --git a/rope/external/cliplib/__pycache__/model.cpython-310.pyc b/rope/external/cliplib/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000..da27328 Binary files /dev/null and b/rope/external/cliplib/__pycache__/model.cpython-310.pyc differ diff --git a/rope/external/cliplib/__pycache__/simple_tokenizer.cpython-310.pyc b/rope/external/cliplib/__pycache__/simple_tokenizer.cpython-310.pyc new file mode 100644 index 0000000..900707b Binary files /dev/null and b/rope/external/cliplib/__pycache__/simple_tokenizer.cpython-310.pyc differ diff --git a/rope/external/cliplib/bpe_simple_vocab_16e6.txt.gz b/rope/external/cliplib/bpe_simple_vocab_16e6.txt.gz new file mode 100644 index 0000000..7b5088a Binary files /dev/null and b/rope/external/cliplib/bpe_simple_vocab_16e6.txt.gz differ diff --git a/rope/external/cliplib/clip.py b/rope/external/cliplib/clip.py new file mode 100644 index 0000000..f7a5da5 --- /dev/null +++ b/rope/external/cliplib/clip.py @@ -0,0 +1,245 @@ +import hashlib +import os +import urllib +import warnings +from typing import Any, Union, List +from pkg_resources import packaging + +import torch +from PIL import Image +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize +from tqdm import tqdm + +from .model import build_model +from .simple_tokenizer import SimpleTokenizer as _Tokenizer + +try: + from torchvision.transforms import InterpolationMode + BICUBIC = InterpolationMode.BICUBIC +except ImportError: + BICUBIC = Image.BICUBIC + + +if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): + warnings.warn("PyTorch version 1.7.1 or higher is recommended") + + +__all__ = ["available_models", "load", "tokenize"] +_tokenizer = _Tokenizer() + +_MODELS = { + "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", + "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", + "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", + "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", + "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", + "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", + "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", + "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", + "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", +} + + +def _download(url: str, root: str): + os.makedirs(root, exist_ok=True) + filename = os.path.basename(url) + + expected_sha256 = url.split("/")[-2] + download_target = os.path.join(root, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: + return download_target + else: + warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: + raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match") + + return download_target + + +def _convert_image_to_rgb(image): + return image.convert("RGB") + + +def _transform(n_px): + return Compose([ + Resize(n_px, interpolation=BICUBIC), + CenterCrop(n_px), + _convert_image_to_rgb, + ToTensor(), + Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ]) + + +def available_models() -> List[str]: + """Returns the names of available CLIP models""" + return list(_MODELS.keys()) + + +def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): + """Load a CLIP model + + Parameters + ---------- + name : str + A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + + device : Union[str, torch.device] + The device to put the loaded model + + jit : bool + Whether to load the optimized JIT model or more hackable non-JIT model (default). + + download_root: str + path to download the model files; by default, it uses "~/.cache/clip" + + Returns + ------- + model : torch.nn.Module + The CLIP model + + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + if name in _MODELS: + model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError(f"Model {name} not found; available models = {available_models()}") + + with open(model_path, 'rb') as opened_file: + try: + # loading JIT archive + model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict + if jit: + warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") + jit = False + state_dict = torch.load(opened_file, map_location="cpu") + + if not jit: + model = build_model(state_dict or model.state_dict()).to(device) + if str(device) == "cpu": + model.float() + return model, _transform(model.visual.input_resolution) + + # patch the device names + device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) + device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] + + def _node_get(node: torch._C.Node, key: str): + """Gets attributes of a node which is polymorphic over return type. + + From https://github.com/pytorch/pytorch/pull/82628 + """ + sel = node.kindOf(key) + return getattr(node, sel)(key) + + def patch_device(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("prim::Constant"): + if "value" in node.attributeNames() and str(_node_get(node, "value")).startswith("cuda"): + node.copyAttributes(device_node) + + model.apply(patch_device) + patch_device(model.encode_image) + patch_device(model.encode_text) + + # patch dtype to float32 on CPU + if str(device) == "cpu": + float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) + float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] + float_node = float_input.node() + + def patch_float(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("aten::to"): + inputs = list(node.inputs()) + for i in [1, 2]: # dtype can be the second or third argument to aten::to() + if _node_get(inputs[i].node(), "value") == 5: + inputs[i].node().copyAttributes(float_node) + + model.apply(patch_float) + patch_float(model.encode_image) + patch_float(model.encode_text) + + model.float() + + return model, _transform(model.input_resolution.item()) + + +def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: + """ + Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + + context_length : int + The context length to use; all CLIP models use 77 as the context length + + truncate: bool + Whether to truncate the text in case its encoding is longer than the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. + We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder["<|startoftext|>"] + eot_token = _tokenizer.encoder["<|endoftext|>"] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] + if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + else: + result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + if truncate: + tokens = tokens[:context_length] + tokens[-1] = eot_token + else: + raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") + result[i, :len(tokens)] = torch.tensor(tokens) + + return result diff --git a/rope/external/cliplib/model.py b/rope/external/cliplib/model.py new file mode 100644 index 0000000..232b779 --- /dev/null +++ b/rope/external/cliplib/model.py @@ -0,0 +1,436 @@ +from collections import OrderedDict +from typing import Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.relu1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.relu2 = nn.ReLU(inplace=True) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu3 = nn.ReLU(inplace=True) + + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential(OrderedDict([ + ("-1", nn.AvgPool2d(stride)), + ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), + ("1", nn.BatchNorm2d(planes * self.expansion)) + ])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu1(self.bn1(self.conv1(x))) + out = self.relu2(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu3(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x[:1], key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + return x.squeeze(0) + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): + super().__init__() + self.output_dim = output_dim + self.input_resolution = input_resolution + + # the 3-layer stem + self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.relu1 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.relu2 = nn.ReLU(inplace=True) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.relu3 = nn.ReLU(inplace=True) + self.avgpool = nn.AvgPool2d(2) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + def stem(x): + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.avgpool(x) + return x + + x = x.type(self.conv1.weight.dtype) + x = stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)) + ])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + + +class VisionTransformer(nn.Module): + def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): + super().__init__() + self.input_resolution = input_resolution + self.output_dim = output_dim + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers, heads) + + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + def forward(self, x: torch.Tensor): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_post(x[:, 0, :]) + + if self.proj is not None: + x = x @ self.proj + + return x + + +class CLIP(nn.Module): + def __init__(self, + embed_dim: int, + # vision + image_resolution: int, + vision_layers: Union[Tuple[int, int, int, int], int], + vision_width: int, + vision_patch_size: int, + # text + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int + ): + super().__init__() + + self.context_length = context_length + + if isinstance(vision_layers, (tuple, list)): + vision_heads = vision_width * 32 // 64 + self.visual = ModifiedResNet( + layers=vision_layers, + output_dim=embed_dim, + heads=vision_heads, + input_resolution=image_resolution, + width=vision_width + ) + else: + vision_heads = vision_width // 64 + self.visual = VisionTransformer( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_width, + layers=vision_layers, + heads=vision_heads, + output_dim=embed_dim + ) + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask() + ) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + self.initialize_parameters() + + def initialize_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + if isinstance(self.visual, ModifiedResNet): + if self.visual.attnpool is not None: + std = self.visual.attnpool.c_proj.in_features ** -0.5 + nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) + + for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + @property + def dtype(self): + return self.visual.conv1.weight.dtype + + def encode_image(self, image): + return self.visual(image.type(self.dtype)) + + def encode_text(self, text): + x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.type(self.dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x).type(self.dtype) + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + + return x + + def forward(self, image, text): + image_features = self.encode_image(image) + text_features = self.encode_text(text) + + # normalized features + image_features = image_features / image_features.norm(dim=1, keepdim=True) + text_features = text_features / text_features.norm(dim=1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logits_per_image.t() + + # shape = [global_batch_size, global_batch_size] + return logits_per_image, logits_per_text + + +def convert_weights(model: nn.Module): + """Convert applicable model parameters to fp16""" + + def _convert_weights_to_fp16(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + + if isinstance(l, nn.MultiheadAttention): + for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.half() + + for name in ["text_projection", "proj"]: + if hasattr(l, name): + attr = getattr(l, name) + if attr is not None: + attr.data = attr.data.half() + + model.apply(_convert_weights_to_fp16) + + +def build_model(state_dict: dict): + vit = "visual.proj" in state_dict + + if vit: + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) + image_resolution = vision_patch_size * grid_size + else: + counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] + vision_layers = tuple(counts) + vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] + output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) + vision_patch_size = None + assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] + image_resolution = output_width * 32 + + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks"))) + + model = CLIP( + embed_dim, + image_resolution, vision_layers, vision_width, vision_patch_size, + context_length, vocab_size, transformer_width, transformer_heads, transformer_layers + ) + + for key in ["input_resolution", "context_length", "vocab_size"]: + if key in state_dict: + del state_dict[key] + + convert_weights(model) + model.load_state_dict(state_dict) + return model.eval() diff --git a/rope/external/cliplib/simple_tokenizer.py b/rope/external/cliplib/simple_tokenizer.py new file mode 100644 index 0000000..0a66286 --- /dev/null +++ b/rope/external/cliplib/simple_tokenizer.py @@ -0,0 +1,132 @@ +import gzip +import html +import os +from functools import lru_cache + +import ftfy +import regex as re + + +@lru_cache() +def default_bpe(): + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe()): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') + merges = merges[1:49152-256-2+1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v+'' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + vocab.extend(['<|startoftext|>', '<|endoftext|>']) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} + self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + ( token[-1] + '',) + pairs = get_pairs(word) + + if not pairs: + return token+'' + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') + return text diff --git a/rope/external/clipseg.py b/rope/external/clipseg.py new file mode 100644 index 0000000..ca47a65 --- /dev/null +++ b/rope/external/clipseg.py @@ -0,0 +1,538 @@ +import math +from os.path import basename, dirname, join, isfile +import torch +from torch import nn +from torch.nn import functional as nnf +from torch.nn.modules.activation import ReLU + + +def get_prompt_list(prompt): + if prompt == 'plain': + return ['{}'] + elif prompt == 'fixed': + return ['a photo of a {}.'] + elif prompt == 'shuffle': + return ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.'] + elif prompt == 'shuffle+': + return ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.', + 'a cropped photo of a {}.', 'a good photo of a {}.', 'a photo of one {}.', + 'a bad photo of a {}.', 'a photo of the {}.'] + else: + raise ValueError('Invalid value for prompt') + + +def forward_multihead_attention(x, b, with_aff=False, attn_mask=None): + """ + Simplified version of multihead attention (taken from torch source code but without tons of if clauses). + The mlp and layer norm come from CLIP. + x: input. + b: multihead attention module. + """ + + x_ = b.ln_1(x) + q, k, v = nnf.linear(x_, b.attn.in_proj_weight, b.attn.in_proj_bias).chunk(3, dim=-1) + tgt_len, bsz, embed_dim = q.size() + + head_dim = embed_dim // b.attn.num_heads + scaling = float(head_dim) ** -0.5 + + q = q.contiguous().view(tgt_len, bsz * b.attn.num_heads, b.attn.head_dim).transpose(0, 1) + k = k.contiguous().view(-1, bsz * b.attn.num_heads, b.attn.head_dim).transpose(0, 1) + v = v.contiguous().view(-1, bsz * b.attn.num_heads, b.attn.head_dim).transpose(0, 1) + + q = q * scaling + + attn_output_weights = torch.bmm(q, k.transpose(1, 2)) # n_heads * batch_size, tokens^2, tokens^2 + if attn_mask is not None: + + + attn_mask_type, attn_mask = attn_mask + n_heads = attn_output_weights.size(0) // attn_mask.size(0) + attn_mask = attn_mask.repeat(n_heads, 1) + + if attn_mask_type == 'cls_token': + # the mask only affects similarities compared to the readout-token. + attn_output_weights[:, 0, 1:] = attn_output_weights[:, 0, 1:] * attn_mask[None,...] + # attn_output_weights[:, 0, 0] = 0*attn_output_weights[:, 0, 0] + + if attn_mask_type == 'all': + # print(attn_output_weights.shape, attn_mask[:, None].shape) + attn_output_weights[:, 1:, 1:] = attn_output_weights[:, 1:, 1:] * attn_mask[:, None] + + + attn_output_weights = torch.softmax(attn_output_weights, dim=-1) + + attn_output = torch.bmm(attn_output_weights, v) + attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn_output = b.attn.out_proj(attn_output) + + x = x + attn_output + x = x + b.mlp(b.ln_2(x)) + + if with_aff: + return x, attn_output_weights + else: + return x + + +class CLIPDenseBase(nn.Module): + + def __init__(self, version, reduce_cond, reduce_dim, prompt, n_tokens): + super().__init__() + + from rope.external.cliplib import clip + + # prec = torch.FloatTensor + self.clip_model, _ = clip.load(version, device='cpu', jit=False) + self.model = self.clip_model.visual + + # if not None, scale conv weights such that we obtain n_tokens. + self.n_tokens = n_tokens + + for p in self.clip_model.parameters(): + p.requires_grad_(False) + + # conditional + if reduce_cond is not None: + self.reduce_cond = nn.Linear(512, reduce_cond) + for p in self.reduce_cond.parameters(): + p.requires_grad_(False) + else: + self.reduce_cond = None + + self.film_mul = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim) + self.film_add = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim) + + self.reduce = nn.Linear(768, reduce_dim) + + self.prompt_list = get_prompt_list(prompt) + + # precomputed prompts + import pickle + if isfile('precomputed_prompt_vectors.pickle'): + precomp = pickle.load(open('precomputed_prompt_vectors.pickle', 'rb')) + self.precomputed_prompts = {k: torch.from_numpy(v) for k, v in precomp.items()} + else: + self.precomputed_prompts = dict() + + def rescaled_pos_emb(self, new_size): + assert len(new_size) == 2 + + a = self.model.positional_embedding[1:].T.view(1, 768, *self.token_shape) + b = nnf.interpolate(a, new_size, mode='bicubic', align_corners=False).squeeze(0).view(768, new_size[0]*new_size[1]).T + return torch.cat([self.model.positional_embedding[:1], b]) + + def visual_forward(self, x_inp, extract_layers=(), skip=False, mask=None): + + + with torch.no_grad(): + + inp_size = x_inp.shape[2:] + + if self.n_tokens is not None: + stride2 = x_inp.shape[2] // self.n_tokens + conv_weight2 = nnf.interpolate(self.model.conv1.weight, (stride2, stride2), mode='bilinear', align_corners=True) + x = nnf.conv2d(x_inp, conv_weight2, bias=self.model.conv1.bias, stride=stride2, dilation=self.model.conv1.dilation) + else: + x = self.model.conv1(x_inp) # shape = [*, width, grid, grid] + + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + + x = torch.cat([self.model.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] + + standard_n_tokens = 50 if self.model.conv1.kernel_size[0] == 32 else 197 + + if x.shape[1] != standard_n_tokens: + new_shape = int(math.sqrt(x.shape[1]-1)) + x = x + self.rescaled_pos_emb((new_shape, new_shape)).to(x.dtype)[None,:,:] + else: + x = x + self.model.positional_embedding.to(x.dtype) + + x = self.model.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + + activations, affinities = [], [] + for i, res_block in enumerate(self.model.transformer.resblocks): + + if mask is not None: + mask_layer, mask_type, mask_tensor = mask + if mask_layer == i or mask_layer == 'all': + # import ipdb; ipdb.set_trace() + size = int(math.sqrt(x.shape[0] - 1)) + + attn_mask = (mask_type, nnf.interpolate(mask_tensor.unsqueeze(1).float(), (size, size)).view(mask_tensor.shape[0], size * size)) + + else: + attn_mask = None + else: + attn_mask = None + + x, aff_per_head = forward_multihead_attention(x, res_block, with_aff=True, attn_mask=attn_mask) + + if i in extract_layers: + affinities += [aff_per_head] + + #if self.n_tokens is not None: + # activations += [nnf.interpolate(x, inp_size, mode='bilinear', align_corners=True)] + #else: + activations += [x] + + if len(extract_layers) > 0 and i == max(extract_layers) and skip: + print('early skip') + break + + x = x.permute(1, 0, 2) # LND -> NLD + x = self.model.ln_post(x[:, 0, :]) + + if self.model.proj is not None: + x = x @ self.model.proj + + return x, activations, affinities + + def sample_prompts(self, words, prompt_list=None): + + prompt_list = prompt_list if prompt_list is not None else self.prompt_list + + prompt_indices = torch.multinomial(torch.ones(len(prompt_list)), len(words), replacement=True) + prompts = [prompt_list[i] for i in prompt_indices] + return [promt.format(w) for promt, w in zip(prompts, words)] + + def get_cond_vec(self, conditional, batch_size): + # compute conditional from a single string + if conditional is not None and type(conditional) == str: + cond = self.compute_conditional(conditional) + cond = cond.repeat(batch_size, 1) + + # compute conditional from string list/tuple + elif conditional is not None and type(conditional) in {list, tuple} and type(conditional[0]) == str: + assert len(conditional) == batch_size + cond = self.compute_conditional(conditional) + + # use conditional directly + elif conditional is not None and type(conditional) == torch.Tensor and conditional.ndim == 2: + cond = conditional + + # compute conditional from image + elif conditional is not None and type(conditional) == torch.Tensor: + with torch.no_grad(): + cond, _, _ = self.visual_forward(conditional) + else: + raise ValueError('invalid conditional') + return cond + + def compute_conditional(self, conditional): + from rope.external.cliplib import clip + + dev = next(self.parameters()).device + + if type(conditional) in {list, tuple}: + text_tokens = clip.tokenize(conditional).to(dev) + cond = self.clip_model.encode_text(text_tokens) + else: + if conditional in self.precomputed_prompts: + cond = self.precomputed_prompts[conditional].float().to(dev) + else: + text_tokens = clip.tokenize([conditional]).to(dev) + cond = self.clip_model.encode_text(text_tokens)[0] + + if self.shift_vector is not None: + return cond + self.shift_vector + else: + return cond + + +def clip_load_untrained(version): + assert version == 'ViT-B/16' + from clip.model import CLIP + from clip.clip import _MODELS, _download + model = torch.jit.load(_download(_MODELS['ViT-B/16'])).eval() + state_dict = model.state_dict() + + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) + image_resolution = vision_patch_size * grid_size + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) + + return CLIP(embed_dim, image_resolution, vision_layers, vision_width, vision_patch_size, + context_length, vocab_size, transformer_width, transformer_heads, transformer_layers) + + +class CLIPDensePredT(CLIPDenseBase): + + def __init__(self, version='ViT-B/32', extract_layers=(3, 6, 9), cond_layer=0, reduce_dim=128, n_heads=4, prompt='fixed', + extra_blocks=0, reduce_cond=None, fix_shift=False, + learn_trans_conv_only=False, limit_to_clip_only=False, upsample=False, + add_calibration=False, rev_activations=False, trans_conv=None, n_tokens=None, complex_trans_conv=False): + + super().__init__(version, reduce_cond, reduce_dim, prompt, n_tokens) + # device = 'cpu' + + self.extract_layers = extract_layers + self.cond_layer = cond_layer + self.limit_to_clip_only = limit_to_clip_only + self.process_cond = None + self.rev_activations = rev_activations + + depth = len(extract_layers) + + if add_calibration: + self.calibration_conds = 1 + + self.upsample_proj = nn.Conv2d(reduce_dim, 1, kernel_size=1) if upsample else None + + self.add_activation1 = True + + self.version = version + + self.token_shape = {'ViT-B/32': (7, 7), 'ViT-B/16': (14, 14)}[version] + + if fix_shift: + # self.shift_vector = nn.Parameter(torch.load(join(dirname(basename(__file__)), 'clip_text_shift_vector.pth')), requires_grad=False) + self.shift_vector = nn.Parameter(torch.load(join(dirname(basename(__file__)), 'shift_text_to_vis.pth')), requires_grad=False) + # self.shift_vector = nn.Parameter(-1*torch.load(join(dirname(basename(__file__)), 'shift2.pth')), requires_grad=False) + else: + self.shift_vector = None + + if trans_conv is None: + trans_conv_ks = {'ViT-B/32': (32, 32), 'ViT-B/16': (16, 16)}[version] + else: + # explicitly define transposed conv kernel size + trans_conv_ks = (trans_conv, trans_conv) + + if not complex_trans_conv: + self.trans_conv = nn.ConvTranspose2d(reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks) + else: + assert trans_conv_ks[0] == trans_conv_ks[1] + + tp_kernels = (trans_conv_ks[0] // 4, trans_conv_ks[0] // 4) + + self.trans_conv = nn.Sequential( + nn.Conv2d(reduce_dim, reduce_dim, kernel_size=3, padding=1), + nn.ReLU(), + nn.ConvTranspose2d(reduce_dim, reduce_dim // 2, kernel_size=tp_kernels[0], stride=tp_kernels[0]), + nn.ReLU(), + nn.ConvTranspose2d(reduce_dim // 2, 1, kernel_size=tp_kernels[1], stride=tp_kernels[1]), + ) + +# self.trans_conv = nn.ConvTranspose2d(reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks) + + assert len(self.extract_layers) == depth + + self.reduces = nn.ModuleList([nn.Linear(768, reduce_dim) for _ in range(depth)]) + self.blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(len(self.extract_layers))]) + self.extra_blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(extra_blocks)]) + + # refinement and trans conv + + if learn_trans_conv_only: + for p in self.parameters(): + p.requires_grad_(False) + + for p in self.trans_conv.parameters(): + p.requires_grad_(True) + + self.prompt_list = get_prompt_list(prompt) + + + def forward(self, inp_image, conditional=None, return_features=False, mask=None): + + assert type(return_features) == bool + + inp_image = inp_image.to(self.model.positional_embedding.device) + + if mask is not None: + raise ValueError('mask not supported') + + # x_inp = normalize(inp_image) + x_inp = inp_image + + bs, dev = inp_image.shape[0], x_inp.device + + cond = self.get_cond_vec(conditional, bs) + + visual_q, activations, _ = self.visual_forward(x_inp, extract_layers=[0] + list(self.extract_layers)) + + activation1 = activations[0] + activations = activations[1:] + + _activations = activations[::-1] if not self.rev_activations else activations + + a = None + for i, (activation, block, reduce) in enumerate(zip(_activations, self.blocks, self.reduces)): + + if a is not None: + a = reduce(activation) + a + else: + a = reduce(activation) + + if i == self.cond_layer: + if self.reduce_cond is not None: + cond = self.reduce_cond(cond) + + a = self.film_mul(cond) * a + self.film_add(cond) + + a = block(a) + + for block in self.extra_blocks: + a = a + block(a) + + a = a[1:].permute(1, 2, 0) # rm cls token and -> BS, Feats, Tokens + + size = int(math.sqrt(a.shape[2])) + + a = a.view(bs, a.shape[1], size, size) + + a = self.trans_conv(a) + + if self.n_tokens is not None: + a = nnf.interpolate(a, x_inp.shape[2:], mode='bilinear', align_corners=True) + + if self.upsample_proj is not None: + a = self.upsample_proj(a) + a = nnf.interpolate(a, x_inp.shape[2:], mode='bilinear') + + if return_features: + return a, visual_q, cond, [activation1] + activations + else: + return a, + + + +class CLIPDensePredTMasked(CLIPDensePredT): + + def __init__(self, version='ViT-B/32', extract_layers=(3, 6, 9), cond_layer=0, reduce_dim=128, n_heads=4, + prompt='fixed', extra_blocks=0, reduce_cond=None, fix_shift=False, learn_trans_conv_only=False, + refine=None, limit_to_clip_only=False, upsample=False, add_calibration=False, n_tokens=None): + + super().__init__(version=version, extract_layers=extract_layers, cond_layer=cond_layer, reduce_dim=reduce_dim, + n_heads=n_heads, prompt=prompt, extra_blocks=extra_blocks, reduce_cond=reduce_cond, + fix_shift=fix_shift, learn_trans_conv_only=learn_trans_conv_only, + limit_to_clip_only=limit_to_clip_only, upsample=upsample, add_calibration=add_calibration, + n_tokens=n_tokens) + + def visual_forward_masked(self, img_s, seg_s): + return super().visual_forward(img_s, mask=('all', 'cls_token', seg_s)) + + def forward(self, img_q, cond_or_img_s, seg_s=None, return_features=False): + + if seg_s is None: + cond = cond_or_img_s + else: + img_s = cond_or_img_s + + with torch.no_grad(): + cond, _, _ = self.visual_forward_masked(img_s, seg_s) + + return super().forward(img_q, cond, return_features=return_features) + + + +class CLIPDenseBaseline(CLIPDenseBase): + + def __init__(self, version='ViT-B/32', cond_layer=0, + extract_layer=9, reduce_dim=128, reduce2_dim=None, prompt='fixed', + reduce_cond=None, limit_to_clip_only=False, n_tokens=None): + + super().__init__(version, reduce_cond, reduce_dim, prompt, n_tokens) + device = 'cpu' + + # self.cond_layer = cond_layer + self.extract_layer = extract_layer + self.limit_to_clip_only = limit_to_clip_only + self.shift_vector = None + + self.token_shape = {'ViT-B/32': (7, 7), 'ViT-B/16': (14, 14)}[version] + + assert reduce2_dim is not None + + self.reduce2 = nn.Sequential( + nn.Linear(reduce_dim, reduce2_dim), + nn.ReLU(), + nn.Linear(reduce2_dim, reduce_dim) + ) + + trans_conv_ks = {'ViT-B/32': (32, 32), 'ViT-B/16': (16, 16)}[version] + self.trans_conv = nn.ConvTranspose2d(reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks) + + + def forward(self, inp_image, conditional=None, return_features=False): + + inp_image = inp_image.to(self.model.positional_embedding.device) + + # x_inp = normalize(inp_image) + x_inp = inp_image + + bs, dev = inp_image.shape[0], x_inp.device + + cond = self.get_cond_vec(conditional, bs) + + visual_q, activations, affinities = self.visual_forward(x_inp, extract_layers=[self.extract_layer]) + + a = activations[0] + a = self.reduce(a) + a = self.film_mul(cond) * a + self.film_add(cond) + + if self.reduce2 is not None: + a = self.reduce2(a) + + # the original model would execute a transformer block here + + a = a[1:].permute(1, 2, 0) # rm cls token and -> BS, Feats, Tokens + + size = int(math.sqrt(a.shape[2])) + + a = a.view(bs, a.shape[1], size, size) + a = self.trans_conv(a) + + if return_features: + return a, visual_q, cond, activations + else: + return a, + + +class CLIPSegMultiLabel(nn.Module): + + def __init__(self, model) -> None: + super().__init__() + + from third_party.JoEm.data_loader import get_seen_idx, get_unseen_idx, VOC + + self.pascal_classes = VOC + + from models.clipseg import CLIPDensePredT + from general_utils import load_model + # self.clipseg = load_model('rd64-vit16-neg0.2-phrasecut', strict=False) + self.clipseg = load_model(model, strict=False) + + self.clipseg.eval() + + def forward(self, x): + + bs = x.shape[0] + out = torch.ones(21, bs, 352, 352).to(x.device) * -10 + + for class_id, class_name in enumerate(self.pascal_classes): + + fac = 3 if class_name == 'background' else 1 + + with torch.no_grad(): + pred = torch.sigmoid(self.clipseg(x, class_name)[0][:,0]) * fac + + out[class_id] += pred + + + out = out.permute(1, 0, 2, 3) + + return out + + # construct output tensor + diff --git a/rope/external/model.py b/rope/external/model.py new file mode 100644 index 0000000..eb1239f --- /dev/null +++ b/rope/external/model.py @@ -0,0 +1,283 @@ +#!/usr/bin/python +# -*- encoding: utf-8 -*- + + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision + +from rope.external.resnet import Resnet18 +# from modules.bn import InPlaceABNSync as BatchNorm2d + + +class ConvBNReLU(nn.Module): + def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs): + super(ConvBNReLU, self).__init__() + self.conv = nn.Conv2d(in_chan, + out_chan, + kernel_size = ks, + stride = stride, + padding = padding, + bias = False) + self.bn = nn.BatchNorm2d(out_chan) + self.init_weight() + + def forward(self, x): + x = self.conv(x) + x = F.relu(self.bn(x)) + return x + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + +class BiSeNetOutput(nn.Module): + def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs): + super(BiSeNetOutput, self).__init__() + self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1) + self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False) + self.init_weight() + + def forward(self, x): + x = self.conv(x) + x = self.conv_out(x) + return x + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + def get_params(self): + wd_params, nowd_params = [], [] + for name, module in self.named_modules(): + if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): + wd_params.append(module.weight) + if not module.bias is None: + nowd_params.append(module.bias) + elif isinstance(module, nn.BatchNorm2d): + nowd_params += list(module.parameters()) + return wd_params, nowd_params + + +class AttentionRefinementModule(nn.Module): + def __init__(self, in_chan, out_chan, *args, **kwargs): + super(AttentionRefinementModule, self).__init__() + self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1) + self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False) + self.bn_atten = nn.BatchNorm2d(out_chan) + self.sigmoid_atten = nn.Sigmoid() + self.init_weight() + + def forward(self, x): + feat = self.conv(x) + atten = F.avg_pool2d(feat, feat.size()[2:]) + atten = self.conv_atten(atten) + atten = self.bn_atten(atten) + atten = self.sigmoid_atten(atten) + out = torch.mul(feat, atten) + return out + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + +class ContextPath(nn.Module): + def __init__(self, *args, **kwargs): + super(ContextPath, self).__init__() + self.resnet = Resnet18() + self.arm16 = AttentionRefinementModule(256, 128) + self.arm32 = AttentionRefinementModule(512, 128) + self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) + self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) + self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0) + + self.init_weight() + + def forward(self, x): + H0, W0 = x.size()[2:] + feat8, feat16, feat32 = self.resnet(x) + H8, W8 = feat8.size()[2:] + H16, W16 = feat16.size()[2:] + H32, W32 = feat32.size()[2:] + + avg = F.avg_pool2d(feat32, feat32.size()[2:]) + avg = self.conv_avg(avg) + avg_up = F.interpolate(avg, (H32, W32), mode='nearest') + + feat32_arm = self.arm32(feat32) + feat32_sum = feat32_arm + avg_up + feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest') + feat32_up = self.conv_head32(feat32_up) + + feat16_arm = self.arm16(feat16) + feat16_sum = feat16_arm + feat32_up + feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest') + feat16_up = self.conv_head16(feat16_up) + + return feat8, feat16_up, feat32_up # x8, x8, x16 + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + def get_params(self): + wd_params, nowd_params = [], [] + for name, module in self.named_modules(): + if isinstance(module, (nn.Linear, nn.Conv2d)): + wd_params.append(module.weight) + if not module.bias is None: + nowd_params.append(module.bias) + elif isinstance(module, nn.BatchNorm2d): + nowd_params += list(module.parameters()) + return wd_params, nowd_params + + +### This is not used, since I replace this with the resnet feature with the same size +class SpatialPath(nn.Module): + def __init__(self, *args, **kwargs): + super(SpatialPath, self).__init__() + self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3) + self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) + self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) + self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0) + self.init_weight() + + def forward(self, x): + feat = self.conv1(x) + feat = self.conv2(feat) + feat = self.conv3(feat) + feat = self.conv_out(feat) + return feat + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + def get_params(self): + wd_params, nowd_params = [], [] + for name, module in self.named_modules(): + if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): + wd_params.append(module.weight) + if not module.bias is None: + nowd_params.append(module.bias) + elif isinstance(module, nn.BatchNorm2d): + nowd_params += list(module.parameters()) + return wd_params, nowd_params + + +class FeatureFusionModule(nn.Module): + def __init__(self, in_chan, out_chan, *args, **kwargs): + super(FeatureFusionModule, self).__init__() + self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) + self.conv1 = nn.Conv2d(out_chan, + out_chan//4, + kernel_size = 1, + stride = 1, + padding = 0, + bias = False) + self.conv2 = nn.Conv2d(out_chan//4, + out_chan, + kernel_size = 1, + stride = 1, + padding = 0, + bias = False) + self.relu = nn.ReLU(inplace=True) + self.sigmoid = nn.Sigmoid() + self.init_weight() + + def forward(self, fsp, fcp): + fcat = torch.cat([fsp, fcp], dim=1) + feat = self.convblk(fcat) + atten = F.avg_pool2d(feat, feat.size()[2:]) + atten = self.conv1(atten) + atten = self.relu(atten) + atten = self.conv2(atten) + atten = self.sigmoid(atten) + feat_atten = torch.mul(feat, atten) + feat_out = feat_atten + feat + return feat_out + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + def get_params(self): + wd_params, nowd_params = [], [] + for name, module in self.named_modules(): + if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): + wd_params.append(module.weight) + if not module.bias is None: + nowd_params.append(module.bias) + elif isinstance(module, nn.BatchNorm2d): + nowd_params += list(module.parameters()) + return wd_params, nowd_params + + +class BiSeNet(nn.Module): + def __init__(self, n_classes, *args, **kwargs): + super(BiSeNet, self).__init__() + self.cp = ContextPath() + ## here self.sp is deleted + self.ffm = FeatureFusionModule(256, 256) + self.conv_out = BiSeNetOutput(256, 256, n_classes) + self.conv_out16 = BiSeNetOutput(128, 64, n_classes) + self.conv_out32 = BiSeNetOutput(128, 64, n_classes) + self.init_weight() + + def forward(self, x): + H, W = x.size()[2:] + feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature + feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature + feat_fuse = self.ffm(feat_sp, feat_cp8) + + feat_out = self.conv_out(feat_fuse) + feat_out16 = self.conv_out16(feat_cp8) + feat_out32 = self.conv_out32(feat_cp16) + + feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True) + feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True) + feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True) + return feat_out, feat_out16, feat_out32 + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + def get_params(self): + wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], [] + for name, child in self.named_children(): + child_wd_params, child_nowd_params = child.get_params() + if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput): + lr_mul_wd_params += child_wd_params + lr_mul_nowd_params += child_nowd_params + else: + wd_params += child_wd_params + nowd_params += child_nowd_params + return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params + + +if __name__ == "__main__": + net = BiSeNet(19) + net.cuda() + net.eval() + in_ten = torch.randn(16, 3, 640, 480).cuda() + out, out16, out32 = net(in_ten) + print(out.shape) + + net.get_params() diff --git a/rope/external/resnet.py b/rope/external/resnet.py new file mode 100644 index 0000000..64969da --- /dev/null +++ b/rope/external/resnet.py @@ -0,0 +1,109 @@ +#!/usr/bin/python +# -*- encoding: utf-8 -*- + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.model_zoo as modelzoo + +# from modules.bn import InPlaceABNSync as BatchNorm2d + +resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth' + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BasicBlock(nn.Module): + def __init__(self, in_chan, out_chan, stride=1): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(in_chan, out_chan, stride) + self.bn1 = nn.BatchNorm2d(out_chan) + self.conv2 = conv3x3(out_chan, out_chan) + self.bn2 = nn.BatchNorm2d(out_chan) + self.relu = nn.ReLU(inplace=True) + self.downsample = None + if in_chan != out_chan or stride != 1: + self.downsample = nn.Sequential( + nn.Conv2d(in_chan, out_chan, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(out_chan), + ) + + def forward(self, x): + residual = self.conv1(x) + residual = F.relu(self.bn1(residual)) + residual = self.conv2(residual) + residual = self.bn2(residual) + + shortcut = x + if self.downsample is not None: + shortcut = self.downsample(x) + + out = shortcut + residual + out = self.relu(out) + return out + + +def create_layer_basic(in_chan, out_chan, bnum, stride=1): + layers = [BasicBlock(in_chan, out_chan, stride=stride)] + for i in range(bnum-1): + layers.append(BasicBlock(out_chan, out_chan, stride=1)) + return nn.Sequential(*layers) + + +class Resnet18(nn.Module): + def __init__(self): + super(Resnet18, self).__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1) + self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2) + self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2) + self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2) + self.init_weight() + + def forward(self, x): + x = self.conv1(x) + x = F.relu(self.bn1(x)) + x = self.maxpool(x) + + x = self.layer1(x) + feat8 = self.layer2(x) # 1/8 + feat16 = self.layer3(feat8) # 1/16 + feat32 = self.layer4(feat16) # 1/32 + return feat8, feat16, feat32 + + def init_weight(self): + state_dict = modelzoo.load_url(resnet18_url) + self_state_dict = self.state_dict() + for k, v in state_dict.items(): + if 'fc' in k: continue + self_state_dict.update({k: v}) + self.load_state_dict(self_state_dict) + + def get_params(self): + wd_params, nowd_params = [], [] + for name, module in self.named_modules(): + if isinstance(module, (nn.Linear, nn.Conv2d)): + wd_params.append(module.weight) + if not module.bias is None: + nowd_params.append(module.bias) + elif isinstance(module, nn.BatchNorm2d): + nowd_params += list(module.parameters()) + return wd_params, nowd_params + + +if __name__ == "__main__": + net = Resnet18() + x = torch.randn(16, 3, 224, 224) + out = net(x) + print(out[0].size()) + print(out[1].size()) + print(out[2].size()) + net.get_params() diff --git a/rope/external/vitseg.py b/rope/external/vitseg.py new file mode 100644 index 0000000..d3231e5 --- /dev/null +++ b/rope/external/vitseg.py @@ -0,0 +1,286 @@ +import math +from posixpath import basename, dirname, join +# import clip +from clip.model import convert_weights +import torch +import json +from torch import nn +from torch.nn import functional as nnf +from torch.nn.modules import activation +from torch.nn.modules.activation import ReLU +from torchvision import transforms + +normalize = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) + +from torchvision.models import ResNet + + +def process_prompts(conditional, prompt_list, conditional_map): + # DEPRECATED + + # randomly sample a synonym + words = [conditional_map[int(i)] for i in conditional] + words = [syns[torch.multinomial(torch.ones(len(syns)), 1, replacement=True).item()] for syns in words] + words = [w.replace('_', ' ') for w in words] + + if prompt_list is not None: + prompt_indices = torch.multinomial(torch.ones(len(prompt_list)), len(words), replacement=True) + prompts = [prompt_list[i] for i in prompt_indices] + else: + prompts = ['a photo of {}'] * (len(words)) + + return [promt.format(w) for promt, w in zip(prompts, words)] + + +class VITDenseBase(nn.Module): + + def rescaled_pos_emb(self, new_size): + assert len(new_size) == 2 + + a = self.model.positional_embedding[1:].T.view(1, 768, *self.token_shape) + b = nnf.interpolate(a, new_size, mode='bicubic', align_corners=False).squeeze(0).view(768, new_size[0]*new_size[1]).T + return torch.cat([self.model.positional_embedding[:1], b]) + + def visual_forward(self, x_inp, extract_layers=(), skip=False, mask=None): + + with torch.no_grad(): + + x_inp = nnf.interpolate(x_inp, (384, 384)) + + x = self.model.patch_embed(x_inp) + cls_token = self.model.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks + if self.model.dist_token is None: + x = torch.cat((cls_token, x), dim=1) + else: + x = torch.cat((cls_token, self.model.dist_token.expand(x.shape[0], -1, -1), x), dim=1) + x = self.model.pos_drop(x + self.model.pos_embed) + + activations = [] + for i, block in enumerate(self.model.blocks): + x = block(x) + + if i in extract_layers: + # permute to be compatible with CLIP + activations += [x.permute(1,0,2)] + + x = self.model.norm(x) + x = self.model.head(self.model.pre_logits(x[:, 0])) + + # again for CLIP compatibility + # x = x.permute(1, 0, 2) + + return x, activations, None + + def sample_prompts(self, words, prompt_list=None): + + prompt_list = prompt_list if prompt_list is not None else self.prompt_list + + prompt_indices = torch.multinomial(torch.ones(len(prompt_list)), len(words), replacement=True) + prompts = [prompt_list[i] for i in prompt_indices] + return [promt.format(w) for promt, w in zip(prompts, words)] + + def get_cond_vec(self, conditional, batch_size): + # compute conditional from a single string + if conditional is not None and type(conditional) == str: + cond = self.compute_conditional(conditional) + cond = cond.repeat(batch_size, 1) + + # compute conditional from string list/tuple + elif conditional is not None and type(conditional) in {list, tuple} and type(conditional[0]) == str: + assert len(conditional) == batch_size + cond = self.compute_conditional(conditional) + + # use conditional directly + elif conditional is not None and type(conditional) == torch.Tensor and conditional.ndim == 2: + cond = conditional + + # compute conditional from image + elif conditional is not None and type(conditional) == torch.Tensor: + with torch.no_grad(): + cond, _, _ = self.visual_forward(conditional) + else: + raise ValueError('invalid conditional') + return cond + + def compute_conditional(self, conditional): + import clip + + dev = next(self.parameters()).device + + if type(conditional) in {list, tuple}: + text_tokens = clip.tokenize(conditional).to(dev) + cond = self.clip_model.encode_text(text_tokens) + else: + if conditional in self.precomputed_prompts: + cond = self.precomputed_prompts[conditional].float().to(dev) + else: + text_tokens = clip.tokenize([conditional]).to(dev) + cond = self.clip_model.encode_text(text_tokens)[0] + + return cond + + +class VITDensePredT(VITDenseBase): + + def __init__(self, extract_layers=(3, 6, 9), cond_layer=0, reduce_dim=128, n_heads=4, prompt='fixed', + depth=3, extra_blocks=0, reduce_cond=None, fix_shift=False, + learn_trans_conv_only=False, refine=None, limit_to_clip_only=False, upsample=False, + add_calibration=False, process_cond=None, not_pretrained=False): + super().__init__() + # device = 'cpu' + + self.extract_layers = extract_layers + self.cond_layer = cond_layer + self.limit_to_clip_only = limit_to_clip_only + self.process_cond = None + + if add_calibration: + self.calibration_conds = 1 + + self.upsample_proj = nn.Conv2d(reduce_dim, 1, kernel_size=1) if upsample else None + + self.add_activation1 = True + + import timm + self.model = timm.create_model('vit_base_patch16_384', pretrained=True) + self.model.head = nn.Linear(768, 512 if reduce_cond is None else reduce_cond) + + for p in self.model.parameters(): + p.requires_grad_(False) + + import clip + self.clip_model, _ = clip.load('ViT-B/16', device='cpu', jit=False) + # del self.clip_model.visual + + + self.token_shape = (14, 14) + + # conditional + if reduce_cond is not None: + self.reduce_cond = nn.Linear(512, reduce_cond) + for p in self.reduce_cond.parameters(): + p.requires_grad_(False) + else: + self.reduce_cond = None + + # self.film = AVAILABLE_BLOCKS['film'](512, 128) + self.film_mul = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim) + self.film_add = nn.Linear(512 if reduce_cond is None else reduce_cond, reduce_dim) + + # DEPRECATED + # self.conditional_map = {c['id']: c['synonyms'] for c in json.load(open(cond_map))} + + assert len(self.extract_layers) == depth + + self.reduces = nn.ModuleList([nn.Linear(768, reduce_dim) for _ in range(depth)]) + self.blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(len(self.extract_layers))]) + self.extra_blocks = nn.ModuleList([nn.TransformerEncoderLayer(d_model=reduce_dim, nhead=n_heads) for _ in range(extra_blocks)]) + + trans_conv_ks = (16, 16) + self.trans_conv = nn.ConvTranspose2d(reduce_dim, 1, trans_conv_ks, stride=trans_conv_ks) + + # refinement and trans conv + + if learn_trans_conv_only: + for p in self.parameters(): + p.requires_grad_(False) + + for p in self.trans_conv.parameters(): + p.requires_grad_(True) + + if prompt == 'fixed': + self.prompt_list = ['a photo of a {}.'] + elif prompt == 'shuffle': + self.prompt_list = ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.'] + elif prompt == 'shuffle+': + self.prompt_list = ['a photo of a {}.', 'a photograph of a {}.', 'an image of a {}.', '{}.', + 'a cropped photo of a {}.', 'a good photo of a {}.', 'a photo of one {}.', + 'a bad photo of a {}.', 'a photo of the {}.'] + elif prompt == 'shuffle_clip': + from models.clip_prompts import imagenet_templates + self.prompt_list = imagenet_templates + + if process_cond is not None: + if process_cond == 'clamp' or process_cond[0] == 'clamp': + + val = process_cond[1] if type(process_cond) in {list, tuple} else 0.2 + + def clamp_vec(x): + return torch.clamp(x, -val, val) + + self.process_cond = clamp_vec + + elif process_cond.endswith('.pth'): + + shift = torch.load(process_cond) + def add_shift(x): + return x + shift.to(x.device) + + self.process_cond = add_shift + + import pickle + precomp = pickle.load(open('precomputed_prompt_vectors.pickle', 'rb')) + self.precomputed_prompts = {k: torch.from_numpy(v) for k, v in precomp.items()} + + + def forward(self, inp_image, conditional=None, return_features=False, mask=None): + + assert type(return_features) == bool + + # inp_image = inp_image.to(self.model.positional_embedding.device) + + if mask is not None: + raise ValueError('mask not supported') + + # x_inp = normalize(inp_image) + x_inp = inp_image + + bs, dev = inp_image.shape[0], x_inp.device + + inp_image_size = inp_image.shape[2:] + + cond = self.get_cond_vec(conditional, bs) + + visual_q, activations, _ = self.visual_forward(x_inp, extract_layers=[0] + list(self.extract_layers)) + + activation1 = activations[0] + activations = activations[1:] + + a = None + for i, (activation, block, reduce) in enumerate(zip(activations[::-1], self.blocks, self.reduces)): + + if a is not None: + a = reduce(activation) + a + else: + a = reduce(activation) + + if i == self.cond_layer: + if self.reduce_cond is not None: + cond = self.reduce_cond(cond) + + a = self.film_mul(cond) * a + self.film_add(cond) + + a = block(a) + + for block in self.extra_blocks: + a = a + block(a) + + a = a[1:].permute(1, 2, 0) # rm cls token and -> BS, Feats, Tokens + + size = int(math.sqrt(a.shape[2])) + + a = a.view(bs, a.shape[1], size, size) + + if self.trans_conv is not None: + a = self.trans_conv(a) + + if self.upsample_proj is not None: + a = self.upsample_proj(a) + a = nnf.interpolate(a, x_inp.shape[2:], mode='bilinear') + + a = nnf.interpolate(a, inp_image_size) + + if return_features: + return a, visual_q, cond, [activation1] + activations + else: + return a, diff --git a/rope/media/CLIP.png b/rope/media/CLIP.png new file mode 100644 index 0000000..c265abe Binary files /dev/null and b/rope/media/CLIP.png differ diff --git a/rope/media/CLIP.png~ b/rope/media/CLIP.png~ new file mode 100644 index 0000000..ae8abf5 Binary files /dev/null and b/rope/media/CLIP.png~ differ diff --git a/rope/media/blur.png b/rope/media/blur.png new file mode 100644 index 0000000..ef1c258 Binary files /dev/null and b/rope/media/blur.png differ diff --git a/rope/media/delemb.png b/rope/media/delemb.png new file mode 100644 index 0000000..75e8dd2 Binary files /dev/null and b/rope/media/delemb.png differ diff --git a/rope/media/diff.png b/rope/media/diff.png new file mode 100644 index 0000000..39f93b5 Binary files /dev/null and b/rope/media/diff.png differ diff --git a/rope/media/gfpgan_logo.png b/rope/media/gfpgan_logo.png new file mode 100644 index 0000000..f019378 Binary files /dev/null and b/rope/media/gfpgan_logo.png differ diff --git a/rope/media/maskblur.png b/rope/media/maskblur.png new file mode 100644 index 0000000..a42ea77 Binary files /dev/null and b/rope/media/maskblur.png differ diff --git a/rope/media/maskside.png b/rope/media/maskside.png new file mode 100644 index 0000000..c35a842 Binary files /dev/null and b/rope/media/maskside.png differ diff --git a/rope/media/maskup.png b/rope/media/maskup.png new file mode 100644 index 0000000..4789066 Binary files /dev/null and b/rope/media/maskup.png differ diff --git a/rope/media/occluder.png b/rope/media/occluder.png new file mode 100644 index 0000000..c7eba81 Binary files /dev/null and b/rope/media/occluder.png differ diff --git a/rope/media/parse.png b/rope/media/parse.png new file mode 100644 index 0000000..a110c3b Binary files /dev/null and b/rope/media/parse.png differ diff --git a/rope/media/play.png b/rope/media/play.png new file mode 100644 index 0000000..2d1b910 Binary files /dev/null and b/rope/media/play.png differ diff --git a/rope/media/rec.png b/rope/media/rec.png new file mode 100644 index 0000000..b4adfcb Binary files /dev/null and b/rope/media/rec.png differ diff --git a/rope/media/save.png b/rope/media/save.png new file mode 100644 index 0000000..6686422 Binary files /dev/null and b/rope/media/save.png differ diff --git a/rope/media/srcface.png b/rope/media/srcface.png new file mode 100644 index 0000000..36afa0c Binary files /dev/null and b/rope/media/srcface.png differ diff --git a/rope/media/swap.png b/rope/media/swap.png new file mode 100644 index 0000000..8003a7a Binary files /dev/null and b/rope/media/swap.png differ diff --git a/rope/media/tarface.png b/rope/media/tarface.png new file mode 100644 index 0000000..2e0e279 Binary files /dev/null and b/rope/media/tarface.png differ diff --git a/rope/media/tarfacedel.png b/rope/media/tarfacedel.png new file mode 100644 index 0000000..7a5f8f3 Binary files /dev/null and b/rope/media/tarfacedel.png differ diff --git a/rope/media/threads.png b/rope/media/threads.png new file mode 100644 index 0000000..019af38 Binary files /dev/null and b/rope/media/threads.png differ diff --git a/rope/media/thresh.png b/rope/media/thresh.png new file mode 100644 index 0000000..efb2b34 Binary files /dev/null and b/rope/media/thresh.png differ