diff --git a/b3d/nvdiffrast_original/jax/jax_rasterize_gl.cpp b/b3d/nvdiffrast_original/jax/jax_rasterize_gl.cpp index 261e6c98..1c276ae6 100644 --- a/b3d/nvdiffrast_original/jax/jax_rasterize_gl.cpp +++ b/b3d/nvdiffrast_original/jax/jax_rasterize_gl.cpp @@ -18,7 +18,6 @@ RasterizeGLStateWrapper::RasterizeGLStateWrapper(bool enableDB, bool automatic_, cudaDeviceIdx = cudaDeviceIdx_; memset(pState, 0, sizeof(RasterizeGLState)); pState->enableDB = enableDB ? 1 : 0; - std::cout << "initialization" << std::endl; rasterizeInitGLContext(NVDR_CTX_PARAMS, *pState, cudaDeviceIdx_); releaseGLContext(); } @@ -63,10 +62,8 @@ void jax_rasterize_fwd_original_gl(cudaStream_t stream, resolution.resize(2); int ranges[2*d.num_objects]; - cudaStreamSynchronize(stream); NVDR_CHECK_CUDA_ERROR(cudaMemcpy(&resolution[0], _resolution, 2 * sizeof(int), cudaMemcpyDeviceToHost)); NVDR_CHECK_CUDA_ERROR(cudaMemcpy(&ranges[0], _ranges, 2 * d.num_objects * sizeof(int), cudaMemcpyDeviceToHost)); - cudaStreamSynchronize(stream); \ // std::cout << "num_images: " << d.num_images << std::endl; // std::cout << "num_objects: " << d.num_objects << std::endl; @@ -89,7 +86,7 @@ void jax_rasterize_fwd_original_gl(cudaStream_t stream, // const at::cuda::OptionalCUDAGuard device_guard(at::device_of(pos)); RasterizeGLState& s = *stateWrapper.pState; - + int posCount = 4 * d.num_images * d.num_vertices; int triCount = 3 * d.num_triangles; diff --git a/test/test_renderer_fps.py b/test/test_renderer_fps.py index fec50b3f..a9d245cf 100644 --- a/test/test_renderer_fps.py +++ b/test/test_renderer_fps.py @@ -1,17 +1,21 @@ -import argparse import os -import pathlib -import sys +import time + +import jax +import jax.numpy as jnp import numpy as np +import rerun as rr import torch -import imageio +import trimesh +from jax import lax + import b3d -from tqdm import tqdm -import jax.numpy as jnp import b3d.nvdiffrast_original.torch as dr -import time +from b3d.renderer_original import Renderer as RendererOriginal + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -import trimesh + + mesh_path = os.path.join( b3d.get_root_path(), "assets/shared_data_bucket/025_mug/textured.obj" @@ -54,7 +58,7 @@ glctx = dr.RasterizeGLContext() #if use_opengl else dr.RasterizeCudaContext() -import rerun as rr + rr.init("demo") rr.connect("127.0.0.1:8812") @@ -85,32 +89,27 @@ ranges_jax = jnp.array([[0, len(faces_jax)]]) poses = b3d.Pose.from_translation(jnp.array([0.0, 0.0, 5.1]))[None, None, ...] -import jax -print("JAX NVdiffrast Original") -from b3d.renderer_original import Renderer as RendererOriginal - +print("JAX NVdiffrast Original (+lax.scan, -stream synchronize)") for resolution in resolutions: renderer = RendererOriginal(resolution, resolution, 100.0, 100.0, resolution/2.0, resolution/2.0, 0.01, 10.0, num_layers=1) render_jit = jax.jit(renderer.rasterize) num_timestep = 1000 resolution_array = jnp.array([resolution, resolution]).astype(jnp.int32) - sum = 0 - start = time.time() - for _ in range(num_timestep): - output, = render_jit( - vertices_jax_4[None,...], faces_jax, ranges_jax, resolution_array + def body_fn(carry, _): + (output,) = render_jit( + vertices_jax_4[None, ...], faces_jax, ranges_jax, resolution_array ) - sum += output.sum() + return carry + output.sum(), None + start = time.time() + sum_result, _ = lax.scan(body_fn, init=0, length=num_timestep) + sum_result.block_until_ready() end = time.time() - print(sum) + print(sum_result) print(f"Resolution: {resolution}x{resolution}, FPS: {num_timestep/(end-start)}") - -import jax - -print("JAX") +print("JAX (with lax.scan)") for resolution in resolutions: renderer = b3d.Renderer(resolution, resolution, 100.0, 100.0, resolution/2.0, resolution/2.0, 0.01, 10.0, num_layers=1) render_jit = jax.jit(renderer.render_attribute_many) @@ -118,9 +117,12 @@ rr.log("jax", rr.Image(image[0])) num_timestep = 1000 + def body_fn(carry, _): + image, _ = render_jit(poses, vertices_jax, faces_jax, ranges_jax, vertex_colors_jax) + return carry + image.sum(), None start = time.time() - for _ in range(num_timestep): - image = render_jit(poses, vertices_jax, faces_jax, ranges_jax, vertex_colors_jax) + sum_result, _ = lax.scan(body_fn, init=0, length=num_timestep) + sum_result.block_until_ready() end = time.time() print(f"Resolution: {resolution}x{resolution}, FPS: {num_timestep/(end-start)}")