|
15 | 15 | from mmda.utils.dataset_utils import (
|
16 | 16 | load_cosmos,
|
17 | 17 | load_flickr,
|
| 18 | + load_handwriting, |
18 | 19 | load_imagenet,
|
19 | 20 | load_kitti,
|
20 | 21 | load_leafy_spurge,
|
|
25 | 26 | load_tiil,
|
26 | 27 | )
|
27 | 28 | from mmda.utils.embed_data import (
|
| 29 | + chronos_ts, |
28 | 30 | clap_audio,
|
29 | 31 | clap_text,
|
30 | 32 | clip_imgs,
|
31 | 33 | clip_text,
|
32 | 34 | cosplace_img,
|
33 | 35 | dinov2,
|
| 36 | + fair_clip_imgs, |
| 37 | + fair_clip_text, |
34 | 38 | gtr_text,
|
35 | 39 | )
|
36 |
| -from mmda.utils.imagebind_utils import ImageBindInference |
37 | 40 | from mmda.utils.video_audio_utils import (
|
38 | 41 | get_video_emb,
|
39 | 42 | prepare_audio_for_imagebind,
|
40 | 43 | process_audio,
|
41 | 44 | )
|
42 | 45 |
|
43 |
| -BATCH_SIZE = 256 |
| 46 | +BATCH_SIZE = 758 |
44 | 47 |
|
45 | 48 |
|
46 | 49 | @hydra.main(version_base=None, config_path="../config", config_name="main")
|
@@ -92,6 +95,8 @@ def main(cfg: DictConfig) -> None: # noqa: PLR0915, C901, PLR0912
|
92 | 95 | pickle.dump(clap_audio_features, f)
|
93 | 96 |
|
94 | 97 | elif dataset == "MSRVTT":
|
| 98 | + from mmda.utils.imagebind_utils import ImageBindInference |
| 99 | + |
95 | 100 | _, captions, video_info_sen_order, video_dict = load_msrvtt(cfg_dataset)
|
96 | 101 | id_order, img_paths, audio_start_secs, audio_num_secs = get_video_emb(
|
97 | 102 | cfg_dataset, video_dict
|
@@ -372,6 +377,24 @@ def main(cfg: DictConfig) -> None: # noqa: PLR0915, C901, PLR0912
|
372 | 377 | text_descriptions = ["An image of " + label + "." for label in orig_labels]
|
373 | 378 |
|
374 | 379 | # get text embeddings
|
| 380 | + model = "openai" |
| 381 | + |
| 382 | + img_emb = fair_clip_imgs(img_path, BATCH_SIZE, model_name=("ViT-L-14", model)) |
| 383 | + with Path( |
| 384 | + cfg_dataset.paths.save_path, f"ImageNet_img_emb_clip{model}.pkl" |
| 385 | + ).open("wb") as f: |
| 386 | + pickle.dump(img_emb, f) |
| 387 | + print("FairCLIP embeddings saved") |
| 388 | + |
| 389 | + text_emb = fair_clip_text( |
| 390 | + text_descriptions, BATCH_SIZE, model_name=("ViT-L-14", model) |
| 391 | + ) |
| 392 | + with Path( |
| 393 | + cfg_dataset.paths.save_path, f"ImageNet_text_emb_clip{model}.pkl" |
| 394 | + ).open("wb") as f: |
| 395 | + pickle.dump(text_emb, f) |
| 396 | + print("FairCLIP embeddings saved") |
| 397 | + |
375 | 398 | text_emb = clip_text(text_descriptions, BATCH_SIZE)
|
376 | 399 | with Path(cfg_dataset.paths.save_path, "ImageNet_text_emb_clip.pkl").open(
|
377 | 400 | "wb"
|
@@ -536,6 +559,47 @@ def main(cfg: DictConfig) -> None: # noqa: PLR0915, C901, PLR0912
|
536 | 559 | pickle.dump(img_emb, f)
|
537 | 560 | print("CLIP embeddings saved")
|
538 | 561 |
|
| 562 | + elif dataset == "handwriting": |
| 563 | + data, labels, num2alphabet, alphabets_hand = load_handwriting(cfg_dataset) |
| 564 | + # save data |
| 565 | + with Path(cfg_dataset.paths.save_path, "Handwriting_data.pkl").open("wb") as f: |
| 566 | + pickle.dump(data, f) |
| 567 | + print("Handwriting data saved") |
| 568 | + return |
| 569 | + |
| 570 | + embeddings = clip_imgs(alphabets_hand, 256) |
| 571 | + print("text shape:", embeddings.shape) |
| 572 | + with Path(cfg_dataset.paths.save_path, "Handwriting_emb_clip.pkl").open( |
| 573 | + "wb" |
| 574 | + ) as f: |
| 575 | + pickle.dump(embeddings, f) |
| 576 | + print("CLIP embeddings saved") |
| 577 | + |
| 578 | + sentences = [f"Alphabet {num2alphabet[label]}." for label in labels] |
| 579 | + print(sentences[15:21]) |
| 580 | + embeddings = gtr_text(sentences) |
| 581 | + assert np.allclose( |
| 582 | + embeddings[15], embeddings[20], atol=1e-3, rtol=1e-4 |
| 583 | + ), f"{embeddings[15].shape}!={embeddings[20].shape}" |
| 584 | + with Path(cfg_dataset.paths.save_path, "Handwriting_emb_gtr.pkl").open( |
| 585 | + "wb" |
| 586 | + ) as f: |
| 587 | + pickle.dump(embeddings, f) |
| 588 | + print("GTR shape:", embeddings.shape) |
| 589 | + print("GTR embeddings saved") |
| 590 | + |
| 591 | + embeddings = chronos_ts(data) |
| 592 | + # check if embeddings has unique rows |
| 593 | + assert embeddings.shape[0] == len( |
| 594 | + np.unique(embeddings, axis=0) |
| 595 | + ), f"Embeddings has repeated entries. {embeddings.shape[0]}!={len(np.unique(embeddings, axis=0))}" |
| 596 | + print("Chronos shape:", embeddings.shape) |
| 597 | + with Path(cfg_dataset.paths.save_path, "Handwriting_emb_chronos.pkl").open( |
| 598 | + "wb" |
| 599 | + ) as f: |
| 600 | + pickle.dump(embeddings, f) |
| 601 | + print("Chronos embeddings saved") |
| 602 | + |
539 | 603 | # TODO: add more datasets
|
540 | 604 | else:
|
541 | 605 | msg = f"Dataset {dataset} not supported."
|
|
0 commit comments