Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/torchcodec/_core/FFMPEGCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,16 @@ int getNumChannels(const SharedAVCodecContext& avCodecContext) {
#endif
}

int getNumChannels(const AVCodecParameters* codecpar) {
TORCH_CHECK(codecpar != nullptr, "codecpar is null")
#if LIBAVFILTER_VERSION_MAJOR > 8 || \
(LIBAVFILTER_VERSION_MAJOR == 8 && LIBAVFILTER_VERSION_MINOR >= 44)
return codecpar->ch_layout.nb_channels;
#else
return codecpar->channels;
#endif
}

void setDefaultChannelLayout(
UniqueAVCodecContext& avCodecContext,
int numChannels) {
Expand Down
1 change: 1 addition & 0 deletions src/torchcodec/_core/FFMPEGCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ const AVPixelFormat* getSupportedPixelFormats(const AVCodec& avCodec);

int getNumChannels(const UniqueAVFrame& avFrame);
int getNumChannels(const SharedAVCodecContext& avCodecContext);
int getNumChannels(const AVCodecParameters* codecpar);

void setDefaultChannelLayout(
UniqueAVCodecContext& avCodecContext,
Expand Down
10 changes: 9 additions & 1 deletion src/torchcodec/_core/Metadata.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@ enum class SeekMode { exact, approximate, custom_frame_mappings };
struct StreamMetadata {
// Common (video and audio) fields derived from the AVStream.
int streamIndex;

// See this link for what various values are available:
// https://ffmpeg.org/doxygen/trunk/group__lavu__misc.html#ga9a84bba4713dfced21a1a56163be1f48
AVMediaType mediaType;

std::optional<AVCodecID> codecId;
std::optional<std::string> codecName;
std::optional<double> durationSecondsFromHeader;
Expand All @@ -39,13 +41,15 @@ struct StreamMetadata {
// These presentation timestamps are in time base.
std::optional<int64_t> beginStreamPtsFromContent;
std::optional<int64_t> endStreamPtsFromContent;

// These presentation timestamps are in seconds.
std::optional<double> beginStreamPtsSecondsFromContent;
std::optional<double> endStreamPtsSecondsFromContent;

// This can be useful for index-based seeking.
std::optional<int64_t> numFramesFromContent;

// Video-only fields derived from the AVCodecContext.
// Video-only fields
std::optional<int> width;
std::optional<int> height;
std::optional<AVRational> sampleAspectRatio;
Expand All @@ -67,13 +71,17 @@ struct ContainerMetadata {
std::vector<StreamMetadata> allStreamMetadata;
int numAudioStreams = 0;
int numVideoStreams = 0;

// Note that this is the container-level duration, which is usually the max
// of all stream durations available in the container.
std::optional<double> durationSecondsFromHeader;

// Total BitRate level information at the container level in bit/s
std::optional<double> bitRate;

// If set, this is the index to the default audio stream.
std::optional<int> bestAudioStreamIndex;

// If set, this is the index to the default video stream.
std::optional<int> bestVideoStreamIndex;
};
Expand Down
22 changes: 9 additions & 13 deletions src/torchcodec/_core/SingleStreamDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ void SingleStreamDecoder::initializeDecoder() {
", does not match AVStream's index, " +
std::to_string(avStream->index) + ".");
streamMetadata.streamIndex = i;
streamMetadata.mediaType = avStream->codecpar->codec_type;
streamMetadata.codecName = avcodec_get_name(avStream->codecpar->codec_id);
streamMetadata.mediaType = avStream->codecpar->codec_type;
streamMetadata.bitRate = avStream->codecpar->bit_rate;

int64_t frameCount = avStream->nb_frames;
Expand All @@ -133,10 +133,18 @@ void SingleStreamDecoder::initializeDecoder() {
if (fps > 0) {
streamMetadata.averageFpsFromHeader = fps;
}
streamMetadata.width = avStream->codecpar->width;
streamMetadata.height = avStream->codecpar->height;
streamMetadata.sampleAspectRatio =
avStream->codecpar->sample_aspect_ratio;
containerMetadata_.numVideoStreams++;
} else if (avStream->codecpar->codec_type == AVMEDIA_TYPE_AUDIO) {
AVSampleFormat format =
static_cast<AVSampleFormat>(avStream->codecpar->format);
streamMetadata.sampleRate =
static_cast<int64_t>(avStream->codecpar->sample_rate);
streamMetadata.numChannels =
static_cast<int64_t>(getNumChannels(avStream->codecpar));

// If the AVSampleFormat is not recognized, we get back nullptr. We have
// to make sure we don't initialize a std::string with nullptr. There's
Expand Down Expand Up @@ -524,11 +532,6 @@ void SingleStreamDecoder::addVideoStream(
auto& streamInfo = streamInfos_[activeStreamIndex_];
streamInfo.videoStreamOptions = videoStreamOptions;

streamMetadata.width = streamInfo.codecContext->width;
streamMetadata.height = streamInfo.codecContext->height;
streamMetadata.sampleAspectRatio =
streamInfo.codecContext->sample_aspect_ratio;

if (seekMode_ == SeekMode::custom_frame_mappings) {
TORCH_CHECK(
customFrameMappings.has_value(),
Expand Down Expand Up @@ -574,13 +577,6 @@ void SingleStreamDecoder::addAudioStream(
auto& streamInfo = streamInfos_[activeStreamIndex_];
streamInfo.audioStreamOptions = audioStreamOptions;

auto& streamMetadata =
containerMetadata_.allStreamMetadata[activeStreamIndex_];
streamMetadata.sampleRate =
static_cast<int64_t>(streamInfo.codecContext->sample_rate);
streamMetadata.numChannels =
static_cast<int64_t>(getNumChannels(streamInfo.codecContext));

// FFmpeg docs say that the decoder will try to decode natively in this
// format, if it can. Docs don't say what the decoder does when it doesn't
// support that format, but it looks like it does nothing, so this probably
Expand Down
62 changes: 44 additions & 18 deletions src/torchcodec/_core/custom_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,34 @@ SeekMode seekModeFromString(std::string_view seekMode) {
}
}

void writeFallbackBasedMetadata(
std::map<std::string, std::string>& map,
const StreamMetadata& streamMetadata,
SeekMode seekMode) {
auto durationSeconds = streamMetadata.getDurationSeconds(seekMode);
if (durationSeconds.has_value()) {
map["durationSeconds"] = std::to_string(durationSeconds.value());
}

auto numFrames = streamMetadata.getNumFrames(seekMode);
if (numFrames.has_value()) {
map["numFrames"] = std::to_string(numFrames.value());
}

double beginStreamSeconds = streamMetadata.getBeginStreamSeconds(seekMode);
map["beginStreamSeconds"] = std::to_string(beginStreamSeconds);

auto endStreamSeconds = streamMetadata.getEndStreamSeconds(seekMode);
if (endStreamSeconds.has_value()) {
map["endStreamSeconds"] = std::to_string(endStreamSeconds.value());
}

auto averageFps = streamMetadata.getAverageFps(seekMode);
if (averageFps.has_value()) {
map["averageFps"] = std::to_string(averageFps.value());
}
}

int checkedToPositiveInt(const std::string& str) {
int ret = 0;
try {
Expand Down Expand Up @@ -917,30 +945,28 @@ std::string get_stream_json_metadata(
// In approximate mode: content-based metadata does not exist for any stream.
// In custom_frame_mappings: content-based metadata exists only for the active
// stream.
//
// Our fallback logic assumes content-based metadata is available.
// It is available for decoding on the active stream, but would break
// when getting metadata from non-active streams.
if ((seekMode != SeekMode::custom_frame_mappings) ||
(seekMode == SeekMode::custom_frame_mappings &&
stream_index == activeStreamIndex)) {
if (streamMetadata.getDurationSeconds(seekMode).has_value()) {
map["durationSeconds"] =
std::to_string(streamMetadata.getDurationSeconds(seekMode).value());
}
if (streamMetadata.getNumFrames(seekMode).has_value()) {
map["numFrames"] =
std::to_string(streamMetadata.getNumFrames(seekMode).value());
}
map["beginStreamSeconds"] =
std::to_string(streamMetadata.getBeginStreamSeconds(seekMode));
if (streamMetadata.getEndStreamSeconds(seekMode).has_value()) {
map["endStreamSeconds"] =
std::to_string(streamMetadata.getEndStreamSeconds(seekMode).value());
}
if (streamMetadata.getAverageFps(seekMode).has_value()) {
map["averageFps"] =
std::to_string(streamMetadata.getAverageFps(seekMode).value());
}
writeFallbackBasedMetadata(map, streamMetadata, seekMode);
} else if (seekMode == SeekMode::custom_frame_mappings) {
// If this is not the active stream, then we don't have content-based
// metadata for custom frame mappings. In that case, we want the same
// behavior as we would get with approximate mode. Encoding this behavior in
// the fallback logic itself is tricky and not worth it for this corner
// case. So we hardcode in approximate mode.
//
// TODO: This hacky behavior is only necessary because the custom frame
// mapping is supplied in SingleStreamDecoder::addVideoStream() rather
// than in the constructor. And it's supplied to addVideoStream() and
// not the constructor because we need to know the stream index. If we
// can encode the relevant stream indices into custom frame mappings
// itself, then we can put it in the constructor.
writeFallbackBasedMetadata(map, streamMetadata, SeekMode::approximate);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand these workarounds are needed right now, but I have a really hard time cleanly reasoning about all the comments above.

We might eventually want to revisit the existence of addStream? Maybe we should just have a constructor, just like we do in Python. I think all the "add stream" logic is mainly a relic of when the decoder was potentially thought to be a multi-stream decoder, but it seems like it's hurting us now

We still want to enable existing use-cases of users getting metadata without having to scan, but I'm pretty sure we can support that by passing approximate mode (that's what we do from the public Python APIs).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NicolasHug, you're right, that's actually the cleanest resolution here: there's no value any more in differentiating the constructor from adding a stream. Created #1064 for follow-up.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To double check my understanding, the expected behaviour for custom frame mappings is

  • like exact mode for the active stream
  • like approximate mode for non active streams

The reason we have to make this distinction with these conditions is because we don't know the active stream index during construction (since we addStream separately from the constructor). Once we consolidate addStream into the constructor, we would be able to get rid of all the conditions and just call
writeFallbackBasedMetadata(map, streamMetadata, seekMode);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your understanding is correct @mollyxu !

}

return mapToJson(map);
Expand Down
24 changes: 16 additions & 8 deletions src/torchcodec/decoders/_audio_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,6 @@ def __init__(
torch._C._log_api_usage_once("torchcodec.decoders.AudioDecoder")
self._decoder = create_decoder(source=source, seek_mode="approximate")

core.add_audio_stream(
self._decoder,
stream_index=stream_index,
sample_rate=sample_rate,
num_channels=num_channels,
)

container_metadata = core.get_container_metadata(self._decoder)
self.stream_index = (
container_metadata.best_audio_stream_index
Expand All @@ -81,13 +74,28 @@ def __init__(
"The best audio stream is unknown and there is no specified stream. "
+ ERROR_REPORTING_INSTRUCTIONS
)
if self.stream_index >= len(container_metadata.streams):
raise ValueError(
f"The stream at index {stream_index} is not a valid stream."
)

self.metadata = container_metadata.streams[self.stream_index]
assert isinstance(self.metadata, core.AudioStreamMetadata) # mypy
if not isinstance(self.metadata, core._metadata.AudioStreamMetadata):
raise ValueError(
f"The stream at index {stream_index} is not an audio stream. "
)

self._desired_sample_rate = (
sample_rate if sample_rate is not None else self.metadata.sample_rate
)

core.add_audio_stream(
self._decoder,
stream_index=stream_index,
sample_rate=sample_rate,
num_channels=num_channels,
)

def get_all_samples(self) -> AudioSamples:
"""Returns all the audio samples from the source.

Expand Down
29 changes: 16 additions & 13 deletions src/torchcodec/decoders/_video_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,16 @@ def __init__(

self._decoder = create_decoder(source=source, seek_mode=seek_mode)

(
self.metadata,
self.stream_index,
self._begin_stream_seconds,
self._end_stream_seconds,
self._num_frames,
) = _get_and_validate_stream_metadata(
decoder=self._decoder, stream_index=stream_index
)

allowed_dimension_orders = ("NCHW", "NHWC")
if dimension_order not in allowed_dimension_orders:
raise ValueError(
Expand All @@ -157,12 +167,11 @@ def __init__(
device = str(device)

device_variant = _get_cuda_backend()

transform_specs = _make_transform_specs(transforms)

core.add_video_stream(
self._decoder,
stream_index=stream_index,
stream_index=self.stream_index,
dimension_order=dimension_order,
num_threads=num_ffmpeg_threads,
device=device,
Expand All @@ -171,16 +180,6 @@ def __init__(
custom_frame_mappings=custom_frame_mappings_data,
)

(
self.metadata,
self.stream_index,
self._begin_stream_seconds,
self._end_stream_seconds,
self._num_frames,
) = _get_and_validate_stream_metadata(
decoder=self._decoder, stream_index=stream_index
)

def __len__(self) -> int:
return self._num_frames

Expand Down Expand Up @@ -413,8 +412,12 @@ def _get_and_validate_stream_metadata(
+ ERROR_REPORTING_INSTRUCTIONS
)

if stream_index >= len(container_metadata.streams):
raise ValueError(f"The stream index {stream_index} is not a valid stream.")

metadata = container_metadata.streams[stream_index]
assert isinstance(metadata, core._metadata.VideoStreamMetadata) # mypy
if not isinstance(metadata, core._metadata.VideoStreamMetadata):
raise ValueError(f"The stream at index {stream_index} is not a video stream. ")

if metadata.begin_stream_seconds is None:
raise ValueError(
Expand Down
4 changes: 2 additions & 2 deletions test/test_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,11 @@ def test_create_fails(self, Decoder):
Decoder(123)

# stream index that does not exist
with pytest.raises(ValueError, match="No valid stream found"):
with pytest.raises(ValueError, match="40 is not a valid stream"):
Decoder(NASA_VIDEO.path, stream_index=40)

# stream index that does exist, but it's not audio or video
with pytest.raises(ValueError, match="No valid stream found"):
with pytest.raises(ValueError, match=r"not (a|an) (video|audio) stream"):
Decoder(NASA_VIDEO.path, stream_index=2)

# user mistakenly forgets to specify binary reading when creating a file
Expand Down
5 changes: 1 addition & 4 deletions test/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def test_get_metadata(metadata_getter):
)
if (seek_mode == "custom_frame_mappings") and get_ffmpeg_major_version() in (4, 5):
pytest.skip(reason="ffprobe isn't accurate on ffmpeg 4 and 5")
with_added_video_stream = seek_mode == "custom_frame_mappings"
metadata = metadata_getter(NASA_VIDEO.path)

with_scan = (
Expand Down Expand Up @@ -99,9 +98,7 @@ def test_get_metadata(metadata_getter):
assert best_video_stream_metadata.begin_stream_seconds_from_header == 0
assert best_video_stream_metadata.bit_rate == 128783
assert best_video_stream_metadata.average_fps == pytest.approx(29.97, abs=0.001)
assert best_video_stream_metadata.pixel_aspect_ratio == (
Fraction(1, 1) if with_added_video_stream else None
)
assert best_video_stream_metadata.pixel_aspect_ratio == Fraction(1, 1)
assert best_video_stream_metadata.codec == "h264"
assert best_video_stream_metadata.num_frames_from_content == (
390 if with_scan else None
Expand Down
Loading