diff --git a/torchao/prototype/moe_training/conversion_utils.py b/torchao/prototype/moe_training/conversion_utils.py index 72056e68b3..2da8186f2d 100644 --- a/torchao/prototype/moe_training/conversion_utils.py +++ b/torchao/prototype/moe_training/conversion_utils.py @@ -84,7 +84,7 @@ def _swap_params( f"Does not support a root nn.Parameter with children: {module}" ) if not isinstance(module.data, ScaledGroupedMMTensor): - new_data = ScaledGroupedMMTensor(module.data, module.data.dtype) + new_data = ScaledGroupedMMTensor(module.data) return nn.Parameter(new_data, requires_grad=module.requires_grad) return module @@ -110,7 +110,7 @@ def post_order_traversal( for param_name, param in module.named_parameters(recurse=False): if not isinstance(param.data, ScaledGroupedMMTensor): new_param = nn.Parameter( - ScaledGroupedMMTensor(param.data, param.data.dtype), + ScaledGroupedMMTensor(param.data), requires_grad=param.requires_grad, ) setattr(module, param_name, new_param) diff --git a/torchao/prototype/moe_training/scaled_grouped_mm.py b/torchao/prototype/moe_training/scaled_grouped_mm.py index 29adffd831..5a08074d5d 100644 --- a/torchao/prototype/moe_training/scaled_grouped_mm.py +++ b/torchao/prototype/moe_training/scaled_grouped_mm.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +import logging from typing import Optional import torch @@ -18,6 +19,8 @@ _is_column_major, ) +logger: logging.Logger = logging.getLogger(__name__) + def _scaled_grouped_mm( A: torch.Tensor, @@ -36,8 +39,8 @@ def _scaled_grouped_mm( and in column-major memory layout. offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group along dim0 of the A tensor. out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported. - use_triton_for_per_group_scales (bool): Whether to use custom triton kernels to compute per-group scales. Default is True. """ + # logger.info("Using scaled_grouped_mm") return _Float8GroupedMM.apply( A, B_t, diff --git a/torchao/prototype/moe_training/tensor.py b/torchao/prototype/moe_training/tensor.py index b41527a4ae..ddcc84f515 100644 --- a/torchao/prototype/moe_training/tensor.py +++ b/torchao/prototype/moe_training/tensor.py @@ -9,13 +9,15 @@ import torch import torch.utils._pytree as pytree +from torch import nn from torch._prims_common import suggest_memory_format +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.fsdp import MixedPrecisionPolicy from torchao.prototype.moe_training import _scaled_grouped_mm logger: logging.Logger = logging.getLogger(__name__) - _ops_to_preserve_subclass = { torch.ops.aten.empty_like.default, torch.ops.aten.new_zeros.default, @@ -44,7 +46,6 @@ class ScaledGroupedMMTensor(torch.Tensor): def __new__( cls, tensor: torch.Tensor, - dtype: torch.dtype, ): return torch.Tensor._make_wrapper_subclass( cls, @@ -52,7 +53,7 @@ def __new__( strides=tensor.stride(), storage_offset=tensor.storage_offset(), memory_format=suggest_memory_format(tensor), - dtype=dtype, + dtype=tensor.dtype, layout=tensor.layout, device=tensor.device, pin_memory=tensor.is_pinned(), @@ -62,14 +63,11 @@ def __new__( def __init__( self, tensor: torch.Tensor, - dtype: torch.dtype, ): self._data = tensor - self._dtype = dtype @classmethod def __torch_function__(cls, func, types, args, kwargs={}): - logger.info(f"{func.__name__}, args: {args}, kwargs: {kwargs}") # override the grouped mm op to use the differentiable _scaled_grouped_mm if func.__name__ == cls.grouped_mm_func_name: # Use torchao scaled grouped mm with dynamic quant for @@ -98,19 +96,10 @@ def __torch_function__(cls, func, types, args, kwargs={}): def __torch_dispatch__(cls, func, types, args, kwargs={}): # detach is special case if func == torch.ops.aten.detach.default: - return ScaledGroupedMMTensor(args[0]._data, args[0]._dtype) - - # unwrap args and kwargs - dtype: Optional[torch.dtype] = None - - def unwrap(t): - nonlocal dtype - if dtype is None: - dtype = t._dtype - else: - assert t._dtype == dtype - return t._data + return ScaledGroupedMMTensor(args[0]._data) + # unwrap args/kwargs + unwrap = lambda x: x._data if isinstance(x, ScaledGroupedMMTensor) else x args, kwargs = pytree.tree_map_only( ScaledGroupedMMTensor, unwrap, (args, kwargs or {}) ) @@ -125,25 +114,33 @@ def unwrap(t): # wrap outputs back into ScaledGroupedMMTensor for ops that do preserve subclass return pytree.tree_map_only( torch.Tensor, - lambda x: ScaledGroupedMMTensor(x, dtype), + lambda x: ScaledGroupedMMTensor(x), out, ) def __repr__(self): - return f"ScaledGroupedMMTensor(data={self._data}, dtype={self._dtype})" + return f"ScaledGroupedMMTensor(data={self._data})" def __tensor_flatten__(self): - return ["_data"], {"_dtype": self._dtype} + return ["_data"] @staticmethod def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): return ScaledGroupedMMTensor( inner_tensors["_data"], - flatten_spec["_dtype"], ) - def fsdp_pre_all_gather(self, mesh): - all_gather_inputs = (self._data,) + # fsdp hooks based on https://github.com/pytorch/pytorch/blob/20e40492b046b9287726d3ec656117e4dc38f0e2/test/distributed/_composable/fsdp/test_fully_shard_extensions.py#L81 + def fsdp_pre_all_gather( + self, + mesh: DeviceMesh, + outer_size: torch.Size, + outer_stride: tuple[int, ...], + module: nn.Module, + mp_policy: MixedPrecisionPolicy, + ): + # cast to mixed precision dtype prior to all-gather + all_gather_inputs = (self._data.to(mp_policy.param_dtype),) all_gather_metadata = () return all_gather_inputs, all_gather_metadata @@ -156,6 +153,25 @@ def fsdp_post_all_gather( out: Optional[torch.Tensor] = None, ): (data,) = all_gather_outputs - output = ScaledGroupedMMTensor(data, param_dtype) + + # For training step 1+, out=unsharded param, so we need to copy data to `out` + # if `self._data`` and `out` do not share the same storage. + # Otherwise, if they do share the same storage, we can just return directly. + if out is not None: + assert isinstance(out, ScaledGroupedMMTensor), f"{type(out)}" + if data.dtype == param_dtype: + assert ( + data.untyped_storage().data_ptr() + == out._data.untyped_storage().data_ptr() + ) + else: + assert out._data.dtype == param_dtype, ( + f"{out._data.dtype} {param_dtype}" + ) + out._data.copy_(data) + return + + # For training step 0, out=None, so we need to return a new ScaledGroupedMMTensor. + output = ScaledGroupedMMTensor(data) inner_tensors = (data,) return output, inner_tensors