Skip to content

Commit

Permalink
auto download model weights
Browse files Browse the repository at this point in the history
  • Loading branch information
bluestyle97 committed Apr 12, 2024
1 parent d23bedb commit e17ac11
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from omegaconf import OmegaConf
from einops import rearrange, repeat
from tqdm import tqdm
from huggingface_hub import hf_hub_download
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler

from src.utils.train_util import instantiate_from_config
Expand Down Expand Up @@ -106,15 +107,23 @@ def render_frames(model, planes, render_cameras, render_size=512, chunk_size=1,

# load custom white-background UNet
print('Loading custom white-background unet ...')
state_dict = torch.load(infer_config.unet_path, map_location='cpu')
if os.path.exists(infer_config.unet_path):
unet_ckpt_path = infer_config.unet_path
else:
unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model")
state_dict = torch.load(unet_ckpt_path, map_location='cpu')
pipeline.unet.load_state_dict(state_dict, strict=True)

pipeline = pipeline.to(device)

# load reconstruction model
print('Loading reconstruction model ...')
model = instantiate_from_config(model_config)
state_dict = torch.load(infer_config.model_path, map_location='cpu')['state_dict']
if os.path.exists(infer_config.model_path):
model_ckpt_path = infer_config.model_path
else:
model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename=f"{config_name.replace('-', '_')}.ckpt", repo_type="model")
state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
model.load_state_dict(state_dict, strict=True)

Expand Down

0 comments on commit e17ac11

Please sign in to comment.