Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: epipolar #144

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
1,052 changes: 1,052 additions & 0 deletions notebooks/mkl/06 - Gaussian posteriors copy 2.ipynb

Large diffs are not rendered by default.

1,416 changes: 1,416 additions & 0 deletions notebooks/mkl/07c - Inference - v1.ipynb

Large diffs are not rendered by default.

944 changes: 944 additions & 0 deletions notebooks/mkl/10b - Datasets TUM.ipynb

Large diffs are not rendered by default.

141 changes: 141 additions & 0 deletions notebooks/mkl/Collecting_Clicks.ipynb

Large diffs are not rendered by default.

1,500 changes: 749 additions & 751 deletions pixi.lock

Large diffs are not rendered by default.

37 changes: 24 additions & 13 deletions src/b3d/camera.py
Original file line number Diff line number Diff line change
@@ -92,7 +92,7 @@ def camera_from_screen_and_depth(


def camera_from_screen(uv: ScreenCoordinates, intrinsics) -> CameraCoordinates:
z = jnp.ones_like(uv.shape[-1:])
z = jnp.ones(uv.shape[:-1])
return camera_from_screen_and_depth(uv, z, intrinsics)


@@ -122,7 +122,9 @@ def camera_from_depth(z: DepthImage, intrinsics) -> CameraCoordinates:
unproject_depth = camera_from_depth


def screen_from_camera(xyz: CameraCoordinates, intrinsics) -> ScreenCoordinates:
def screen_from_camera(
xyz: CameraCoordinates, intrinsics, culling=False
) -> ScreenCoordinates:
"""
Maps to sensor coordintaes `uv` from camera coordinates `xyz`, which are
defined by $(u,v) = (u'/z,v'/z)$, where
@@ -138,25 +140,31 @@ def screen_from_camera(xyz: CameraCoordinates, intrinsics) -> ScreenCoordinates:
Returns:
(...,2) array of screen coordinates.
"""
# TODO: check this
xyz = jnp.clip(
xyz,
jnp.array([-jnp.inf, -jnp.inf, intrinsics.near]),
jnp.array([jnp.inf, jnp.inf, intrinsics.far]),
)
_, _, fx, fy, cx, cy, _, _ = intrinsics
_, _, fx, fy, cx, cy, near, far = intrinsics
x, y, z = xyz[..., 0], xyz[..., 1], xyz[..., 2]
u = x * fx / z + cx
v = y * fy / z + cy
u_ = x * fx / z + cx
v_ = y * fy / z + cy

# TODO: What is the right way of doing this? Returning infs?
in_range = ((near <= z) & (z <= far)) | (not culling)

u = jnp.where(in_range, u_, jnp.inf)
v = jnp.where(in_range, v_, jnp.inf)

return jnp.stack([u, v], axis=-1)


screen_from_xyz = screen_from_camera


def screen_from_world(x, cam, intr):
def screen_from_world(x, cam, intr, culling=False):
"""Maps to screen coordintaes `uv` from world coordinates `xyz`."""
return screen_from_camera(cam.inv().apply(x), intr)
return screen_from_camera(cam.inv().apply(x), intr, culling=culling)


def world_from_screen(uv, cam, intr):
"""Maps to world coordintaes `xyz` from screen coords `uv`."""
return cam.apply(camera_from_screen(uv, intr))


def camera_matrix_from_intrinsics(intr: Intrinsics) -> CameraMatrix3x3:
@@ -216,6 +224,9 @@ def homogeneous_coordinates(xs, z=jnp.array(1.0)):
return jnp.concatenate([xs, jnp.ones_like(xs[..., :1])], axis=-1) * z[..., None]


homogeneous = homogeneous_coordinates


def planar_coordinates(xs):
"""
Maps homogeneous to planar coordinates, eg.,
2 changes: 2 additions & 0 deletions src/b3d/chisight/sfm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .eight_point import *
from .epipolar import *
51 changes: 51 additions & 0 deletions src/b3d/chisight/sfm/camera_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""
Computation of the camera matrix from world points and their projections.
References:
> Terzakis--Lourakis, "A Consistently Fast and Globally Optimal Solution to the Perspective-n-Point Problem"
> Hartley--Zisserman, "Multiple View Geometry in Computer Vision", 2nd ed.
"""

from jax import numpy as jnp

from b3d.types import Matrix3x4, Point3D


def solve_camera_projection_constraints(Xs: Point3D, ys: Point3D) -> Matrix3x4:
"""
Solve for the camera projection matrix given 3D points and their 2D projections,
as described in Chapter 7 ("Computation of the Camera Matrix P") of
> Hartley--Zisserman, "Multiple View Geometry in Computer Vision" (2nd ed).
Args:
Xs: 3D points in world coordinates, shape (N, 3).
ys: Normalized image coordinates, shape (N, 2).
Returns:
Camera projection matrix, shape (3, 4).
"""
# We change notation from B3D notation
# to Hartley--Zisserman, for easy of comparison
X = Xs
x = ys[:, 0]
y = ys[:, 1]
w = ys[:, 2]
n = X.shape[0]

A = jnp.concatenate(
[
jnp.block(
[
[jnp.zeros(3), -w[i] * X[i], y[i] * X[i]],
[w[i] * X[i], jnp.zeros(3), -x[i] * X[i]],
[-y[i] * X[i], x[i] * X[i], jnp.zeros(3)],
]
)
for i in jnp.arange(n)
],
axis=0,
)

_, _, vt = jnp.linalg.svd(A)
P = vt[-1].reshape(3, 4)
return P
285 changes: 285 additions & 0 deletions src/b3d/chisight/sfm/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,285 @@
import glob
import os
import re
import subprocess
from pathlib import Path

import imageio.v3 as iio
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

from b3d.camera import Intrinsics, camera_from_depth, camera_from_screen_and_depth
from b3d.pose import Pose

_DOWNLOAD_BASH_SCRIPT = """#!/bin/bash
# Check if both sequence and target_folder arguments are provided
if [ $# -ne 2 ]; then
echo "Usage: $0 <sequence_url> <target_folder>"
exit 1
fi
# Assign arguments to variables
sequence_url=$1
target_folder=$2
sequence=$(basename $sequence_url)
echo "Downloading $sequence to $target_folder..."
# Ensure the target folder exists
mkdir -p "$target_folder"
# Download the file using wget
wget "$sequence_url" -P "$target_folder"
# Extract the tar.gz file
tar -xzf "$target_folder/$sequence" -C "$target_folder"
# Remove the tar.gz file
rm "$target_folder/$sequence"
"""


class TumRgbdData:
""" "
Helper class to handle RGB-D Sequences from the TUM RGBD SLAM benchmark dataset.
The dataset can be downloaded from the following link:
> https://cvg.cit.tum.de/data/datasets/rgbd-dataset/download
Example Usage:
```
# Grab a sequence URL from set
# a target folder to store the data
# > https://cvg.cit.tum.de/data/datasets/rgbd-dataset/download
sequence_url = "https://cvg.cit.tum.de/rgbd/dataset/freiburg1/rgbd_dataset_freiburg1_xyz.tgz"
target_folder = "~/workspace/rgbd_slam_dataset_freiburg"
# Download and extract the sequence data into
# a new folder under the target folder
sequence_folder = TumRgbdData._download_from_url(sequence_url, target_folder)
data = TumRgbdData(sequence_folder)
# Get the i'th RGB image
# Note that rgb, depth, and pose sequences are not synchronized, so the i'th RGB image
# and the i'th depth image and pose are not guaranteed to be from the same time.
i = 100
rgb = data.get_rgb(i)
# This returns i'th RGB image and the CLOSEST (in time) available depth image and pose
rgb, depth, pose = data.get_synced(i)
# Plot the RGB and depth images side by side
fig, axs = plt.subplots(1, 3, figsize=(10,5))
axs[0].imshow(rgb)
axs[1].imshow(np.where(depth>0, depth, np.nan))
axs[2].imshow(rgb, alpha=1.)
axs[2].imshow(np.where(depth>0, depth, np.nan), alpha=0.75)
```
"""

def __init__(self, path):
"""
Args:
path (str): Path to one ofthe the TUM RGB-D datasets, e.g.,
.../data/rgbd_dataset_freiburg2_desk
"""

self.path = path
self.name = path.stem

self.gt_data = np.loadtxt(
path / "groundtruth.txt",
comments="#",
dtype=[
("timestamp", "f8"),
("tx", "f8"),
("ty", "f8"),
("tz", "f8"),
("qx", "f8"),
("qy", "f8"),
("qz", "f8"),
("qw", "f8"),
],
)

self.rgb_data = np.loadtxt(
path / "rgb.txt",
comments="#",
dtype=[("timestamp", "f8"), ("filename", "U50")],
)

self.depth_data = np.loadtxt(
path / "depth.txt",
comments="#",
dtype=[("timestamp", "f8"), ("filename", "U50")],
)

@staticmethod
def _list_datasets(root):
wildcard = f"{str(Path(root))}/rgbd_dataset_freiburg*"
datasets = [Path(s).stem for s in glob.glob(wildcard) if os.path.isdir(s)]
return datasets

@staticmethod
def _show_available_datasets(root):
datasets = TumRgbdData._list_datasets(root)
print(f"\033[1mAvailable datasets\033[0m in \033[94m{str(root)}/...\033[0m:")
for i, name in enumerate(datasets):
print(f"\033[95m({i}) {name}\033[0m")
return datasets

@staticmethod
def _download_from_url(sequence_url, target_folder):
# Target folder for the sequence data
sequence_folder = Path(target_folder) / Path(sequence_url).stem
sequence_name = Path(sequence_url).stem
# Check if the target folder exists
if os.path.exists(sequence_folder):
print(f"Sequence \033[1m{sequence_name}\033[0m already exists.")
return sequence_folder

try:
# Execute the Bash script using subprocess and pass in the arguments
print("Downloading and extracting...this might take a minute....")
result = subprocess.run(
["bash", "-c", _DOWNLOAD_BASH_SCRIPT, "_", sequence_url, target_folder],
check=True,
text=True,
capture_output=True,
)

# Print the output of the script
print("Script Output:...\n", result.stdout)
print("Script executed successfully.")
return sequence_folder

except subprocess.CalledProcessError as e:
print(f"An error occurred while executing the script: {e}")
print(f"Error output: {e.stderr}")

def len(self):
"""Returns the number of (RGB) frames in the dataset."""
return len(self.rgb_data)

@property
def shape(self):
"""Returns the shape of the RGB images."""
return (self.len(),) + self.get_rgb(0).shape[:2]

def get_rgb(self, i):
"""Returns the RGB image at index i."""
return iio.imread(self.path / self.rgb_data[i][1])

def get_depth(self, i):
"""Returns the depth image at index i."""
return iio.imread(self.path / self.depth_data[i][1]) / 5_000

def get_pose(self, i):
"""Returns the pose at index i."""
_, tx, ty, tz, qx, qy, qz, qw = self.gt_data[i]
return Pose(jnp.array([tx, ty, tz]), jnp.array([qx, qy, qz, qw]))

def get_synced(self, i):
"""Returns the timestamp, RGB, depth image, and pose at index i."""
t = self.rgb_data[i]["timestamp"]
i_pose = np.argmin(np.abs(self.gt_data["timestamp"] - t))
i_depth = np.argmin(np.abs(self.depth_data["timestamp"] - t))
return self.get_rgb(i), self.get_depth(i_depth), self.get_pose(i_pose)

def get_timestamp(self, i):
return self.rgb_data[i]["timestamp"]

def __getitem__(self, i):
"""Returns the RGB, depth image, and pose at index i."""
if isinstance(i, int):
return self.get_synced(i)

if isinstance(i, slice):
i = range(*i.indices(len(self.rgb_data)))

rs, ds, ps = [], [], []
for j in i:
r, d, p = self.get_synced(j)
rs.append(r)
ds.append(d)
ps.append(p)

return np.array(rs), np.array(ds), Pose.stack_poses(ps)

def get_intrinsics(self, index=0):
"""Returns the camera intrinsics."""
# See
# > https://cvg.cit.tum.de/data/datasets/rgbd-dataset/file_formats#intrinsic_camera_calibration_of_the_kinect
intr0 = Intrinsics(640, 480, 525.0, 525.0, 319.5, 239.5, 1e-2, 1e-4)
intr1 = Intrinsics(640, 480, 517.3, 516.5, 318.6, 255.3, 1e-2, 1e-4)
intr2 = Intrinsics(640, 480, 520.9, 521.0, 325.1, 249.7, 1e-2, 1e-4)
intr3 = Intrinsics(640, 480, 535.4, 539.2, 320.1, 247.6, 1e-2, 1e-4)

return [intr0, intr1, intr2, intr3][index]

@staticmethod
def _extract_number_after_freiburg(input_string):
match = re.search(r"freiburg(\d+)", input_string)
if match:
return int(match.group(1))
else:
return None

def plot_multiple_frames(self, ids, axs=None):
n = len(ids)
if axs is None:
fig, axs = plt.subplots(1, n, figsize=(n * 3, 5))
for t, i in enumerate(ids):
rgb = self.get_rgb(i)
axs[t].set_title(f"Frame: {i}")
axs[t].imshow(rgb)

return fig, axs

def plot_synced(self, i, axs=None):
rgb, depth = self.get_synced(i)
if axs is None:
fig, axs = plt.subplots(1, 3, figsize=(10, 5))
axs[0].imshow(rgb)
axs[1].imshow(np.where(depth > 0, depth, np.nan))
axs[2].imshow(rgb, alpha=1.0)
axs[2].imshow(np.where(depth > 0, depth, np.nan), alpha=0.75)
return fig, axs

@classmethod
def _sequence_url_from_sequence_name(cls, sequence_name):
n = cls._extract_number_after_freiburg(sequence_name)
return f"https://cvg.cit.tum.de/rgbd/dataset/freiburg{n}/{sequence_name}.tgz"

def _get_colored_world_points(self, i, intr=None):
"""
Returns world points and their validity
given for a given frame `i`.
"""
intr = intr or self.get_intrinsics()
rgb_im, depth_im, cam = self.get_synced(i)
xs = camera_from_depth(depth_im, intr).reshape(-1, 3)
cs = rgb_im.reshape(-1, 3)
valid = (depth_im > 0).reshape(-1, 3)
return cam(xs), cs, valid

def _approxiate_world_points(self, uvs, i, intr=None):
"""
Returns world points and their validity given
2D sensor coordinates `uv` from a given frame `i`.
"""
intr = intr or self.get_intrinsics()
_, depth_im, cam = self.get_synced(i)
zs = vals_from_im(uvs, jnp.array(depth_im))
xs = camera_from_screen_and_depth(uvs, zs, intr)
valid = zs > 0
return cam(xs), valid


def val_from_im(uv, im):
return im[uv[1].astype(jnp.int32), uv[0].astype(jnp.int32)]


vals_from_im = jax.vmap(val_from_im, in_axes=(0, None))
343 changes: 343 additions & 0 deletions src/b3d/chisight/sfm/eight_point.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,343 @@
"""
Implementation of the eight point algorithm and other relevant pieces.
Notation and terminology:
Normalized image coordinates: "Normalized" means "intrinsics-free" coordinates, which
in our language are just the coordinates in the camera frame.
`x`: World coordinates
`y`: Camera coordinates (3D), i.e. normalized image coordinates.
`uv` or sometimes only `u`: Sensor coordinates (2D)
`intr`: Camera intrinsics
Example:
```
from b3d.sfm.eight_point import normalized_eight_point, poses_from_essential
# Load data and so on
...
# Choose a pair of frames
t0 = 0
t1 = 1
# Choose a subset of keypoints to run the algorithm on
key = keysplit(key)
sub = jax.random.choice(key, jnp.where(vis[t0]*vis[t1])[0], (8,), replace=False)
# Normalized image coordinates from sensor coordinates
ys0 = camera.camera_from_screen(uvs[t0,sub], intr)
ys1 = camera.camera_from_screen(uvs[t1,sub], intr)
# Estimate essential matrix and extract
# possible choices for the second camera pose
E = normalized_eight_point(ys0, ys1)
ps = poses_from_essential(E)
```
"""

from typing import Tuple

import jax
from jax import numpy as jnp

from b3d.pose import Pose
from b3d.types import Array, Int, Matrix3x3, Matrix3x4, Point3D


def cross_product_matrix(a) -> Matrix3x3:
"""
Returns matrix A such that for any vector b,
the cross product of a and b is given by Ab.
"""
# > https://en.wikipedia.org/wiki/Cross_product#Conversion_to_matrix_multiplication
return jnp.array([[0, -a[2], a[1]], [a[2], 0, -a[0]], [-a[1], a[0], 0]])


def essential_from_pose(p) -> Matrix3x3:
"""
Essential matrix from B3D camera pose.
Args:
`p`: Camera pose
Returns:
Essential matrix
"""
# Two things to note:
#
# 1) A camera projection matrix [R | t] that maps
# *world* coordinates x to *camera* coordinates y = Rx + t
# corresponds to a B3D camera pose with rotation `R.T` and
# translation `- R.T t` (this is the inverse of [R | t]).
#
# 2) Recall that the essential matrix for a camera
# projection matrix [R | t] is given by E = [t] R,
# where [t] denotes the matrix representation of the
# cross product with t
#
# Therefore, the essential matrix for a B3D camera pose p = Pose(x,q)
# is given by E = [-Q^T x] Q^T, where x is the position and Q
# is the rotation matrix.
x = p.pos
Q = p.rot.as_matrix()
return cross_product_matrix(-Q.T @ x) @ Q.T


def camera_projection_from_pose(p) -> Tuple[Matrix3x3, Point3D]:
"""
Returns camera projection data R,t from from B3D camera pose.
Args:
`p`: Camera pose
Returns:
Camera projection matrix entries R,t
"""
# Note: A camera projection matrix [R | t] that maps
# *world* coordinates x to *camera* coordinates y = Rx + t
# corresponds to a B3D camera pose with rotation `R.T` and
# translation `- R.T t` (this is the inverse of [R | t]).
x = p.pos
Q = p.rot.as_matrix()
return Q.T, -Q.T @ x


def camera_projection_matrix(p: Pose) -> Matrix3x4:
"""
Returns camera projection matrix P = [R | t] from from B3D camera pose.
Args:
`p`: Camera pose
Returns:
Camera projection matrix P = [R | t]
"""
# Note: A camera projection matrix [R | t] that maps
# *world* coordinates x to *camera* coordinates y = Rx + t
# corresponds to a B3D camera pose with rotation `R.T` and
# translation `- R.T t` (this is the inverse of [R | t]).
x = p.pos
Q = p.rot.as_matrix()
return jnp.concatenate([Q.T, (-Q.T @ x)[:, None]], axis=1)


def poses_from_essential(E: Matrix3x3) -> Pose:
"""
Extract the 4 possible choices for the second camera matrix
from essential matrix.
Args:
`E`: Essential matrix
Returns:
Stacked camera poses
"""
# According to [Hartley-Zisserman, "Multiple View Geometry" (2nd ed.), Result 9.19.],
# for a given essential matrix E = U diag(1, 1, 0)VT, and first camera matrix P = [I | 0],
# there are four possible choices for the second camera matrix P, namely
# P = [UWVT | +u3] or [UWVT |−u3] or [UWTVT | +u3] or [UWTVT |−u3].
u, _, vt = jnp.linalg.svd(E)
x = u[:, -1]
w = jnp.array([[0.0, -1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0]])
r0 = u @ w @ vt
r1 = u @ w.T @ vt

# Make sure these are orientation preserving; det(u) or det(vt) may be -1
# but we are enforcing the equality of the first two singular values so
# we can flip the first two rows / columns to flip the sign.
r0 *= jnp.sign(jnp.linalg.det(r0))
r1 *= jnp.sign(jnp.linalg.det(r1))

# Recall that a **camera projection** matrix [R | t] that maps
# world coordinates x to camera coordinates y = [R | t] hom(x)
# corresponds to a B3D **camera pose** with rotation `R.T` and
# translation `- R.T t`.
return Pose.stack_poses(
[
Pose.from_pos_matrix(-r0.T @ x, r0.T),
Pose.from_pos_matrix(-r1.T @ x, r1.T),
Pose.from_pos_matrix(r0.T @ x, r0.T),
Pose.from_pos_matrix(r1.T @ x, r1.T),
]
)


extract_poses = poses_from_essential


def solve_epipolar_constraints(ys0: Point3D, ys1: Point3D) -> Matrix3x3:
"""
Returns the essential matrix that minimizes
the epipolar constraint `y1.T E y0`.
"""
# We want to solve `y1.T E y0 = 0`, which can be
# rewritten as `(y0 kronecker Y1).T vec(E) = 0`,
# where "kronecker" denotes the Kronecker product and
# vec(E) is the vectorized form of E.
#
# Useful references:
# > https://en.wikipedia.org/wiki/Eight-point_algorithm#Step_1:_Formulating_a_homogeneous_linear_equation
# > https://en.wikipedia.org/wiki/Kronecker_product
Y = jax.vmap(jnp.kron)(ys1, ys0)
_, _, vt = jnp.linalg.svd(Y)
# Solution is the minimal right singular vector
e = vt[-1].reshape((3, 3))
return e


def epipolar_errors(E: Matrix3x3, ys0: Point3D, ys1: Point3D) -> Array:
"""
Compute the epipolar errors for a given essential matrix and
normalized image coordinates.
Args:
`E`: Essential matrix
`ys0`: Normalized image coordinates at time 0
`ys1`: Normalized image coordinates at time 1
Returns:
Array of Epipolar errors
"""
return jax.vmap(lambda y0, y1: jnp.abs(y1.T @ E @ y0))(ys0, ys1)


def enforce_internal_constraint(E_est: Matrix3x3) -> Matrix3x3:
"""Enforce the fundamental matrix rank constraint."""
# Following the recipe in Section 11.1 in Hartley-Zisserman
# we enforce the rank constraint by setting the smallest
# singular value to zero.
#
# However that does not enforce the constraint
# `EE^T - 1/2 trace(EE^T)I = 0` (https://en.wikipedia.org/wiki/Essential_matrix#Properties).
# Alternativeley one can set the diag matrix to diag(1,1,0).
# That ensures that the above constraint is satisfied.
u, s, vt = jnp.linalg.svd(E_est)
return u @ jnp.diag(jnp.array([s[0], s[1], 0.0])) @ vt


def normalize_hartley(x: Array) -> tuple[Array, Array]:
"""
Normalize a homogeneous batch to mean zero and range from -1 to 1;
as suggested in
> Hartley, "In defense of the eight point algorithm", 1997.
"""
x /= x[:, -1:]
u = x[:, :-1]
means = jnp.mean(u, axis=0)
maxes = jnp.max(jnp.abs(u - means), axis=0)
normalizer = jnp.diag(jnp.append(1 / maxes, 1.0)).at[:-1, -1].set(-means / maxes)
return (x @ normalizer.T, normalizer)


def _normalized_eight_point(ys0, ys1):
"""
Normalized 8-point algorithm estimating the essential matrix,
as described in
> Hartley-Zisserman, "Multiple view geometry in computer vision", 2nd ed., Algorithm 11.1.
Args:
`ys0`: Normalized 3D image coordinates at time 0
`ys1`: Normalized 3D image coordinates at time 1
Returns:
Estimated essential matrix
"""
ys0, T0 = normalize_hartley(ys0)
ys1, T1 = normalize_hartley(ys1)
E = solve_epipolar_constraints(ys0, ys1)
E = enforce_internal_constraint(E)
E = T1.T @ E @ T0
return E


normalized_eight_point = jax.jit(_normalized_eight_point)


def _triangulate_linear_hartley(cam0, cam1, y0, y1) -> Point3D:
"""
Linear triangulation method as described in
> Hartley-Zisserman, "Multiple view geometry in computer vision", 2nd ed.; Section 12.2
Args:
`cam0`: Camera pose at time 0
`cam1`: Camera pose at time 1
`y0`: Normalized image coordinates of keypoint at time 0
`y1`: Normalized image coordinates of keypoint at time 1
Returns:
Inferred world point.
"""
# We follow Section 12.2 in Hartley-Zisserman.
# First we map to notation from Hartley-Zisserman.
# Here "underscore" reads "prime", i.e.
# `x_` translates to `x'`
P = camera_projection_matrix(cam0)
P_ = camera_projection_matrix(cam1)
x, y = y0[:2]
x_, y_ = y1[:2]
A = jnp.array(
[x * P[2] - P[0], y * P[2] - P[1], x_ * P_[2] - P_[0], y_ * P_[2] - P_[1]]
)
_, _, vt = jnp.linalg.svd(A)
X = vt[-1]
# TODO: Obviously there is a problem when X[3] is zero. Review this.
# Hartley-Zisserman address this I think.
return X[:3] / X[3]


triangulate_linear_hartley = jax.vmap(_triangulate_linear_hartley, (None, None, 0, 0))


def in_front_count(cam0, cam1, xs_world: Point3D) -> Int:
"""Count the world points that are in front of both cameras."""
ys0 = cam0.inv()(xs_world)
ys1 = cam1.inv()(xs_world)
return jnp.sum((ys0[:, 2] > 0) & (ys1[:, 2] > 0))


def find_best_chirality(cams, ys0, ys1):
xss = jax.vmap(triangulate_linear_hartley, (None, 0, None, None))(
Pose.id(), cams, ys0, ys1
)
counts = jax.vmap(in_front_count, (None, 0, 0))(Pose.id(), cams, xss)
i = jnp.argmax(counts)
return cams[i], xss[i]


# TODO:
# - Check triangulation angles ("orthogonality score")
# - Asses posterior


def _triangulate_linear_midpoint(cam0, cam1, y0, y1):
"""
Returns the mid point of the line segment
between the two rays through the keypoints.
Args:
`cam0`: Camera pose at time 0
`cam1`: Camera pose at time 1
`y0`: Camera coordinates of keypoint at time 0
`y1`: Camera coordinates of keypoint at time 1
Returns:
Inferred world point.
"""
# We need to solve
# c0 + s0*v0 = c1 + s1*v1,
# where ci are the camera positions in the world and vi
# are world vectors through the image keypoints.
v0 = cam0(y0) - cam0.pos
v1 = cam1(y1) - cam1.pos

V = jnp.stack([v0, -v1], axis=1)
c = cam1.pos - cam0.pos

s = jnp.linalg.pinv(V) @ c[:, None]

xs = jnp.array([cam0.pos + s[0] * v0, cam1.pos + s[1] * v1])
x = xs.mean(0)
return x


triangulate_linear_midpoint = jax.vmap(_triangulate_linear_midpoint, (None, None, 0, 0))
418 changes: 418 additions & 0 deletions src/b3d/chisight/sfm/epipolar.py

Large diffs are not rendered by default.

109 changes: 109 additions & 0 deletions src/b3d/chisight/sfm/plotting_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import jax.numpy as jnp
import numpy as np
import rerun as rr
from sklearn.utils import Bunch


def create_box_mesh(dims=np.array([1.0, 1.0, 1.0])):
# Define the 8 vertices of the box
w, h, d = dims / 2.0
vertex_positions = np.array(
[
[-w, -h, -d],
[w, -h, -d],
[w, h, -d],
[-w, h, -d],
[-w, -h, d],
[w, -h, d],
[w, h, d],
[-w, h, d],
]
)

# Define the 12 triangles (two per face)
triangle_indices = np.array(
[
[0, 1, 2],
[0, 2, 3], # Front face
[4, 5, 6],
[4, 6, 7], # Back face
[0, 1, 5],
[0, 5, 4], # Bottom face
[2, 3, 7],
[2, 7, 6], # Top face
[0, 3, 7],
[0, 7, 4], # Left face
[1, 2, 6],
[1, 6, 5], # Right face
]
)
vertex_normals = vertex_positions

return vertex_positions, triangle_indices, vertex_normals


def create_box_mesh2(dims=np.array([1.0, 1.0, 1.0])):
# Define the 8 vertices of the box
w, h, d = dims / 2.0
vertex_positions = np.array(
[
[-w, -h, -d],
[w, -h, -d],
[w, h, -d],
[-w, h, -d],
[-2 * w, -2 * h, d],
[2 * w, -2 * h, d],
[2 * w, 2 * h, d],
[-2 * w, 2 * h, d],
]
)

# Define the 12 triangles (two per face)
triangle_indices = np.array(
[
[0, 1, 2],
[0, 2, 3], # Front face
[4, 5, 6],
[4, 6, 7], # Back face
[0, 1, 5],
[0, 5, 4], # Bottom face
[2, 3, 7],
[2, 7, 6], # Top face
[0, 3, 7],
[0, 7, 4], # Left face
[1, 2, 6],
[1, 6, 5], # Right face
]
)
vertex_normals = vertex_positions

return vertex_positions, triangle_indices, vertex_normals


def create_pose_bunch(
p, c=jnp.array([0.7, 0.7, 0.7]), s=1.0, dims=np.array([0.2, 0.2, 1.0])
):
vs, fs, ns = create_box_mesh2(dims=s * dims)

if c is None:
c = jnp.array([0.7, 0.7, 0.7])
cs = c[None, :] * jnp.ones((vs.shape[0], 1))

return Bunch(
vertex_positions=p(vs),
triangle_indices=fs,
vertex_normals=p.rot.apply(ns),
vertex_colors=cs,
)


def log_pose(
s, p, c=jnp.array([0.7, 0.7, 0.7]), scale=1.0, dims=np.array([0.2, 0.2, 1.0])
):
rr.log(
s,
rr.Mesh3D(
**create_pose_bunch(p, c=c, s=scale, dims=dims),
# mesh_material=rr.components.Material(albedo_factor=[255, 255, 255]),
),
)
9 changes: 9 additions & 0 deletions src/b3d/chisight/sfm/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import jax

from b3d.pose import uniform_pose_in_ball

vmap_uniform_pose = jax.jit(
jax.vmap(uniform_pose_in_ball.sample, (0, None, None, None))
)

# uniform_quat_samples_around_identity(key, N, rq)
4 changes: 3 additions & 1 deletion src/b3d/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from b3d.utils import get_shared

from .data_loader import *
from .feature_track_data import FeatureTrackData
from .mesh_data import MeshData
from .utils import *
from .video_input import VideoInput

__all__ = ["MeshData", "FeatureTrackData", "VideoInput"]
__all__ = ["MeshData", "FeatureTrackData", "VideoInput", "get_shared"]
39 changes: 31 additions & 8 deletions src/b3d/io/feature_track_data.py
Original file line number Diff line number Diff line change
@@ -13,9 +13,10 @@

DESCR = """
FeatureTrackData:
Timesteps: {data.uv.shape[0]}
Num Frames: {data.uv.shape[0]}
Num Keypoints: {data.uv.shape[1]}
Sensor shape (width x height): {data.rgb.shape[2]} x {data.rgb.shape[1]}
Image shape (width x height): {data.rgb.shape[2]} x {data.rgb.shape[1]}
FPS: {data.fps}
"""


@@ -132,9 +133,26 @@ def uv(self):
def vis(self):
return self.visibility

@property
def rgb_float(self):
rgb = self.rgbd_images[..., :3]
if rgb.max() > 1.0:
rgb = rgb / 255
return rgb

@property
def rgb_uint(self):
rgb = self.rgbd_images[..., :3]
if rgb.max() <= 1.0:
rgb = (rgb * 255).astype(jnp.uint8)
return rgb

@property
def rgb(self):
return self.rgbd_images[..., :3]
rgb = self.rgbd_images[..., :3]
if rgb.max() > 1.0:
rgb = rgb / 255
return self.rgb_float

@property
def visibility(self):
@@ -451,7 +469,9 @@ def min_2D_distance_at_frame0(self):
distances = jnp.where(jnp.eye(distances.shape[0]) == 1.0, jnp.inf, distances)
return jnp.min(distances)

def quick_plot(self, t=None, fname=None, ax=None, figsize=(3, 3), downsize=10):
def quick_plot(
self, t=None, fname=None, ax=None, figsize=(3, 3), downsize=10, ids=None
):
if t is None:
figsize = (figsize[0] * self.num_frames, figsize[1])

@@ -460,23 +480,26 @@ def quick_plot(self, t=None, fname=None, ax=None, figsize=(3, 3), downsize=10):
ax.set_aspect(1)
ax.axis("off")

rgb = downsize_images(self.rgb, downsize)
rgb = downsize_images(self.rgb_float, downsize)

if ids is None:
ids = np.where(self.vis[t])[0]

if t is None:
_h, w = self.rgb.shape[1:3]
ax.imshow(np.concatenate(rgb, axis=1))
ax.scatter(
*np.concatenate(
[
self.uv[t, self.vis[t]] / downsize
+ np.array([t * w, 0]) / downsize
self.uv[t, ids] / downsize + np.array([t * w, 0]) / downsize
for t in range(self.num_frames)
]
).T,
s=1,
)
else:
ax.imshow(rgb[t])
ax.scatter(*(self.uv[t, self.vis[t]] / downsize).T, s=1)
ax.scatter(*(self.uv[t, ids] / downsize).T, s=1)


### Filter 2D keypoints to ones that are sufficently distant ###
5 changes: 5 additions & 0 deletions src/b3d/pose/pose_utils.py
Original file line number Diff line number Diff line change
@@ -43,6 +43,11 @@ def volume_of_cap_around_north_pole(r):
return jnp.pi * (jnp.pi - (jnp.sin(2 * jnp.arccos(r)) + 2 * jnp.arccos(r)))


def uniform_quat_samples_around_identity(key, N, rq):
qs = unit_disc_to_sphere(rq * uniform_samples_from_disc(key, N, d=3))
return qs


def uniform_samples_from_SE3_around_identity(key, N, rx=1.0, rq=1.0):
"""
Returns N samples from SE(3) around the identity, where positions
6 changes: 6 additions & 0 deletions src/b3d/types.py
Original file line number Diff line number Diff line change
@@ -17,6 +17,12 @@
Vector = Array
Direction = Array
GaussianParticle = Any
Matrix3x3 = Matrix
Matrix3x4 = Matrix
Matrix4x4 = Matrix
Matrix3 = Matrix
Point3D = Array
Point2D = Array

Key: TypeAlias = jax.Array
Pytree = Any