Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Debugging performance discrepancy between PyTorch and JAX variants of NVDiffrast #21

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions b3d/nvdiffrast_original/jax/jax_rasterize_gl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ RasterizeGLStateWrapper::RasterizeGLStateWrapper(bool enableDB, bool automatic_,
cudaDeviceIdx = cudaDeviceIdx_;
memset(pState, 0, sizeof(RasterizeGLState));
pState->enableDB = enableDB ? 1 : 0;
std::cout << "initialization" << std::endl;
rasterizeInitGLContext(NVDR_CTX_PARAMS, *pState, cudaDeviceIdx_);
releaseGLContext();
}
Expand Down Expand Up @@ -63,10 +62,8 @@ void jax_rasterize_fwd_original_gl(cudaStream_t stream,
resolution.resize(2);
int ranges[2*d.num_objects];

cudaStreamSynchronize(stream);
NVDR_CHECK_CUDA_ERROR(cudaMemcpy(&resolution[0], _resolution, 2 * sizeof(int), cudaMemcpyDeviceToHost));
NVDR_CHECK_CUDA_ERROR(cudaMemcpy(&ranges[0], _ranges, 2 * d.num_objects * sizeof(int), cudaMemcpyDeviceToHost));
cudaStreamSynchronize(stream);
\
// std::cout << "num_images: " << d.num_images << std::endl;
// std::cout << "num_objects: " << d.num_objects << std::endl;
Expand All @@ -89,7 +86,7 @@ void jax_rasterize_fwd_original_gl(cudaStream_t stream,
// const at::cuda::OptionalCUDAGuard device_guard(at::device_of(pos));
RasterizeGLState& s = *stateWrapper.pState;


int posCount = 4 * d.num_images * d.num_vertices;
int triCount = 3 * d.num_triangles;

Expand Down
54 changes: 28 additions & 26 deletions test/test_renderer_fps.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
import argparse
import os
import pathlib
import sys
import time

import jax
import jax.numpy as jnp
import numpy as np
import rerun as rr
import torch
import imageio
import trimesh
from jax import lax

import b3d
from tqdm import tqdm
import jax.numpy as jnp
import b3d.nvdiffrast_original.torch as dr
import time
from b3d.renderer_original import Renderer as RendererOriginal

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import trimesh



mesh_path = os.path.join(
b3d.get_root_path(), "assets/shared_data_bucket/025_mug/textured.obj"
Expand Down Expand Up @@ -54,7 +58,7 @@

glctx = dr.RasterizeGLContext() #if use_opengl else dr.RasterizeCudaContext()

import rerun as rr

rr.init("demo")
rr.connect("127.0.0.1:8812")

Expand Down Expand Up @@ -85,42 +89,40 @@
ranges_jax = jnp.array([[0, len(faces_jax)]])
poses = b3d.Pose.from_translation(jnp.array([0.0, 0.0, 5.1]))[None, None, ...]

import jax
print("JAX NVdiffrast Original")
from b3d.renderer_original import Renderer as RendererOriginal

print("JAX NVdiffrast Original (+lax.scan, -stream synchronize)")
for resolution in resolutions:
renderer = RendererOriginal(resolution, resolution, 100.0, 100.0, resolution/2.0, resolution/2.0, 0.01, 10.0, num_layers=1)
render_jit = jax.jit(renderer.rasterize)
num_timestep = 1000
resolution_array = jnp.array([resolution, resolution]).astype(jnp.int32)
sum = 0
start = time.time()
for _ in range(num_timestep):
output, = render_jit(
vertices_jax_4[None,...], faces_jax, ranges_jax, resolution_array
def body_fn(carry, _):
(output,) = render_jit(
vertices_jax_4[None, ...], faces_jax, ranges_jax, resolution_array
)
sum += output.sum()
return carry + output.sum(), None
start = time.time()
sum_result, _ = lax.scan(body_fn, init=0, length=num_timestep)
sum_result.block_until_ready()
end = time.time()
print(sum)
print(sum_result)

print(f"Resolution: {resolution}x{resolution}, FPS: {num_timestep/(end-start)}")



import jax

print("JAX")
print("JAX (with lax.scan)")
for resolution in resolutions:
renderer = b3d.Renderer(resolution, resolution, 100.0, 100.0, resolution/2.0, resolution/2.0, 0.01, 10.0, num_layers=1)
render_jit = jax.jit(renderer.render_attribute_many)
image,_ = render_jit(poses, vertices_jax, faces_jax, ranges_jax, vertex_colors_jax)
rr.log("jax", rr.Image(image[0]))

num_timestep = 1000
def body_fn(carry, _):
image, _ = render_jit(poses, vertices_jax, faces_jax, ranges_jax, vertex_colors_jax)
return carry + image.sum(), None
start = time.time()
for _ in range(num_timestep):
image = render_jit(poses, vertices_jax, faces_jax, ranges_jax, vertex_colors_jax)
sum_result, _ = lax.scan(body_fn, init=0, length=num_timestep)
sum_result.block_until_ready()
end = time.time()

print(f"Resolution: {resolution}x{resolution}, FPS: {num_timestep/(end-start)}")
Expand Down