Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion library/lumina_train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,6 @@ def time_shift(mu: float, sigma: float, t: torch.Tensor):
# Since we adopt the reverse, the 1-t operations are needed
t = 1 - t
t = math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
t = 1 - t
return t


Expand Down Expand Up @@ -1060,6 +1059,7 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser):
default=1.0,
help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"の場合のみ有効)。',
)

parser.add_argument(
"--model_prediction_type",
choices=["raw", "additive", "sigma_scaled"],
Expand Down
3 changes: 3 additions & 0 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5525,6 +5525,9 @@ def prepare_accelerator(args: argparse.Namespace):
if args.ddp_gradient_as_bucket_view or args.ddp_static_graph
else None
),
(
DistributedDataParallelKwargs(find_unused_parameters=True)
),
Comment on lines +5528 to +5530
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What was the purpose of this addition? I would appreciate an explanation.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to my testing for full fintune lumina image model on multigpu you will get this error "expected gradient for parameter … but none found", so adding will handle this problem and train normal on multi-gpu without error

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the explanation. This function is commonly called by all model training scripts, so any changes made here will require testing all models.

I think it might be a good idea to find out why Lumina needs this argument and solve that problem.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can add a flag to enable this when fine-tuning all Lumina models. Could improve flexibility.

]
kwargs_handlers = [i for i in kwargs_handlers if i is not None]
deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args)
Expand Down
134 changes: 71 additions & 63 deletions lumina_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,70 +361,78 @@ def train(args):
# 学習に必要なクラスを準備する
accelerator.print("prepare optimizer, data loader etc.")

if args.blockwise_fused_optimizers:
# fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html
# Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each block of parameters.
# This balances memory usage and management complexity.

# split params into groups. currently different learning rates are not supported
grouped_params = []
param_group = {}
for group in params_to_optimize:
named_parameters = list(nextdit.named_parameters())
assert len(named_parameters) == len(
group["params"]
), "number of parameters does not match"
for p, np in zip(group["params"], named_parameters):
# determine target layer and block index for each parameter
block_type = "other" # double, single or other
if np[0].startswith("double_blocks"):
block_index = int(np[0].split(".")[1])
block_type = "double"
elif np[0].startswith("single_blocks"):
block_index = int(np[0].split(".")[1])
block_type = "single"
else:
block_index = -1

param_group_key = (block_type, block_index)
if param_group_key not in param_group:
param_group[param_group_key] = []
param_group[param_group_key].append(p)

block_types_and_indices = []
for param_group_key, param_group in param_group.items():
block_types_and_indices.append(param_group_key)
grouped_params.append({"params": param_group, "lr": args.learning_rate})

num_params = 0
for p in param_group:
num_params += p.numel()
accelerator.print(f"block {param_group_key}: {num_params} parameters")

# prepare optimizers for each group
optimizers = []
for group in grouped_params:
_, _, optimizer = train_util.get_optimizer(args, trainable_params=[group])
optimizers.append(optimizer)
optimizer = optimizers[0] # avoid error in the following code

logger.info(
f"using {len(optimizers)} optimizers for blockwise fused optimizers"
)

if train_util.is_schedulefree_optimizer(optimizers[0], args):
raise ValueError(
"Schedule-free optimizer is not supported with blockwise fused optimizers"
)
optimizer_train_fn = lambda: None # dummy function
optimizer_eval_fn = lambda: None # dummy function
else:
_, _, optimizer = train_util.get_optimizer(
# if args.blockwise_fused_optimizers:
# # fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html
# # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each block of parameters.
# # This balances memory usage and management complexity.

# # split params into groups. currently different learning rates are not supported
# grouped_params = []
# param_group = {}
# for group in params_to_optimize:
# named_parameters = list(nextdit.named_parameters())
# assert len(named_parameters) == len(
# group["params"]
# ), "number of parameters does not match"
# for p, np in zip(group["params"], named_parameters):
# # determine target layer and block index for each parameter
# block_type = "other" # double, single or other
# if np[0].startswith("double_blocks"):
# block_index = int(np[0].split(".")[1])
# block_type = "double"
# elif np[0].startswith("single_blocks"):
# block_index = int(np[0].split(".")[1])
# block_type = "single"
# else:
# block_index = -1

# param_group_key = (block_type, block_index)
# if param_group_key not in param_group:
# param_group[param_group_key] = []
# param_group[param_group_key].append(p)

# block_types_and_indices = []
# for param_group_key, param_group in param_group.items():
# block_types_and_indices.append(param_group_key)
# grouped_params.append({"params": param_group, "lr": args.learning_rate})

# num_params = 0
# for p in param_group:
# num_params += p.numel()
# accelerator.print(f"block {param_group_key}: {num_params} parameters")

# # prepare optimizers for each group
# optimizers = []
# for group in grouped_params:
# _, _, optimizer = train_util.get_optimizer(args, trainable_params=[group])
# optimizers.append(optimizer)
# optimizer = optimizers[0] # avoid error in the following code

# logger.info(
# f"using {len(optimizers)} optimizers for blockwise fused optimizers"
# )

# if train_util.is_schedulefree_optimizer(optimizers[0], args):
# raise ValueError(
# "Schedule-free optimizer is not supported with blockwise fused optimizers"
# )
# optimizer_train_fn = lambda: None # dummy function
# optimizer_eval_fn = lambda: None # dummy function
# else:
# _, _, optimizer = train_util.get_optimizer(
# args, trainable_params=params_to_optimize
# )
# optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(
# optimizer, args
# )

#Currently when using blockwise_fused_optimizers the weight of model is not updated.
_, _, optimizer = train_util.get_optimizer(
args, trainable_params=params_to_optimize
)
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(
optimizer, args
)
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(
optimizer, args
)

# prepare dataloader
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
Expand Down Expand Up @@ -743,7 +751,7 @@ def grad_hook(parameter: torch.Tensor):
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
model_pred = nextdit(
x=noisy_model_input, # image latents (B, C, H, W)
t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期
t= 1 - timesteps / 1000, # timesteps需要除以1000来匹配模型预期
cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features
cap_mask=gemma2_attn_mask.to(
dtype=torch.int32
Expand Down
2 changes: 1 addition & 1 deletion lumina_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def call_dit(img, gemma2_hidden_states, gemma2_attn_mask, timesteps):
# NextDiT forward expects (x, t, cap_feats, cap_mask)
model_pred = dit(
x=img, # image latents (B, C, H, W)
t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期
t= 1 - timesteps / 1000, # timesteps需要除以1000来匹配模型预期
cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features
cap_mask=gemma2_attn_mask.to(dtype=torch.int32), # Gemma2的attention mask
)
Expand Down