diff --git a/finetrainers/models/cogvideox/lora.py b/finetrainers/models/cogvideox/lora.py index 65d86ee9..d8adc2aa 100644 --- a/finetrainers/models/cogvideox/lora.py +++ b/finetrainers/models/cogvideox/lora.py @@ -3,7 +3,7 @@ import torch from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel from PIL import Image -from transformers import T5EncoderModel, T5Tokenizer +from transformers import AutoTokenizer, T5EncoderModel, T5Tokenizer from .utils import prepare_rotary_positional_embeddings @@ -15,7 +15,14 @@ def load_condition_models( cache_dir: Optional[str] = None, **kwargs, ): - tokenizer = T5Tokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, cache_dir=cache_dir) + try: + tokenizer = T5Tokenizer.from_pretrained( + model_id, subfolder="tokenizer", revision=revision, cache_dir=cache_dir + ) + except: # noqa + tokenizer = AutoTokenizer.from_pretrained( + model_id, subfolder="tokenizer", revision=revision, cache_dir=cache_dir + ) text_encoder = T5EncoderModel.from_pretrained( model_id, subfolder="text_encoder", torch_dtype=text_encoder_dtype, revision=revision, cache_dir=cache_dir ) diff --git a/finetrainers/models/ltx_video/lora.py b/finetrainers/models/ltx_video/lora.py index bdd6ffa3..49ea1db4 100644 --- a/finetrainers/models/ltx_video/lora.py +++ b/finetrainers/models/ltx_video/lora.py @@ -5,7 +5,7 @@ from accelerate.logging import get_logger from diffusers import AutoencoderKLLTXVideo, FlowMatchEulerDiscreteScheduler, LTXPipeline, LTXVideoTransformer3DModel from PIL import Image -from transformers import T5EncoderModel, T5Tokenizer +from transformers import AutoTokenizer, T5EncoderModel, T5Tokenizer logger = get_logger("finetrainers") # pylint: disable=invalid-name @@ -18,7 +18,14 @@ def load_condition_models( cache_dir: Optional[str] = None, **kwargs, ) -> Dict[str, nn.Module]: - tokenizer = T5Tokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, cache_dir=cache_dir) + try: + tokenizer = T5Tokenizer.from_pretrained( + model_id, subfolder="tokenizer", revision=revision, cache_dir=cache_dir + ) + except: # noqa + tokenizer = AutoTokenizer.from_pretrained( + model_id, subfolder="tokenizer", revision=revision, cache_dir=cache_dir + ) text_encoder = T5EncoderModel.from_pretrained( model_id, subfolder="text_encoder", torch_dtype=text_encoder_dtype, revision=revision, cache_dir=cache_dir ) diff --git a/finetrainers/trainer.py b/finetrainers/trainer.py index 2a7f2261..89324044 100644 --- a/finetrainers/trainer.py +++ b/finetrainers/trainer.py @@ -54,7 +54,13 @@ ) from .utils.file_utils import string_to_filename from .utils.hub_utils import save_model_card -from .utils.memory_utils import free_memory, get_memory_statistics, make_contiguous +from .utils.memory_utils import ( + free_memory, + get_memory_statistics, + make_contiguous, + reset_memory_stats, + synchronize_device, +) from .utils.model_utils import resolve_vae_cls_from_ckpt_path from .utils.optimizer_utils import get_optimizer from .utils.torch_utils import align_device_and_dtype, expand_tensor_dims, unwrap_model @@ -259,7 +265,7 @@ def collate_fn(batch): memory_statistics = get_memory_statistics() logger.info(f"Memory after precomputing conditions: {json.dumps(memory_statistics, indent=4)}") - torch.cuda.reset_peak_memory_stats(accelerator.device) + reset_memory_stats(accelerator.device) # Precompute latents with self.state.accelerator.main_process_first(): @@ -307,7 +313,7 @@ def collate_fn(batch): memory_statistics = get_memory_statistics() logger.info(f"Memory after precomputing latents: {json.dumps(memory_statistics, indent=4)}") - torch.cuda.reset_peak_memory_stats(accelerator.device) + reset_memory_stats(accelerator.device) # Update dataloader to use precomputed conditions and latents self.dataloader = torch.utils.data.DataLoader( @@ -997,7 +1003,7 @@ def validate(self, step: int, final_validation: bool = False) -> None: free_memory() memory_statistics = get_memory_statistics() logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}") - torch.cuda.reset_peak_memory_stats(accelerator.device) + reset_memory_stats(accelerator.device) if not final_validation: self.transformer.train() @@ -1120,7 +1126,7 @@ def _delete_components(self) -> None: self.vae = None self.scheduler = None free_memory() - torch.cuda.synchronize(self.state.accelerator.device) + synchronize_device(self.state.accelerator.device) def _get_and_prepare_pipeline_for_validation(self, final_validation: bool = False) -> DiffusionPipeline: accelerator = self.state.accelerator diff --git a/finetrainers/utils/memory_utils.py b/finetrainers/utils/memory_utils.py index d7616b19..dcde3d89 100644 --- a/finetrainers/utils/memory_utils.py +++ b/finetrainers/utils/memory_utils.py @@ -9,29 +9,33 @@ def get_memory_statistics(precision: int = 3) -> Dict[str, Any]: - memory_allocated = None - memory_reserved = None - max_memory_allocated = None - max_memory_reserved = None + memory_stats = { + "memory_allocated": None, + "memory_reserved": None, + "max_memory_allocated": None, + "max_memory_reserved": None, + } if torch.cuda.is_available(): device = torch.cuda.current_device() - memory_allocated = torch.cuda.memory_allocated(device) - memory_reserved = torch.cuda.memory_reserved(device) - max_memory_allocated = torch.cuda.max_memory_allocated(device) - max_memory_reserved = torch.cuda.max_memory_reserved(device) + memory_stats.update( + { + "memory_allocated": torch.cuda.memory_allocated(device), + "memory_reserved": torch.cuda.memory_reserved(device), + "max_memory_allocated": torch.cuda.max_memory_allocated(device), + "max_memory_reserved": torch.cuda.max_memory_reserved(device), + } + ) elif torch.backends.mps.is_available(): - memory_allocated = torch.mps.current_allocated_memory() + memory_stats["memory_allocated"] = torch.mps.current_allocated_memory() else: logger.warning("No CUDA, MPS, or ROCm device found. Memory statistics are not available.") return { - "memory_allocated": round(bytes_to_gigabytes(memory_allocated), ndigits=precision), - "memory_reserved": round(bytes_to_gigabytes(memory_reserved), ndigits=precision), - "max_memory_allocated": round(bytes_to_gigabytes(max_memory_allocated), ndigits=precision), - "max_memory_reserved": round(bytes_to_gigabytes(max_memory_reserved), ndigits=precision), + key: (round(bytes_to_gigabytes(value), ndigits=precision) if value else None) + for key, value in memory_stats.items() } @@ -49,6 +53,21 @@ def free_memory() -> None: # TODO(aryan): handle non-cuda devices +def reset_memory_stats(device: torch.device): + # TODO: handle for non-cuda devices + if torch.cuda.is_available(): + torch.cuda.reset_peak_memory_stats(device) + else: + logger.warning("No CUDA, device found. Nothing to reset memory of.") + + +def synchronize_device(device: torch.device): + if torch.cuda.is_available(): + torch.cuda.synchronize(device) + else: + logger.warning("No CUDA, device found. Nothing to synchronize.") + + def make_contiguous(x: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: if isinstance(x, torch.Tensor): return x.contiguous() diff --git a/tests/trainers/__init__.py b/tests/trainers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/trainers/cogvideox/__init__.py b/tests/trainers/cogvideox/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/trainers/cogvideox/test_cogvideox.py b/tests/trainers/cogvideox/test_cogvideox.py new file mode 100644 index 00000000..e0790322 --- /dev/null +++ b/tests/trainers/cogvideox/test_cogvideox.py @@ -0,0 +1,44 @@ +import sys +import unittest +from pathlib import Path + + +current_file = Path(__file__).resolve() +root_dir = current_file.parents[3] +sys.path.append(str(root_dir)) + +from finetrainers import Args # noqa +from ..test_trainers_common import TrainerTestMixin, parse_resolution_bucket # noqa + + +class CogVideoXTester(unittest.TestCase, TrainerTestMixin): + MODEL_NAME = "cogvideox" + EXPECTED_PRECOMPUTATION_LATENT_KEYS = {"latents"} + EXPECTED_PRECOMPUTATION_CONDITION_KEYS = {"prompt_embeds"} + + def get_training_args(self): + args = Args() + args.model_name = self.MODEL_NAME + args.training_type = "lora" + args.pretrained_model_name_or_path = "finetrainers/dummy-cogvideox" + args.data_root = "" # will be set from the tester method. + args.video_resolution_buckets = [parse_resolution_bucket("9x16x16")] + args.precompute_conditions = True + args.validation_prompts = [] + args.validation_heights = [] + args.validation_widths = [] + return args + + @property + def latent_output_shape(self): + return (8, 3, 2, 2) + + @property + def condition_output_shape(self): + return (226, 32) + + def populate_shapes(self): + for k in self.EXPECTED_PRECOMPUTATION_LATENT_KEYS: + self.EXPECTED_LATENT_SHAPES[k] = self.latent_output_shape + for k in self.EXPECTED_PRECOMPUTATION_CONDITION_KEYS: + self.EXPECTED_CONDITION_SHAPES[k] = self.condition_output_shape diff --git a/tests/trainers/hunyaun_video/__init__.py b/tests/trainers/hunyaun_video/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/trainers/hunyaun_video/test_hunyaun_video.py b/tests/trainers/hunyaun_video/test_hunyaun_video.py new file mode 100644 index 00000000..935223ba --- /dev/null +++ b/tests/trainers/hunyaun_video/test_hunyaun_video.py @@ -0,0 +1,51 @@ +import sys +import unittest +from pathlib import Path + + +current_file = Path(__file__).resolve() +root_dir = current_file.parents[3] +sys.path.append(str(root_dir)) + +from finetrainers import Args # noqa +from ..test_trainers_common import TrainerTestMixin, parse_resolution_bucket # noqa + + +class HunyuanVideoTester(unittest.TestCase, TrainerTestMixin): + MODEL_NAME = "hunyuan_video" + EXPECTED_PRECOMPUTATION_LATENT_KEYS = {"latents"} + EXPECTED_PRECOMPUTATION_CONDITION_KEYS = { + "guidance", + "pooled_prompt_embeds", + "prompt_attention_mask", + "prompt_embeds", + } + + def get_training_args(self): + args = Args() + args.model_name = self.MODEL_NAME + args.training_type = "lora" + args.pretrained_model_name_or_path = "finetrainers/dummy-hunyaunvideo" + args.data_root = "" # will be set from the tester method. + args.video_resolution_buckets = [parse_resolution_bucket("9x16x16")] + args.precompute_conditions = True + args.validation_prompts = [] + args.validation_heights = [] + args.validation_widths = [] + return args + + @property + def latent_output_shape(self): + # only tensor object shapes + return (8, 3, 2, 2) + + @property + def condition_output_shape(self): + # only tensor object shapes + return (), (8,), (256,), (256, 16) + + def populate_shapes(self): + for i, k in enumerate(sorted(self.EXPECTED_PRECOMPUTATION_LATENT_KEYS)): + self.EXPECTED_LATENT_SHAPES[k] = self.latent_output_shape + for i, k in enumerate(sorted(self.EXPECTED_PRECOMPUTATION_CONDITION_KEYS)): + self.EXPECTED_CONDITION_SHAPES[k] = self.condition_output_shape[i] diff --git a/tests/trainers/ltx_video/__init__.py b/tests/trainers/ltx_video/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/trainers/ltx_video/test_ltx_video.py b/tests/trainers/ltx_video/test_ltx_video.py new file mode 100644 index 00000000..e0de0aa8 --- /dev/null +++ b/tests/trainers/ltx_video/test_ltx_video.py @@ -0,0 +1,50 @@ +import sys +import unittest +from pathlib import Path + + +current_file = Path(__file__).resolve() +root_dir = current_file.parents[3] +sys.path.append(str(root_dir)) + +from finetrainers import Args # noqa +from ..test_trainers_common import TrainerTestMixin, parse_resolution_bucket # noqa + + +class LTXVideoTester(unittest.TestCase, TrainerTestMixin): + MODEL_NAME = "ltx_video" + EXPECTED_PRECOMPUTATION_LATENT_KEYS = {"height", "latents", "latents_mean", "latents_std", "num_frames", "width"} + EXPECTED_PRECOMPUTATION_CONDITION_KEYS = {"prompt_attention_mask", "prompt_embeds"} + + def get_training_args(self): + args = Args() + args.model_name = self.MODEL_NAME + args.training_type = "lora" + args.pretrained_model_name_or_path = "finetrainers/dummy-ltxvideo" + args.data_root = "" # will be set from the tester method. + args.video_resolution_buckets = [parse_resolution_bucket("9x16x16")] + args.precompute_conditions = True + args.validation_prompts = [] + args.validation_heights = [] + args.validation_widths = [] + return args + + @property + def latent_output_shape(self): + # only tensor object shapes + return (16, 3, 4, 4), (), () + + @property + def condition_output_shape(self): + # only tensor object shapes + return (128,), (128, 32) + + def populate_shapes(self): + i = 0 + for k in sorted(self.EXPECTED_PRECOMPUTATION_LATENT_KEYS): + if k in ["height", "num_frames", "width"]: + continue + self.EXPECTED_LATENT_SHAPES[k] = self.latent_output_shape[i] + i += 1 + for i, k in enumerate(sorted(self.EXPECTED_PRECOMPUTATION_CONDITION_KEYS)): + self.EXPECTED_CONDITION_SHAPES[k] = self.condition_output_shape[i] diff --git a/tests/trainers/test_trainers_common.py b/tests/trainers/test_trainers_common.py new file mode 100644 index 00000000..9be7294a --- /dev/null +++ b/tests/trainers/test_trainers_common.py @@ -0,0 +1,176 @@ +import sys +import tempfile +from pathlib import Path +from typing import Tuple + +import torch +from huggingface_hub import snapshot_download + + +current_file = Path(__file__).resolve() +root_dir = current_file.parents[2] +sys.path.append(str(root_dir)) + +from finetrainers import Trainer # noqa +from finetrainers.constants import ( # noqa + PRECOMPUTED_CONDITIONS_DIR_NAME, + PRECOMPUTED_DIR_NAME, + PRECOMPUTED_LATENTS_DIR_NAME, +) +from finetrainers.utils.file_utils import string_to_filename # noqa + + +def parse_resolution_bucket(resolution_bucket: str) -> Tuple[int, ...]: + """Parse a resolution like '512x512' into a tuple of ints (512, 512).""" + return tuple(map(int, resolution_bucket.split("x"))) + + +class TrainerTestMixin: + MODEL_NAME = None + EXPECTED_PRECOMPUTATION_LATENT_KEYS = set() + EXPECTED_LATENT_SHAPES = {} + EXPECTED_PRECOMPUTATION_CONDITION_KEYS = set() + EXPECTED_CONDITION_SHAPES = {} + + def get_training_args(self): + raise NotImplementedError + + @property + def latent_output_shape(self): + raise NotImplementedError + + @property + def condition_output_shape(self): + raise NotImplementedError + + def populate_shapes(self): + raise NotImplementedError + + def download_dataset_txt_format(self, cache_dir): + return snapshot_download(repo_id="finetrainers/dummy-disney-dataset", repo_type="dataset", cache_dir=cache_dir) + + def get_precomputation_dir(self, training_args): + """Return the path of the precomputation directory based on the training args.""" + cleaned_model_id = string_to_filename(training_args.pretrained_model_name_or_path) + return Path(training_args.data_root) / f"{training_args.model_name}_{cleaned_model_id}_{PRECOMPUTED_DIR_NAME}" + + def tearDown(self): + super().tearDown() + self.EXPECTED_LATENT_SHAPES.clear() + self.EXPECTED_CONDITION_SHAPES.clear() + + def _verify_precomputed_files(self, video_paths, all_conditions, all_latents): + """Check that the correct number of precomputed files exist and have the right keys.""" + assert len(video_paths) == len(all_conditions), "Mismatch in conditions file count" + assert len(video_paths) == len(all_latents), "Mismatch in latents file count" + + for latent, condition in zip(all_latents, all_conditions): + latent_keys = sorted(set(torch.load(latent, weights_only=True).keys())) + condition_keys = sorted(set(torch.load(condition, weights_only=True).keys())) + assert latent_keys == sorted( + self.EXPECTED_PRECOMPUTATION_LATENT_KEYS + ), f"Unexpected latent keys: {latent_keys}" + assert condition_keys == sorted( + self.EXPECTED_PRECOMPUTATION_CONDITION_KEYS + ), f"Unexpected condition keys: {condition_keys}" + + def _verify_shapes(self, latent_files, condition_files): + """Check that the shapes of latents and conditions match expected shapes.""" + self.populate_shapes() + for l_path, c_path in zip(latent_files, condition_files): + latent = torch.load(l_path, weights_only=True, map_location="cpu") + condition = torch.load(c_path, weights_only=True, map_location="cpu") + + for key in self.EXPECTED_PRECOMPUTATION_LATENT_KEYS: + if not torch.is_tensor(latent[key]): + continue + expected = self.EXPECTED_LATENT_SHAPES[key] + original = tuple(latent[key].shape[1:]) + assert ( + original == expected + ), f"Latent shape mismatch for key: {key}. expected={expected}, got={original}" + + for key in self.EXPECTED_PRECOMPUTATION_CONDITION_KEYS: + if not torch.is_tensor(condition[key]): + continue + expected = self.EXPECTED_CONDITION_SHAPES[key] + original = tuple(condition[key].shape[1:]) + assert ( + original == expected + ), f"Condition shape mismatch for key: {key}. expected={expected}, got={original}" + + def _setup_trainer(self, tmpdir): + """ + Helper method to reduce duplication across tests. + Creates and returns a trainer, along with updated training args. + """ + training_args = self.get_training_args() + training_args.data_root = Path(self.download_dataset_txt_format(cache_dir=tmpdir)) + training_args.video_column = "videos.txt" + training_args.caption_column = "prompt.txt" + training_args.output_dir = tmpdir + + trainer = Trainer(training_args) + # Trainer may update the training_args internally, so refresh the reference + training_args = trainer.args + + trainer.prepare_dataset() + trainer.prepare_models() + return trainer, training_args + + def test_precomputation_txt_format_creates_files(self): + with tempfile.TemporaryDirectory() as tmpdir: + trainer, training_args = self._setup_trainer(tmpdir) + + # Load video paths (only needed in this test) + with open(training_args.data_root / training_args.video_column, "r", encoding="utf-8") as file: + video_paths = [training_args.data_root / line.strip() for line in file if line.strip()] + + trainer.prepare_precomputations() + + precomputation_dir = self.get_precomputation_dir(training_args) + conditions_dir = precomputation_dir / PRECOMPUTED_CONDITIONS_DIR_NAME + latents_dir = precomputation_dir / PRECOMPUTED_LATENTS_DIR_NAME + + assert precomputation_dir.exists(), f"Precomputed dir not found: {precomputation_dir}" + assert conditions_dir.exists(), f"Conditions dir not found: {conditions_dir}" + assert latents_dir.exists(), f"Latents dir not found: {latents_dir}" + + all_conditions = list(conditions_dir.glob("*.pt")) + all_latents = list(latents_dir.glob("*.pt")) + + self._verify_precomputed_files(video_paths, all_conditions, all_latents) + + def test_precomputation_txt_format_matches_shapes(self): + with tempfile.TemporaryDirectory() as tmpdir: + trainer, training_args = self._setup_trainer(tmpdir) + + with self.assertLogs(level="INFO") as captured: + trainer.prepare_precomputations() + assert any( + "Precomputed data not found. Running precomputation." in msg for msg in captured.output + ), "Expected info log about missing precomputed data." + + precomputation_dir = self.get_precomputation_dir(training_args) + conditions_dir = precomputation_dir / PRECOMPUTED_CONDITIONS_DIR_NAME + latents_dir = precomputation_dir / PRECOMPUTED_LATENTS_DIR_NAME + + latent_files = list(latents_dir.glob("*.pt")) + condition_files = list(conditions_dir.glob("*.pt")) + + self._verify_shapes(latent_files, condition_files) + + def test_precomputation_txt_format_no_redo(self): + with tempfile.TemporaryDirectory() as tmpdir: + trainer, _ = self._setup_trainer(tmpdir) + + # should create new precomputations + trainer.prepare_precomputations() + + # should detect existing precomputations and not redo + with self.assertLogs(level="INFO") as captured: + trainer.prepare_precomputations() + + assert any( + "Precomputed conditions and latents found. Loading precomputed data" in msg for msg in captured.output + ), "Expected info log about found precomputations."