Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
f8fe69e
rename submodules
dmarx May 26, 2022
b1fa850
refactor commits
dmarx May 26, 2022
63cfbec
refactor commits
dmarx May 26, 2022
dfef33a
refactor commits
dmarx May 26, 2022
ce819e9
Merge pull request #173 from pytti-tools/image_model_refactor
dmarx May 26, 2022
5cabe42
isolated cutout factory
dmarx May 27, 2022
719b326
refactor
dmarx May 27, 2022
60ed9e4
refactor cutout augs
dmarx May 27, 2022
1d82ad9
Merge pull request #177 from pytti-tools/cutouts
dmarx May 27, 2022
75f193f
query config for default cuda device
dmarx May 17, 2022
b799242
added pipfile (pipenv)
dmarx May 17, 2022
19d60b5
updated RGBImage to use default cuda device if none specified
dmarx May 17, 2022
f039157
updated PixelImage cuda device and workhorse image model inits
dmarx May 17, 2022
45d62bd
updated VQGANImage cuda device
dmarx May 17, 2022
89f9a79
refactored devices for losses
dmarx May 17, 2022
1146b4a
refactored a bunch more device stuff
dmarx May 17, 2022
afbdcb6
refactored device stuff in transforms.py
dmarx May 17, 2022
5175e76
added a bunch of logging stuff to update_func
dmarx May 17, 2022
48f5f97
trying to fix device in 3D transform
dmarx May 17, 2022
a7f7206
trying to fix device issues with depth
dmarx May 23, 2022
a89c5ae
added device tests
dmarx Jun 10, 2022
fef0b66
removed artifact from merge conflict fix
dmarx Jun 10, 2022
3ac7ddf
Merge pull request #168 from pytti-tools/device_selection
dmarx Jun 10, 2022
542bc00
updated flow loss to respect device selection
dmarx Jun 10, 2022
227db4e
fixed local path in test config
dmarx Jun 10, 2022
1873aa7
fixed local path in test config
dmarx Jun 10, 2022
b2a68b7
added test img to assets
dmarx Jun 10, 2022
054a297
added missing import
dmarx Jun 10, 2022
0dad5b6
fixed local paths
dmarx Jun 10, 2022
b8adebe
added missing import...
dmarx Jun 10, 2022
ad67708
fetch item from generator
dmarx Jun 10, 2022
5ef708c
changed get_flow from static method for self.device
dmarx Jun 10, 2022
8d11aae
fixed path
dmarx Jun 10, 2022
eeb72f8
fixed get_flow invocation post change to class method
dmarx Jun 10, 2022
210393e
fine, get_flow stays a static method for now
dmarx Jun 10, 2022
c8ae534
we're gonna get this get_flow thing right eventually
dmarx Jun 10, 2022
9882b75
fixed issues with flow device
dmarx Jun 10, 2022
8c60b02
fixed setting of missing device param
dmarx Jun 10, 2022
eaf3c51
use open_dict context to add key to param obj
dmarx Jun 10, 2022
dd6edd0
tossing spaghetti at the wall now
dmarx Jun 10, 2022
eb8b156
fixed issue with device placement stemming from DataParallel
dmarx Jun 10, 2022
38562f3
Merge pull request #196 from pytti-tools/device_selection
dmarx Jun 10, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions Pipfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
[[source]]
url = "https://pypi.org/simple"
verify_ssl = true
name = "pypi"

[[source]]
url = "https://download.pytorch.org/whl/cu113/"
verify_ssl = false
name = "pytorch"

[packages]
transformers = "==4.15.0"
gdown = "===4.2.0"
ftfy = "==6.0.3"
regex = "*"
tqdm = "==4.62.3"
omegaconf = "==2.1.1"
pytorch-lightning = "==1.5.7"
kornia = "==0.6.2"
einops = "==0.3.2"
imageio-ffmpeg = "==0.4.5"
exrex = "*"
matplotlib-label-lines = "==0.4.3"
pandas = "==1.3.4"
seaborn = "==0.11.2"
scikit-learn = "*"
loguru = "*"
hydra-core = "*"
jupyter = "*"
imageio = "==2.4.1"
PyGLM = "==2.5.7"
adjustText = "*"
Pillow = "*"
torch = "*"
torchvision = "*"
torchaudio = "*"
requests = "*"
pyttitools-adabins = {path = "./vendor/AdaBins"}
pyttitools-gma = {path = "./vendor/GMA"}
clip = {path = "./vendor/CLIP"}
pyttitools-taming-transformers = {path = "./vendor/taming-transformers"}
tensorflow = "*"
protobuf = "==3.9.2"
pyttitools-core = {path = "."}
mmc = {git = "https://github.com/dmarx/Multi-Modal-Comparators"}

[dev-packages]
pytest = "*"
pre-commit = "*"
click = "==8.0.4"
black = "*"

[requires]
python_version = "3.9"
11 changes: 0 additions & 11 deletions src/pytti/Image/__init__.py

This file was deleted.

4 changes: 2 additions & 2 deletions src/pytti/ImageGuide.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
vram_usage_mode,
)
from pytti.AudioParse import SpectralAudioParser
from pytti.Image.differentiable_image import DifferentiableImage
from pytti.Image.PixelImage import PixelImage
from pytti.image_models.differentiable_image import DifferentiableImage
from pytti.image_models.pixel import PixelImage
from pytti.Notebook import tqdm, make_hbox

# from pytti.rotoscoper import update_rotoscopers
Expand Down
12 changes: 8 additions & 4 deletions src/pytti/LossAug/BaseLossClass.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import torch
from torch import nn
from pytti import DEVICE, replace_grad, parametric_eval
from pytti import replace_grad, parametric_eval


class Loss(nn.Module):
def __init__(self, weight, stop, name):
def __init__(self, weight, stop, name, device=None):
super().__init__()
# self.register_buffer('weight', torch.as_tensor(weight))
# self.register_buffer('stop', torch.as_tensor(stop))
Expand All @@ -13,6 +13,9 @@ def __init__(self, weight, stop, name):
self.input_axes = ("n", "s", "y", "x")
self.name = name
self.enabled = True
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.device = device

def get_loss(self, input, img):
raise NotImplementedError
Expand All @@ -29,10 +32,11 @@ def set_stop(stop):
def __str__(self):
return self.name

def forward(self, input, img, device=DEVICE):
def forward(self, input, img, device=None):
if not self.enabled or self.weight in [0, "0"]:
return 0, 0

if device is None:
device = self.device
weight = torch.as_tensor(parametric_eval(self.weight), device=device)
stop = torch.as_tensor(parametric_eval(self.stop), device=device)
loss_raw = self.get_loss(input, img)
Expand Down
16 changes: 10 additions & 6 deletions src/pytti/LossAug/DepthLossClass.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
infer_helper = None


def init_AdaBins():
def init_AdaBins(device=None):
global infer_helper
if infer_helper is None:
with vram_usage_mode("AdaBins"):
logger.debug("Loading AdaBins...")
infer_helper = InferenceHelper(dataset="nyu")
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
infer_helper = InferenceHelper(dataset="nyu", device=device)
logger.debug("AdaBins loaded.")


Expand Down Expand Up @@ -55,13 +57,15 @@ def get_loss(self, input, img):

@classmethod
@vram_usage_mode("Depth Loss")
def make_comp(cls, pil_image, device=DEVICE):
depth, _ = DepthLoss.get_depth(pil_image)
def make_comp(cls, pil_image, device=None):
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
depth, _ = DepthLoss.get_depth(pil_image, device=device)
return torch.from_numpy(depth).to(device)

@staticmethod
def get_depth(pil_image):
init_AdaBins()
def get_depth(pil_image, device=None):
init_AdaBins(device=device)
width, height = pil_image.size

# if the area of an image is above this, the depth model fails
Expand Down
3 changes: 2 additions & 1 deletion src/pytti/LossAug/LossOrchestratorClass.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from loguru import logger
from PIL import Image

from pytti.Image import PixelImage
from pytti.image_models import PixelImage

# from pytti.LossAug import build_loss
from pytti.LossAug import TVLoss, HSVLoss, OpticalFlowLoss, TargetFlowLoss
Expand Down Expand Up @@ -125,6 +125,7 @@ def configure_optical_flows(img, params, loss_augs):
TargetFlowLoss.TargetImage(
f"optical flow stabilization:{params.flow_stabilization_weight}",
img.image_shape,
device="cuda",
)
]
for optical_flow in optical_flows:
Expand Down
28 changes: 18 additions & 10 deletions src/pytti/LossAug/MSELossClass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

# from pytti.Notebook import Rotoscoper
from pytti.rotoscoper import Rotoscoper
from pytti import DEVICE, fetch, parse, vram_usage_mode
from pytti import fetch, parse, vram_usage_mode
import torch


Expand All @@ -19,22 +19,22 @@ def __init__(
stop=-math.inf,
name="direct target loss",
image_shape=None,
device=DEVICE,
device=None,
):
super().__init__(weight, stop, name)
super().__init__(weight, stop, name, device)
self.register_buffer("comp", comp)
if image_shape is None:
height, width = comp.shape[-2:]
image_shape = (width, height)
self.image_shape = image_shape
self.register_buffer("mask", torch.ones(1, 1, 1, 1, device=device))
self.register_buffer("mask", torch.ones(1, 1, 1, 1, device=self.device))
self.use_mask = False

@classmethod
@vram_usage_mode("Loss Augs")
@torch.no_grad()
def TargetImage(
cls, prompt_string, image_shape, pil_image=None, is_path=False, device=DEVICE
cls, prompt_string, image_shape, pil_image=None, is_path=False, device=None
):
# Why is this prompt parsing stuff here? Deprecate in favor of centralized
# parsing functions (if feasible)
Expand All @@ -44,6 +44,8 @@ def TargetImage(
weight, mask = parse(weight, r"_", ["1", ""])
text = text.strip()
mask = mask.strip()
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if pil_image is None and text != "" and is_path:
pil_image = Image.open(fetch(text)).convert("RGB")
im = pil_image.resize(image_shape, Image.LANCZOS)
Expand All @@ -55,12 +57,14 @@ def TargetImage(
comp = cls.make_comp(im)
if image_shape is None:
image_shape = pil_image.size
out = cls(comp, weight, stop, text + " (direct)", image_shape)
out = cls(comp, weight, stop, text + " (direct)", image_shape, device=device)
out.set_mask(mask)
return out

@torch.no_grad()
def set_mask(self, mask, inverted=False, device=DEVICE):
def set_mask(self, mask, inverted=False, device=None):
if device is None:
device = self.device
if isinstance(mask, str) and mask != "":
if mask[0] == "-":
mask = mask[1:]
Expand All @@ -86,16 +90,20 @@ def convert_input(cls, input, img):
return input

@classmethod
def make_comp(cls, pil_image, device=DEVICE):
def make_comp(cls, pil_image, device=None):
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
out = (
TF.to_tensor(pil_image)
.unsqueeze(0)
.to(device, memory_format=torch.channels_last)
)
return cls.convert_input(out, None)

def set_comp(self, pil_image, device=DEVICE):
self.comp.set_(type(self).make_comp(pil_image))
def set_comp(self, pil_image, device=None):
if device is None:
device = self.device
self.comp.set_(type(self).make_comp(pil_image, device=device))

def get_loss(self, input, img):
input = type(self).convert_input(input, img)
Expand Down
Loading