Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 73 additions & 14 deletions torchrec_dlrm/dlrm_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@
from torchrec.optim.optimizers import in_backward_optimizer_filter
from tqdm import tqdm

try:
from distributed_shampoo import DistributedShampoo, SGDPreconditionerConfig
except ImportError:
pass

# OSS import
try:
# pyre-ignore[21]
Expand Down Expand Up @@ -80,6 +85,12 @@ def parse_args(argv: list[str]) -> argparse.Namespace:
default=1,
help="number of epochs to train",
)
parser.add_argument(
"--precondition_frequency",
type=int,
default=100,
help="number of steps before running preconditioner",
)
parser.add_argument(
"--batch_size",
type=int,
Expand Down Expand Up @@ -263,6 +274,16 @@ def parse_args(argv: list[str]) -> argparse.Namespace:
action="store_true",
help="Flag to determine if adagrad optimizer should be used.",
)
parser.add_argument(
"--shampoo_embedding",
action="store_true",
help="Use DistributedShampoo optimizer.",
)
parser.add_argument(
"--shampoo_dense",
action="store_true",
help="Use DistributedShampoo optimizer.",
)
parser.add_argument(
"--interaction_type",
type=InteractionType,
Expand Down Expand Up @@ -491,8 +512,8 @@ def train_val_test(
args.limit_train_batches,
args.limit_val_batches,
)
val_auroc = _evaluate(args.limit_val_batches, pipeline, val_dataloader, "val")
results.val_aurocs.append(val_auroc)
#val_auroc = _evaluate(args.limit_val_batches, pipeline, val_dataloader, "val")
results.val_aurocs.append(0.0)

test_auroc = _evaluate(args.limit_test_batches, pipeline, test_dataloader, "test")
results.test_auroc = test_auroc
Expand Down Expand Up @@ -635,22 +656,23 @@ def main(argv: list[str]) -> None:
)

train_model = DLRMTrain(dlrm_model)
embedding_optimizer = torch.optim.Adagrad if args.adagrad else torch.optim.SGD
# embedding_optimizer = torch.optim.Adagrad if args.adagrad else torch.optim.SGD
# This will apply the Adagrad optimizer in the backward pass for the embeddings (sparse_arch). This means that
# the optimizer update will be applied in the backward pass, in this case through a fused op.
# TorchRec will use the FBGEMM implementation of EXACT_ADAGRAD. For GPU devices, a fused CUDA kernel is invoked. For CPU, FBGEMM_GPU invokes CPU kernels
# https://github.com/pytorch/FBGEMM/blob/2cb8b0dff3e67f9a009c4299defbd6b99cc12b8f/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py#L676-L678

# Note that lr_decay, weight_decay and initial_accumulator_value for Adagrad optimizer in FBGEMM v0.3.2
# cannot be specified below. This equivalently means that all these parameters are hardcoded to zero.
optimizer_kwargs = {"lr": args.learning_rate}
if args.adagrad:
optimizer_kwargs["eps"] = args.eps
apply_optimizer_in_backward(
embedding_optimizer,
train_model.model.sparse_arch.parameters(),
optimizer_kwargs,
)
# optimizer_kwargs = {"lr": args.learning_rate}
# if args.adagrad:
# optimizer_kwargs["eps"] = args.eps

# apply_optimizer_in_backward(
# embedding_optimizer,
# train_model.model.sparse_arch.parameters(),
# optimizer_kwargs,
# )
planner = EmbeddingShardingPlanner(
topology=Topology(
local_world_size=get_local_size(),
Expand All @@ -660,7 +682,7 @@ def main(argv: list[str]) -> None:
batch_size=args.batch_size,
# If experience OOM, increase the percentage. see
# https://pytorch.org/torchrec/torchrec.distributed.planner.html#torchrec.distributed.planner.storage_reservations.HeuristicalStorageReservation
storage_reservation=HeuristicalStorageReservation(percentage=0.05),
storage_reservation=HeuristicalStorageReservation(percentage=0.2),
)
plan = planner.collective_plan(
train_model, get_default_sharders(), dist.GroupMember.WORLD
Expand All @@ -678,18 +700,55 @@ def main(argv: list[str]) -> None:
print(table_name, "\n", plan, "\n")

def optimizer_with_params():
if args.adagrad:
if args.shampoo_dense:
return lambda params: DistributedShampoo(
params,
lr=0.001,
betas=(0., 0.999),
epsilon=1e-12,
momentum=0.9,
weight_decay=1e-05,
max_preconditioner_dim=8192,
precondition_frequency=args.precondition_frequency,
grafting_config=SGDPreconditionerConfig(),
)
elif args.adagrad:
return lambda params: torch.optim.Adagrad(
params, lr=args.learning_rate, eps=args.eps
)
else:
return lambda params: torch.optim.SGD(params, lr=args.learning_rate)

def embedding_optimizer_with_params():
if args.shampoo_embedding:
return lambda params: DistributedShampoo(
params,
lr=args.learning_rate,
betas=(0., 0.999),
epsilon=args.eps,
momentum=0.9,
weight_decay=1e-05,
max_preconditioner_dim=8192,
precondition_frequency=args.precondition_frequency,
grafting_config=SGDPreconditionerConfig(),
)
elif args.adagrad:
return lambda params: torch.optim.Adagrad(
params, lr=args.learning_rate, eps=args.eps
)
else:
return lambda params: torch.optim.SGD(params, lr=args.learning_rate)

embedding_optimizer = KeyedOptimizerWrapper(
dict(in_backward_optimizer_filter(model.named_parameters(), include=True)),
embedding_optimizer_with_params(),
)

dense_optimizer = KeyedOptimizerWrapper(
dict(in_backward_optimizer_filter(model.named_parameters())),
optimizer_with_params(),
)
optimizer = CombinedOptimizer([model.fused_optimizer, dense_optimizer])
optimizer = CombinedOptimizer([embedding_optimizer, dense_optimizer])
lr_scheduler = LRPolicyScheduler(
optimizer, args.lr_warmup_steps, args.lr_decay_start, args.lr_decay_steps
)
Expand Down