Skip to content

Commit

Permalink
feat: update Gradio script to allow choosing of model to load
Browse files Browse the repository at this point in the history
  • Loading branch information
MrLemur committed Jan 7, 2025
1 parent 45f44de commit 51d87cb
Showing 1 changed file with 20 additions and 8 deletions.
28 changes: 20 additions & 8 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import argparse
import imageio
import numpy as np
import torch
Expand Down Expand Up @@ -66,14 +67,19 @@ def images_to_video(images, output_path, fps=30):


###############################################################################
# Configuration.
# Arguments.
###############################################################################

seed_everything(0)
parser = argparse.ArgumentParser()
parser.add_argument('config', nargs='?', type=str, help='Path to config file.', default='configs/instant-mesh-large.yaml')
args = parser.parse_args()

config_path = 'configs/instant-mesh-large.yaml'
config = OmegaConf.load(config_path)
config_name = os.path.basename(config_path).replace('.yaml', '')
###############################################################################
# Configuration.
###############################################################################
seed_everything(0)
config = OmegaConf.load(args.config)
config_name = os.path.basename(args.config).replace('.yaml', '')
model_config = config.model_config
infer_config = config.infer_config

Expand All @@ -94,16 +100,22 @@ def images_to_video(images, output_path, fps=30):
)

# load custom white-background UNet
unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model", cache_dir=model_cache_dir)
print('Loading custom white-background unet ...')
if os.path.exists(infer_config.unet_path):
unet_ckpt_path = infer_config.unet_path
else:
unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model", cache_dir=model_cache_dir)
state_dict = torch.load(unet_ckpt_path, map_location='cpu')
pipeline.unet.load_state_dict(state_dict, strict=True)

pipeline = pipeline.to(device0)

# load reconstruction model
print('Loading reconstruction model ...')
model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="instant_mesh_large.ckpt", repo_type="model", cache_dir=model_cache_dir)
model = instantiate_from_config(model_config)
if os.path.exists(infer_config.model_path):
model_ckpt_path = infer_config.model_path
else:
model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename=f"{config_name.replace('-', '_')}.ckpt", repo_type="model", cache_dir=model_cache_dir)
state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.') and 'source_camera' not in k}
model.load_state_dict(state_dict, strict=True)
Expand Down

0 comments on commit 51d87cb

Please sign in to comment.