-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrender_utils.py
281 lines (229 loc) · 10.3 KB
/
render_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import os
import enum
import types
from typing import List, Mapping, Optional, Text, Tuple, Union
import copy
from PIL import Image
import mediapy as media
from matplotlib import cm
from tqdm import tqdm
import torch
def normalize(x: np.ndarray) -> np.ndarray:
"""Normalization helper function."""
return x / np.linalg.norm(x)
def pad_poses(p: np.ndarray) -> np.ndarray:
"""Pad [..., 3, 4] pose matrices with a homogeneous bottom row [0,0,0,1]."""
bottom = np.broadcast_to([0, 0, 0, 1.], p[..., :1, :4].shape)
return np.concatenate([p[..., :3, :4], bottom], axis=-2)
def unpad_poses(p: np.ndarray) -> np.ndarray:
"""Remove the homogeneous bottom row from [..., 4, 4] pose matrices."""
return p[..., :3, :4]
def recenter_poses(poses: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""Recenter poses around the origin."""
cam2world = average_pose(poses)
transform = np.linalg.inv(pad_poses(cam2world))
poses = transform @ pad_poses(poses)
return unpad_poses(poses), transform
def average_pose(poses: np.ndarray) -> np.ndarray:
"""New pose using average position, z-axis, and up vector of input poses."""
position = poses[:, :3, 3].mean(0)
z_axis = poses[:, :3, 2].mean(0)
up = poses[:, :3, 1].mean(0)
cam2world = viewmatrix(z_axis, up, position)
return cam2world
def viewmatrix(lookdir: np.ndarray, up: np.ndarray,
position: np.ndarray) -> np.ndarray:
"""Construct lookat view matrix."""
vec2 = normalize(lookdir)
vec0 = normalize(np.cross(up, vec2))
vec1 = normalize(np.cross(vec2, vec0))
m = np.stack([vec0, vec1, vec2, position], axis=1)
return m
def focus_point_fn(poses: np.ndarray) -> np.ndarray:
"""Calculate nearest point to all focal axes in poses."""
directions, origins = poses[:, :3, 2:3], poses[:, :3, 3:4]
m = np.eye(3) - directions * np.transpose(directions, [0, 2, 1])
mt_m = np.transpose(m, [0, 2, 1]) @ m
focus_pt = np.linalg.inv(mt_m.mean(0)) @ (mt_m @ origins).mean(0)[:, 0]
return focus_pt
def transform_poses_pca(poses: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""Transforms poses so principal components lie on XYZ axes.
Args:
poses: a (N, 3, 4) array containing the cameras' camera to world transforms.
Returns:
A tuple (poses, transform), with the transformed poses and the applied
camera_to_world transforms.
"""
t = poses[:, :3, 3]
t_mean = t.mean(axis=0)
t = t - t_mean
eigval, eigvec = np.linalg.eig(t.T @ t)
# Sort eigenvectors in order of largest to smallest eigenvalue.
inds = np.argsort(eigval)[::-1]
eigvec = eigvec[:, inds]
rot = eigvec.T
if np.linalg.det(rot) < 0:
rot = np.diag(np.array([1, 1, -1])) @ rot
transform = np.concatenate([rot, rot @ -t_mean[:, None]], -1)
poses_recentered = unpad_poses(transform @ pad_poses(poses))
transform = np.concatenate([transform, np.eye(4)[3:]], axis=0)
# Flip coordinate system if z component of y-axis is negative
if poses_recentered.mean(axis=0)[2, 1] < 0:
poses_recentered = np.diag(np.array([1, -1, -1])) @ poses_recentered
transform = np.diag(np.array([1, -1, -1, 1])) @ transform
return poses_recentered, transform
# points = np.random.rand(3,100)
# points_h = np.concatenate((points,np.ones_like(points[:1])), axis=0)
# (poses_recentered @ points_h)[0]
# (transform @ pad_poses(poses) @ points_h)[0,:3]
# import pdb; pdb.set_trace()
# # Just make sure it's it in the [-1, 1]^3 cube
# scale_factor = 1. / np.max(np.abs(poses_recentered[:, :3, 3]))
# poses_recentered[:, :3, 3] *= scale_factor
# transform = np.diag(np.array([scale_factor] * 3 + [1])) @ transform
# return poses_recentered, transform
def generate_ellipse_path(poses: np.ndarray,
n_frames: int = 120,
const_speed: bool = True,
z_variation: float = 0.,
z_phase: float = 0.) -> np.ndarray:
"""Generate an elliptical render path based on the given poses."""
# Calculate the focal point for the path (cameras point toward this).
center = focus_point_fn(poses)
# Path height sits at z=0 (in middle of zero-mean capture pattern).
offset = np.array([center[0], center[1], 0])
# Calculate scaling for ellipse axes based on input camera positions.
sc = np.percentile(np.abs(poses[:, :3, 3] - offset), 90, axis=0)
# Use ellipse that is symmetric about the focal point in xy.
low = -sc + offset
high = sc + offset
# Optional height variation need not be symmetric
z_low = np.percentile((poses[:, :3, 3]), 10, axis=0)
z_high = np.percentile((poses[:, :3, 3]), 90, axis=0)
def get_positions(theta):
# Interpolate between bounds with trig functions to get ellipse in x-y.
# Optionally also interpolate in z to change camera height along path.
return np.stack([
low[0] + (high - low)[0] * (np.cos(theta) * .5 + .5),
low[1] + (high - low)[1] * (np.sin(theta) * .5 + .5),
z_variation * (z_low[2] + (z_high - z_low)[2] *
(np.cos(theta + 2 * np.pi * z_phase) * .5 + .5)),
], -1)
theta = np.linspace(0, 2. * np.pi, n_frames + 1, endpoint=True)
positions = get_positions(theta)
#if const_speed:
# # Resample theta angles so that the velocity is closer to constant.
# lengths = np.linalg.norm(positions[1:] - positions[:-1], axis=-1)
# theta = stepfun.sample(None, theta, np.log(lengths), n_frames + 1)
# positions = get_positions(theta)
# Throw away duplicated last position.
positions = positions[:-1]
# Set path's up vector to axis closest to average of input pose up vectors.
avg_up = poses[:, :3, 1].mean(0)
avg_up = avg_up / np.linalg.norm(avg_up)
ind_up = np.argmax(np.abs(avg_up))
up = np.eye(3)[ind_up] * np.sign(avg_up[ind_up])
return np.stack([viewmatrix(p - center, up, p) for p in positions])
def generate_path(viewpoint_cameras, n_frames=480):
c2ws = np.array([np.linalg.inv(np.asarray((cam.world_view_transform.T).cpu().numpy())) for cam in viewpoint_cameras])
pose = c2ws[:,:3,:] @ np.diag([1, -1, -1, 1])
pose_recenter, colmap_to_world_transform = transform_poses_pca(pose)
# generate new poses
new_poses = generate_ellipse_path(poses=pose_recenter, n_frames=n_frames)
# warp back to orignal scale
new_poses = np.linalg.inv(colmap_to_world_transform) @ pad_poses(new_poses)
traj = []
for c2w in new_poses:
c2w = c2w @ np.diag([1, -1, -1, 1])
cam = copy.deepcopy(viewpoint_cameras[0])
cam.image_height = int(cam.image_height / 2) * 2
cam.image_width = int(cam.image_width / 2) * 2
cam.world_view_transform = torch.from_numpy(np.linalg.inv(c2w).T).float().cuda()
cam.full_proj_transform = (cam.world_view_transform.unsqueeze(0).bmm(cam.projection_matrix.unsqueeze(0))).squeeze(0)
cam.camera_center = cam.world_view_transform.inverse()[3, :3]
traj.append(cam)
return traj
def load_img(pth: str) -> np.ndarray:
"""Load an image and cast to float32."""
with open(pth, 'rb') as f:
image = np.array(Image.open(f), dtype=np.float32)
return image
def create_videos(base_dir, input_dir, out_name, num_frames=480):
"""Creates videos out of the images saved to disk."""
# Last two parts of checkpoint path are experiment name and scene name.
video_prefix = f'{out_name}'
zpad = max(5, len(str(num_frames - 1)))
idx_to_str = lambda idx: str(idx).zfill(zpad)
os.makedirs(base_dir, exist_ok=True)
render_dist_curve_fn = np.log
# Load one example frame to get image shape and depth range.
depth_file = os.path.join(input_dir, 'vis', f'depth_{idx_to_str(0)}.tiff')
depth_frame = load_img(depth_file)
shape = depth_frame.shape
p = 3
distance_limits = np.percentile(depth_frame.flatten(), [p, 100 - p])
lo, hi = [render_dist_curve_fn(x) for x in distance_limits]
print(f'Video shape is {shape[:2]}')
video_kwargs = {
'shape': shape[:2],
'codec': 'h264',
'fps': 60,
'crf': 18,
}
for k in ['depth', 'normal', 'color']:
video_file = os.path.join(base_dir, f'{video_prefix}_{k}.mp4')
input_format = 'gray' if k == 'alpha' else 'rgb'
file_ext = 'png' if k in ['color', 'normal'] else 'tiff'
idx = 0
if k == 'color':
file0 = os.path.join(input_dir, 'renders', f'{idx_to_str(0)}.{file_ext}')
else:
file0 = os.path.join(input_dir, 'vis', f'{k}_{idx_to_str(0)}.{file_ext}')
if not os.path.exists(file0):
print(f'Images missing for tag {k}')
continue
print(f'Making video {video_file}...')
with media.VideoWriter(
video_file, **video_kwargs, input_format=input_format) as writer:
for idx in tqdm(range(num_frames)):
# img_file = os.path.join(input_dir, f'{k}_{idx_to_str(idx)}.{file_ext}')
if k == 'color':
img_file = os.path.join(input_dir, 'renders', f'{idx_to_str(idx)}.{file_ext}')
else:
img_file = os.path.join(input_dir, 'vis', f'{k}_{idx_to_str(idx)}.{file_ext}')
if not os.path.exists(img_file):
ValueError(f'Image file {img_file} does not exist.')
img = load_img(img_file)
if k in ['color', 'normal']:
img = img / 255.
elif k.startswith('depth'):
img = render_dist_curve_fn(img)
img = np.clip((img - np.minimum(lo, hi)) / np.abs(hi - lo), 0, 1)
img = cm.get_cmap('turbo')(img)[..., :3]
frame = (np.clip(np.nan_to_num(img), 0., 1.) * 255.).astype(np.uint8)
writer.add_image(frame)
idx += 1
def save_img_u8(img, pth):
"""Save an image (probably RGB) in [0, 1] to disk as a uint8 PNG."""
with open(pth, 'wb') as f:
Image.fromarray(
(np.clip(np.nan_to_num(img), 0., 1.) * 255.).astype(np.uint8)).save(
f, 'PNG')
def save_img_f32(depthmap, pth):
"""Save an image (probably a depthmap) to disk as a float32 TIFF."""
with open(pth, 'wb') as f:
Image.fromarray(np.nan_to_num(depthmap).astype(np.float32)).save(f, 'TIFF')