From a17b8dff773942b1648cd947373fe960f18463cd Mon Sep 17 00:00:00 2001 From: Aleksey Morozov <36787333+amrzv@users.noreply.github.com> Date: Sat, 29 May 2021 14:12:13 +0300 Subject: [PATCH 1/2] enhanced braindance notebook --- scripts/braindance.ipynb | 136 +++++++++++++++++++++------------------ 1 file changed, 73 insertions(+), 63 deletions(-) diff --git a/scripts/braindance.ipynb b/scripts/braindance.ipynb index 5a51284..9dd0e26 100644 --- a/scripts/braindance.ipynb +++ b/scripts/braindance.ipynb @@ -66,9 +66,10 @@ "import numpy as np\n", "from PIL import Image\n", "\n", + "\n", "def load_im_as_example(im):\n", " size = [208, 368]\n", - " w,h = im.size\n", + " w, h = im.size\n", " if np.abs(w/h - size[1]/size[0]) > 0.1:\n", " print(f\"Center cropping image to AR {size[1]/size[0]}\")\n", " if w/h < size[1]/size[0]:\n", @@ -85,7 +86,7 @@ " right = w/2 + size[1]/size[0]*h\n", " im = im.crop(box=(left, top, right, bottom))\n", "\n", - " im = im.resize((size[1],size[0]),\n", + " im = im.resize((size[1], size[0]),\n", " resample=Image.LANCZOS)\n", " im = np.array(im)/127.5-1.0\n", " im = im.astype(np.float32)\n", @@ -97,12 +98,13 @@ " [0.0, 0.0, 1.0]], dtype=np.float32)\n", " example[\"K_inv\"] = np.linalg.inv(example[\"K\"])\n", "\n", - " ## dummy data not used during inference\n", + " # dummy data not used during inference\n", " example[\"dst_img\"] = np.zeros_like(example[\"src_img\"])\n", - " example[\"src_points\"] = np.zeros((1,3), dtype=np.float32)\n", + " example[\"src_points\"] = np.zeros((1, 3), dtype=np.float32)\n", "\n", " return example\n", "\n", + "\n", "def load_as_example(path):\n", " im = Image.open(path)\n", " return load_im_as_example(im)" @@ -128,25 +130,29 @@ "def normalize(x):\n", " return x/np.linalg.norm(x)\n", "\n", + "\n", "def cosd(x):\n", " return np.cos(np.deg2rad(x))\n", "\n", + "\n", "def sind(x):\n", " return np.sin(np.deg2rad(x))\n", "\n", + "\n", "def look_to(camera_pos, camera_dir, camera_up):\n", " camera_right = normalize(np.cross(camera_up, camera_dir))\n", " R = np.zeros((4, 4))\n", - " R[0,0:3] = normalize(camera_right)\n", - " R[1,0:3] = normalize(np.cross(camera_dir, camera_right))\n", - " R[2,0:3] = normalize(camera_dir)\n", - " R[3,3] = 1\n", + " R[0, :3] = normalize(camera_right)\n", + " R[1, :3] = normalize(np.cross(camera_dir, camera_right))\n", + " R[2, :3] = normalize(camera_dir)\n", + " R[3, 3] = 1\n", " trans_matrix = np.array([[1.0, 0.0, 0.0, -camera_pos[0]],\n", " [0.0, 1.0, 0.0, -camera_pos[1]],\n", " [0.0, 0.0, 1.0, -camera_pos[2]],\n", " [0.0, 0.0, 0.0, 1.0]])\n", " tmp = R@trans_matrix\n", - " return tmp[:3,:3], tmp[:3,3]\n", + " return tmp[:3, :3], tmp[:3, 3]\n", + "\n", "\n", "def rotate_around_axis(angle, axis):\n", " axis = normalize(axis)\n", @@ -182,6 +188,7 @@ "import torch\n", "from splatting import splatting_function\n", "\n", + "\n", "def render_forward(src_ims, src_dms,\n", " Rcam, tcam,\n", " K_src,\n", @@ -190,7 +197,7 @@ " tcam = tcam.to(device=src_ims.device)[None]\n", "\n", " R = Rcam\n", - " t = tcam[...,None]\n", + " t = tcam[..., None]\n", " K_src_inv = K_src.inverse()\n", "\n", " assert len(src_ims.shape) == 4\n", @@ -200,39 +207,39 @@ "\n", " x = np.arange(src_ims[0].shape[1])\n", " y = np.arange(src_ims[0].shape[0])\n", - " coord = np.stack(np.meshgrid(x,y), -1)\n", - " coord = np.concatenate((coord, np.ones_like(coord)[:,:,[0]]), -1) # z=1\n", + " coord = np.stack(np.meshgrid(x, y), -1)\n", + " coord = np.concatenate((coord, np.ones_like(coord)[..., [0]]), -1) # z=1\n", " coord = coord.astype(np.float32)\n", " coord = torch.as_tensor(coord, dtype=K_src.dtype, device=K_src.device)\n", - " coord = coord[None] # bs, h, w, 3\n", + " coord = coord[None] # bs, h, w, 3\n", "\n", - " D = src_dms[:,:,:,None,None]\n", + " D = src_dms[..., None, None]\n", "\n", - " points = K_dst[None,None,None,...]@(R[:,None,None,...]@(D*K_src_inv[None,None,None,...]@coord[:,:,:,:,None])+t[:,None,None,:,:])\n", + " points = K_dst[None, None, None, ...]@(R[:, None, None, ...]@(D*K_src_inv[None, None, None, ...]@coord[..., None])+t[:, None, None, ...])\n", " points = points.squeeze(-1)\n", "\n", - " new_z = points[:,:,:,[2]].clone().permute(0,3,1,2) # b,1,h,w\n", - " points = points/torch.clamp(points[:,:,:,[2]], 1e-8, None)\n", + " new_z = points[..., [2]].clone().permute(0, 3, 1, 2) # b,1,h,w\n", + " points = points/torch.clamp(points[..., [2]], 1e-8, None)\n", "\n", - " src_ims = src_ims.permute(0,3,1,2)\n", + " src_ims = src_ims.permute(0, 3, 1, 2)\n", " flow = points - coord\n", - " flow = flow.permute(0,3,1,2)[:,:2,...]\n", + " flow = flow.permute(0, 3, 1, 2)[:, :2, ...]\n", "\n", " alpha = 0.5\n", " importance = alpha/new_z\n", - " importance_min = importance.amin((1,2,3),keepdim=True)\n", - " importance_max = importance.amax((1,2,3),keepdim=True)\n", - " importance=(importance-importance_min)/(importance_max-importance_min+1e-6)*10-10\n", + " importance_min = importance.amin((1, 2, 3), keepdim=True)\n", + " importance_max = importance.amax((1, 2, 3), keepdim=True)\n", + " importance = (importance-importance_min)/(importance_max-importance_min+1e-6)*10-10\n", " importance = importance.exp()\n", "\n", " input_data = torch.cat([importance*src_ims, importance], 1)\n", " output_data = splatting_function(\"summation\", input_data, flow)\n", "\n", - " num = torch.sum(output_data[:,:-1,:,:], dim=0, keepdim=True)\n", - " nom = torch.sum(output_data[:,-1:,:,:], dim=0, keepdim=True)\n", + " num = torch.sum(output_data[:, :-1, ...], dim=0, keepdim=True)\n", + " nom = torch.sum(output_data[:, -1:, ...], dim=0, keepdim=True)\n", "\n", " rendered = num/(nom+1e-7)\n", - " rendered = rendered.permute(0,2,3,1)[0,...]\n", + " rendered = rendered.permute(0, 2, 3, 1)[0, ...]\n", " return rendered" ], "execution_count": null, @@ -256,6 +263,7 @@ "from geofree import pretrained_models\n", "from torch.utils.data.dataloader import default_collate\n", "\n", + "\n", "class Renderer(object):\n", " def __init__(self, model, device):\n", " self.model = pretrained_models(model=model)\n", @@ -271,8 +279,8 @@ " self.step = 0\n", "\n", " batch = self.batch = default_collate([example])\n", - " batch[\"R_rel\"] = show_R[None,...]\n", - " batch[\"t_rel\"] = show_t[None,...]\n", + " batch[\"R_rel\"] = show_R[None, ...]\n", + " batch[\"t_rel\"] = show_t[None, ...]\n", "\n", " _, cdict, edict = self.model.get_xce(batch)\n", " for k in cdict:\n", @@ -280,16 +288,16 @@ " for k in edict:\n", " edict[k] = edict[k].to(device=self.model.device)\n", "\n", - " quant_d, quant_c, dc_indices, embeddings = self.model.get_normalized_c(cdict,edict,fixed_scale=True)\n", + " quant_d, quant_c, dc_indices, embeddings = self.model.get_normalized_c(cdict, edict, fixed_scale=True)\n", "\n", - " start_im = start_im[None,...].to(self.model.device).permute(0,3,1,2)\n", + " start_im = start_im[None, ...].to(self.model.device).permute(0, 3, 1, 2)\n", " quant_c, c_indices = self.model.encode_to_c(c=start_im)\n", " cond_rec = self.model.cond_stage_model.decode(quant_c)\n", "\n", - " self.current_im = cond_rec.permute(0,2,3,1)[0]\n", + " self.current_im = cond_rec.permute(0, 2, 3, 1)[0]\n", " self.current_sample = c_indices\n", "\n", - " self.quant_c = quant_c # to know shape\n", + " self.quant_c = quant_c # to know shape\n", " # for sampling\n", " self.dc_indices = dc_indices\n", " self.embeddings = embeddings\n", @@ -297,9 +305,9 @@ " def __call__(self):\n", " if self.step < self.current_sample.shape[1]:\n", " z_start_indices = self.current_sample[:, :self.step]\n", - " temperature=None\n", - " top_k=250\n", - " callback=None\n", + " temperature = None\n", + " top_k = 250\n", + " callback = None\n", " index_sample = self.model.sample(z_start_indices, self.dc_indices,\n", " steps=1,\n", " temperature=temperature if temperature is not None else 1.0,\n", @@ -308,12 +316,12 @@ " callback=callback if callback is not None else lambda k: None,\n", " embeddings=self.embeddings)\n", " self.current_sample = torch.cat((index_sample,\n", - " self.current_sample[:,self.step+1:]),\n", + " self.current_sample[:, self.step+1:]),\n", " dim=1)\n", "\n", " sample_dec = self.model.decode_to_img(self.current_sample,\n", " self.quant_c.shape)\n", - " self.current_im = sample_dec.permute(0,2,3,1)[0]\n", + " self.current_im = sample_dec.permute(0, 2, 3, 1)[0]\n", " self.step += 1\n", "\n", " if self.step >= self.current_sample.shape[1]:\n", @@ -325,10 +333,10 @@ " return self._active\n", "\n", " def reconstruct(self, x):\n", - " x = x.to(self.model.device).permute(0,3,1,2)\n", + " x = x.to(self.model.device).permute(0, 3, 1, 2)\n", " quant_c, c_indices = self.model.encode_to_c(c=x)\n", " x_rec = self.model.cond_stage_model.decode(quant_c)\n", - " return x_rec.permute(0,2,3,1)" + " return x_rec.permute(0, 2, 3, 1)" ], "execution_count": null, "outputs": [] @@ -407,17 +415,17 @@ " self.renderer = renderer\n", " self.init_example(example)\n", " self.RENDERING = False\n", - " \n", + "\n", " def init_example(self, example):\n", " self.example = example\n", "\n", - " ims = example[\"src_img\"][None,...]\n", + " ims = example[\"src_img\"][None, ...]\n", " K = example[\"K\"]\n", "\n", " # compute depth for preview\n", " dms = [None]\n", " for i in range(ims.shape[0]):\n", - " midas_in = torch.tensor(ims[i])[None,...].permute(0,3,1,2).to(device)\n", + " midas_in = torch.tensor(ims[i])[None, ...].permute(0, 3, 1, 2).to(device)\n", " scaled_idepth = self.midas.fixed_scale_depth(midas_in, return_inverse_depth=True)\n", " dms[i] = 1.0/scaled_idepth[0].cpu().numpy()\n", "\n", @@ -440,7 +448,7 @@ " self.MOUSE_SENSITIVITY = 10.0\n", "\n", " def update_camera(self, keys):\n", - " ######### Camera\n", + " \"\"\"Camera\"\"\"\n", " if keys[\"a\"]:\n", " self.camera_pos += self.CAM_SPEED*normalize(np.cross(self.camera_dir, self.camera_up))\n", " if keys[\"d\"]:\n", @@ -472,7 +480,7 @@ " self.camera_up))\n", " self.camera_dir = rotation@self.camera_dir\n", "\n", - " show_R, show_t = look_to(self.camera_pos, self.camera_dir, self.camera_up) # look from pos in direction dir\n", + " show_R, show_t = look_to(self.camera_pos, self.camera_dir, self.camera_up) # look from pos in direction dir\n", " show_R = torch.as_tensor(show_R, dtype=torch.float32)\n", " show_t = torch.as_tensor(show_t, dtype=torch.float32)\n", "\n", @@ -486,14 +494,14 @@ " self.show_R, self.show_t,\n", " K_src=self.K,\n", " K_dst=self.K)\n", - " \n", + "\n", " if keys[\"render\"]:\n", " self.RENDERING = True\n", " self.renderer.init(wrp_im, self.example, self.show_R, self.show_t)\n", "\n", " if self.RENDERING:\n", " wrp_im = self.renderer()\n", - " \n", + "\n", " if not self.renderer._active or keys[\"stop\"]:\n", " self.RENDERING = False\n", "\n", @@ -517,59 +525,61 @@ "id": "BXq8YBdwyTQ8" }, "source": [ - "import IPython\n", + "from IPython.display import JSON, HTML\n", "from google.colab import output, files\n", - "import base64\n", - "import io\n", + "from base64 import b64encode\n", "from io import BytesIO\n", "\n", "looper = Looper(midas, renderer, example)\n", "\n", + "\n", "def as_png(x):\n", " if hasattr(x, \"detach\"):\n", " x = x.detach().cpu().numpy()\n", - " #x = x.transpose(1,0,2)\n", + " # x = x.transpose(1,0,2)\n", " x = (x+1.0)*127.5\n", " x = x.clip(0, 255).astype(np.uint8)\n", - " data = io.BytesIO()\n", + " data = BytesIO()\n", " Image.fromarray(x).save(data, format=\"png\")\n", " data.seek(0)\n", " data = data.read()\n", - " return base64.b64encode(data).decode()\n", + " return b64encode(data).decode()\n", + "\n", "\n", "def pyloop(data):\n", " if data.get(\"upload\", False):\n", " data = files.upload()\n", " fname = sorted(data.keys())[0]\n", - " I = Image.open(BytesIO(data[fname]))\n", - " looper.init_example(load_im_as_example(I))\n", + " img = Image.open(BytesIO(data[fname]))\n", + " looper.init_example(load_im_as_example(img))\n", "\n", " keys = dict()\n", " if \"look\" in data:\n", " keys[\"look\"] = np.array(data[\"look\"])*2.0-1.0\n", " move = data.get(\"direction\", None)\n", - " keys[\"w\"] = move==\"forward\"\n", - " keys[\"a\"] = move==\"left\"\n", - " keys[\"s\"] = move==\"backward\"\n", - " keys[\"d\"] = move==\"right\"\n", - " keys[\"q\"] = move==\"up\"\n", - " keys[\"e\"] = move==\"down\"\n", - " keys[\"render\"] = move==\"render\"\n", + " keys[\"w\"] = move == \"forward\"\n", + " keys[\"a\"] = move == \"left\"\n", + " keys[\"s\"] = move == \"backward\"\n", + " keys[\"d\"] = move == \"right\"\n", + " keys[\"q\"] = move == \"up\"\n", + " keys[\"e\"] = move == \"down\"\n", + " keys[\"render\"] = move == \"render\"\n", " keys[\"stop\"] = data.get(\"stop\", False)\n", " output, rendering = looper.update(keys)\n", "\n", " ret = dict()\n", " ret[\"image\"] = as_png(output)\n", " ret[\"loop\"] = rendering\n", - " ret = IPython.display.JSON(ret)\n", + " ret = JSON(ret)\n", "\n", " return ret\n", "\n", + "\n", "output.register_callback('pyloop', pyloop)\n", "\n", "# The front-end for our interactive demo.\n", "\n", - "html='''\n", + "html = '''\n", "