diff --git a/transforms.py b/transforms.py index 071bf56..0761368 100644 --- a/transforms.py +++ b/transforms.py @@ -15,21 +15,11 @@ class VideoFilePathToTensor(object): It can be composed with torchvision.transforms.Compose(). Args: - max_len (int): Maximum output time depth (L <= max_len). Default is None. - If it is set to None, it will output all frames. - fps (int): sample frame per seconds. It must lower than or equal the origin video fps. - Default is None. - padding_mode (str): Type of padding. Default to None. Only available when max_len is not None. - - None: won't padding, video length is variable. - - 'zero': padding the rest empty frames to zeros. - - 'last': padding the rest empty frames to the last frame. + fps (int): The fps, number of frames per second (size of the L dimension) """ - def __init__(self, max_len=None, fps=None, padding_mode=None): - self.max_len = max_len + def __init__(self, fps=None): self.fps = fps - assert padding_mode in (None, 'zero', 'last') - self.padding_mode = padding_mode self.channels = 3 # only available to read 3 channels video def __call__(self, path): @@ -38,67 +28,17 @@ def __call__(self, path): path (str): path of video file. Returns: - torch.Tensor: Video Tensor (C x L x H x W) + torch.Tensor: Video Tensor (L x C x H x W) """ - # open video file - cap = cv2.VideoCapture(path) - assert(cap.isOpened()) - - # calculate sample_factor to reset fps - sample_factor = 1 - if self.fps: - old_fps = cap.get(cv2.CAP_PROP_FPS) # fps of video - sample_factor = int(old_fps / self.fps) - assert(sample_factor >= 1) - - # init empty output frames (C x L x H x W) - height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) - width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) - num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) - - time_len = None - if self.max_len: - # time length has upper bound - if self.padding_mode: - # padding all video to the same time length - time_len = self.max_len - else: - # video have variable time length - time_len = min(int(num_frames / sample_factor), self.max_len) - else: - # time length is unlimited - time_len = int(num_frames / sample_factor) - - frames = torch.FloatTensor(self.channels, time_len, height, width) - - for index in range(time_len): - frame_index = sample_factor * index - - # read frame - cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index) - ret, frame = cap.read() - if ret: - # successfully read frame - # BGR to RGB - frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - frame = torch.from_numpy(frame) - # (H x W x C) to (C x H x W) - frame = frame.permute(2, 0, 1) - frames[:, index, :, :] = frame.float() - else: - # reach the end of the video - if self.padding_mode == 'zero': - # fill the rest frames with 0.0 - frames[:, index:, :, :] = 0 - elif self.padding_mode == 'last': - # fill the rest frames with the last frame - assert(index > 0) - frames[:, index:, :, :] = frames[:, index-1, :, :].view(self.channels, 1, height, width) - break + frames,_,_ = torchvision.io.read_video(path) + frames = frames.transpose(3,1).transpose(2,3) # T C H W + + factor = frames.shape[0] // self.fps - frames /= 255 - cap.release() + indexes = [i for i in range(0,frames.shape[0],factor+1)] + frames = frames[indexes] + return frames @@ -372,4 +312,4 @@ def __call__(self, video): return grayscaled_video - \ No newline at end of file +