Skip to content

Commit

Permalink
latest save
Browse files Browse the repository at this point in the history
  • Loading branch information
theopsall committed Nov 18, 2021
1 parent 89cbac3 commit 8fc77d1
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 24 deletions.
5 changes: 3 additions & 2 deletions deep_video_extraction/extractors/VisualExtractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from utils.utils import clean_GPU, device
from tqdm import tqdm
from gc import collect as gc_collect
import numpy as np


class VisualExtractor(nn.Module):
Expand Down Expand Up @@ -43,9 +44,9 @@ def extract(self, testLoader: DataLoader) -> ndarray:
for batch in testLoader:
batch = batch.to(self.device)
output = self.model(batch).to('cpu')
[out.append(t) for t in output]
[out.append(np.array(t)) for t in output]
del output
del batch
gc_collect()
torch.cuda.empty_cache()
return out
return np.array(out)
63 changes: 42 additions & 21 deletions deep_video_extraction/featureExtraction.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,59 @@
import gc
import sys
from gc import collect as gc_collect
import os

import torch
from torch.utils.data import DataLoader, dataloader
from torchvision.transforms import Compose, Normalize, ToTensor
from cv2 import _OutputArray_DEPTH_MASK_FLT
import numpy as np
from torch.cuda import empty_cache
from torch.utils.data import DataLoader
from tqdm import tqdm

from config import MEAN, STD
from extractors.VisualExtractor import VisualExtractor
from utils import utils
from utils.dataset import VideoDataset
from time import sleep
import time
import sys

VIDEO_PATH = '/home/theo/Documents/deep_video_extraction/visual_features'
VIDEO_PATH = '/media/theo/Hard Disk 2/projects_git/deep_video_extraction/Video_smaller'
OUTPUT = 'output'

# @utils.timeit


@utils.timeit
def extractVisual(directory: str, model: str, layers: int):
def extractVisual(directory: str, model: str, layers: int, output: str = 'output', save: bool = True) -> None:
tree = utils.crawl_directory(directory)
destination = None
predictions = []
visual_extractor = VisualExtractor(model='vgg', layers=2)
for filename in tree:
print(f'Processing {filename}')
dataset = VideoDataset(filename)
dataloader = DataLoader(dataset, batch_size=32,
shuffle=False, num_workers=4)
predictions.append(visual_extractor.extract(dataloader))
return predictions
visual_extractor = VisualExtractor(model=model, layers=layers)
for filepath in tree:
print(f'Processing {filepath}')
dataset = VideoDataset(filepath)
dataloader = DataLoader(dataset, batch_size=16,
shuffle=False, num_workers=2)

filename = filepath.split(os.sep)[-1].split('.')[0]
classname = filepath.split(os.sep)[-2]
destination = os.path.join(output, classname)
if not utils.is_dir(destination):
utils.create_dir(destination)
predictions = visual_extractor.extract(dataloader)

if (save):
np.save(os.path.join(destination, f'{filename}.npy'), predictions)

del predictions
empty_cache()
gc_collect()


def main():
pass
OUTPUT = 'output'
if not utils.is_dir(OUTPUT):
utils.create_dir(OUTPUT)
else:
OUTPUT = OUTPUT + '_' + utils.get_timestamp()
print(f'Creating new with name: {OUTPUT}')
utils.create_dir(OUTPUT)

extractVisual(directory=VIDEO_PATH, model='vgg',
layers=2, output=OUTPUT, save=True)


if __name__ == "__main__":
Expand Down
18 changes: 17 additions & 1 deletion deep_video_extraction/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@ def is_dir_empty(directory: str) -> bool:
return len(os.listdir(directory)) == 0


def create_dir(directory: str) -> bool:
try:
return os.mkdir(directory)
except FileExistsError:
print(f'{directory} already exists')
return False


def is_file(filename: str) -> bool:
return os.path.isfile(filename)

Expand All @@ -74,6 +82,10 @@ def crawl_directory(directory: str) -> list:
return tree


def clone_structure(src: str, dst: str) -> None:
pass


def read_video():
pass

Expand All @@ -83,6 +95,10 @@ def allowed_file(filename: str):
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS


def get_timestamp():
return f'{time()}'


def isolate_audio(path: str):
pass

Expand All @@ -105,7 +121,7 @@ def analyze_video(video: str) -> np.ndarray:
while success:
success, frame = cap.read()
if success:
frame = cv2.resize(frame, (224, 224))
frame = cv2.resize(frame, (124, 124))
batches.append(np.array(frame))
return np.array(batches)

Expand Down

0 comments on commit 8fc77d1

Please sign in to comment.