Skip to content

Support TensorRT for video decoding #60

Open
kingjulio8238 wants to merge 9 commits into
Overworldai:mainfrom
kingjulio8238:main
Open

Support TensorRT for video decoding #60
kingjulio8238 wants to merge 9 commits into
Overworldai:mainfrom
kingjulio8238:main

Conversation

@kingjulio8238
Copy link
Copy Markdown

  • 2x faster video decoding using TensorRT
  • Fallback on PyTorch compile: mode=max-autotune, dynamic=False, fullgraph=True

@kingjulio8238
Copy link
Copy Markdown
Author

Results

num_frames=1 ; bs=1
tensorrt average latency: 51.09ms
PyTorch (compile) average latency: 49.20ms
PyTorch (baseline) average latency: 91.58ms
tensorrt improvement over PyTorch (eager): 1.79x
tensorrt improvement over PyTorch (compiled): 0.96x

num_frames=8 ; bs=1
tensorrt average latency: 411.56ms
PyTorch (compile) average latency: 373.24ms
PyTorch (baseline) average latency: 790.41ms
tensorrt improvement over PyTorch (eager): 1.92x
tensorrt improvement over PyTorch (compiled): 0.91x

num_frames=4 ; bs=2
tensorrt average latency: 404.63ms
PyTorch (compile) average latency: 372.81ms
PyTorch (baseline) average latency: 790.85ms
tensorrt improvement over PyTorch (eager): 1.95x
tensorrt improvement over PyTorch (compiled): 0.92x

num_frames=8 ; bs=2
tensorrt average latency: 808.05ms
PyTorch (compile) average latency: 817.54ms
PyTorch (baseline) average latency: 1669.52ms
tensorrt improvement over PyTorch (eager): 2.07x
tensorrt improvement over PyTorch (compiled): 1.01x

@kingjulio8238
Copy link
Copy Markdown
Author

kingjulio8238 commented Aug 30, 2025

Instructions

  • Initialize owl-vaes as git submodule
  • Log into HF (if VAE is on HF)
  • Download vae to owl-vaes/checkpoints/cod_yt_v2/step_515000.pt
  • pip install omegaconf diffusers einops rotary-embedding-torch>=0.8.8 onnx>=1.17.0 onnxruntime-gpu>=1.19.0 tensorrt==10.1.0 --extra-index-url https://pypi.nvidia.com
  • Run benchmark: python benchmark_decoder.py

Benchmark script:

#!/usr/bin/env python3
"""
Simple benchmark script for comparing TensorRT, PyTorch baseline, and PyTorch compiled performance
for video decoding using the owl_vae_bridge module with fixed batch_size=1 and num_frames=8.
"""

import time
import statistics
import os
import sys
import logging
import torch

# Add owl-vaes to path
sys.path.append("./owl-vaes")

from owl_wms.utils.owl_vae_bridge import get_cod_yt_v2_decoder, make_batched_decode_fn

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def setup_environment():
    """Setup environment variables for TensorRT and ONNX export."""
    if not torch.cuda.is_available():
        raise RuntimeError("CUDA is required for this benchmark")
    os.environ["TENSORRT_ENABLED"] = "true"
    os.environ["TENSORRT_PRECISION"] = "fp16"
    os.environ["TENSORRT_CACHE_DIR"] = "./tensorrt_cache"
    os.environ["REPLACE_LANDSCAPE"] = "1"
    os.environ["REPLACE_SANA"] = "1"
    os.environ["SIMPLIFY_ATTENTION"] = "1"
    logger.info(f"CUDA Device: {torch.cuda.get_device_name()}")

def generate_test_data(batch_size: int, num_frames: int) -> torch.Tensor:
    """Generate test latent data with shape [batch_size, num_frames, 128, 8, 8]."""
    return torch.randn(batch_size, num_frames, 128, 8, 8, dtype=torch.bfloat16, device='cuda')

def benchmark_pytorch_eager(decoder, test_data: torch.Tensor) -> float:
    """Benchmark PyTorch eager performance, returning mean time in ms."""
    logger.info("Benchmarking PyTorch (eager)...")
    # Reshape input: [batch, num_frames, channels, height, width] -> [batch*num_frames, channels, height, width]
    b, n, c, h, w = test_data.shape
    reshaped_data = test_data.view(b * n, c, h, w)
    
    # Warmup
    with torch.no_grad():
        for _ in range(2):
            _ = decoder(reshaped_data)
    torch.cuda.synchronize()
    # Benchmark
    times = []
    for _ in range(5):
        torch.cuda.synchronize()
        start_time = time.perf_counter()
        with torch.no_grad():
            _ = decoder(reshaped_data)
        torch.cuda.synchronize()
        times.append((time.perf_counter() - start_time) * 1000)
    return statistics.mean(times)

def benchmark_pytorch_compiled(decoder, test_data: torch.Tensor) -> float:
    """Benchmark PyTorch compiled performance, returning mean time in ms."""
    logger.info("Benchmarking PyTorch (compiled)...")
    try:
        compiled_decoder = torch.compile(decoder, mode="max-autotune", dynamic=False, fullgraph=True)
        logger.info("Successfully compiled PyTorch decoder")
    except Exception as e:
        logger.warning(f"Failed to compile PyTorch decoder: {e}")
        return float('inf')  # Indicate failure
    
    # Reshape input: [batch, num_frames, channels, height, width] -> [batch*num_frames, channels, height, width]
    b, n, c, h, w = test_data.shape
    reshaped_data = test_data.view(b * n, c, h, w)
    
    # Warmup
    with torch.no_grad():
        for _ in range(2):
            _ = compiled_decoder(reshaped_data)
    torch.cuda.synchronize()
    # Benchmark
    times = []
    for _ in range(5):
        torch.cuda.synchronize()
        start_time = time.perf_counter()
        with torch.no_grad():
            _ = compiled_decoder(reshaped_data)
        torch.cuda.synchronize()
        times.append((time.perf_counter() - start_time) * 1000)
    return statistics.mean(times)

def benchmark_tensorrt(decoder, test_data: torch.Tensor, batch_size: int) -> float:
    """Benchmark TensorRT performance, returning mean time in ms."""
    logger.info("Benchmarking TensorRT...")
    try:
        tensorrt_decode_fn = make_batched_decode_fn(decoder, batch_size=batch_size, use_tensorrt=True)
        logger.info("Successfully created TensorRT decode function")
    except Exception as e:
        logger.warning(f"Failed to create TensorRT decode function: {e}")
        return float('inf')  # Indicate failure
    # Warmup
    with torch.no_grad():
        for _ in range(2):
            _ = tensorrt_decode_fn(test_data)
    torch.cuda.synchronize()
    # Benchmark
    times = []
    for _ in range(5):
        torch.cuda.synchronize()
        start_time = time.perf_counter()
        with torch.no_grad():
            _ = tensorrt_decode_fn(test_data)
        torch.cuda.synchronize()
        times.append((time.perf_counter() - start_time) * 1000)
    return statistics.mean(times)

def main():
    # Fixed parameters
    batch_size = 2
    num_frames = 8
    
    setup_environment()
    logger.info(f"Loading COD-YT v2 decoder for batch_size={batch_size}, num_frames={num_frames}...")
    try:
        decoder = get_cod_yt_v2_decoder().cuda().bfloat16().eval()
        logger.info("Successfully loaded decoder")
    except Exception as e:
        logger.error(f"Failed to load decoder: {e}")
        return
    
    test_data = generate_test_data(batch_size, num_frames)
    logger.info(f"Input shape: {test_data.shape}")
    
    # Run benchmarks
    eager_time = benchmark_pytorch_eager(decoder, test_data)
    compiled_time = benchmark_pytorch_compiled(decoder, test_data)
    tensorrt_time = benchmark_tensorrt(decoder, test_data, batch_size)
    
    # Calculate speedups
    tensorrt_vs_eager = eager_time / tensorrt_time if tensorrt_time != float('inf') else float('inf')
    tensorrt_vs_compiled = compiled_time / tensorrt_time if tensorrt_time != float('inf') else float('inf')
    
    # Print results
    print(f"\ntensorrt average latency: {tensorrt_time:.2f}ms")
    print(f"PyTorch (compile) average latency: {compiled_time:.2f}ms")
    print(f"PyTorch (baseline) average latency: {eager_time:.2f}ms")
    print(f"tensorrt improvement over PyTorch (eager): {tensorrt_vs_eager:.2f}x")
    print(f"tensorrt improvement over PyTorch (compiled): {tensorrt_vs_compiled:.2f}x")

if __name__ == "__main__":
    main()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant