Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support store_param_remainders feature from Apex in TE Fused Adam #1443

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion tests/pytorch/test_fused_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ def gen_precision_aware_test(
grad_dtype,
exp_avg_dtype,
exp_avg_sq_dtype,
store_param_remainders=False,
model_rtol=None,
model_atol=None,
master_rtol=None,
Expand Down Expand Up @@ -221,6 +222,7 @@ def gen_precision_aware_test(
"weight_decay": 0,
"amsgrad": False,
}

ref_optim = torch.optim.Adam(ref_params, **options)
tst_optim = te.optimizers.FusedAdam(
model_params,
Expand All @@ -229,6 +231,7 @@ def gen_precision_aware_test(
exp_avg_dtype=exp_avg_dtype,
exp_avg_sq_dtype=exp_avg_sq_dtype,
use_decoupled_grad=True,
store_param_remainders=store_param_remainders,
**options,
)

Expand All @@ -238,7 +241,7 @@ def test_one_iteration(ref_optimizer, tst_optimizer):
p.decoupled_grad = p_ref.grad.clone().to(grad_dtype)
ref_optimizer.step()
tst_optimizer.step()
if use_master_weights:
if use_master_weights and not store_param_remainders:
master_weights_to_fp32 = [
tst_optim.get_unscaled_state(p, "master_param") for p in model_params
]
Expand Down Expand Up @@ -271,6 +274,7 @@ def test_one_iteration(ref_optimizer, tst_optimizer):
exp_avg_dtype=exp_avg_dtype,
exp_avg_sq_dtype=exp_avg_sq_dtype,
use_decoupled_grad=True,
store_param_remainders=store_param_remainders,
**options,
)
tst_optim.load_state_dict(state_dict)
Expand Down Expand Up @@ -301,6 +305,19 @@ def test_fp32_master(self):
exp_avg_sq_dtype=torch.float32,
)

@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
def test_fp32_master_store_param_remainders(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.float32,
store_param_remainders=True,
)

@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
def test_fp16_master(self):
self.gen_precision_aware_test(
Expand Down
6 changes: 6 additions & 0 deletions transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,12 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
const int step, const int mode, const int bias_correction,
const float weight_decay);

void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode,
const int bias_correction, const float weight_decay);

void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, const float lr,
const float beta1, const float beta2, const float epsilon,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,122 @@ struct AdamFunctorMaster {
}
};

template <typename GRAD_T, typename FULL_T, typename index_t>
struct AdamFunctorMasterParamRemainder {
__device__ __forceinline__ void operator()(index_t chunk_size, volatile int *noop_gmem,
TensorListMetadata<5> &tl, // NOLINT(*)
const float beta1, const float beta2,
const float beta1_correction,
const float beta2_correction, const float epsilon,
const float lr, adamMode_t mode, const float decay) {
index_t tensor_loc = tl.block_to_tensor[blockIdx.x];

index_t chunk_idx = tl.block_to_chunk[blockIdx.x];
index_t n = tl.sizes[tensor_loc];

GRAD_T *g = reinterpret_cast<GRAD_T *>(tl.addresses[0][tensor_loc]);
g += chunk_idx * chunk_size;

int16_t *p = reinterpret_cast<int16_t *>(tl.addresses[1][tensor_loc]);
p += chunk_idx * chunk_size;

FULL_T *m = reinterpret_cast<FULL_T *>(tl.addresses[2][tensor_loc]);
m += chunk_idx * chunk_size;

FULL_T *v = reinterpret_cast<FULL_T *>(tl.addresses[3][tensor_loc]);
v += chunk_idx * chunk_size;

int16_t *p_remainder = reinterpret_cast<int16_t *>(tl.addresses[4][tensor_loc]);
p_remainder += chunk_idx * chunk_size;

n -= chunk_idx * chunk_size;

// see note in multi_tensor_scale_kernel.cu
for (index_t i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {
union fp32_or_int162 {
float fp32;
int16_t int16[2];
};
fp32_or_int162 local_master_param[ILP];
int16_t local_p[ILP];
int16_t local_p_rem[ILP];
MATH_T r_g[ILP];
MATH_T r_m[ILP];
MATH_T r_v[ILP];
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
r_g[ii] = static_cast<MATH_T>(g[i]);
r_m[ii] = static_cast<MATH_T>(m[i]);
r_v[ii] = static_cast<MATH_T>(v[i]);

local_p[ii] = static_cast<int16_t>(p[i]);
local_p_rem[ii] = static_cast<int16_t>(p_remainder[i]);
} else {
r_g[ii] = MATH_T(0);
r_m[ii] = MATH_T(0);
r_v[ii] = MATH_T(0);

local_p[ii] = int16_t(0);
local_p_rem[ii] = int16_t(0);
}
}
// Reconstruct FP32 params
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
if (local_p_rem[ii] < 0) local_p[ii]--; // Undo rounding
local_master_param[ii].int16[1] = local_p[ii];
local_master_param[ii].int16[0] = local_p_rem[ii];
}

MATH_T *r_p = reinterpret_cast<MATH_T *>(local_master_param);

#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
if (mode == ADAM_MODE_0) { // L2
r_g[ii] = r_g[ii] + (decay * r_p[ii]);
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
MATH_T update = next_m_unbiased / denom;
r_p[ii] = r_p[ii] - (lr * update);
} else { // weight decay
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]);
r_p[ii] = r_p[ii] - (lr * update);
}
}

// Split into BF16 params (rounded-to-nearest) and remainders
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
local_p[ii] = local_master_param[ii].int16[1];
local_p_rem[ii] = local_master_param[ii].int16[0];
if (local_p_rem[ii] < 0) local_p[ii]++; // Round up
}

#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
p_remainder[i] = static_cast<int16_t>(local_p_rem[ii]);
p[i] = static_cast<int16_t>(local_p[ii]);

m[i] = static_cast<FULL_T>(r_m[ii]);
v[i] = static_cast<FULL_T>(r_v[ii]);
}
}
}
}
};

template <typename PARAM_T, typename GRAD_T, typename FULL_T, typename index_t>
struct AdamFunctor {
__device__ __forceinline__ void operator()(index_t chunk_size, volatile int *noop_gmem,
Expand Down Expand Up @@ -548,6 +664,42 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
AT_CUDA_CHECK(cudaGetLastError());
}

void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode,
const int bias_correction, const float weight_decay) {
using namespace at;

// Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
if (bias_correction == 1) {
bias_correction1 = 1 - std::pow(beta1, step);
bias_correction2 = 1 - std::pow(beta2, step);
}

const auto g_in_type = tensor_lists[0][0].scalar_type();
const auto p_in_type = tensor_lists[1][0].scalar_type();
auto tl_size = tensor_lists.size();

// case 5: g, p, m, v, p_master
TORCH_CHECK(tl_size == 5, "tensor list must contain 5");
TORCH_CHECK(p_in_type == at::ScalarType::BFloat16,
"Adam with BF16 param remainders requires BF16 params");

// g, p, m, v, p_master
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
p_in_type, 0, "adam",
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
g_in_type, 1, "adam",
multi_tensor_apply<5>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists,
AdamFunctorMasterParamRemainder<scalar_t_1, float, int64_t>(),
beta1, beta2, bias_correction1, bias_correction2, epsilon, lr,
(adamMode_t)mode, weight_decay);));

AT_CUDA_CHECK(cudaGetLastError());
}

void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, const float lr,
const float beta1, const float beta2, const float epsilon,
Expand Down
4 changes: 4 additions & 0 deletions transformer_engine/pytorch/csrc/extensions/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("multi_tensor_adam", &multi_tensor_adam_cuda,
"Compute and apply gradient update to parameters for Adam optimizer",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_adam_param_remainder", &multi_tensor_adam_param_remainder_cuda,
"Compute and apply gradient update to parameters for Adam optimizer"
"where the master parameters only store the remainder bits",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_adam_fp8", &multi_tensor_adam_fp8_cuda,
"Compute and apply gradient update to parameters for Adam optimizer",
py::call_guard<py::gil_scoped_release>());
Expand Down
Loading