JeffreyXiang commited on
Commit
b7b00e2
·
1 Parent(s): cd41f5f

Add multiimage and gaussian

Browse files
app.py CHANGED
@@ -9,7 +9,6 @@ from typing import *
9
  import torch
10
  import numpy as np
11
  import imageio
12
- import uuid
13
  from easydict import EasyDict as edict
14
  from PIL import Image
15
  from trellis.pipelines import TrellisImageTo3DPipeline
@@ -24,17 +23,15 @@ os.makedirs(TMP_DIR, exist_ok=True)
24
 
25
  def start_session(req: gr.Request):
26
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
27
- print(f'Creating user directory: {user_dir}')
28
  os.makedirs(user_dir, exist_ok=True)
29
 
30
 
31
  def end_session(req: gr.Request):
32
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
33
- print(f'Removing user directory: {user_dir}')
34
  shutil.rmtree(user_dir)
35
 
36
 
37
- def preprocess_image(image: Image.Image) -> Tuple[str, Image.Image]:
38
  """
39
  Preprocess the input image.
40
 
@@ -42,14 +39,28 @@ def preprocess_image(image: Image.Image) -> Tuple[str, Image.Image]:
42
  image (Image.Image): The input image.
43
 
44
  Returns:
45
- str: uuid of the trial.
46
  Image.Image: The preprocessed image.
47
  """
48
  processed_image = pipeline.preprocess_image(image)
49
  return processed_image
50
 
51
 
52
- def pack_state(gs: Gaussian, mesh: MeshExtractResult, trial_id: str) -> dict:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  return {
54
  'gaussian': {
55
  **gs.init_params,
@@ -63,7 +74,6 @@ def pack_state(gs: Gaussian, mesh: MeshExtractResult, trial_id: str) -> dict:
63
  'vertices': mesh.vertices.cpu().numpy(),
64
  'faces': mesh.faces.cpu().numpy(),
65
  },
66
- 'trial_id': trial_id,
67
  }
68
 
69
 
@@ -87,7 +97,7 @@ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
87
  faces=torch.tensor(state['mesh']['faces'], device='cuda'),
88
  )
89
 
90
- return gs, mesh, state['trial_id']
91
 
92
 
93
  def get_seed(randomize_seed: bool, seed: int) -> int:
@@ -100,11 +110,14 @@ def get_seed(randomize_seed: bool, seed: int) -> int:
100
  @spaces.GPU
101
  def image_to_3d(
102
  image: Image.Image,
 
 
103
  seed: int,
104
  ss_guidance_strength: float,
105
  ss_sampling_steps: int,
106
  slat_guidance_strength: float,
107
  slat_sampling_steps: int,
 
108
  req: gr.Request,
109
  ) -> Tuple[dict, str]:
110
  """
@@ -112,43 +125,62 @@ def image_to_3d(
112
 
113
  Args:
114
  image (Image.Image): The input image.
 
 
115
  seed (int): The random seed.
116
  ss_guidance_strength (float): The guidance strength for sparse structure generation.
117
  ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
118
  slat_guidance_strength (float): The guidance strength for structured latent generation.
119
  slat_sampling_steps (int): The number of sampling steps for structured latent generation.
 
120
 
121
  Returns:
122
  dict: The information of the generated 3D model.
123
  str: The path to the video of the 3D model.
124
  """
125
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
126
- outputs = pipeline.run(
127
- image,
128
- seed=seed,
129
- formats=["gaussian", "mesh"],
130
- preprocess_image=False,
131
- sparse_structure_sampler_params={
132
- "steps": ss_sampling_steps,
133
- "cfg_strength": ss_guidance_strength,
134
- },
135
- slat_sampler_params={
136
- "steps": slat_sampling_steps,
137
- "cfg_strength": slat_guidance_strength,
138
- },
139
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
141
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
142
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
143
- trial_id = uuid.uuid4()
144
- video_path = os.path.join(user_dir, f"{trial_id}.mp4")
145
  imageio.mimsave(video_path, video, fps=15)
146
- state = pack_state(outputs['gaussian'][0], outputs['mesh'][0], trial_id)
147
  torch.cuda.empty_cache()
148
  return state, video_path
149
 
150
 
151
- @spaces.GPU
152
  def extract_glb(
153
  state: dict,
154
  mesh_simplify: float,
@@ -167,24 +199,83 @@ def extract_glb(
167
  str: The path to the extracted GLB file.
168
  """
169
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
170
- gs, mesh, trial_id = unpack_state(state)
171
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
172
- glb_path = os.path.join(user_dir, f"{trial_id}.glb")
173
  glb.export(glb_path)
174
  torch.cuda.empty_cache()
175
  return glb_path, glb_path
176
 
177
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  with gr.Blocks(delete_cache=(600, 600)) as demo:
179
  gr.Markdown("""
180
  ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
181
  * Upload an image and click "Generate" to create a 3D asset. If the image has alpha channel, it be used as the mask. Otherwise, we use `rembg` to remove the background.
182
  * If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
 
 
183
  """)
184
 
185
  with gr.Row():
186
  with gr.Column():
187
- image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=300)
 
 
 
 
 
 
 
 
 
188
 
189
  with gr.Accordion(label="Generation Settings", open=False):
190
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
@@ -197,6 +288,7 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
197
  with gr.Row():
198
  slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
199
  slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
 
200
 
201
  generate_btn = gr.Button("Generate")
202
 
@@ -204,17 +296,26 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
204
  mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
205
  texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
206
 
207
- extract_glb_btn = gr.Button("Extract GLB", interactive=False)
 
 
 
 
 
208
 
209
  with gr.Column():
210
  video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
211
- model_output = LitModel3D(label="Extracted GLB", exposure=20.0, height=300)
212
- download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
213
 
 
 
 
 
 
214
  output_buf = gr.State()
215
 
216
  # Example images at the bottom of the page
217
- with gr.Row():
218
  examples = gr.Examples(
219
  examples=[
220
  f'assets/example_image/{image}'
@@ -226,16 +327,39 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
226
  run_on_click=True,
227
  examples_per_page=64,
228
  )
 
 
 
 
 
 
 
 
 
229
 
230
  # Handlers
231
  demo.load(start_session)
232
  demo.unload(end_session)
233
 
 
 
 
 
 
 
 
 
 
234
  image_prompt.upload(
235
  preprocess_image,
236
  inputs=[image_prompt],
237
  outputs=[image_prompt],
238
  )
 
 
 
 
 
239
 
240
  generate_btn.click(
241
  get_seed,
@@ -243,16 +367,16 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
243
  outputs=[seed],
244
  ).then(
245
  image_to_3d,
246
- inputs=[image_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
247
  outputs=[output_buf, video_output],
248
  ).then(
249
- lambda: gr.Button(interactive=True),
250
- outputs=[extract_glb_btn],
251
  )
252
 
253
  video_output.clear(
254
- lambda: gr.Button(interactive=False),
255
- outputs=[extract_glb_btn],
256
  )
257
 
258
  extract_glb_btn.click(
@@ -263,6 +387,15 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
263
  lambda: gr.Button(interactive=True),
264
  outputs=[download_glb],
265
  )
 
 
 
 
 
 
 
 
 
266
 
267
  model_output.clear(
268
  lambda: gr.Button(interactive=False),
 
9
  import torch
10
  import numpy as np
11
  import imageio
 
12
  from easydict import EasyDict as edict
13
  from PIL import Image
14
  from trellis.pipelines import TrellisImageTo3DPipeline
 
23
 
24
  def start_session(req: gr.Request):
25
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
 
26
  os.makedirs(user_dir, exist_ok=True)
27
 
28
 
29
  def end_session(req: gr.Request):
30
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
 
31
  shutil.rmtree(user_dir)
32
 
33
 
34
+ def preprocess_image(image: Image.Image) -> Image.Image:
35
  """
36
  Preprocess the input image.
37
 
 
39
  image (Image.Image): The input image.
40
 
41
  Returns:
 
42
  Image.Image: The preprocessed image.
43
  """
44
  processed_image = pipeline.preprocess_image(image)
45
  return processed_image
46
 
47
 
48
+ def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
49
+ """
50
+ Preprocess a list of input images.
51
+
52
+ Args:
53
+ images (List[Tuple[Image.Image, str]]): The input images.
54
+
55
+ Returns:
56
+ List[Image.Image]: The preprocessed images.
57
+ """
58
+ images = [image[0] for image in images]
59
+ processed_images = [pipeline.preprocess_image(image) for image in images]
60
+ return processed_images
61
+
62
+
63
+ def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
64
  return {
65
  'gaussian': {
66
  **gs.init_params,
 
74
  'vertices': mesh.vertices.cpu().numpy(),
75
  'faces': mesh.faces.cpu().numpy(),
76
  },
 
77
  }
78
 
79
 
 
97
  faces=torch.tensor(state['mesh']['faces'], device='cuda'),
98
  )
99
 
100
+ return gs, mesh
101
 
102
 
103
  def get_seed(randomize_seed: bool, seed: int) -> int:
 
110
  @spaces.GPU
111
  def image_to_3d(
112
  image: Image.Image,
113
+ multiimages: List[Tuple[Image.Image, str]],
114
+ is_multiimage: bool,
115
  seed: int,
116
  ss_guidance_strength: float,
117
  ss_sampling_steps: int,
118
  slat_guidance_strength: float,
119
  slat_sampling_steps: int,
120
+ multiimage_algo: Literal["multidiffusion", "stochastic"],
121
  req: gr.Request,
122
  ) -> Tuple[dict, str]:
123
  """
 
125
 
126
  Args:
127
  image (Image.Image): The input image.
128
+ multiimages (List[Tuple[Image.Image, str]]): The input images in multi-image mode.
129
+ is_multiimage (bool): Whether is in multi-image mode.
130
  seed (int): The random seed.
131
  ss_guidance_strength (float): The guidance strength for sparse structure generation.
132
  ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
133
  slat_guidance_strength (float): The guidance strength for structured latent generation.
134
  slat_sampling_steps (int): The number of sampling steps for structured latent generation.
135
+ multiimage_algo (Literal["multidiffusion", "stochastic"]): The algorithm for multi-image generation.
136
 
137
  Returns:
138
  dict: The information of the generated 3D model.
139
  str: The path to the video of the 3D model.
140
  """
141
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
142
+ if not is_multiimage:
143
+ outputs = pipeline.run(
144
+ image,
145
+ seed=seed,
146
+ formats=["gaussian", "mesh"],
147
+ preprocess_image=False,
148
+ sparse_structure_sampler_params={
149
+ "steps": ss_sampling_steps,
150
+ "cfg_strength": ss_guidance_strength,
151
+ },
152
+ slat_sampler_params={
153
+ "steps": slat_sampling_steps,
154
+ "cfg_strength": slat_guidance_strength,
155
+ },
156
+ )
157
+ else:
158
+ outputs = pipeline.run_multi_image(
159
+ [image[0] for image in multiimages],
160
+ seed=seed,
161
+ formats=["gaussian", "mesh"],
162
+ preprocess_image=False,
163
+ sparse_structure_sampler_params={
164
+ "steps": ss_sampling_steps,
165
+ "cfg_strength": ss_guidance_strength,
166
+ },
167
+ slat_sampler_params={
168
+ "steps": slat_sampling_steps,
169
+ "cfg_strength": slat_guidance_strength,
170
+ },
171
+ mode=multiimage_algo,
172
+ )
173
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
174
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
175
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
176
+ video_path = os.path.join(user_dir, 'sample.mp4')
 
177
  imageio.mimsave(video_path, video, fps=15)
178
+ state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
179
  torch.cuda.empty_cache()
180
  return state, video_path
181
 
182
 
183
+ @spaces.GPU(duration=90)
184
  def extract_glb(
185
  state: dict,
186
  mesh_simplify: float,
 
199
  str: The path to the extracted GLB file.
200
  """
201
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
202
+ gs, mesh = unpack_state(state)
203
  glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
204
+ glb_path = os.path.join(user_dir, 'sample.glb')
205
  glb.export(glb_path)
206
  torch.cuda.empty_cache()
207
  return glb_path, glb_path
208
 
209
 
210
+ @spaces.GPU
211
+ def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
212
+ """
213
+ Extract a Gaussian file from the 3D model.
214
+
215
+ Args:
216
+ state (dict): The state of the generated 3D model.
217
+
218
+ Returns:
219
+ str: The path to the extracted Gaussian file.
220
+ """
221
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
222
+ gs, _ = unpack_state(state)
223
+ gaussian_path = os.path.join(user_dir, 'sample.ply')
224
+ gs.save_ply(gaussian_path)
225
+ torch.cuda.empty_cache()
226
+ return gaussian_path, gaussian_path
227
+
228
+
229
+ def prepare_multi_example() -> List[Image.Image]:
230
+ multi_case = list(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")]))
231
+ images = []
232
+ for case in multi_case:
233
+ _images = []
234
+ for i in range(1, 4):
235
+ img = Image.open(f'assets/example_multi_image/{case}_{i}.png')
236
+ W, H = img.size
237
+ img = img.resize((int(W / H * 512), 512))
238
+ _images.append(np.array(img))
239
+ images.append(Image.fromarray(np.concatenate(_images, axis=1)))
240
+ return images
241
+
242
+
243
+ def split_image(image: Image.Image) -> List[Image.Image]:
244
+ """
245
+ Split an image into multiple views.
246
+ """
247
+ image = np.array(image)
248
+ alpha = image[..., 3]
249
+ alpha = np.any(alpha>0, axis=0)
250
+ start_pos = np.where(~alpha[:-1] & alpha[1:])[0].tolist()
251
+ end_pos = np.where(alpha[:-1] & ~alpha[1:])[0].tolist()
252
+ images = []
253
+ for s, e in zip(start_pos, end_pos):
254
+ images.append(Image.fromarray(image[:, s:e+1]))
255
+ return [preprocess_image(image) for image in images]
256
+
257
+
258
  with gr.Blocks(delete_cache=(600, 600)) as demo:
259
  gr.Markdown("""
260
  ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
261
  * Upload an image and click "Generate" to create a 3D asset. If the image has alpha channel, it be used as the mask. Otherwise, we use `rembg` to remove the background.
262
  * If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
263
+
264
+ ✨New: 1) Experimental multi-image support. 2) Gaussian file extraction.
265
  """)
266
 
267
  with gr.Row():
268
  with gr.Column():
269
+ with gr.Tabs() as input_tabs:
270
+ with gr.Tab(label="Single Image", id=0) as single_image_input_tab:
271
+ image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=300)
272
+ with gr.Tab(label="Multiple Images", id=1) as multiimage_input_tab:
273
+ multiimage_prompt = gr.Gallery(label="Image Prompt", format="png", type="pil", height=300, columns=3)
274
+ gr.Markdown("""
275
+ Input different views of the object in separate images.
276
+
277
+ *NOTE: this is an experimental algorithm without training a specialized model. It may not produce the best results for all images, especially those having different poses or inconsistent details.*
278
+ """)
279
 
280
  with gr.Accordion(label="Generation Settings", open=False):
281
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
 
288
  with gr.Row():
289
  slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
290
  slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
291
+ multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Multi-image Algorithm", value="stochastic")
292
 
293
  generate_btn = gr.Button("Generate")
294
 
 
296
  mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
297
  texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
298
 
299
+ with gr.Row():
300
+ extract_glb_btn = gr.Button("Extract GLB", interactive=False)
301
+ extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
302
+ gr.Markdown("""
303
+ *NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.*
304
+ """)
305
 
306
  with gr.Column():
307
  video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
308
+ model_output = LitModel3D(label="Extracted GLB/Gaussian", exposure=10.0, height=300)
 
309
 
310
+ with gr.Row():
311
+ download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
312
+ download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
313
+
314
+ is_multiimage = gr.State(False)
315
  output_buf = gr.State()
316
 
317
  # Example images at the bottom of the page
318
+ with gr.Row() as single_image_example:
319
  examples = gr.Examples(
320
  examples=[
321
  f'assets/example_image/{image}'
 
327
  run_on_click=True,
328
  examples_per_page=64,
329
  )
330
+ with gr.Row(visible=False) as multiimage_example:
331
+ examples_multi = gr.Examples(
332
+ examples=prepare_multi_example(),
333
+ inputs=[image_prompt],
334
+ fn=split_image,
335
+ outputs=[multiimage_prompt],
336
+ run_on_click=True,
337
+ examples_per_page=8,
338
+ )
339
 
340
  # Handlers
341
  demo.load(start_session)
342
  demo.unload(end_session)
343
 
344
+ single_image_input_tab.select(
345
+ lambda: tuple([False, gr.Row.update(visible=True), gr.Row.update(visible=False)]),
346
+ outputs=[is_multiimage, single_image_example, multiimage_example]
347
+ )
348
+ multiimage_input_tab.select(
349
+ lambda: tuple([True, gr.Row.update(visible=False), gr.Row.update(visible=True)]),
350
+ outputs=[is_multiimage, single_image_example, multiimage_example]
351
+ )
352
+
353
  image_prompt.upload(
354
  preprocess_image,
355
  inputs=[image_prompt],
356
  outputs=[image_prompt],
357
  )
358
+ multiimage_prompt.upload(
359
+ preprocess_images,
360
+ inputs=[multiimage_prompt],
361
+ outputs=[multiimage_prompt],
362
+ )
363
 
364
  generate_btn.click(
365
  get_seed,
 
367
  outputs=[seed],
368
  ).then(
369
  image_to_3d,
370
+ inputs=[image_prompt, multiimage_prompt, is_multiimage, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo],
371
  outputs=[output_buf, video_output],
372
  ).then(
373
+ lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
374
+ outputs=[extract_glb_btn, extract_gs_btn],
375
  )
376
 
377
  video_output.clear(
378
+ lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
379
+ outputs=[extract_glb_btn, extract_gs_btn],
380
  )
381
 
382
  extract_glb_btn.click(
 
387
  lambda: gr.Button(interactive=True),
388
  outputs=[download_glb],
389
  )
390
+
391
+ extract_gs_btn.click(
392
+ extract_gaussian,
393
+ inputs=[output_buf],
394
+ outputs=[model_output, download_gs],
395
+ ).then(
396
+ lambda: gr.Button(interactive=True),
397
+ outputs=[download_gs],
398
+ )
399
 
400
  model_output.clear(
401
  lambda: gr.Button(interactive=False),
assets/example_multi_image/character_1.png ADDED
assets/example_multi_image/character_2.png ADDED
assets/example_multi_image/character_3.png ADDED
assets/example_multi_image/mushroom_1.png ADDED
assets/example_multi_image/mushroom_2.png ADDED
assets/example_multi_image/mushroom_3.png ADDED
assets/example_multi_image/orangeguy_1.png ADDED
assets/example_multi_image/orangeguy_2.png ADDED
assets/example_multi_image/orangeguy_3.png ADDED
assets/example_multi_image/popmart_1.png ADDED
assets/example_multi_image/popmart_2.png ADDED
assets/example_multi_image/popmart_3.png ADDED
assets/example_multi_image/rabbit_1.png ADDED
assets/example_multi_image/rabbit_2.png ADDED
assets/example_multi_image/rabbit_3.png ADDED
assets/example_multi_image/tiger_1.png ADDED
assets/example_multi_image/tiger_2.png ADDED
assets/example_multi_image/tiger_3.png ADDED
assets/example_multi_image/yoimiya_1.png ADDED
assets/example_multi_image/yoimiya_2.png ADDED
assets/example_multi_image/yoimiya_3.png ADDED
trellis/pipelines/trellis_image_to_3d.py CHANGED
@@ -1,4 +1,5 @@
1
  from typing import *
 
2
  import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
@@ -281,3 +282,95 @@ class TrellisImageTo3DPipeline(Pipeline):
281
  coords = self.sample_sparse_structure(cond, num_samples, sparse_structure_sampler_params)
282
  slat = self.sample_slat(cond, coords, slat_sampler_params)
283
  return self.decode_slat(slat, formats)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from typing import *
2
+ from contextlib import contextmanager
3
  import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
 
282
  coords = self.sample_sparse_structure(cond, num_samples, sparse_structure_sampler_params)
283
  slat = self.sample_slat(cond, coords, slat_sampler_params)
284
  return self.decode_slat(slat, formats)
285
+
286
+ @contextmanager
287
+ def inject_sampler_multi_image(
288
+ self,
289
+ sampler_name: str,
290
+ num_images: int,
291
+ num_steps: int,
292
+ mode: Literal['stochastic', 'multidiffusion'] = 'stochastic',
293
+ ):
294
+ """
295
+ Inject a sampler with multiple images as condition.
296
+
297
+ Args:
298
+ sampler_name (str): The name of the sampler to inject.
299
+ num_images (int): The number of images to condition on.
300
+ num_steps (int): The number of steps to run the sampler for.
301
+ """
302
+ sampler = getattr(self, sampler_name)
303
+ setattr(sampler, f'_old_inference_model', sampler._inference_model)
304
+
305
+ if mode == 'stochastic':
306
+ if num_images > num_steps:
307
+ print(f"\033[93mWarning: number of conditioning images is greater than number of steps for {sampler_name}. "
308
+ "This may lead to performance degradation.\033[0m")
309
+
310
+ cond_indices = (np.arange(num_steps) % num_images).tolist()
311
+ def _new_inference_model(self, model, x_t, t, cond, **kwargs):
312
+ cond_idx = cond_indices.pop(0)
313
+ cond_i = cond[cond_idx:cond_idx+1]
314
+ return self._old_inference_model(model, x_t, t, cond=cond_i, **kwargs)
315
+
316
+ elif mode =='multidiffusion':
317
+ from .samplers import FlowEulerSampler
318
+ def _new_inference_model(self, model, x_t, t, cond, neg_cond, cfg_strength, cfg_interval, **kwargs):
319
+ if cfg_interval[0] <= t <= cfg_interval[1]:
320
+ preds = []
321
+ for i in range(len(cond)):
322
+ preds.append(FlowEulerSampler._inference_model(self, model, x_t, t, cond[i:i+1], **kwargs))
323
+ pred = sum(preds) / len(preds)
324
+ neg_pred = FlowEulerSampler._inference_model(self, model, x_t, t, neg_cond, **kwargs)
325
+ return (1 + cfg_strength) * pred - cfg_strength * neg_pred
326
+ else:
327
+ preds = []
328
+ for i in range(len(cond)):
329
+ preds.append(FlowEulerSampler._inference_model(self, model, x_t, t, cond[i:i+1], **kwargs))
330
+ pred = sum(preds) / len(preds)
331
+ return pred
332
+
333
+ else:
334
+ raise ValueError(f"Unsupported mode: {mode}")
335
+
336
+ sampler._inference_model = _new_inference_model.__get__(sampler, type(sampler))
337
+
338
+ yield
339
+
340
+ sampler._inference_model = sampler._old_inference_model
341
+ delattr(sampler, f'_old_inference_model')
342
+
343
+ @torch.no_grad()
344
+ def run_multi_image(
345
+ self,
346
+ images: List[Image.Image],
347
+ num_samples: int = 1,
348
+ seed: int = 42,
349
+ sparse_structure_sampler_params: dict = {},
350
+ slat_sampler_params: dict = {},
351
+ formats: List[str] = ['mesh', 'gaussian', 'radiance_field'],
352
+ preprocess_image: bool = True,
353
+ mode: Literal['stochastic', 'multidiffusion'] = 'stochastic',
354
+ ) -> dict:
355
+ """
356
+ Run the pipeline with multiple images as condition
357
+
358
+ Args:
359
+ images (List[Image.Image]): The multi-view images of the assets
360
+ num_samples (int): The number of samples to generate.
361
+ sparse_structure_sampler_params (dict): Additional parameters for the sparse structure sampler.
362
+ slat_sampler_params (dict): Additional parameters for the structured latent sampler.
363
+ preprocess_image (bool): Whether to preprocess the image.
364
+ """
365
+ if preprocess_image:
366
+ images = [self.preprocess_image(image) for image in images]
367
+ cond = self.get_cond(images)
368
+ cond['neg_cond'] = cond['neg_cond'][:1]
369
+ torch.manual_seed(seed)
370
+ ss_steps = {**self.sparse_structure_sampler_params, **sparse_structure_sampler_params}.get('steps')
371
+ with self.inject_sampler_multi_image('sparse_structure_sampler', len(images), ss_steps, mode=mode):
372
+ coords = self.sample_sparse_structure(cond, num_samples, sparse_structure_sampler_params)
373
+ slat_steps = {**self.slat_sampler_params, **slat_sampler_params}.get('steps')
374
+ with self.inject_sampler_multi_image('slat_sampler', len(images), slat_steps, mode=mode):
375
+ slat = self.sample_slat(cond, coords, slat_sampler_params)
376
+ return self.decode_slat(slat, formats)
trellis/representations/gaussian/gaussian_model.py CHANGED
@@ -2,6 +2,7 @@ import torch
2
  import numpy as np
3
  from plyfile import PlyData, PlyElement
4
  from .general_utils import inverse_sigmoid, strip_symmetric, build_scaling_rotation
 
5
 
6
 
7
  class Gaussian:
@@ -120,14 +121,21 @@ class Gaussian:
120
  for i in range(self._rotation.shape[1]):
121
  l.append('rot_{}'.format(i))
122
  return l
123
-
124
- def save_ply(self, path):
125
  xyz = self.get_xyz.detach().cpu().numpy()
126
  normals = np.zeros_like(xyz)
127
  f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
128
  opacities = inverse_sigmoid(self.get_opacity).detach().cpu().numpy()
129
  scale = torch.log(self.get_scaling).detach().cpu().numpy()
130
  rotation = (self._rotation + self.rots_bias[None, :]).detach().cpu().numpy()
 
 
 
 
 
 
 
131
 
132
  dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()]
133
 
@@ -137,7 +145,7 @@ class Gaussian:
137
  el = PlyElement.describe(elements, 'vertex')
138
  PlyData([el]).write(path)
139
 
140
- def load_ply(self, path):
141
  plydata = PlyData.read(path)
142
 
143
  xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
@@ -172,6 +180,13 @@ class Gaussian:
172
  for idx, attr_name in enumerate(rot_names):
173
  rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
174
 
 
 
 
 
 
 
 
175
  # convert to actual gaussian attributes
176
  xyz = torch.tensor(xyz, dtype=torch.float, device=self.device)
177
  features_dc = torch.tensor(features_dc, dtype=torch.float, device=self.device).transpose(1, 2).contiguous()
 
2
  import numpy as np
3
  from plyfile import PlyData, PlyElement
4
  from .general_utils import inverse_sigmoid, strip_symmetric, build_scaling_rotation
5
+ import utils3d
6
 
7
 
8
  class Gaussian:
 
121
  for i in range(self._rotation.shape[1]):
122
  l.append('rot_{}'.format(i))
123
  return l
124
+
125
+ def save_ply(self, path, transform=[[1, 0, 0], [0, 0, -1], [0, 1, 0]]):
126
  xyz = self.get_xyz.detach().cpu().numpy()
127
  normals = np.zeros_like(xyz)
128
  f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
129
  opacities = inverse_sigmoid(self.get_opacity).detach().cpu().numpy()
130
  scale = torch.log(self.get_scaling).detach().cpu().numpy()
131
  rotation = (self._rotation + self.rots_bias[None, :]).detach().cpu().numpy()
132
+
133
+ if transform is not None:
134
+ transform = np.array(transform)
135
+ xyz = np.matmul(xyz, transform.T)
136
+ rotation = utils3d.numpy.quaternion_to_matrix(rotation)
137
+ rotation = np.matmul(transform, rotation)
138
+ rotation = utils3d.numpy.matrix_to_quaternion(rotation)
139
 
140
  dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()]
141
 
 
145
  el = PlyElement.describe(elements, 'vertex')
146
  PlyData([el]).write(path)
147
 
148
+ def load_ply(self, path, transform=[[1, 0, 0], [0, 0, -1], [0, 1, 0]]):
149
  plydata = PlyData.read(path)
150
 
151
  xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
 
180
  for idx, attr_name in enumerate(rot_names):
181
  rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
182
 
183
+ if transform is not None:
184
+ transform = np.array(transform)
185
+ xyz = np.matmul(xyz, transform)
186
+ rotation = utils3d.numpy.quaternion_to_matrix(rotation)
187
+ rotation = np.matmul(rotation, transform)
188
+ rotation = utils3d.numpy.matrix_to_quaternion(rotation)
189
+
190
  # convert to actual gaussian attributes
191
  xyz = torch.tensor(xyz, dtype=torch.float, device=self.device)
192
  features_dc = torch.tensor(features_dc, dtype=torch.float, device=self.device).transpose(1, 2).contiguous()
trellis/utils/postprocessing_utils.py CHANGED
@@ -14,6 +14,7 @@ import cv2
14
  from PIL import Image
15
  from .random_utils import sphere_hammersley_sequence
16
  from .render_utils import render_multiview
 
17
  from ..representations import Strivec, Gaussian, MeshExtractResult
18
 
19
 
@@ -454,5 +455,133 @@ def to_glb(
454
 
455
  # rotate mesh (from z-up to y-up)
456
  vertices = vertices @ np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]])
457
- mesh = trimesh.Trimesh(vertices, faces, visual=trimesh.visual.TextureVisuals(uv=uvs, image=texture))
 
 
 
 
 
458
  return mesh
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  from PIL import Image
15
  from .random_utils import sphere_hammersley_sequence
16
  from .render_utils import render_multiview
17
+ from ..renderers import GaussianRenderer
18
  from ..representations import Strivec, Gaussian, MeshExtractResult
19
 
20
 
 
455
 
456
  # rotate mesh (from z-up to y-up)
457
  vertices = vertices @ np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]])
458
+ material = trimesh.visual.material.PBRMaterial(
459
+ roughnessFactor=1.0,
460
+ baseColorTexture=texture,
461
+ baseColorFactor=np.array([255, 255, 255, 255], dtype=np.uint8)
462
+ )
463
+ mesh = trimesh.Trimesh(vertices, faces, visual=trimesh.visual.TextureVisuals(uv=uvs, material=material))
464
  return mesh
465
+
466
+
467
+ def simplify_gs(
468
+ gs: Gaussian,
469
+ simplify: float = 0.95,
470
+ verbose: bool = True,
471
+ ):
472
+ """
473
+ Simplify 3D Gaussians
474
+ NOTE: this function is not used in the current implementation for the unsatisfactory performance.
475
+
476
+ Args:
477
+ gs (Gaussian): 3D Gaussian.
478
+ simplify (float): Ratio of Gaussians to remove in simplification.
479
+ """
480
+ if simplify <= 0:
481
+ return gs
482
+
483
+ # simplify
484
+ observations, extrinsics, intrinsics = render_multiview(gs, resolution=1024, nviews=100)
485
+ observations = [torch.tensor(obs / 255.0).float().cuda().permute(2, 0, 1) for obs in observations]
486
+
487
+ # Following https://arxiv.org/pdf/2411.06019
488
+ renderer = GaussianRenderer({
489
+ "resolution": 1024,
490
+ "near": 0.8,
491
+ "far": 1.6,
492
+ "ssaa": 1,
493
+ "bg_color": (0,0,0),
494
+ })
495
+ new_gs = Gaussian(**gs.init_params)
496
+ new_gs._features_dc = gs._features_dc.clone()
497
+ new_gs._features_rest = gs._features_rest.clone() if gs._features_rest is not None else None
498
+ new_gs._opacity = torch.nn.Parameter(gs._opacity.clone())
499
+ new_gs._rotation = torch.nn.Parameter(gs._rotation.clone())
500
+ new_gs._scaling = torch.nn.Parameter(gs._scaling.clone())
501
+ new_gs._xyz = torch.nn.Parameter(gs._xyz.clone())
502
+
503
+ start_lr = [1e-4, 1e-3, 5e-3, 0.025]
504
+ end_lr = [1e-6, 1e-5, 5e-5, 0.00025]
505
+ optimizer = torch.optim.Adam([
506
+ {"params": new_gs._xyz, "lr": start_lr[0]},
507
+ {"params": new_gs._rotation, "lr": start_lr[1]},
508
+ {"params": new_gs._scaling, "lr": start_lr[2]},
509
+ {"params": new_gs._opacity, "lr": start_lr[3]},
510
+ ], lr=start_lr[0])
511
+
512
+ def exp_anealing(optimizer, step, total_steps, start_lr, end_lr):
513
+ return start_lr * (end_lr / start_lr) ** (step / total_steps)
514
+
515
+ def cosine_anealing(optimizer, step, total_steps, start_lr, end_lr):
516
+ return end_lr + 0.5 * (start_lr - end_lr) * (1 + np.cos(np.pi * step / total_steps))
517
+
518
+ _zeta = new_gs.get_opacity.clone().detach().squeeze()
519
+ _lambda = torch.zeros_like(_zeta)
520
+ _delta = 1e-7
521
+ _interval = 10
522
+ num_target = int((1 - simplify) * _zeta.shape[0])
523
+
524
+ with tqdm(total=2500, disable=not verbose, desc='Simplifying Gaussian') as pbar:
525
+ for i in range(2500):
526
+ # prune
527
+ if i % 100 == 0:
528
+ mask = new_gs.get_opacity.squeeze() > 0.05
529
+ mask = torch.nonzero(mask).squeeze()
530
+ new_gs._xyz = torch.nn.Parameter(new_gs._xyz[mask])
531
+ new_gs._rotation = torch.nn.Parameter(new_gs._rotation[mask])
532
+ new_gs._scaling = torch.nn.Parameter(new_gs._scaling[mask])
533
+ new_gs._opacity = torch.nn.Parameter(new_gs._opacity[mask])
534
+ new_gs._features_dc = new_gs._features_dc[mask]
535
+ new_gs._features_rest = new_gs._features_rest[mask] if new_gs._features_rest is not None else None
536
+ _zeta = _zeta[mask]
537
+ _lambda = _lambda[mask]
538
+ # update optimizer state
539
+ for param_group, new_param in zip(optimizer.param_groups, [new_gs._xyz, new_gs._rotation, new_gs._scaling, new_gs._opacity]):
540
+ stored_state = optimizer.state[param_group['params'][0]]
541
+ if 'exp_avg' in stored_state:
542
+ stored_state['exp_avg'] = stored_state['exp_avg'][mask]
543
+ stored_state['exp_avg_sq'] = stored_state['exp_avg_sq'][mask]
544
+ del optimizer.state[param_group['params'][0]]
545
+ param_group['params'][0] = new_param
546
+ optimizer.state[param_group['params'][0]] = stored_state
547
+
548
+ opacity = new_gs.get_opacity.squeeze()
549
+
550
+ # sparisfy
551
+ if i % _interval == 0:
552
+ _zeta = _lambda + opacity.detach()
553
+ if opacity.shape[0] > num_target:
554
+ index = _zeta.topk(num_target)[1]
555
+ _m = torch.ones_like(_zeta, dtype=torch.bool)
556
+ _m[index] = 0
557
+ _zeta[_m] = 0
558
+ _lambda = _lambda + opacity.detach() - _zeta
559
+
560
+ # sample a random view
561
+ view_idx = np.random.randint(len(observations))
562
+ observation = observations[view_idx]
563
+ extrinsic = extrinsics[view_idx]
564
+ intrinsic = intrinsics[view_idx]
565
+
566
+ color = renderer.render(new_gs, extrinsic, intrinsic)['color']
567
+ rgb_loss = torch.nn.functional.l1_loss(color, observation)
568
+ loss = rgb_loss + \
569
+ _delta * torch.sum(torch.pow(_lambda + opacity - _zeta, 2))
570
+
571
+ optimizer.zero_grad()
572
+ loss.backward()
573
+ optimizer.step()
574
+
575
+ # update lr
576
+ for j in range(len(optimizer.param_groups)):
577
+ optimizer.param_groups[j]['lr'] = cosine_anealing(optimizer, i, 2500, start_lr[j], end_lr[j])
578
+
579
+ pbar.set_postfix({'loss': rgb_loss.item(), 'num': opacity.shape[0], 'lambda': _lambda.mean().item()})
580
+ pbar.update()
581
+
582
+ new_gs._xyz = new_gs._xyz.data
583
+ new_gs._rotation = new_gs._rotation.data
584
+ new_gs._scaling = new_gs._scaling.data
585
+ new_gs._opacity = new_gs._opacity.data
586
+
587
+ return new_gs