Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 11 additions & 71 deletions transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,11 @@ class VideoFilePathToTensor(object):
It can be composed with torchvision.transforms.Compose().

Args:
max_len (int): Maximum output time depth (L <= max_len). Default is None.
If it is set to None, it will output all frames.
fps (int): sample frame per seconds. It must lower than or equal the origin video fps.
Default is None.
padding_mode (str): Type of padding. Default to None. Only available when max_len is not None.
- None: won't padding, video length is variable.
- 'zero': padding the rest empty frames to zeros.
- 'last': padding the rest empty frames to the last frame.
fps (int): The fps, number of frames per second (size of the L dimension)
"""

def __init__(self, max_len=None, fps=None, padding_mode=None):
self.max_len = max_len
def __init__(self, fps=None):
self.fps = fps
assert padding_mode in (None, 'zero', 'last')
self.padding_mode = padding_mode
self.channels = 3 # only available to read 3 channels video

def __call__(self, path):
Expand All @@ -38,67 +28,17 @@ def __call__(self, path):
path (str): path of video file.

Returns:
torch.Tensor: Video Tensor (C x L x H x W)
torch.Tensor: Video Tensor (L x C x H x W)
"""

# open video file
cap = cv2.VideoCapture(path)
assert(cap.isOpened())

# calculate sample_factor to reset fps
sample_factor = 1
if self.fps:
old_fps = cap.get(cv2.CAP_PROP_FPS) # fps of video
sample_factor = int(old_fps / self.fps)
assert(sample_factor >= 1)

# init empty output frames (C x L x H x W)
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

time_len = None
if self.max_len:
# time length has upper bound
if self.padding_mode:
# padding all video to the same time length
time_len = self.max_len
else:
# video have variable time length
time_len = min(int(num_frames / sample_factor), self.max_len)
else:
# time length is unlimited
time_len = int(num_frames / sample_factor)

frames = torch.FloatTensor(self.channels, time_len, height, width)

for index in range(time_len):
frame_index = sample_factor * index

# read frame
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
ret, frame = cap.read()
if ret:
# successfully read frame
# BGR to RGB
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = torch.from_numpy(frame)
# (H x W x C) to (C x H x W)
frame = frame.permute(2, 0, 1)
frames[:, index, :, :] = frame.float()
else:
# reach the end of the video
if self.padding_mode == 'zero':
# fill the rest frames with 0.0
frames[:, index:, :, :] = 0
elif self.padding_mode == 'last':
# fill the rest frames with the last frame
assert(index > 0)
frames[:, index:, :, :] = frames[:, index-1, :, :].view(self.channels, 1, height, width)
break
frames,_,_ = torchvision.io.read_video(path)
frames = frames.transpose(3,1).transpose(2,3) # T C H W

factor = frames.shape[0] // self.fps

frames /= 255
cap.release()
indexes = [i for i in range(0,frames.shape[0],factor+1)]
frames = frames[indexes]

return frames


Expand Down Expand Up @@ -372,4 +312,4 @@ def __call__(self, video):
return grayscaled_video