Skip to content

Commit cac99ae

Browse files
authored
Accept float frame_rate in VideoEncoder (#1061)
1 parent 408b373 commit cac99ae

File tree

6 files changed

+106
-58
lines changed

6 files changed

+106
-58
lines changed

src/torchcodec/_core/Encoder.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -662,7 +662,7 @@ VideoEncoder::~VideoEncoder() {
662662

663663
VideoEncoder::VideoEncoder(
664664
const torch::Tensor& frames,
665-
int frameRate,
665+
double frameRate,
666666
std::string_view fileName,
667667
const VideoStreamOptions& videoStreamOptions)
668668
: frames_(validateFrames(frames)), inFrameRate_(frameRate) {
@@ -694,7 +694,7 @@ VideoEncoder::VideoEncoder(
694694

695695
VideoEncoder::VideoEncoder(
696696
const torch::Tensor& frames,
697-
int frameRate,
697+
double frameRate,
698698
std::string_view formatName,
699699
std::unique_ptr<AVIOContextHolder> avioContextHolder,
700700
const VideoStreamOptions& videoStreamOptions)
@@ -787,9 +787,9 @@ void VideoEncoder::initializeEncoder(
787787
avCodecContext_->width = outWidth_;
788788
avCodecContext_->height = outHeight_;
789789
avCodecContext_->pix_fmt = outPixelFormat_;
790-
// TODO-VideoEncoder: Verify that frame_rate and time_base are correct
791-
avCodecContext_->time_base = {1, inFrameRate_};
792-
avCodecContext_->framerate = {inFrameRate_, 1};
790+
// TODO-VideoEncoder: Add and utilize output frame_rate option
791+
avCodecContext_->framerate = av_d2q(inFrameRate_, INT_MAX);
792+
avCodecContext_->time_base = av_inv_q(avCodecContext_->framerate);
793793

794794
// Set flag for containers that require extradata to be in the codec context
795795
if (avFormatContext_->oformat->flags & AVFMT_GLOBALHEADER) {
@@ -833,6 +833,10 @@ void VideoEncoder::initializeEncoder(
833833

834834
// Set the stream time base to encode correct frame timestamps
835835
avStream_->time_base = avCodecContext_->time_base;
836+
// Set the stream frame rate to store correct frame durations for some
837+
// containers (webm, mkv)
838+
avStream_->r_frame_rate = avCodecContext_->framerate;
839+
836840
status = avcodec_parameters_from_context(
837841
avStream_->codecpar, avCodecContext_.get());
838842
TORCH_CHECK(

src/torchcodec/_core/Encoder.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,13 +143,13 @@ class VideoEncoder {
143143

144144
VideoEncoder(
145145
const torch::Tensor& frames,
146-
int frameRate,
146+
double frameRate,
147147
std::string_view fileName,
148148
const VideoStreamOptions& videoStreamOptions);
149149

150150
VideoEncoder(
151151
const torch::Tensor& frames,
152-
int frameRate,
152+
double frameRate,
153153
std::string_view formatName,
154154
std::unique_ptr<AVIOContextHolder> avioContextHolder,
155155
const VideoStreamOptions& videoStreamOptions);
@@ -172,7 +172,7 @@ class VideoEncoder {
172172
UniqueSwsContext swsContext_;
173173

174174
const torch::Tensor frames_;
175-
int inFrameRate_;
175+
double inFrameRate_;
176176

177177
int inWidth_ = -1;
178178
int inHeight_ = -1;

src/torchcodec/_core/custom_ops.cpp

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@ TORCH_LIBRARY(torchcodec_ns, m) {
3737
m.def(
3838
"_encode_audio_to_file_like(Tensor samples, int sample_rate, str format, int file_like_context, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()");
3939
m.def(
40-
"encode_video_to_file(Tensor frames, int frame_rate, str filename, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> ()");
40+
"encode_video_to_file(Tensor frames, float frame_rate, str filename, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> ()");
4141
m.def(
42-
"encode_video_to_tensor(Tensor frames, int frame_rate, str format, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> Tensor");
42+
"encode_video_to_tensor(Tensor frames, float frame_rate, str format, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> Tensor");
4343
m.def(
44-
"_encode_video_to_file_like(Tensor frames, int frame_rate, str format, int file_like_context, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> ()");
44+
"_encode_video_to_file_like(Tensor frames, float frame_rate, str format, int file_like_context, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> ()");
4545
m.def(
4646
"create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor");
4747
m.def(
@@ -639,7 +639,7 @@ void _encode_audio_to_file_like(
639639

640640
void encode_video_to_file(
641641
const at::Tensor& frames,
642-
int64_t frame_rate,
642+
double frame_rate,
643643
std::string_view file_name,
644644
std::optional<std::string_view> codec = std::nullopt,
645645
std::optional<std::string_view> pixel_format = std::nullopt,
@@ -657,17 +657,12 @@ void encode_video_to_file(
657657
unflattenExtraOptions(extra_options.value());
658658
}
659659

660-
VideoEncoder(
661-
frames,
662-
validateInt64ToInt(frame_rate, "frame_rate"),
663-
file_name,
664-
videoStreamOptions)
665-
.encode();
660+
VideoEncoder(frames, frame_rate, file_name, videoStreamOptions).encode();
666661
}
667662

668663
at::Tensor encode_video_to_tensor(
669664
const at::Tensor& frames,
670-
int64_t frame_rate,
665+
double frame_rate,
671666
std::string_view format,
672667
std::optional<std::string_view> codec = std::nullopt,
673668
std::optional<std::string_view> pixel_format = std::nullopt,
@@ -688,7 +683,7 @@ at::Tensor encode_video_to_tensor(
688683

689684
return VideoEncoder(
690685
frames,
691-
validateInt64ToInt(frame_rate, "frame_rate"),
686+
frame_rate,
692687
format,
693688
std::move(avioContextHolder),
694689
videoStreamOptions)
@@ -697,7 +692,7 @@ at::Tensor encode_video_to_tensor(
697692

698693
void _encode_video_to_file_like(
699694
const at::Tensor& frames,
700-
int64_t frame_rate,
695+
double frame_rate,
701696
std::string_view format,
702697
int64_t file_like_context,
703698
std::optional<std::string_view> codec = std::nullopt,
@@ -724,7 +719,7 @@ void _encode_video_to_file_like(
724719

725720
VideoEncoder encoder(
726721
frames,
727-
validateInt64ToInt(frame_rate, "frame_rate"),
722+
frame_rate,
728723
format,
729724
std::move(avioContextHolder),
730725
videoStreamOptions);

src/torchcodec/_core/ops.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def encode_audio_to_file_like(
210210

211211
def encode_video_to_file_like(
212212
frames: torch.Tensor,
213-
frame_rate: int,
213+
frame_rate: float,
214214
format: str,
215215
file_like: Union[io.RawIOBase, io.BufferedIOBase],
216216
codec: Optional[str] = None,
@@ -329,7 +329,7 @@ def _encode_audio_to_file_like_abstract(
329329
@register_fake("torchcodec_ns::encode_video_to_file")
330330
def encode_video_to_file_abstract(
331331
frames: torch.Tensor,
332-
frame_rate: int,
332+
frame_rate: float,
333333
filename: str,
334334
codec: Optional[str] = None,
335335
pixel_format: Optional[str] = None,
@@ -343,7 +343,7 @@ def encode_video_to_file_abstract(
343343
@register_fake("torchcodec_ns::encode_video_to_tensor")
344344
def encode_video_to_tensor_abstract(
345345
frames: torch.Tensor,
346-
frame_rate: int,
346+
frame_rate: float,
347347
format: str,
348348
codec: Optional[str] = None,
349349
pixel_format: Optional[str] = None,
@@ -357,7 +357,7 @@ def encode_video_to_tensor_abstract(
357357
@register_fake("torchcodec_ns::_encode_video_to_file_like")
358358
def _encode_video_to_file_like_abstract(
359359
frames: torch.Tensor,
360-
frame_rate: int,
360+
frame_rate: float,
361361
format: str,
362362
file_like_context: int,
363363
codec: Optional[str] = None,

src/torchcodec/encoders/_video_encoder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@ class VideoEncoder:
1515
tensor of shape ``(N, C, H, W)`` where N is the number of frames,
1616
C is 3 channels (RGB), H is height, and W is width.
1717
Values must be uint8 in the range ``[0, 255]``.
18-
frame_rate (int): The frame rate of the **input** ``frames``. Also defines the encoded **output** frame rate.
18+
frame_rate (float): The frame rate of the **input** ``frames``. Also defines the encoded **output** frame rate.
1919
"""
2020

21-
def __init__(self, frames: Tensor, *, frame_rate: int):
21+
def __init__(self, frames: Tensor, *, frame_rate: float):
2222
torch._C._log_api_usage_once("torchcodec.encoders.VideoEncoder")
2323
if not isinstance(frames, Tensor):
2424
raise ValueError(f"Expected frames to be a Tensor, got {type(frames) = }.")

0 commit comments

Comments
 (0)