Skip to content

Commit

Permalink
gpu utilization
Browse files Browse the repository at this point in the history
  • Loading branch information
theopsall committed Nov 16, 2021
1 parent 5872b00 commit 89cbac3
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 9 deletions.
14 changes: 9 additions & 5 deletions deep_video_extraction/extractors/VisualExtractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torchvision import models, transforms
from utils.utils import clean_GPU, device
from tqdm import tqdm
from gc import collect as gc_collect


class VisualExtractor(nn.Module):
Expand Down Expand Up @@ -39,9 +40,12 @@ def transform(self, x):
def extract(self, testLoader: DataLoader) -> ndarray:
out = []
with torch.no_grad():
with tqdm(testLoader, unit="batch", position=0, leave=True) as tepoch:
for batch in testLoader:
batch = batch.to(self.device)
output = self.model(batch)
out.append([t for t in output])
for batch in testLoader:
batch = batch.to(self.device)
output = self.model(batch).to('cpu')
[out.append(t) for t in output]
del output
del batch
gc_collect()
torch.cuda.empty_cache()
return out
3 changes: 1 addition & 2 deletions deep_video_extraction/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@ class VideoDataset(Dataset):
def __init__(self, video_path) -> None:
super().__init__()
self.video_path = video_path
self.frames, self.fps = analyze_video(video_path)
self.frames = analyze_video(video_path)
self.toTensor = transforms.ToTensor()
self.normalize = transforms.Normalize(mean=MEAN, std=STD)
print(self.frames[0].shape)

def __str__(self):
return f'Video DataLoader'
Expand Down
8 changes: 6 additions & 2 deletions deep_video_extraction/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ def is_dir(directory: str) -> bool:
return os.path.isdir(directory)


def is_dir_empty(directory: str) -> bool:
return len(os.listdir(directory)) == 0


def is_file(filename: str) -> bool:
return os.path.isfile(filename)

Expand Down Expand Up @@ -94,7 +98,7 @@ def analyze_video(video: str) -> np.ndarray:
fps = int(cap.get(cv2.CAP_PROP_FPS)) + 1
except ValueError:
assert f"Cannot convert video {video} fps to integer"
print(f'Proccessing {video} with: {fps} fps')
# print(f'Proccessing {video} with: {fps} fps')
success = True
batches = []

Expand All @@ -103,7 +107,7 @@ def analyze_video(video: str) -> np.ndarray:
if success:
frame = cv2.resize(frame, (224, 224))
batches.append(np.array(frame))
return np.array(batches), fps
return np.array(batches)


def analyze_video_in_batches(video: str, batch_size: int = 32):
Expand Down

0 comments on commit 89cbac3

Please sign in to comment.