Skip to content

Commit 408b373

Browse files
mollyxuMolly Xu
andauthored
Fallback to container duration in approximate mode (#989)
Co-authored-by: Molly Xu <[email protected]>
1 parent 04b02b9 commit 408b373

File tree

5 files changed

+34
-35
lines changed

5 files changed

+34
-35
lines changed

src/torchcodec/_core/Metadata.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ std::optional<double> StreamMetadata::getDurationSeconds(
2929
return static_cast<double>(numFramesFromHeader.value()) /
3030
averageFpsFromHeader.value();
3131
}
32+
if (durationSecondsFromContainer.has_value()) {
33+
return durationSecondsFromContainer.value();
34+
}
3235
return std::nullopt;
3336
default:
3437
TORCH_CHECK(false, "Unknown SeekMode");
@@ -80,13 +83,13 @@ std::optional<int64_t> StreamMetadata::getNumFrames(SeekMode seekMode) const {
8083
numFramesFromContent.has_value(), "Missing numFramesFromContent");
8184
return numFramesFromContent.value();
8285
case SeekMode::approximate: {
86+
auto durationSeconds = getDurationSeconds(seekMode);
8387
if (numFramesFromHeader.has_value()) {
8488
return numFramesFromHeader.value();
8589
}
86-
if (averageFpsFromHeader.has_value() &&
87-
durationSecondsFromHeader.has_value()) {
90+
if (averageFpsFromHeader.has_value() && durationSeconds.has_value()) {
8891
return static_cast<int64_t>(
89-
averageFpsFromHeader.value() * durationSecondsFromHeader.value());
92+
averageFpsFromHeader.value() * durationSeconds.value());
9093
}
9194
return std::nullopt;
9295
}

src/torchcodec/_core/Metadata.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ struct StreamMetadata {
3737
std::optional<double> averageFpsFromHeader;
3838
std::optional<double> bitRate;
3939

40+
// Used as fallback in approximate mode when stream duration is unavailable.
41+
std::optional<double> durationSecondsFromContainer;
42+
4043
// More accurate duration, obtained by scanning the file.
4144
// These presentation timestamps are in time base.
4245
std::optional<int64_t> beginStreamPtsFromContent;

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,26 @@ void SingleStreamDecoder::initializeDecoder() {
100100
"Failed to find stream info: ",
101101
getFFMPEGErrorStringFromErrorCode(status));
102102

103+
if (formatContext_->duration > 0) {
104+
AVRational defaultTimeBase{1, AV_TIME_BASE};
105+
containerMetadata_.durationSecondsFromHeader =
106+
ptsToSeconds(formatContext_->duration, defaultTimeBase);
107+
}
108+
109+
if (formatContext_->bit_rate > 0) {
110+
containerMetadata_.bitRate = formatContext_->bit_rate;
111+
}
112+
113+
int bestVideoStream = getBestStreamIndex(AVMEDIA_TYPE_VIDEO);
114+
if (bestVideoStream >= 0) {
115+
containerMetadata_.bestVideoStreamIndex = bestVideoStream;
116+
}
117+
118+
int bestAudioStream = getBestStreamIndex(AVMEDIA_TYPE_AUDIO);
119+
if (bestAudioStream >= 0) {
120+
containerMetadata_.bestAudioStreamIndex = bestAudioStream;
121+
}
122+
103123
for (unsigned int i = 0; i < formatContext_->nb_streams; i++) {
104124
AVStream* avStream = formatContext_->streams[i];
105125
StreamMetadata streamMetadata;
@@ -157,27 +177,10 @@ void SingleStreamDecoder::initializeDecoder() {
157177
containerMetadata_.numAudioStreams++;
158178
}
159179

160-
containerMetadata_.allStreamMetadata.push_back(streamMetadata);
161-
}
162-
163-
if (formatContext_->duration > 0) {
164-
AVRational defaultTimeBase{1, AV_TIME_BASE};
165-
containerMetadata_.durationSecondsFromHeader =
166-
ptsToSeconds(formatContext_->duration, defaultTimeBase);
167-
}
168-
169-
if (formatContext_->bit_rate > 0) {
170-
containerMetadata_.bitRate = formatContext_->bit_rate;
171-
}
172-
173-
int bestVideoStream = getBestStreamIndex(AVMEDIA_TYPE_VIDEO);
174-
if (bestVideoStream >= 0) {
175-
containerMetadata_.bestVideoStreamIndex = bestVideoStream;
176-
}
180+
streamMetadata.durationSecondsFromContainer =
181+
containerMetadata_.durationSecondsFromHeader;
177182

178-
int bestAudioStream = getBestStreamIndex(AVMEDIA_TYPE_AUDIO);
179-
if (bestAudioStream >= 0) {
180-
containerMetadata_.bestAudioStreamIndex = bestAudioStream;
183+
containerMetadata_.allStreamMetadata.push_back(streamMetadata);
181184
}
182185

183186
if (seekMode_ == SeekMode::exact) {

src/torchcodec/_core/_metadata.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ class StreamMetadata:
4444
from the actual frames if a :term:`scan` was performed. Otherwise we
4545
fall back to ``duration_seconds_from_header``. If that value is also None,
4646
we instead calculate the duration from ``num_frames_from_header`` and
47-
``average_fps_from_header``.
47+
``average_fps_from_header``. If all of those are unavailable, we fall back
48+
to the container-level ``duration_seconds_from_header``.
4849
"""
4950
begin_stream_seconds: Optional[float]
5051
"""Beginning of the stream, in seconds (float). Conceptually, this

test/test_decoders.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444
SINE_MONO_S32,
4545
SINE_MONO_S32_44100,
4646
SINE_MONO_S32_8000,
47-
supports_approximate_mode,
4847
TEST_SRC_2_720P,
4948
TEST_SRC_2_720P_H265,
5049
TEST_SRC_2_720P_MPEG4,
@@ -1465,8 +1464,6 @@ def test_get_frames_at_tensor_indices(self):
14651464
def test_beta_cuda_interface_get_frame_at(
14661465
self, asset, contiguous_indices, seek_mode
14671466
):
1468-
if seek_mode == "approximate" and not supports_approximate_mode(asset):
1469-
pytest.skip("asset doesn't work with approximate mode")
14701467

14711468
if in_fbcode() and asset is AV1_VIDEO:
14721469
pytest.skip("AV1 CUDA not supported internally")
@@ -1513,8 +1510,6 @@ def test_beta_cuda_interface_get_frame_at(
15131510
def test_beta_cuda_interface_get_frames_at(
15141511
self, asset, contiguous_indices, seek_mode
15151512
):
1516-
if seek_mode == "approximate" and not supports_approximate_mode(asset):
1517-
pytest.skip("asset doesn't work with approximate mode")
15181513
if in_fbcode() and asset is AV1_VIDEO:
15191514
pytest.skip("AV1 CUDA not supported internally")
15201515

@@ -1558,8 +1553,6 @@ def test_beta_cuda_interface_get_frames_at(
15581553
)
15591554
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
15601555
def test_beta_cuda_interface_get_frame_played_at(self, asset, seek_mode):
1561-
if seek_mode == "approximate" and not supports_approximate_mode(asset):
1562-
pytest.skip("asset doesn't work with approximate mode")
15631556
if in_fbcode() and asset is AV1_VIDEO:
15641557
pytest.skip("AV1 CUDA not supported internally")
15651558

@@ -1600,8 +1593,6 @@ def test_beta_cuda_interface_get_frame_played_at(self, asset, seek_mode):
16001593
)
16011594
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
16021595
def test_beta_cuda_interface_get_frames_played_at(self, asset, seek_mode):
1603-
if seek_mode == "approximate" and not supports_approximate_mode(asset):
1604-
pytest.skip("asset doesn't work with approximate mode")
16051596
if in_fbcode() and asset is AV1_VIDEO:
16061597
pytest.skip("AV1 CUDA not supported internally")
16071598

@@ -1643,8 +1634,6 @@ def test_beta_cuda_interface_get_frames_played_at(self, asset, seek_mode):
16431634
)
16441635
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
16451636
def test_beta_cuda_interface_backwards(self, asset, seek_mode):
1646-
if seek_mode == "approximate" and not supports_approximate_mode(asset):
1647-
pytest.skip("asset doesn't work with approximate mode")
16481637
if in_fbcode() and asset is AV1_VIDEO:
16491638
pytest.skip("AV1 CUDA not supported internally")
16501639

0 commit comments

Comments
 (0)