From eeb004408d913723fcbaab92b0d6a8758c24c199 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 16 May 2024 09:04:28 -0700 Subject: [PATCH 001/244] Calibration fix Signed-off-by: Pawel Gadzinski --- qa/L0_pytorch_unittest/test.sh | 1 + tests/pytorch/test_torch_save_load.py | 37 +++++++++++++++++++++-- transformer_engine/pytorch/module/base.py | 19 ++++++++++++ 3 files changed, 54 insertions(+), 3 deletions(-) diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 2c14664dce..2aa58e6018 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -17,3 +17,4 @@ NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_a pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py +pytest -v -s $TE_PATH/tests/pytorch/test_torch_save_load.py \ No newline at end of file diff --git a/tests/pytorch/test_torch_save_load.py b/tests/pytorch/test_torch_save_load.py index 85ec7685b3..211030fe6d 100644 --- a/tests/pytorch/test_torch_save_load.py +++ b/tests/pytorch/test_torch_save_load.py @@ -65,6 +65,9 @@ def __init__(self, precision, use_bias): self.inp_type = tex.DType.kFloat8E4M3 self.weights_type = tex.DType.kFloat8E4M3 self.outp_type = precision + + def get_fp8_weights_scratchpad(self, is_first_microbatch): + raise RuntimeError("Method get_fp8_weights_scratchpad is dummy and should not be invoked.") def forward(self, inp, weight): inp_fp8 = cast_to_fp8( @@ -145,14 +148,11 @@ def test_fp8_model_checkpoint( params_dtype=dtype, device=device, ) - # Keep track of model output x = torch.randn(dims, dtype=dtype, device=device) with te.fp8_autocast(): y_ref = model(x.detach().clone()).detach().clone() - # Keep track of weights and FP8 scaling factors - weight_ref = model.weight.float().detach().clone() fp8_meta_ref = { "scaling_fwd": {}, "scaling_bwd": {} } with te.fp8_autocast(), torch.no_grad(): fp8_meta_fwd = model.fp8_meta["scaling_fwd"] @@ -168,6 +168,18 @@ def test_fp8_model_checkpoint( fp8_meta_bwd.scale.copy_(fp8_meta_bwd_ref["scale"]) fp8_meta_bwd.scale_inv.copy_(fp8_meta_bwd_ref["scale_inv"]) del fp8_meta_fwd, fp8_meta_bwd + + # [ This is part of logic that tests save_fp8_model=False and load_fp8_model=True ] + # This line copies the fp8 scale_inv from the model metadata to the weight fp8 tensor. + # The sole purpose of the following lines is to set the scale_inv of the weight tensor, which is the simplest method. + # It is essential for these values to be equal, so setting scale_inv only in the model metadata is insufficient. + model.weight.data.copy_(model.weight.float().cuda()) + # After copying, the tensor computes the meta scale_inv based on the amax history; we then reset these values. + model.fp8_meta["scaling_fwd"].scale = fp8_meta_fwd_ref["scale"] + model.fp8_meta["scaling_fwd"].scale_inv = fp8_meta_fwd_ref["scale_inv"] + + # Keep track of weights and FP8 scaling factors + weight_ref = model.weight.float().detach().clone() # Save checkpoint byte_stream = io.BytesIO() @@ -214,6 +226,18 @@ def test_fp8_model_checkpoint( with pytest.raises(AssertionError): torch.testing.assert_close(y, y_ref, **tols) + + # [ This is part of logic that tests save_fp8_model=False and load_fp8_model=True ] + # When save_fp8_model=True, we load a model with weights in high precision, + # which does not include _scale_inv, + # but has the fp8 scaling factor in the meta data. This scenario can occur + # when using te.fp8_autocast(enabled=False, calibrating=True). + # + # In such cases, the default behavior of load_state_dict is incorrect - it loads tensors first, + # followed by the fp8 metadata. This results in an incorrect _scale_inv for the tensor. This behavior + # is corrected by overriding the _load_state_dict method from PyTorch in TransformerEngineBaseModule, + # to load the fp8 metadata before loading tensors. + # # Load checkpoint model.load_state_dict(torch.load(io.BytesIO(model_bytes))) del model_bytes @@ -232,3 +256,10 @@ def test_fp8_model_checkpoint( with te.fp8_autocast(): y = model(x.detach().clone()) torch.testing.assert_close(y, y_ref, **tols) + + if load_fp8_model: + # [ This is part of logic that tests save_fp8_model=False and load_fp8_model=True ] + # We need to ensure that the tensor's scale_inv parameter matches its meta data. + # This is crucial to avoid confusion about which value is correct. + meta_index = model.weight._fp8_meta_index + torch.testing.assert_close(model.weight._scale_inv.item(), fp8_meta_fwd_ref["scale_inv"][meta_index].item()) \ No newline at end of file diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 0803b474f6..7cfcf4b6d5 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -858,3 +858,22 @@ def get_fp8_weights_scratchpad( is_first_microbatch: Union[bool, None], ) -> List[torch.Tensor]: """Needs override.""" + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + """ + This function loads tensors and extra state including fp8 metadata. + This metadata is essential for copying fp8 tensors, as the copy_ function + uses the scale_inv parameter from fp8_meta to set the correct scaling factor + for the new tensor. + Hence, this extra state must be loaded before the tensor copying process, + not after, as is typically done in _load_from_state_dict. + Tensors are copied into fp8 tensors only when self.primary_weights_in_fp8=True, + otherwise, this behavior is not required. + """ + if self.primary_weights_in_fp8: + extra_state_key = prefix + torch.nn.modules.module._EXTRA_STATE_KEY_SUFFIX + if extra_state_key in state_dict: + self.set_extra_state(state_dict[extra_state_key]) + super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs) \ No newline at end of file From 8605435dc51ee92f2c8c787455af641100a71f50 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 16 May 2024 10:05:39 -0700 Subject: [PATCH 002/244] Lint fix Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/module/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 7cfcf4b6d5..31011be897 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -876,4 +876,4 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, if extra_state_key in state_dict: self.set_extra_state(state_dict[extra_state_key]) super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, - missing_keys, unexpected_keys, error_msgs) \ No newline at end of file + missing_keys, unexpected_keys, error_msgs) From 953d2a9aea09bdbc121675db178be9b49972df24 Mon Sep 17 00:00:00 2001 From: Sangkug Lym Date: Tue, 2 Apr 2024 20:18:44 -0700 Subject: [PATCH 003/244] Do not store input activations when not computing weight gradients (#739) * Do not store input activations when not computing weight gradients Signed-off-by: Sangkug Lym * fix userbuffer tp comm overlap case Signed-off-by: Sangkug Lym --------- Signed-off-by: Sangkug Lym Co-authored-by: Kirthi Shankar Sivamani Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/module/layernorm_linear.py | 4 ++-- transformer_engine/pytorch/module/layernorm_mlp.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 551c070eb9..18777cc9e3 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -292,7 +292,7 @@ def forward( weight, weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None, weight_t_fp8, - ln_out, + ln_out if weight.requires_grad else None, fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, ) @@ -369,7 +369,7 @@ def backward( if ctx.ub_bulk_dgrad: tp_world_size = get_distributed_world_size(ctx.tp_group) - if tp_world_size == 1: + if tp_world_size == 1 or not weight.requires_grad: ctx.ub_bulk_dgrad = False if ctx.ub_bulk_dgrad: dim_size = list(ln_out.size()) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 979c3068f5..91683ea0a8 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -498,9 +498,9 @@ def forward( ln_weight, mu, rsigma, - ln_out, + ln_out if fc1_weight.requires_grad else None, fc1_out, - gelu_out, + gelu_out if fc2_weight.requires_grad else None, fc1_weight, fc1_weight.main_grad if (cpu_offloading and fuse_wgrad_accumulation) else None, fc1_weight_t_fp8, @@ -600,7 +600,7 @@ def backward( if ctx.ub_bulk_dgrad: tp_world_size = get_distributed_world_size(ctx.tp_group) - if tp_world_size == 1: + if tp_world_size == 1 or not fc1_weight.requires_grad: ctx.ub_bulk_dgrad = False if ctx.ub_bulk_dgrad: dim_size = list(ln_out.size()) From 95a5c22ee5a184954611668cf568b0317c5cf4c5 Mon Sep 17 00:00:00 2001 From: Sangkug Lym Date: Tue, 2 Apr 2024 22:30:47 -0700 Subject: [PATCH 004/244] Atomic gemm for TP-AR and TP-RS overlap with P2P exchanges (#732) * Atomic gemm for TP-AR and TP-RS overlap with P2P exchanges Signed-off-by: Sangkug Lym * FP8 reduction for atomic TP-RS with p2p exchange Signed-off-by: Sangkug Lym * Fix Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Sangkug Lym Signed-off-by: Kirthi Shankar Sivamani Co-authored-by: Kirthi Shankar Sivamani Signed-off-by: Pawel Gadzinski --- .../pytorch/cpp_extensions/gemm.py | 5 +- .../pytorch/csrc/comm_gemm_overlap.h | 168 ++++++++---------- .../pytorch/csrc/userbuffers/userbuffers.cu | 20 +++ .../pytorch/csrc/userbuffers/userbuffers.h | 1 + transformer_engine/pytorch/module/base.py | 36 +++- 5 files changed, 129 insertions(+), 101 deletions(-) diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index c270fef652..46ce244ce6 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -139,7 +139,10 @@ def fp8_gemm( extra_output_tensor is not None ), 'ATOMIC_GEMM_RS_P2P requires extra output tensor' args = tuple(args + (extra_output_tensor,)) - _ = fn(*args) + if ub_algo is not None and ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P: + out = fn(*args) + else: + _ = fn(*args) return out, gelu_input diff --git a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h index 817a3ef366..4e3daf7512 100644 --- a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h +++ b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h @@ -623,26 +623,20 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { _ubuf_scale_inv_initialized = false; _atomic_gemm = atomic_gemm; + _self_chunk_id = _tp_id; if (_atomic_gemm) { auto counter_options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA); counter = torch::zeros({tp_size * 2}, counter_options); counter.index_put_({Slice(None, tp_size)}, 1); - _self_chunk_id = _tp_id; if (!is_reduce_scatter) { - const char *env_p = std::getenv("NVTE_AG_P2P_ATOMIC"); + const char *env_p = std::getenv("NVTE_AG_P2P_MULTI_ATOMIC"); if (rank == 0 && env_p != nullptr) { if (env_p[0] == '1') { - printf("!!userbuffers_sendrecv_atomic\n"); - } else if (env_p[0] == '2') { - printf("!!userbuffers_sendrecv_multiatomic\n"); - } else if (env_p[0] == '3') { - printf("!!userbuffers_sendrecv_multiatomic_shuffle\n"); - _self_chunk_id = 0; - } else { - printf("!!userbuffers_sendrecv\n"); + printf("!!userbuffers_sendrecv_multi_atomic_shuffle\n"); } } + _self_chunk_id = 0; counter.index_put_({_self_chunk_id}, 0); } } @@ -675,13 +669,17 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { // Get GEMM dimensions between TN and NN input layouts const int m = (transa) ? A.size(0) : A.size(1); const int k = (transa) ? A.size(1) : A.size(0); - const int n_chunk = _ubufs[0].size(0); + const int n = _ubuf.size(0); + const int n_chunk = n / _tp_size; // Get communication and GEMM output chunk sizes const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); + // Create an GEMM output buffer with N+1 chunks in a contiguous memory + torch::Tensor D_buffer = torch::empty({n_chunk * (_tp_size + 1), m}, D.options()); + D = torch::from_blob(D_buffer.data_ptr(), {D.size(0), D.size(1)}, D.options()); + // Get output and workspace data pointers - char *output_ptr = reinterpret_cast(D.data_ptr()); char *workspace_ptr = reinterpret_cast(workspace.data_ptr()); int *counter_ptr = reinterpret_cast(counter.data_ptr()); int workspace_size_chunk = workspaceSize / _stream_compute.size(); @@ -692,100 +690,75 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor]; - at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); - CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main)); - assert(pre_gelu_out.numel() == 0); + // Catch up the default torch stream + at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); + CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main)); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0)); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0)); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _start_compute, 0)); - torch::Tensor output_chunk = torch::from_blob(output_ptr, {_ubuf.size(0), m}, D.options()); torch::Tensor workspace_chunk = torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options()); - for (int i = 0; i < _tp_size; i++) { + + for (int i = 0; i < _tp_size - 1; i++) { // Set the userbuffer id. Buffer under send is the input for the current // GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to // have the AG output in all ranks to be contiguous after the ring // exchanges - int send_chunk_id = (_tp_size + _tp_id - i) % _tp_size; - int recv_chunk_id = (_tp_size + _tp_id - i - 1) % _tp_size; + int send_chunk_id = i; + int recv_chunk_id = i + 1; int send_offset = comm_bytes * send_chunk_id; int recv_offset = comm_bytes * recv_chunk_id; - if (i < _tp_size - 1) { - const char *env_p = std::getenv("NVTE_AG_P2P_ATOMIC"); - if (env_p != nullptr && env_p[0] == '1') { - userbuffers_sendrecv_atomic(_ub_reg, _ub_reg, send_offset, recv_offset, comm_bytes, - _ub_comm, _next_rank, _prev_rank, &counter_ptr[recv_chunk_id], - (cudaStream_t)_stream_recv); - } else if (env_p != nullptr && env_p[0] == '2') { - if (i == 0) { - userbuffers_sendrecv_multiatomic(_ub_reg, _ub_reg, comm_bytes, comm_bytes, comm_bytes, - _ub_comm, _next_rank, _prev_rank, _tp_size, - counter_ptr, false, (cudaStream_t)_stream_recv); - } - } else if (env_p != nullptr && env_p[0] == '3') { - if (i == 0) { - userbuffers_sendrecv_multiatomic(_ub_reg, _ub_reg, comm_bytes, comm_bytes, comm_bytes, - _ub_comm, _next_rank, _prev_rank, _tp_size, - counter_ptr, true, (cudaStream_t)_stream_recv); - } - } else { - // P2P communication - // userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, - // comm_bytes, _ub_comm, - // _next_rank, (cudaStream_t)_stream_send); - // userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, - // comm_bytes, _ub_comm, - // _prev_rank, (cudaStream_t)_stream_recv); - // CHECK_CUDA(cudaEventRecord(_stop_recv, - // (cudaStream_t)_stream_recv)); - // CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, - // _stop_recv, 0)); - userbuffers_sendrecv(_ub_reg, _ub_reg, send_offset, recv_offset, comm_bytes, _ub_comm, - _next_rank, _prev_rank, (cudaStream_t)_stream_recv); - producer(counter_ptr, recv_chunk_id, (cudaStream_t)_stream_recv); - } + const char *env_p = std::getenv("NVTE_AG_P2P_MULTI_ATOMIC"); + if (env_p != nullptr && env_p[0] == '1') { if (i == 0) { - at::cuda::setCurrentCUDAStream(_stream_compute[0]); - te_atomic_gemm(A, A_scale_inverse, A_type, transa, _ubuf, B_scale_inverse, B_type, transb, - output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, - workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, - _math_sms, 0, _tp_size, false, counter); + userbuffers_sendrecv_multiatomic(_ub_reg, _ub_reg, comm_bytes, comm_bytes, comm_bytes, + _ub_comm, _next_rank, _prev_rank, _tp_size, + counter_ptr, true, (cudaStream_t)_stream_recv); } } else { - // GEMM - // userbuffers_send_multiatomic(_ub_reg, 0, _ub_reg, 0, comm_bytes, - // _ub_comm, - // _next_rank, _tp_size, comm_bytes, comm_bytes, - // (cudaStream_t)_stream_send); - // userbuffers_recv_multiatomic(_ub_reg, 0, _ub_reg, 0, comm_bytes, - // _ub_comm, - // _prev_rank, _tp_size, counter_ptr, - // (cudaStream_t)_stream_recv); - if (B_copy.numel() > 0) { - assert(B_copy.numel() == _ubufs[_tp_id].numel()); - assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); - CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_tp_id].data_ptr(), - _ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(), - cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send)); - CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send)); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0)); - } + userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, + _ub_comm, _next_rank, (cudaStream_t) _stream_recv); + userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, + _ub_comm, _prev_rank, (cudaStream_t) _stream_recv); + producer(counter_ptr, recv_chunk_id, (cudaStream_t)_stream_recv); } - } - for (int i = 0; i < _tp_size; i++) { - if (i != _self_chunk_id) { - consumer(counter_ptr, i, (cudaStream_t)_stream_compute[0]); + if (i == 0) { + te_atomic_gemm(A, A_scale_inverse, A_type, transa, _ubuf, B_scale_inverse, B_type, transb, + D, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, + workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, + _math_sms, 0, _tp_size, false, counter); } } - at::cuda::setCurrentCUDAStream(stream_main); - CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[0])); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); - return D; + // Store the input activation for backprop + if (B_copy.numel() > 0) { + assert(B_copy.numel() == _ubufs[_self_chunk_id].numel()); + assert(B_copy.element_size() == _ubufs[_self_chunk_id].element_size()); + CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_self_chunk_id].data_ptr(), + _ubufs[_self_chunk_id].numel() * + _ubufs[_self_chunk_id].element_size(), + cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send)); + CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send)); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0)); + } + + // Reset atomic counters + consumer_batch(counter_ptr, 1, _tp_size, (cudaStream_t)stream_main); + + // Copy the first GEMM output chunk to the end chunk position of D_buffer + char *src_ptr = reinterpret_cast(D_buffer.data_ptr()); + CHECK_CUDA(cudaMemcpyAsync( + src_ptr + (D.numel() * D.element_size()), + src_ptr, + n_chunk * m * D.element_size(), + cudaMemcpyDeviceToDevice, + (cudaStream_t) stream_main)); + // Return the last N rows of D_buffer + torch::Tensor D_return = D_buffer.narrow(0, n_chunk, n); + return D_return; } // atomic_gemm_overlap_ag /* @@ -1018,6 +991,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0)); // Atomic GEMM + // Process GEMM chunks in the order that AG+GEMM places the output chunks. torch::Tensor workspace_chunk = torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options()); te_atomic_gemm(A, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb, @@ -1031,23 +1005,31 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { int recv_chunk_id = send_chunk_id + _tp_size; int send_offset = comm_bytes * send_chunk_id; int recv_offset = comm_bytes * recv_chunk_id; - int send_rank = (_tp_id + i) % _tp_size + _rank_round_tp; - int recv_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp; + int send_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp; + int recv_rank = (_tp_id + i) % _tp_size + _rank_round_tp; consumer(counter_ptr, send_chunk_id, (cudaStream_t)_stream_recv); userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, - _ub_comm, send_rank, (cudaStream_t) _stream_recv); + _ub_comm, send_rank, (cudaStream_t) _stream_recv); userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, - _ub_comm, recv_rank, (cudaStream_t) _stream_recv); + _ub_comm, recv_rank, (cudaStream_t) _stream_recv); } CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t) _stream_recv)); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t) stream_main, _stop_recv, 0)); // Reduce GEMM output chunks char *reduce_buf_ptr = reinterpret_cast(_ubufs[_tp_size - 1].data_ptr()); - torch::Tensor reduce_buf = torch::from_blob( - reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options()); - torch::sum_out(rs_output, reduce_buf, 0); + if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) { + assert(_ubuf_scale_inv_initialized); + float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); + char *rs_output_ptr = reinterpret_cast(rs_output.data_ptr()); + reduce_fp8_in_bf16_out<__nv_fp8_e4m3>(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr, + _tp_size, _ubufs[0].numel(), (cudaStream_t) stream_main); + } else { + torch::Tensor reduce_buf = torch::from_blob( + reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options()); + torch::sum_out(rs_output, reduce_buf, 0); + } } /* @@ -1174,7 +1156,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { if (_comm_type != COMM_TYPE::AG && _comm_type != COMM_TYPE::RS) NVTE_ERROR("Invalid comm_type"); if (_comm_type == COMM_TYPE::RS) - ubuf_wt_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size(); + ubuf_wt_ptr += _ubuf.numel() / _tp_size * _self_chunk_id * _ubuf.element_size(); int output_c_dim0 = (_comm_type == COMM_TYPE::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size; int output_c_dim1 = _ubuf.size(1); return torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf.options()); diff --git a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu b/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu index b572c5b273..ab03039b3d 100644 --- a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu +++ b/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu @@ -3671,6 +3671,20 @@ static __global__ void consumer_kernel(void *atomic_ptr, int chunk_i) { } } +// consumer_batch +static __global__ void consumer_batch_kernel(void *atomic_ptr, int first_chunk_i, int num_chunks) { + // Wait for producer to change the val to 0, which signal producer ready + if (blockIdx.x == 0 && threadIdx.x == 0) { + int old_val; + for (int i = first_chunk_i; i < num_chunks; i++) { + while (0 != (old_val = atomicCAS((unsigned int *)atomic_ptr + i, 0, 0))) { + } + ((unsigned int *)atomic_ptr)[i] = 1; + asm volatile("fence.sc.gpu;\n"); + } + } +} + void producer(void *atomic_ptr, int chunk_i, cudaStream_t stream) { dim3 block(1); dim3 grid(1); @@ -3683,6 +3697,12 @@ void consumer(void *atomic_ptr, int chunk_i, cudaStream_t stream) { consumer_kernel<<>>(atomic_ptr, chunk_i); } +void consumer_batch(void *atomic_ptr, int first_chunk_i, int num_chunks, cudaStream_t stream) { + dim3 block(1); + dim3 grid(1); + consumer_batch_kernel<<>>(atomic_ptr, first_chunk_i, num_chunks); +} + template __global__ void __launch_bounds__(MAX_THREADS / 4) reduce_fp8_in_bf16_out_cuda(void *inputs, void *output, const float *scale, diff --git a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.h b/transformer_engine/pytorch/csrc/userbuffers/userbuffers.h index 407f9479c3..1306636881 100644 --- a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.h +++ b/transformer_engine/pytorch/csrc/userbuffers/userbuffers.h @@ -151,6 +151,7 @@ typedef struct communicator communicator; void producer(void *atomic_ptr, int chunk_i, cudaStream_t stream); void consumer(void *atomic_ptr, int chunk_i, cudaStream_t stream); +void consumer_batch(void *atomic_ptr, int first_chunk_i, int num_chunks, cudaStream_t stream); int create_communicator(communicator **comm); /* creates communicator, allocates all internal buffers if necessary */ diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 59e5949e06..31e305cc15 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -45,6 +45,7 @@ _ub_communicators = None _NUM_MAX_UB_STREAMS = 3 _amax_reduce_handle_bwd = None +layers_atomic_ring_exchange = [] def get_cublas_workspace_size_bytes() -> None: @@ -138,6 +139,12 @@ def initialize_ub( } layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"] + # AG-RS overlap pairs of layers forming a tensor-parallel block + ag_rs_pairs = {"qkv_fprop":"proj_fprop", "fc1_fprop":"fc2_fprop"} + rs_ag_pairs = {v : k for k, v in ag_rs_pairs.items()} + global layers_atomic_ring_exchange + layers_atomic_ring_exchange = [] + def get_method(name): for method, names in methods.items(): if name in names: @@ -160,20 +167,35 @@ def add_ub( "Atomic GEMM uses a beta API from cublas and is not tested for all use cases." ) assert use_fp8, "Atomic GEMM overlap supported only for FP8 GEMM." - if is_reduce_scatter and method == "ring_exchange": - raise ValueError( - "Atomic GEMM is not supported for ReduceScatter with `ring_exchange` method." - ) if method == 'bulk': warnings.warn( - "Atoimic GEMM not is supported for a bulk overlap." + f"At {name}, atoimic GEMM not is supported for a bulk overlap." "Defaulting to `atomic_gemm=False`." ) atomic_gemm = 0 if not is_reduce_scatter and method == 'pipeline': raise ValueError( - "`pipeline` overlap method is not supported for AllGather." + f"At {name}, `pipeline` overlap method is not supported for AllGather." + ) + # Check if both AG and RS overlaps use `atomic GEMM`` + `p2p ring-exchange`. + # Using atomic GEMM + p2p ring-exchange in only one of the pair breaks functionality. + global layers_atomic_ring_exchange + if atomic_gemm and method == "ring_exchange" and name in ag_rs_pairs: + layers_atomic_ring_exchange += [name, ag_rs_pairs[name]] + if name in rs_ag_pairs: + assert_message = ( + f"At {name}, atomic AG-GEMM overlap with `ring_exchange` shuffles GEMM chunk " + "outputs, and RS-GEMM overlap un-suffle them. When one of the GEMM-AG and " + "GEMM-RS overlaps forming a TP block (e.g., qkv_fprop and proj_fprop) uses " + "`atomic gemm` and `ring_exhcnage`, its pair must use the same overlap config " + "for functionality." ) + if name in layers_atomic_ring_exchange: + assert atomic_gemm and method == "ring_exchange", assert_message + else: + if atomic_gemm and method == "ring_exchange": + assert rs_ag_pairs[name] in layers_atomic_ring_exchange, assert_message + sample_buffer = torch.empty( shape, dtype=torch.uint8 if (use_fp8 and name in fp8_buf) else dtype, @@ -213,7 +235,7 @@ def add_ub( method = ub_cfg["method"] if "method" in ub_cfg else get_method(name) num_sm = ub_cfg["num_sm"] if "num_sm" in ub_cfg else 16 cga_size = ub_cfg["cga_size"] if "cga_size" in ub_cfg else 2 - num_splits = ub_cfg["num_splits"] if "num_splits" in ub_cfg else 0 + num_splits = ub_cfg["num_splits"] if "num_splits" in ub_cfg else 4 set_sm_margin = ub_cfg["set_sm_margin"] if "set_sm_margin" in ub_cfg else 0 aggregate = ub_cfg["aggregate"] if "aggregate" in ub_cfg else 0 atomic_gemm = ub_cfg["atomic_gemm"] if "atomic_gemm" in ub_cfg else 0 From 797eb250a3ed920d6eb531d3d6db3e1fcc5f2ca2 Mon Sep 17 00:00:00 2001 From: Santosh Bhavani Date: Wed, 3 Apr 2024 00:48:37 -0500 Subject: [PATCH 005/244] Update README.rst (#733) * Update README.rst 1. Updated latest news with databricks blog 2. Fixed formatting issues 3. Added GTC 2024 video Signed-off-by: Santosh Bhavani * Update README.rst added back overview marker for docs generation Signed-off-by: Santosh Bhavani * Added MPT-13B convergence result Signed-off-by: Santosh Bhavani * Added Levanter/JAX to integrations section of README Signed-off-by: Santosh Bhavani * Update README.rst Signed-off-by: Kirthi Shankar Sivamani * Update README.rst Signed-off-by: Kirthi Shankar Sivamani * Update README.rst Signed-off-by: Kirthi Shankar Sivamani * Update README.rst Signed-off-by: Kirthi Shankar Sivamani * Update README.rst Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Santosh Bhavani Signed-off-by: Kirthi Shankar Sivamani Co-authored-by: Kirthi Shankar Sivamani Signed-off-by: Pawel Gadzinski --- README.rst | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/README.rst b/README.rst index de3a331d10..190f8fc57c 100644 --- a/README.rst +++ b/README.rst @@ -11,8 +11,9 @@ Transformer Engine `Quickstart <#examples>`_ | `Installation <#installation>`_ | `User Guide `_ | `Examples `_ | `FP8 Convergence <#fp8-convergence>`_ | `Integrations <#integrations>`_ | `Release notes `_ Latest News -================== +=========== +* [03/2024] `Turbocharged Training: Optimizing the Databricks Mosaic AI stack with FP8 `_ * [03/2024] `FP8 Training Support in SageMaker Model Parallelism Library `_ * [12/2023] `New NVIDIA NeMo Framework Features and NVIDIA H200 `_ @@ -28,7 +29,7 @@ Latest News * [04/2023] `Benchmarking Large Language Models on NVIDIA H100 GPUs with CoreWeave (Part 1) `_ What is Transformer Engine? -================== +=========================== .. overview-begin-marker-do-not-remove Transformer Engine (TE) is a library for accelerating Transformer models on NVIDIA GPUs, including @@ -55,7 +56,7 @@ Modules provided by TE internally maintain scaling factors and other values need simplifying mixed precision training for users. Highlights ----------- +========== * Easy-to-use modules for building Transformer layers with FP8 support * Optimizations (e.g. fused kernels) for Transformer models @@ -63,7 +64,7 @@ Highlights * Support for optimizations across all precisions (FP16, BF16) on NVIDIA Ampere GPU architecture generations and later Examples ----------- +======== PyTorch ^^^^^^^ @@ -142,7 +143,7 @@ Flax .. overview-end-marker-do-not-remove Installation ----------- +============ .. installation Pre-requisites @@ -188,7 +189,7 @@ It is a known issue that FlashAttention-2 compilation is resource-intensive and Note that NGC PyTorch 23.08+ containers include FlashAttention-2. FP8 Convergence -================== +=============== FP8 has been tested extensively across different model architectures and configurations and we found **no significant difference** between FP8 and BF16 training loss curves. FP8 has also been validated for accuracy on downstream LLM tasks (e.g. LAMBADA and WikiText). Below are examples of models tested for convergence across different frameworks. @@ -207,6 +208,8 @@ FP8 has been tested extensively across different model architectures and configu +------------+------------------+---------------------------------------------------------------------------------------------------------+ | T5-11B | JAX/T5x | Available on request | +------------+------------------+---------------------------------------------------------------------------------------------------------+ +| MPT-13B | Mosaic Composer | https://www.databricks.com/blog/turbocharged-training-optimizing-databricks-mosaic-ai-stack-fp8 | ++------------+------------------+---------------------------------------------------------------------------------------------------------+ | GPT-22B | NeMo Framework | Available on request | +------------+------------------+---------------------------------------------------------------------------------------------------------+ | LLama2-70B | Alibaba Pai | https://mp.weixin.qq.com/s/NQT0uKXLbXyh5031zBdeBQ | @@ -215,7 +218,7 @@ FP8 has been tested extensively across different model architectures and configu +------------+------------------+---------------------------------------------------------------------------------------------------------+ Integrations -================== +============ Transformer Engine has been integrated with popular LLM frameworks such as: @@ -227,19 +230,20 @@ Transformer Engine has been integrated with popular LLM frameworks such as: * `NVIDIA Megatron-LM `_ * `NVIDIA NeMo Framework `_ * `Amazon SageMaker Model Parallel Library `_ +* `Levanter `_ * `Colossal-AI `_ - Coming soon! * `PeriFlow `_ - Coming soon! * `GPT-NeoX `_ - Coming soon! Contributing -================== +============ We welcome contributions to Transformer Engine! To contribute to Transformer Engine and make pull requests, follow the guidelines outlined in the ``_ guide. Papers -================== +====== * `Attention original paper `_ * `Megatron-LM tensor parallel `_ @@ -247,10 +251,11 @@ Papers * `FP8 Formats for Deep Learning `_ Videos -================== +====== -* `FP8 Training with Transformer Engine `_ -* `FP8 for Deep Learning `_ +* `What's New in Transformer Engine and FP8 Training | GTC 2024 `_ +* `FP8 Training with Transformer Engine | GTC 2023 `_ +* `FP8 for Deep Learning | GTC 2023 `_ * `Inside the Hopper Architecture `_ .. |License| image:: https://img.shields.io/badge/License-Apache%202.0-blue.svg From 55d11779e6e5ac947763abebb035871f3500d060 Mon Sep 17 00:00:00 2001 From: "Pavel Shamis (Pasha)" Date: Wed, 3 Apr 2024 21:35:07 -0500 Subject: [PATCH 006/244] Fixing potential integer overflow on sequence counter (#729) * Fixing potential integer overflow on sequence counter Current implementation may potential cause hangs or data corruption Signed-off-by: Pasha (Pavel) Shamis * Fixing typo in comments Addressing reviewers comments Signed-off-by: Pasha (Pavel) Shamis --------- Signed-off-by: Pasha (Pavel) Shamis Co-authored-by: Kirthi Shankar Sivamani Signed-off-by: Pawel Gadzinski --- .../pytorch/csrc/userbuffers/userbuffers.cu | 66 ++++++++++--------- 1 file changed, 35 insertions(+), 31 deletions(-) diff --git a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu b/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu index ab03039b3d..bb62b55262 100644 --- a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu +++ b/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu @@ -51,6 +51,10 @@ asm volatile("fence.sc.gpu;\n"); \ } +// Return true if producer > consumer, otherwise false while preventing integer overflow +// If we expect that producer will be 2B+ messages behind consumer +#define CHECK_IDS(producer, consumer) (((unsigned)(producer) - (unsigned)(consumer)) & (~INT_MAX)) + template __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_rw(const int op, const int flagoffset, const int firstrank, @@ -74,7 +78,7 @@ __global__ void __launch_bounds__(MAX_THREADS) volatile int *flag = (volatile int *)&(myptr[targetgpu]); userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu + handleridx]); clock_t s = clock64(); - while (*flag < reduce_id) { + while (CHECK_IDS(*flag, reduce_id)) { if (clock64() - s > TIMEOUT) { printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, *flag); @@ -128,7 +132,7 @@ __global__ void __launch_bounds__(MAX_THREADS) flagptr[physgpu] = reduce_id; volatile int *flag = (volatile int *)&myptr[targetgpu]; clock_t s = clock64(); - while (*flag < reduce_id) { + while (CHECK_IDS(*flag, reduce_id)) { if (clock64() - s > 2ull * TIMEOUT) { printf("NVONLY AGBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, *flag); @@ -162,7 +166,7 @@ __global__ void __launch_bounds__(MAX_THREADS) volatile int *flag = (volatile int *)&(myptr[targetgpu]); userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu + handleridx]); clock_t s = clock64(); - while (*flag < reduce_id) { + while (CHECK_IDS(*flag, reduce_id)) { if (clock64() - s > TIMEOUT) { printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, *flag); @@ -211,7 +215,7 @@ __global__ void __launch_bounds__(MAX_THREADS) flagptr[physgpu] = reduce_id; volatile int *flag = (volatile int *)&myptr[targetgpu]; clock_t s = clock64(); - while (*flag < reduce_id) { + while (CHECK_IDS(*flag, reduce_id)) { if (clock64() - s > 2ull * TIMEOUT) { printf("NVONLY AGBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, *flag); @@ -273,7 +277,7 @@ __global__ void __launch_bounds__(MAX_THREADS) volatile int *flag = (volatile int *)&(myptr[targetgpu]); userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu + handleridx]); clock_t s = clock64(); - while (*flag < reduce_id) { + while (CHECK_IDS(*flag, reduce_id)) { if (clock64() - s > TIMEOUT) { printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, *flag); @@ -348,7 +352,7 @@ __global__ void __launch_bounds__(MAX_THREADS) volatile int *flag = (volatile int *)&(myptr[targetgpu]); userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu + handleridx]); clock_t s = clock64(); - while (*flag < reduce_id) { + while (CHECK_IDS(*flag, reduce_id)) { if (clock64() - s > TIMEOUT) { printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, *flag); @@ -422,7 +426,7 @@ __global__ void __launch_bounds__(MAX_THREADS) flagptr[physgpu] = reduce_id; volatile int *flag = (volatile int *)&(myptr[targetgpu]); clock_t s = clock64(); - while (*flag < reduce_id) { + while (CHECK_IDS(*flag, reduce_id)) { if (clock64() - s > TIMEOUT) { printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, *flag); @@ -490,7 +494,7 @@ __global__ void __launch_bounds__(MAX_THREADS) flagptr[physgpu] = reduce_id; volatile int *flag = (volatile int *)&myptr[targetgpu]; clock_t s = clock64(); - while (*flag < reduce_id) { + while (CHECK_IDS(*flag, reduce_id)) { if (clock64() - s > 2ull * TIMEOUT) { printf("NVONLY AGBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, *flag); @@ -525,7 +529,7 @@ __global__ void __launch_bounds__(MAX_THREADS) flagptr[physgpu] = reduce_id; volatile int *flag = (volatile int *)&(myptr[targetgpu]); clock_t s = clock64(); - while (*flag < reduce_id) { + while (CHECK_IDS(*flag, reduce_id)) { if (clock64() - s > TIMEOUT) { printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, *flag); @@ -610,7 +614,7 @@ __global__ void __launch_bounds__(MAX_THREADS) flagptr[physgpu] = reduce_id; volatile int *flag = (volatile int *)&(myptr[targetgpu]); clock_t s = clock64(); - while (*flag < reduce_id) { + while (CHECK_IDS(*flag, reduce_id)) { if (clock64() - s > TIMEOUT) { printf("[%d] NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", myrank, blockIdx.x, threadIdx.x, reduce_id, *flag); @@ -740,7 +744,7 @@ __global__ void __launch_bounds__(MAX_THREADS) flagptr[physgpu] = reduce_id; volatile int *flag = (volatile int *)&myptr[targetgpu]; clock_t s = clock64(); - while (*flag < reduce_id) { + while (CHECK_IDS(*flag, reduce_id)) { if (clock64() - s > 2ull * TIMEOUT) { printf("NVONLY AGBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, *flag); @@ -800,7 +804,7 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ volatile int *flag = (volatile int *)&(myptr[targetgpu]); userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu + handleridx]); clock_t s = clock64(); - while (*flag < reduce_id) { + while (CHECK_IDS(*flag, reduce_id)) { if (clock64() - s > TIMEOUT) { printf("[%d] NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", myrank, blockIdx.x, threadIdx.x, reduce_id, *flag); @@ -888,7 +892,7 @@ __global__ void __launch_bounds__(MAX_THREADS) volatile int *flag = (volatile int *)&(myptr[targetgpu]); userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu + handleridx]); clock_t s = clock64(); - while (*flag < reduce_id) { + while (CHECK_IDS(*flag, reduce_id)) { if (clock64() - s > TIMEOUT) { printf("[%d] NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", myrank, blockIdx.x, threadIdx.x, reduce_id, *flag); @@ -975,7 +979,7 @@ __global__ void __launch_bounds__(MAX_THREADS) volatile int *flag = (volatile int *)&(myptr[targetgpu]); userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu + handleridx]); clock_t s = clock64(); - while (*flag < reduce_id) { + while (CHECK_IDS(*flag, reduce_id)) { if (clock64() - s > TIMEOUT) { printf("[%d] NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", myrank, blockIdx.x, threadIdx.x, reduce_id, *flag); @@ -1072,7 +1076,7 @@ userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride_atomic_fp8( volatile int* flag = (volatile int*)&(myptr[targetgpu]); userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu+handleridx]); clock_t s = clock64(); - while (*flag < reduce_id) { + while (CHECK_IDS(*flag, reduce_id)) { if (clock64()-s > TIMEOUT) { printf("[%d] NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", myrank, blockIdx.x, threadIdx.x, reduce_id, *flag); @@ -1171,7 +1175,7 @@ __global__ void __launch_bounds__(MAX_THREADS) volatile int *flag = (volatile int *)&(myptr[targetgpu]); userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu + handleridx]); clock_t s = clock64(); - while (*flag < reduce_id) { + while (CHECK_IDS(*flag, reduce_id)) { if (clock64() - s > TIMEOUT) { printf("[%d] NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", myrank, blockIdx.x, threadIdx.x, reduce_id, *flag); @@ -1270,7 +1274,7 @@ __global__ void __launch_bounds__(MAX_THREADS) volatile int *flag = (volatile int *)&(myptr[targetgpu]); userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu + handleridx]); clock_t s = clock64(); - while (*flag < reduce_id) { + while (CHECK_IDS(*flag, reduce_id)) { if (clock64() - s > TIMEOUT) { printf("[%d] NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", myrank, blockIdx.x, threadIdx.x, reduce_id, *flag); @@ -1389,7 +1393,7 @@ __global__ void __launch_bounds__(MAX_THREADS) flagptr[physgpu] = reduce_id; volatile int *flag = (volatile int *)&myptr[targetgpu]; clock_t s = clock64(); - while (*flag < reduce_id) { + while (CHECK_IDS(*flag, reduce_id)) { if (clock64() - s > 2ull * TIMEOUT) { printf("NVONLY AGBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, *flag); @@ -1486,7 +1490,7 @@ __global__ void __launch_bounds__(MAX_THREADS) flagptr[physgpu] = reduce_id; volatile int *flag = (volatile int *)&myptr[targetgpu]; clock_t s = clock64(); - while (*flag < reduce_id) { + while (CHECK_IDS(*flag, reduce_id)) { if (clock64() - s > 2ull * TIMEOUT) { printf("NVONLY AGBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, *flag); @@ -1517,7 +1521,7 @@ __global__ void __launch_bounds__(MAX_THREADS) } volatile int *flag = (volatile int *)&((reinterpret_cast( commbuff[myrank + firstrank]))[flagoffset + threadIdx.x + firstrank]); - while (*flag < basecounter) { + while (CHECK_IDS(*flag, basecounter)) { } } __syncthreads(); @@ -1635,7 +1639,7 @@ __global__ void __launch_bounds__(MAX_THREADS) const int end_aligned = start_elem + aligned_elem; if (mythreadIdx == 0) { - while (*flag < gathercounter) { + while (CHECK_IDS(*flag, gathercounter)) { } gathercounter++; } @@ -1694,7 +1698,7 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ flagptr[flagoffset + gpustep * myrank + firstrank] = basecounter; } volatile int *flag = &localflag[gpustep * threadIdx.x + firstrank]; - while (*flag < basecounter) { + while (CHECK_IDS(*flag, basecounter)) { } } __syncthreads(); @@ -1864,7 +1868,7 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ const int end_aligned = start_elem + aligned_elem; if (mythreadIdx == 0) { - while (*flag < gathercounter) { + while (CHECK_IDS(*flag, gathercounter)) { } gathercounter++; } @@ -1908,7 +1912,7 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ flagptr[flagoffset + gpustep * myrank + firstrank] = basecounter; } volatile int *flag = &localflag[gpustep * threadIdx.x + firstrank]; - while (*flag < basecounter) { + while (CHECK_IDS(*flag, basecounter)) { } } __syncthreads(); @@ -2114,7 +2118,7 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ const int end_aligned = start_elem + aligned_elem; if (mythreadIdx == 0) { - while (*flag < gathercounter) { + while (CHECK_IDS(*flag, gathercounter)) { } gathercounter++; } @@ -3013,7 +3017,7 @@ __global__ void __launch_bounds__(MAX_THREADS) const int signal_id = (*recv_id) + 1; volatile int *flag = (volatile int *)recv_flagptr; clock_t s = clock64(); - while (*flag < signal_id) { + while (CHECK_IDS(*flag, signal_id)) { if (clock64() - s > TIMEOUT) { printf("[%d from %d] pullrecv: expected %d, stuck with %d\n", myrank, peer, signal_id, *flag); @@ -3073,7 +3077,7 @@ __global__ void __launch_bounds__(MAX_THREADS) const int signal_id = (*recv_id) + 1; volatile int *flag = (volatile int *)flagptr; clock_t s = clock64(); - while (*flag < signal_id) { + while (CHECK_IDS(*flag, signal_id)) { if (clock64() - s > TIMEOUT) { printf("[%d from %d] pullrecv: expected %d, stuck with %d\n", myrank, peer, signal_id, *flag); @@ -3142,7 +3146,7 @@ __global__ void kuserbuffers_pushrecv(int myrank, int peer, int *recv_id, int *f if (*flag >= signal_id) return; clock_t s = clock64(); - while (atomicAdd_system(flagptr, 0) < signal_id) { + while (CHECK_IDS(*flag, signal_id)) { if (clock64() - s > TIMEOUT) { printf("%d from %d] pushrecv: expected %d, stuck with %d\n", myrank, peer, signal_id, *flag); return; @@ -3193,7 +3197,7 @@ __global__ void __launch_bounds__(MAX_THREADS) if (*flag >= signal_id) return; clock_t s = clock64(); - while (*flag < signal_id) { + while (CHECK_IDS(*flag, signal_id)) { if (clock64() - s > TIMEOUT) { printf("%d from %d] pushrecv: expected %d, stuck with %d\n", myrank, peer, signal_id, *flag); @@ -3245,7 +3249,7 @@ __global__ void __launch_bounds__(MAX_THREADS) volatile int *flag = (volatile int *)recv_flagptr; // if(*flag>=signal_id) return; clock_t s = clock64(); - while (*flag < signal_id) { + while (CHECK_IDS(*flag, signal_id)) { if (clock64() - s > TIMEOUT) { printf("%d from %d] pushrecv: expected %d, stuck with %d\n", myrank, peer, signal_id, *flag); /*return;*/ @@ -3312,7 +3316,7 @@ __global__ void __launch_bounds__(MAX_THREADS) volatile int *flag = (volatile int *)recv_flagptr; // if(*flag>=signal_id) return; clock_t s = clock64(); - while (*flag < signal_id) { + while (CHECK_IDS(*flag, signal_id)) { if (clock64() - s > TIMEOUT) { printf("%d from %d] pushrecv: expected %d, stuck with %d\n", myrank, peer, signal_id, *flag); /*return;*/ From 6338367c40b56e64032962c1dd0cca8445a8437a Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 4 Apr 2024 09:59:14 -0700 Subject: [PATCH 007/244] [PyTorch] Fix backward compatibility for checkpoint API (#748) * Args can be None Signed-off-by: Kirthi Shankar Sivamani * Fix other arg types Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/distributed.py | 19 ++----------------- 1 file changed, 2 insertions(+), 17 deletions(-) diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 6a2a801efd..239cecf39b 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -516,12 +516,6 @@ def checkpoint( kwargs : dict dictionary of string keys for keyword arguments to :attr:`function`. """ - only_tensor_args = True - for arg in args: - if not isinstance(arg, torch.Tensor): - only_tensor_args = False - break - # Pop out te.distributed.checkpoint() arguments global _USE_REENTRANT_ACTIVATION_RECOMPUTE _USE_REENTRANT_ACTIVATION_RECOMPUTE = kwargs.pop("use_reentrant", True) @@ -530,23 +524,14 @@ def checkpoint( get_rng_state_tracker = kwargs.pop("get_rng_state_tracker", None) # Ensure backward compatibility. - if not only_tensor_args: + if (len(args) > 3 and isinstance(args[0], bool) and callable(args[1]) + and isinstance(args[2], None | dist_group_type)): warnings.warn( "Passing non-tensor non-keyword arguments is deprecated and support will be removed in " "future releases of TransformerEngine. `distribute_saved_activations`, `tp_group`, and " "`get_rng_state_tracker` must be passed as keyword arguments to `checkpoint`.", DeprecationWarning, stacklevel=2, ) - assert len(args) > 3, "Incorrect number of arguments for deprecated `checkpoint` API." - assert ( - isinstance(args[0], bool) and callable(args[1]) - and isinstance(args[2], None | dist_group_type) - ), "Incorrect arguments for deprecated `checkpoint` API." - for arg in args[3:]: - assert ( - isinstance(arg, None | torch.Tensor) - ), f"Expected tensor argument, found {type(arg)}." - distribute_saved_activations, get_rng_state_tracker, tp_group = args[:3] # pylint: disable=unbalanced-tuple-unpacking args = args[3:] From 48d54789440fdb51a7b9d6ec7fa63287a4e1a53d Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Thu, 4 Apr 2024 11:54:01 -0700 Subject: [PATCH 008/244] Compile tuned RMSNorm kernels for hidden size 8192 (#747) Signed-off-by: Tim Moon Signed-off-by: Pawel Gadzinski --- .../common/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu | 4 ++++ .../common/rmsnorm/rmsnorm_fwd_cuda_kernel.cu | 7 +++++++ 2 files changed, 11 insertions(+) diff --git a/transformer_engine/common/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu b/transformer_engine/common/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu index 982adc27d4..552cd1b4bc 100644 --- a/transformer_engine/common/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu +++ b/transformer_engine/common/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu @@ -201,6 +201,10 @@ REGISTER_BWD_TUNED_LAUNCHER(4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_TUNED_LAUNCHER(4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_TUNED_LAUNCHER(4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_TUNED_LAUNCHER(8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_TUNED_LAUNCHER(8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_TUNED_LAUNCHER(8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); + // Create rmsnorm general launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, ... // WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL diff --git a/transformer_engine/common/rmsnorm/rmsnorm_fwd_cuda_kernel.cu b/transformer_engine/common/rmsnorm/rmsnorm_fwd_cuda_kernel.cu index ee3595f934..bce89fafb1 100644 --- a/transformer_engine/common/rmsnorm/rmsnorm_fwd_cuda_kernel.cu +++ b/transformer_engine/common/rmsnorm/rmsnorm_fwd_cuda_kernel.cu @@ -170,6 +170,13 @@ REGISTER_FWD_TUNED_LAUNCHER(4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16); REGISTER_FWD_TUNED_LAUNCHER(4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +REGISTER_FWD_TUNED_LAUNCHER(8192, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_FWD_TUNED_LAUNCHER(8192, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_FWD_TUNED_LAUNCHER(8192, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_FWD_TUNED_LAUNCHER(8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_FWD_TUNED_LAUNCHER(8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_TUNED_LAUNCHER(8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16); + // Create rmsnorm general launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, BYTES_PER_LDG From d8c19720fd21031dd3caf617761237b7b26670d0 Mon Sep 17 00:00:00 2001 From: Sangkug Lym Date: Thu, 4 Apr 2024 14:03:11 -0700 Subject: [PATCH 009/244] userbuffer: support fp8 buffer for individual overlap instance (#750) * userbuffer fp8 reduction support for individual overlap Signed-off-by: Sangkug Lym * cleanup dict ub_cfg dict value load Signed-off-by: Sangkug Lym * cleanup Signed-off-by: Sangkug Lym * Remove unnecessary fence from producer From @erhoo82 Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Sangkug Lym Signed-off-by: Kirthi Shankar Sivamani Co-authored-by: Kirthi Shankar Sivamani Signed-off-by: Pawel Gadzinski --- .../pytorch/csrc/userbuffers/userbuffers.cu | 1 - transformer_engine/pytorch/module/base.py | 27 ++++++++++--------- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu b/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu index bb62b55262..0cf1a091b9 100644 --- a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu +++ b/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu @@ -48,7 +48,6 @@ #define ATOMIC_PRODUCER(chunk) \ if (counters) { \ ((unsigned int *)counters)[chunk] = 0; \ - asm volatile("fence.sc.gpu;\n"); \ } // Return true if producer > consumer, otherwise false while preventing integer overflow diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 31e305cc15..9f99fbb553 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -126,18 +126,16 @@ def initialize_ub( _cublas_workspace = get_workspace().repeat(_NUM_MAX_UB_STREAMS) # Default buffer precision: AllGather buffers use fp8 when using fp8 recipe - fp8_buf = [ + layers_all_gather_overlap = [ "qkv_fprop", "qkv_dgrad", "proj_dgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad" ] - if bool(int(os.getenv("NVTE_UB_FP8_RS", "0"))): - fp8_buf += ["proj_fprop", "fc2_fprop"] + layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"] # Default overlap methods for layers methods = { "ring_exchange":["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad"], "pipeline":["proj_fprop", "fc2_fprop"], "bulk":["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"], } - layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"] # AG-RS overlap pairs of layers forming a tensor-parallel block ag_rs_pairs = {"qkv_fprop":"proj_fprop", "fc1_fprop":"fc2_fprop"} @@ -161,6 +159,7 @@ def add_ub( aggregate: int = 0, atomic_gemm: int = 0, is_reduce_scatter: int = 0, + fp8_buf: bool = False, ) -> None: if atomic_gemm: warnings.warn( @@ -198,7 +197,7 @@ def add_ub( sample_buffer = torch.empty( shape, - dtype=torch.uint8 if (use_fp8 and name in fp8_buf) else dtype, + dtype=torch.uint8 if (use_fp8 and fp8_buf) else dtype, device='cuda') if method == 'ring_exchange': ub_obj = tex.UbufP2PCommOverlap( @@ -232,14 +231,17 @@ def add_ub( for name in (methods["ring_exchange"]+methods["pipeline"]+methods["bulk"]): if ub_cfgs is not None and name in ub_cfgs: ub_cfg = ub_cfgs[name] - method = ub_cfg["method"] if "method" in ub_cfg else get_method(name) - num_sm = ub_cfg["num_sm"] if "num_sm" in ub_cfg else 16 - cga_size = ub_cfg["cga_size"] if "cga_size" in ub_cfg else 2 - num_splits = ub_cfg["num_splits"] if "num_splits" in ub_cfg else 4 - set_sm_margin = ub_cfg["set_sm_margin"] if "set_sm_margin" in ub_cfg else 0 - aggregate = ub_cfg["aggregate"] if "aggregate" in ub_cfg else 0 - atomic_gemm = ub_cfg["atomic_gemm"] if "atomic_gemm" in ub_cfg else 0 + method = ub_cfg.get("method", get_method(name)) + num_sm = ub_cfg.get("num_sm", 16) + cga_size = ub_cfg.get("cga_size", 2) + num_splits = ub_cfg.get("num_splits", 4) + set_sm_margin = ub_cfg.get("set_sm_margin", 0) + aggregate = ub_cfg.get("aggregate", 0) + atomic_gemm = ub_cfg.get("atomic_gemm", 0) is_reduce_scatter = 1 if name in layers_reduce_scatter_overlap else 0 + # Support FP8 userbuffer when (1) AllGather and (2) FP8-GEMM output ReduceScatter + fp8_buf = ((name in layers_all_gather_overlap) or + (ub_cfg.get("fp8_buf", False) and name in methods["pipeline"])) add_ub( name, method, @@ -250,6 +252,7 @@ def add_ub( aggregate, atomic_gemm, is_reduce_scatter, + fp8_buf, ) else: method = get_method(name) From 3b5fe44a64e5b6a8cd7f928f31af57a53fe60c08 Mon Sep 17 00:00:00 2001 From: Jaemin Choi Date: Fri, 5 Apr 2024 22:42:31 -0700 Subject: [PATCH 010/244] Enable DGRAD RS overlap (#754) * Enable DGRAD RS overlap Signed-off-by: Jaemin Choi * fix lint; apply suggestions Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Jaemin Choi Signed-off-by: Kirthi Shankar Sivamani Co-authored-by: Vasudevan Rengasamy Co-authored-by: Kirthi Shankar Sivamani Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/attention.py | 3 + transformer_engine/pytorch/module/base.py | 9 ++ .../pytorch/module/layernorm_linear.py | 60 ++++++++++-- .../pytorch/module/layernorm_mlp.py | 93 ++++++++++++++++--- transformer_engine/pytorch/transformer.py | 4 + 5 files changed, 148 insertions(+), 21 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index f5e7753e6a..f03350eb4e 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -3171,6 +3171,7 @@ def __init__( qkv_weight_interleaved: bool = True, ub_bulk_wgrad: bool = False, ub_bulk_dgrad: bool = False, + ub_overlap_rs_dgrad: bool = False, ub_overlap_rs: bool = False, ub_overlap_ag: bool = False, bias: bool = True, @@ -3259,6 +3260,7 @@ def __init__( zero_centered_gamma=zero_centered_gamma, ub_bulk_wgrad=ub_bulk_wgrad, ub_bulk_dgrad=ub_bulk_dgrad, + ub_overlap_rs_dgrad=ub_overlap_rs_dgrad, ub_overlap_ag=ub_overlap_ag, normalization=normalization, ub_name="qkv", @@ -3290,6 +3292,7 @@ def __init__( zero_centered_gamma=zero_centered_gamma, ub_bulk_wgrad=ub_bulk_wgrad, ub_bulk_dgrad=ub_bulk_dgrad, + ub_overlap_rs_dgrad=ub_overlap_rs_dgrad, ub_overlap_ag=ub_overlap_ag, normalization=normalization, ub_name="qkv", diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 9f99fbb553..6ef6d4eb3b 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -130,6 +130,7 @@ def initialize_ub( "qkv_fprop", "qkv_dgrad", "proj_dgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad" ] layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"] + dgrad_reduce_scatter_overlap = ["qkv_dgrad", "fc1_dgrad"] # Default overlap methods for layers methods = { "ring_exchange":["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad"], @@ -228,6 +229,14 @@ def add_ub( ) _ub_communicators[name] = ub_obj + if ub_cfgs is not None: + for name in dgrad_reduce_scatter_overlap: + if name in ub_cfgs and 'method' in ub_cfgs[name] and ub_cfgs[name]['method'] != 'bulk': + wgrad_name = name.replace('dgrad','wgrad') + assert wgrad_name not in ub_cfgs + layers_reduce_scatter_overlap.remove(wgrad_name) + layers_reduce_scatter_overlap.append(name) + for name in (methods["ring_exchange"]+methods["pipeline"]+methods["bulk"]): if ub_cfgs is not None and name in ub_cfgs: ub_cfg = ub_cfgs[name] diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 18777cc9e3..985d587e54 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -86,6 +86,7 @@ def forward( primary_weights_in_fp8: bool, ub_bulk_wgrad: bool, ub_bulk_dgrad: bool, + ub_overlap_rs_dgrad: bool, ub_overlap_ag: bool, ub_name: str, ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: @@ -316,6 +317,7 @@ def forward( ctx.zero_centered_gamma = zero_centered_gamma ctx.ub_bulk_wgrad = ub_bulk_wgrad ctx.ub_bulk_dgrad = ub_bulk_dgrad + ctx.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad ctx.ub_name = ub_name ctx.requires_dgrad = inp.requires_grad ctx.normalization = normalization @@ -367,6 +369,12 @@ def backward( update_cache="reuse_only" if ctx.is_first_microbatch is None else "lazy", ) + if ctx.ub_overlap_rs_dgrad: + ctx.ub_bulk_dgrad = False + ctx.ub_bulk_wgrad = False + tp_world_size = get_distributed_world_size(ctx.tp_group) + if tp_world_size == 1: + ctx.ub_overlap_rs_dgrad = False if ctx.ub_bulk_dgrad: tp_world_size = get_distributed_world_size(ctx.tp_group) if tp_world_size == 1 or not weight.requires_grad: @@ -416,9 +424,36 @@ def backward( if ctx.ub_bulk_wgrad: # allocate dgrad output ub_obj_dgrad = get_ub(ctx.ub_name+"_wgrad") dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output + elif ctx.ub_overlap_rs_dgrad: + ub_obj_dgrad = get_ub(ctx.ub_name+"_dgrad") + dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output else: dgrad = torch.empty(dgrad_size, dtype=ctx.activation_dtype, device=weight.device) + if ctx.ub_bulk_dgrad: + ub_algo = tex.UbufOverlapAlgo.BULK_OVERLAP_AG + ub_obj = ub_obj_lnout + elif ctx.ub_overlap_rs_dgrad: + dim_size = list(grad_output.size()) + dim_size[0] = dim_size[0] // tp_world_size + dim_size[1] = weight.size(1) + rs_out = torch.empty( + dim_size, dtype=ctx.activation_dtype, device=grad_output.device) + if ub_obj_dgrad.is_p2p_overlap(): + if ctx.fp8 and ub_obj_dgrad.is_atomic_gemm(): + ub_algo=tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P + else: + ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P + else: + if ctx.fp8 and ub_obj_dgrad.is_atomic_gemm(): + ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS + else: + ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS + ub_obj = ub_obj_dgrad + else: + ub_algo = None + ub_obj = None + if ctx.fp8: fp8_dtype_forward = get_fp8_te_dtype( ctx.fp8_meta["recipe"], fprop_tensor=True @@ -428,7 +463,7 @@ def backward( ) out_index, meta_tensor, out_te_type, out_type = ( None, None, None, ctx.activation_dtype) - if ctx.ub_bulk_wgrad and ub_obj_dgrad.is_fp8_ubuf(): + if (ctx.ub_bulk_wgrad or ctx.ub_overlap_rs_dgrad) and ub_obj_dgrad.is_fp8_ubuf(): out_index = tex.FP8BwdTensors.GRAD_INPUT1 meta_tensor = ctx.fp8_meta["scaling_bwd"] out_te_type = fp8_dtype_backward @@ -449,8 +484,9 @@ def backward( get_workspace(), out=dgrad, use_split_accumulator=_2X_ACC_DGRAD, - ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_AG if ctx.ub_bulk_dgrad else None, - ub=ub_obj_lnout if ctx.ub_bulk_dgrad else None, + ub_algo=ub_algo, + ub=ub_obj, + extra_output_tensor=rs_out if ctx.ub_overlap_rs_dgrad else None, out_index=out_index, fp8_meta_tensor = meta_tensor, D_dtype = out_te_type, @@ -466,8 +502,9 @@ def backward( out=dgrad, layout="NN", grad=True, - ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_AG if ctx.ub_bulk_dgrad else None, - ub=ub_obj_lnout if ctx.ub_bulk_dgrad else None + ub_algo=ub_algo, + ub=ub_obj, + extra_output_tensor=rs_out if ctx.ub_overlap_rs_dgrad else None, ) if ctx.ub_bulk_dgrad: ln_out_total = ub_obj_lnout.get_ubuf_output(1) @@ -476,7 +513,7 @@ def backward( if ctx.parallel_mode == "column" and ctx.sequence_parallel: if not ctx.ub_bulk_dgrad and handle is not None: handle.wait() - if not ctx.ub_bulk_wgrad: + if not ctx.ub_bulk_wgrad and not ctx.ub_overlap_rs_dgrad: if ctx.return_layernorm_output and ctx.return_layernorm_output_gathered: dgrad = dgrad + grad_outputs[1].view_as(dgrad) dgrad, handle = reduce_scatter_along_first_dim( @@ -569,7 +606,10 @@ def backward( handle.wait() # LayerNorm gradient - dgrad = dgrad.view(inputmat.shape) + if ctx.ub_overlap_rs_dgrad: + dgrad = rs_out.view(inputmat.shape) + else: + dgrad = dgrad.view(inputmat.shape) # Residual gradient if ctx.return_layernorm_output and not ctx.return_layernorm_output_gathered: @@ -645,6 +685,7 @@ def backward( None, None, None, + None, ) @@ -758,6 +799,7 @@ def __init__( ub_bulk_wgrad: bool = False, ub_bulk_dgrad: bool = False, ub_overlap_ag: bool = False, + ub_overlap_rs_dgrad: bool = False, ub_name: Optional[str] = None, ) -> None: super().__init__() @@ -778,7 +820,8 @@ def __init__( self.ub_bulk_wgrad = ub_bulk_wgrad self.ub_bulk_dgrad = ub_bulk_dgrad self.ub_overlap_ag = ub_overlap_ag - if any([ub_bulk_wgrad, ub_bulk_dgrad, ub_overlap_ag]): + self.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad + if any([ub_bulk_wgrad, ub_bulk_dgrad, ub_overlap_ag, ub_overlap_rs_dgrad]): assert ub_name is not None, "Userbuffer name [string] is not set." self.ub_name = ub_name @@ -1110,6 +1153,7 @@ def forward( self.primary_weights_in_fp8, self.ub_bulk_wgrad, self.ub_bulk_dgrad, + self.ub_overlap_rs_dgrad, self.ub_overlap_ag, self.ub_name, ) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 91683ea0a8..ad66e01e07 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -117,6 +117,7 @@ def forward( primary_weights_in_fp8: bool, ub_bulk_wgrad: bool, ub_bulk_dgrad: bool, + ub_overlap_rs_dgrad: bool, ub_overlap_rs: bool, ub_overlap_ag: bool, gemm_gelu_fusion: bool, @@ -533,6 +534,7 @@ def forward( ctx.zero_centered_gamma = zero_centered_gamma ctx.ub_bulk_wgrad = ub_bulk_wgrad ctx.ub_bulk_dgrad = ub_bulk_dgrad + ctx.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad ctx.ub_overlap_ag = ub_overlap_ag ctx.requires_dgrad = inp.requires_grad ctx.normalization = normalization @@ -598,6 +600,12 @@ def backward( activation_func = _act_func(ctx.activation)[1] + if ctx.ub_overlap_rs_dgrad: + ctx.ub_bulk_dgrad = False + ctx.ub_bulk_wgrad = False + tp_world_size = get_distributed_world_size(ctx.tp_group) + if tp_world_size == 1: + ctx.ub_overlap_rs_dgrad = False if ctx.ub_bulk_dgrad: tp_world_size = get_distributed_world_size(ctx.tp_group) if tp_world_size == 1 or not fc1_weight.requires_grad: @@ -773,19 +781,49 @@ def backward( None, None, None, ctx.activation_dtype) fc1_dgrad_size = list(dgelu.size()) fc1_dgrad_size[1] = fc1_weight.size(1) + # Get/alloc fc1_dgrad if ctx.ub_bulk_wgrad: # allocate dgrad output ub_obj_dgrad = get_ub("fc1_wgrad") fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output - if ub_obj_dgrad.is_fp8_ubuf(): - out_index = tex.FP8BwdTensors.GRAD_INPUT2 - meta_tensor = ctx.fp8_meta["scaling_bwd"] - out_te_type = fp8_dtype_backward - out_type = torch.uint8 - ub_obj_dgrad.set_ubuf_scale_inv(meta_tensor.scale_inv[out_index]) + elif ctx.ub_overlap_rs_dgrad: + ub_obj_dgrad = get_ub("fc1_dgrad") + fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output else: fc1_dgrad = torch.empty( fc1_dgrad_size, dtype=ctx.activation_dtype, device=fc1_weight.device ) + + # FP8 RS + if (ctx.ub_bulk_wgrad or ctx.ub_overlap_rs_dgrad) and ub_obj_dgrad.is_fp8_ubuf(): + out_index = tex.FP8BwdTensors.GRAD_INPUT2 + meta_tensor = ctx.fp8_meta["scaling_bwd"] + out_te_type = fp8_dtype_backward + out_type = torch.uint8 + ub_obj_dgrad.set_ubuf_scale_inv(meta_tensor.scale_inv[out_index]) + + # Set UB algo and UB obj for fc1_dgrad bulk/pipelined overlap + if ctx.ub_bulk_dgrad: + ub_algo = tex.UbufOverlapAlgo.BULK_OVERLAP_AG + ub_obj = ub_obj_lnout + elif ctx.ub_overlap_rs_dgrad: + dim_size = list(dgelu.size()) + dim_size[0] = dim_size[0] // tp_world_size + dim_size[1] = fc1_weight_t_fp8.size(0) + rs_out = torch.empty(dim_size, dtype=ctx.activation_dtype, device=dgelu.device) + if ub_obj_dgrad.is_p2p_overlap(): + if ub_obj_dgrad.is_atomic_gemm(): + ub_algo=tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P + else: + ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P + else: + if ub_obj_dgrad.is_atomic_gemm(): + ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS + else: + ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS + ub_obj = ub_obj_dgrad + else: + ub_algo = None + ub_obj = None # FC1 DGRAD: Unconditional _ = tex.fp8_gemm( fc1_weight_t_fp8._data, @@ -800,8 +838,9 @@ def backward( get_workspace(), out=fc1_dgrad, use_split_accumulator=_2X_ACC_DGRAD, - ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_AG if ctx.ub_bulk_dgrad else None, - ub=ub_obj_lnout if ctx.ub_bulk_dgrad else None, + ub_algo=ub_algo, + ub=ub_obj, + extra_output_tensor=rs_out if ctx.ub_overlap_rs_dgrad else None, out_index=out_index, fp8_meta_tensor = meta_tensor, D_dtype = out_te_type, @@ -859,11 +898,31 @@ def backward( if ctx.ub_bulk_wgrad: # allocate dgrad output ub_obj_dgrad = get_ub("fc1_wgrad") fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output + elif ctx.ub_overlap_rs_dgrad: + ub_obj_dgrad = get_ub("fc1_dgrad") + fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output else: fc1_dgrad = torch.empty( fc1_dgrad_size, dtype=ctx.activation_dtype, device=fc1_weight.device ) + # Set UB algo and UB obj for fc1_dgrad bulk/pipelined overlap + if ctx.ub_bulk_dgrad: + ub_algo = tex.UbufOverlapAlgo.BULK_OVERLAP_AG + ub_obj = ub_obj_lnout + elif ctx.ub_overlap_rs_dgrad: + dim_size = list(dgelu.size()) + dim_size[0] = dim_size[0] // tp_world_size + dim_size[1] = fc1_weight.size(1) + rs_out = torch.empty(dim_size, dtype=ctx.activation_dtype, device=dgelu.device) + if ub_obj_dgrad.is_p2p_overlap(): + ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P + else: + ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS + ub_obj = ub_obj_dgrad + else: + ub_algo = None + ub_obj = None # FC1 DGRAD: Unconditional _ = tex.gemm( fc1_weight, @@ -873,8 +932,9 @@ def backward( out=fc1_dgrad, layout="NN", grad=True, - ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_AG if ctx.ub_bulk_dgrad else None, - ub=ub_obj_lnout if ctx.ub_bulk_dgrad else None + ub_algo=ub_algo, + ub=ub_obj, + extra_output_tensor=rs_out if ctx.ub_overlap_rs_dgrad else None, ) if ctx.ub_bulk_dgrad: @@ -883,7 +943,7 @@ def backward( if ctx.set_parallel_mode and ctx.sequence_parallel: if not ctx.ub_bulk_dgrad and handle is not None: handle.wait() - if not ctx.ub_bulk_wgrad: + if not ctx.ub_bulk_wgrad and not ctx.ub_overlap_rs_dgrad: if ctx.return_layernorm_output and ctx.return_layernorm_output_gathered: fc1_dgrad = fc1_dgrad + grad_outputs[1].view_as(fc1_dgrad) fc1_dgrad, handle = reduce_scatter_along_first_dim( @@ -985,7 +1045,10 @@ def backward( handle.wait() # LayerNorm gradient - dgrad = fc1_dgrad.view(inputmat.shape) + if ctx.ub_overlap_rs_dgrad: + dgrad = rs_out.view(inputmat.shape) + else: + dgrad = fc1_dgrad.view(inputmat.shape) # Residual gradient if ctx.return_layernorm_output and not ctx.return_layernorm_output_gathered: @@ -1087,6 +1150,7 @@ def backward( None, None, None, + None, ) @@ -1209,6 +1273,7 @@ def __init__( device: Union[torch.device, str] = "cuda", ub_bulk_wgrad: bool = False, ub_bulk_dgrad: bool = False, + ub_overlap_rs_dgrad: bool = False, ub_overlap_rs: bool = False, ub_overlap_ag: bool = False, ) -> None: @@ -1231,6 +1296,7 @@ def __init__( self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters() self.ub_bulk_wgrad = ub_bulk_wgrad self.ub_bulk_dgrad = ub_bulk_dgrad + self.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad self.ub_overlap_rs = ub_overlap_rs self.ub_overlap_ag = ub_overlap_ag # GEMM-GELU fusion is currently only supported with split GEMM-AG overlap @@ -1238,7 +1304,7 @@ def __init__( (bool(int(os.getenv("NVTE_GEMM_GELU_FUSION", "0"))) and self.activation == 'gelu' and not get_ub("fc1_fprop").is_atomic_gemm()) - if any([ub_bulk_wgrad, ub_bulk_dgrad, ub_overlap_rs, ub_overlap_ag]): + if any([ub_bulk_wgrad, ub_bulk_dgrad, ub_overlap_rs, ub_overlap_ag, ub_overlap_rs_dgrad]): assert ( tex.userbuf_comm_available() ), "Userbuffer communication backend not available." @@ -1492,6 +1558,7 @@ def forward( self.primary_weights_in_fp8, self.ub_bulk_wgrad, self.ub_bulk_dgrad, + self.ub_overlap_rs_dgrad, self.ub_overlap_rs, self.ub_overlap_ag, self.gemm_gelu_fusion, diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index a0fd231913..2e00333fa0 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -261,6 +261,7 @@ def __init__( ub_bulk_dgrad: bool = True, ub_overlap_ag: bool = True, ub_overlap_rs: bool = True, + ub_overlap_rs_dgrad: bool = False, bias: bool = True, activation: str = 'gelu', normalization: str = "LayerNorm", @@ -282,6 +283,7 @@ def __init__( ub_bulk_dgrad = ub_tp_comm_overlap and ub_bulk_dgrad ub_overlap_ag = ub_tp_comm_overlap and ub_overlap_ag ub_overlap_rs = ub_tp_comm_overlap and ub_overlap_rs + ub_overlap_rs_dgrad = ub_tp_comm_overlap and ub_overlap_rs_dgrad bias_dropout_fusion = bool(int(os.getenv("NVTE_BIAS_DROPOUT_FUSION", "1"))) self.layer_number = layer_number @@ -357,6 +359,7 @@ def __init__( "ub_bulk_dgrad" : ub_bulk_dgrad, "ub_overlap_ag" : ub_overlap_ag, "ub_overlap_rs" : ub_overlap_rs, + "ub_overlap_rs_dgrad" : ub_overlap_rs_dgrad, "qkv_format" : self.attn_input_format, } @@ -410,6 +413,7 @@ def __init__( zero_centered_gamma=zero_centered_gamma, ub_bulk_wgrad=ub_bulk_wgrad, ub_bulk_dgrad=ub_bulk_dgrad, + ub_overlap_rs_dgrad=ub_overlap_rs_dgrad, ub_overlap_rs=ub_overlap_rs, ub_overlap_ag=ub_overlap_ag, activation=activation, From 36b99c140fe155bbec561b883f38c309d95a5e1a Mon Sep 17 00:00:00 2001 From: Sangkug Lym Date: Fri, 5 Apr 2024 22:44:27 -0700 Subject: [PATCH 011/244] Fix the default userbuffer communicator init settings (#755) fix the default userbuffer communicator init settings Signed-off-by: Sangkug Lym Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/module/base.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 6ef6d4eb3b..56dd3c8fc4 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -153,13 +153,13 @@ def get_method(name): def add_ub( name: str, method: str, + is_reduce_scatter: int, num_sm: int = 16, cga_size: int = 2, set_sm_margin: int = 0, - num_splits: int = 4, + num_splits: int = 0, aggregate: int = 0, atomic_gemm: int = 0, - is_reduce_scatter: int = 0, fp8_buf: bool = False, ) -> None: if atomic_gemm: @@ -243,7 +243,7 @@ def add_ub( method = ub_cfg.get("method", get_method(name)) num_sm = ub_cfg.get("num_sm", 16) cga_size = ub_cfg.get("cga_size", 2) - num_splits = ub_cfg.get("num_splits", 4) + num_splits = ub_cfg.get("num_splits", 4 if method == "pipeline" else 0) set_sm_margin = ub_cfg.get("set_sm_margin", 0) aggregate = ub_cfg.get("aggregate", 0) atomic_gemm = ub_cfg.get("atomic_gemm", 0) @@ -254,21 +254,24 @@ def add_ub( add_ub( name, method, + is_reduce_scatter, num_sm, cga_size, set_sm_margin, num_splits, aggregate, atomic_gemm, - is_reduce_scatter, fp8_buf, ) else: method = get_method(name) - if method == "pipeline": - add_ub(name, method) - else: - add_ub(name, method, num_splits=0) + add_ub( + name, + method=method, + is_reduce_scatter=1 if name in layers_reduce_scatter_overlap else 0, + num_splits=4 if method == "pipeline" else 0, + fp8_buf=name in layers_all_gather_overlap, + ) def get_ub(name: str): From 67295b00b6c5443facf60927ac9df55569e1c2bb Mon Sep 17 00:00:00 2001 From: Reese Wang Date: Sun, 7 Apr 2024 00:38:50 +0800 Subject: [PATCH 012/244] [JAX] Adapt latest JAX/PAX image (#744) * value_and_grad requires same shape for input and gradients Signed-off-by: Reese Wang * Use high precision layernorm Signed-off-by: Reese Wang * Remove local_device_ids as it caused unexpected behaviors Signed-off-by: Reese Wang * Revert "Remove local_device_ids as it caused unexpected behaviors" This reverts commit c54349b2ce1e96ae696cf0d74f5210e55002cf72. Signed-off-by: Reese Wang --------- Signed-off-by: Reese Wang Signed-off-by: Pawel Gadzinski --- tests/jax/test_custom_call_compute.py | 5 +++-- tests/jax/utils.py | 7 +++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 212ddd6d07..8aa6c399f4 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -485,7 +485,8 @@ def primitive_bwd(ctx, g): primitive.defvjp(primitive_fwd, primitive_bwd) func = value_and_grad(lambda x, y, z, w: jnp.mean(primitive(x, y, z, w)), (0, 1, 2, 3)) - return func(inputs, no_use, no_use, no_use) + return func(inputs, jnp.transpose(inputs, (2, 0, 1)), + jnp.zeros(inputs.shape[-1], dtype=inputs.dtype), no_use) @pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.parametrize('shape', [(32, 2, 64), (64, 2, 256)]) @@ -582,7 +583,7 @@ def primitive_bwd(ctx, g): primitive.defvjp(primitive_fwd, primitive_bwd) func = value_and_grad(lambda x, y, z: jnp.mean(primitive(x, y, z)), (0, 1, 2)) - return func(inputs, no_use, no_use) + return func(inputs, jnp.transpose(inputs, (1, 2, 0)), no_use) @pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.parametrize('shape', [(32, 2, 64), (64, 2, 256)]) diff --git a/tests/jax/utils.py b/tests/jax/utils.py index 8eabafde57..c8e1b1b183 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -731,19 +731,18 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: axes=('embed',)) bias = jnp.asarray(bias, self.dtype) - y = jnp.asarray(y, self.dtype) if not self.zero_centered_gamma: z = y * scale + bias else: - z = y * (scale + 1) + bias + z = y * (scale + 1.) + bias else: assert self.layernorm_type == 'rmsnorm' assert not self.zero_centered_gamma mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True) - y = jnp.asarray(x * lax.rsqrt(mean2 + self.epsilon), self.dtype) + y = x * lax.rsqrt(mean2 + self.epsilon) z = y * scale - return z + return jnp.asarray(z, self.dtype) class RelativePositionBiases(nn.Module): From 1cecc03a1d0efe5ba19944c3fbd478fe85e2aafb Mon Sep 17 00:00:00 2001 From: Jinze Xue <155670984+jinzex@users.noreply.github.com> Date: Wed, 10 Apr 2024 13:30:03 -0700 Subject: [PATCH 013/244] Fix undefined symbol issue for transformer_engine::getenv (#763) Signed-off-by: Jinze Xue Co-authored-by: Jinze Xue Signed-off-by: Pawel Gadzinski --- setup.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/setup.py b/setup.py index d50a9b8706..d442aec872 100644 --- a/setup.py +++ b/setup.py @@ -445,6 +445,12 @@ def setup_pytorch_extension() -> setuptools.Extension: sources = [ src_dir / "common.cu", src_dir / "ts_fp8_op.cpp", + # We need to compile system.cpp because the pytorch extension uses + # transformer_engine::getenv. This is a workaround to avoid direct + # linking with libtransformer_engine.so, as the pre-built PyTorch + # wheel from conda or PyPI was not built with CXX11_ABI, and will + # cause undefined symbol issues. + root_path / "transformer_engine" / "common" / "util" / "system.cpp", ] + \ _all_files_in_dir(extensions_dir) From edc73cdf1b057e46bb7cdf9d8dc8f971c736148a Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Fri, 12 Apr 2024 06:39:22 -0700 Subject: [PATCH 014/244] [PyTorch] cuda graph support (#575) * FP8 cuda graphs Signed-off-by: Kirthi Shankar Sivamani Co-authored-by: Vasudevan Rengasamy Co-authored-by: Charlene Yang * Fix numerics Signed-off-by: Kirthi Shankar Sivamani * exclude torch compile from numerics tests Signed-off-by: Kirthi Shankar Sivamani * More numerics fixes Signed-off-by: Kirthi Shankar Sivamani * Fix tests Signed-off-by: Kirthi Shankar Sivamani * Fix CI Signed-off-by: Kirthi Shankar Sivamani * rm fusion from unfused path Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani Co-authored-by: Vasudevan Rengasamy Co-authored-by: Charlene Yang Signed-off-by: Pawel Gadzinski --- docs/api/pytorch.rst | 2 + qa/L0_pytorch_unittest/test.sh | 5 +- tests/pytorch/fused_attn/test_fused_attn.py | 60 +- tests/pytorch/test_cuda_graphs.py | 215 +++++++ tests/pytorch/test_float8tensor.py | 88 +-- tests/pytorch/test_numerics.py | 58 +- tests/pytorch/test_onnx_export.py | 6 + tests/pytorch/test_sanity.py | 11 +- .../transformer_engine/cast_transpose_noop.h | 35 ++ .../include/transformer_engine/recipe.h | 39 ++ .../common/layer_norm/ln_api.cpp | 37 +- transformer_engine/common/recipe/__init__.py | 10 + .../common/recipe/delayed_scaling.cu | 244 +++++++- .../common/rmsnorm/rmsnorm_api.cpp | 34 +- .../common/transpose/cast_transpose.cu | 41 ++ .../common/transpose/rtc/transpose.cu | 3 + .../common/transpose/transpose.cu | 39 ++ transformer_engine/pytorch/__init__.py | 1 + transformer_engine/pytorch/attention.py | 21 +- .../pytorch/cpp_extensions/transpose.py | 13 +- .../pytorch/csrc/comm_gemm_overlap.h | 64 +- transformer_engine/pytorch/csrc/common.h | 1 + transformer_engine/pytorch/csrc/extensions.h | 39 +- .../pytorch/csrc/extensions/pybind.cpp | 11 +- .../pytorch/csrc/extensions/recipe.cu | 62 +- .../pytorch/csrc/extensions/transpose.cu | 65 ++ transformer_engine/pytorch/distributed.py | 115 +++- transformer_engine/pytorch/float8_tensor.py | 150 ++--- transformer_engine/pytorch/fp8.py | 586 ++++++++---------- transformer_engine/pytorch/graph.py | 548 ++++++++++++++++ transformer_engine/pytorch/module/base.py | 223 +++---- .../pytorch/module/layernorm_linear.py | 42 +- .../pytorch/module/layernorm_mlp.py | 58 +- transformer_engine/pytorch/module/linear.py | 47 +- transformer_engine/pytorch/transformer.py | 12 +- 35 files changed, 2196 insertions(+), 789 deletions(-) create mode 100644 tests/pytorch/test_cuda_graphs.py create mode 100644 transformer_engine/common/include/transformer_engine/cast_transpose_noop.h create mode 100644 transformer_engine/pytorch/graph.py diff --git a/docs/api/pytorch.rst b/docs/api/pytorch.rst index 9b291e6d0a..c9504c20af 100644 --- a/docs/api/pytorch.rst +++ b/docs/api/pytorch.rst @@ -41,4 +41,6 @@ pyTorch .. autoapifunction:: transformer_engine.pytorch.onnx_export +.. autoapifunction:: transformer_engine.pytorch.make_graphed_callables + .. autoapifunction:: transformer_engine.pytorch.get_cpu_offload_context diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 0b94a8b77e..50f54cd714 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -9,9 +9,10 @@ set -e pip install pytest==6.2.5 onnxruntime==1.13.1 pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py -PYTORCH_JIT=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py +PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py +PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py pytest -v -s $TE_PATH/tests/pytorch/test_jit.py -pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py +NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index 65c3b8269b..b2c8f69ef3 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -5,7 +5,6 @@ import functools from importlib.metadata import version import os -import math from typing import Any, Dict, List, Tuple, Union from pkg_resources import packaging @@ -28,15 +27,9 @@ fused_attn_bwd, fused_attn_fwd, ) -from transformer_engine.pytorch.distributed import ( - _set_cuda_rng_state, - CudaRNGStatesTracker, -) +from transformer_engine.pytorch.distributed import CudaRNGStatesTracker import transformer_engine.pytorch.fp8 as fp8 -from transformer_engine.pytorch.module.base import ( - TransformerEngineBaseModule, - _prepare_backward, -) +from transformer_engine.pytorch.module.base import TransformerEngineBaseModule from transformer_engine.pytorch.utils import ( get_device_compute_capability, init_method_normal, @@ -58,10 +51,18 @@ _NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) + def reset_rng_states() -> None: """Revert back to initial RNG state""" torch.set_rng_state(_cpu_rng_state) - _set_cuda_rng_state(_cuda_rng_state) + torch.cuda.set_rng_state(_cuda_rng_state) + + +@pytest.fixture(autouse=True) +def reset_global_fp8_state(): + yield + fp8.FP8GlobalStateManager.reset() + @functools.cache def _cudnn_version() -> Tuple[int, int, int]: @@ -71,6 +72,7 @@ def _cudnn_version() -> Tuple[int, int, int]: minor, patch = divmod(encoded_version, 100) return (major, minor, patch) + class ModelConfig: def __init__( self, @@ -103,6 +105,7 @@ def __init__( self.num_layers = num_layers self.bias_shape = bias_shape + def _is_fused_attention_supported( config: ModelConfig, dtype: torch.dtype, @@ -151,24 +154,28 @@ def _is_fused_attention_supported( return True, backends return False, backends + @functools.cache def _is_flash_attention_2_available() -> bool: """Check if flash-attn 2.0+ is available""" Version = packaging.version.Version return Version(version("flash-attn")) >= Version("2") + @functools.cache def _is_flash_attention_2_1() -> bool: """Check if flash-attn 2.1+ is available""" Version = packaging.version.Version return Version(version("flash-attn")) >= Version("2.1") + @functools.cache def _is_flash_attention_2_3() -> bool: """Check if flash-attn 2.3+ is available""" Version = packaging.version.Version return Version(version("flash-attn")) >= Version("2.3") + def _is_flash_attention_supported(config: ModelConfig) -> bool: """Check if FlashAttention supports a model configuration""" if get_device_compute_capability() < (8, 0): @@ -184,6 +191,7 @@ def _is_flash_attention_supported(config: ModelConfig) -> bool: return False return True + def _is_unfused_attention_supported(config: ModelConfig) -> bool: """Check if UnfusedDotProductAttention supports a model configuration""" if ("padding" in config.attn_mask_type): @@ -192,6 +200,7 @@ def _is_unfused_attention_supported(config: ModelConfig) -> bool: return False return True + model_configs_base = { # test: b, h, hg, d, sq, skv, p, mask, bias # attn , backend "base_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), # self , 0 @@ -200,11 +209,13 @@ def _is_unfused_attention_supported(config: ModelConfig) -> bool: "base_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), # cross, 1 } + param_types = [torch.float16] if is_bf16_compatible(): # bf16 requires sm_80 or higher param_types.append(torch.bfloat16) param_types_lean = [torch.bfloat16] + def get_swa(seq_q, seq_kv, w=None): """Generate a random sliding window size (left, right) if w is None, and create its equivalent attention mask in [seq_q, seq_kv] shape""" @@ -216,6 +227,7 @@ def get_swa(seq_q, seq_kv, w=None): ml = ~ ml return w, ml + @pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("model_configs", [model_configs_base]) @@ -313,6 +325,7 @@ def test_dot_product_attention(dtype, model_configs, model, ckpt_attn, workspace for i,_ in enumerate(fused_attn_bwd): torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_1[i], **tols) + @pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("model_configs", [model_configs_base]) @@ -321,6 +334,7 @@ def test_dpa_checkpoint(dtype, model_configs, model): """Test DotProductAttention module with checkpointing""" test_dot_product_attention(dtype, model_configs, model, True, True, None, False) + model_configs_mask = { # test: b, h, hg, d, sq, skv, p, mask, bias "mask_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "causal", "no_bias"), @@ -337,6 +351,7 @@ def test_dpa_checkpoint(dtype, model_configs, model): "mask_6_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"), } + @pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("model_configs", [model_configs_mask]) @@ -345,6 +360,7 @@ def test_dpa_mask(dtype, model_configs, model): """Test DotProductAttention module with different mask types""" test_dot_product_attention(dtype, model_configs, model, False, True, None, False) + model_configs_bias = { # test: b, h, hg, d, sq, skv, p, mask, bias "bias_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"), @@ -373,6 +389,7 @@ def test_dpa_mask(dtype, model_configs, model): "bias_4_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "alibi"), # skipped } + @pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("model_configs", [model_configs_bias]) @@ -381,6 +398,7 @@ def test_dpa_bias(dtype, model_configs, model): """Test DotProductAttention module with different bias types""" test_dot_product_attention(dtype, model_configs, model, False, True, None, False) + model_configs_bias_shapes = { # test: b, h, hg, d, sq, skv, p, "bias_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, @@ -398,6 +416,7 @@ def test_dpa_bias(dtype, model_configs, model): "causal", "alibi", bias_shape='bhss', alibi_type='custom'), } + @pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("model_configs", [model_configs_bias_shapes]) @@ -413,6 +432,8 @@ def test_dpa_bias_shapes(dtype, model_configs, model): "swa_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), "swa_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), } + + @pytest.mark.skipif(not _is_flash_attention_2_3(), reason="Flash-attn 2.3+ is required.") @pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("model_configs", [model_configs_swa]) @@ -428,6 +449,8 @@ def test_dpa_sliding_window(dtype, model_configs, model): "alibi_2_0": ModelConfig(2, 24, 24, 128, 1024, 1024, 0.0, "causal", "alibi", alibi_type= "custom"), "alibi_2_1": ModelConfig(1, 24, 24, 128, 1024, 2048, 0.0, "causal", "alibi", alibi_type= "custom"), } + + @pytest.mark.skipif(not _is_flash_attention_2_3(), reason="Flash-attn 2.3+ is required.") @pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("model_configs", [model_configs_alibi_slopes]) @@ -436,6 +459,7 @@ def test_dpa_alibi_slopes(dtype, model_configs, model): """Test DotProductAttention module with ALiBi slopes""" test_dot_product_attention(dtype, model_configs, model, False, True, None, False) + qkv_layouts = [ 'sb3hd', 'sbh3d', 'sbhd_sb2hd', 'sbhd_sbh2d', 'sbhd_sbhd_sbhd', 'bs3hd', 'bsh3d', 'bshd_bs2hd', 'bshd_bsh2d', 'bshd_bshd_bshd', @@ -443,6 +467,7 @@ def test_dpa_alibi_slopes(dtype, model_configs, model): #'t3hd', 'th3d', 'thd_t2hd', 'thd_th2d', 'thd_thd_thd', ] + model_configs_layout = { # test: b, h, hg, d, sq, skv, p, mask, bias "layout_0_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), @@ -455,6 +480,7 @@ def test_dpa_alibi_slopes(dtype, model_configs, model): "layout_1_3": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "post_scale_bias"), } + @pytest.mark.skipif(_cudnn_version() < (8,9,5), reason="cuDNN 8.9.5+ is required.") @pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("model_configs", [model_configs_layout]) @@ -464,6 +490,7 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout): """Test DotProductAttention module with different QKV layouts""" test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False) + def _run_dot_product_attention( dtype: torch.dtype, config: ModelConfig, @@ -646,6 +673,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: return out, (inp[0].grad, inp[1].grad, inp[2].grad) + model_configs_te_layer = { # test: b, h, hg, d, sq, skv, p, mask, bias "te_1_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"), @@ -658,6 +686,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: "te_3_1": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal", "alibi"), } + @pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("model_configs", [model_configs_te_layer]) @@ -742,6 +771,7 @@ def test_transformer_layer(dtype, model_configs, model, ckpt_attn, qkv_format, f torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols) torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, **tols) + @pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("model_configs", [model_configs_te_layer]) @@ -755,6 +785,7 @@ def test_te_layer_misc(dtype, model_configs, model, qkv_format): test_transformer_layer(dtype, model_configs, model, ckpt_attn, qkv_format, fused_qkv_params, RoPE) + @pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("model_configs", [model_configs_te_layer]) @@ -780,6 +811,7 @@ def find_factors(x): test_transformer_layer(dtype, model_configs, model, ckpt_attn, qkv_format, fused_qkv_params, RoPE) + def _run_transformer_layer( dtype: torch.dtype, config: ModelConfig, @@ -912,8 +944,10 @@ def _run_transformer_layer( "fp8_1": ModelConfig(1, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), "fp8_2": ModelConfig(4, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), } + param_types_fp8 = [torch.float16] + @pytest.mark.skipif(_cudnn_version() < (8,9,3), reason="cuDNN 8.9.3+ is required.") @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(get_device_compute_capability() != (9, 0), reason="FP8 tests require Hopper.") @@ -946,6 +980,7 @@ def test_dpa_fp8(dtype, model): torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols) torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, **tols) + def _run_dpa_fp8(dtype, config, backend): """Run FusedAttention FP8 backend, i.e. fused_attn_fwd/bwd_qkvpacked from cpp_extensions""" @@ -989,6 +1024,7 @@ def _run_dpa_fp8(dtype, config, backend): dqkv.view(config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim).transpose(0,1).contiguous()) + def _run_dpa_fp8_ref(dtype, config, backend): """Run UnfusedDotProductAttention as a reference, i.e. plain PyTorch implementation in FP16 and inputs/outputs @@ -1188,8 +1224,7 @@ def forward( def backward( ctx, grad_output: torch.Tensor ) -> Tuple[Union[torch.Tensor, None], ...]: - - with _prepare_backward(True, ctx.fp8_meta, None, 1, name="_DPA"): + with torch.cuda.nvtx.range("_DPA"): ( inputmat_t, qkv_weight_t_fp8, @@ -1298,6 +1333,7 @@ def backward( None, None) + class DPA_FP8(TransformerEngineBaseModule): def __init__( self, diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py new file mode 100644 index 0000000000..2b1dcb3aa3 --- /dev/null +++ b/tests/pytorch/test_cuda_graphs.py @@ -0,0 +1,215 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from typing import List, Tuple +import pytest + +import torch +from transformer_engine.pytorch import ( + DotProductAttention, LayerNormLinear, LayerNormMLP, Linear, make_graphed_callables, + MultiheadAttention, TransformerLayer, fp8_autocast, fp8_model_init, +) +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch.utils import is_bf16_compatible + + +# Only run FP8 tests on H100. +fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() + + +seed = 1234 +torch.manual_seed(seed) +torch.cuda.manual_seed(seed) +# Record initial RNG state from script run. +_cpu_rng_state = torch.get_rng_state() +_cuda_rng_state = torch.cuda.get_rng_state() + + +class ModelConfig: + def __init__(self, hidden_size, nheads, kv, seq_len): + self.h = hidden_size + self.nheads = nheads + self.kv = kv + self.s = seq_len + +model_configs = { + "small": ModelConfig(64, 2, 32, 32), +} + +modules = ["transformer", "layernorm_mlp", "layernorm_linear", "linear", "mha", "dpa"] + +optimizers = [torch.optim.SGD, torch.optim.Adam] + +all_boolean = [True, False] + +dtypes = [torch.float32, torch.float16] +if is_bf16_compatible(): # bf16 requires sm_80 or higher + dtypes.append(torch.bfloat16) + + +def reset_rng_states() -> None: + """revert back to initial RNG state.""" + torch.set_rng_state(_cpu_rng_state) + torch.cuda.set_rng_state(_cuda_rng_state) + + +@pytest.fixture(autouse=True) +def reset_global_fp8_state(): + yield + FP8GlobalStateManager.reset() + + +def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None) -> bool: + """Ensures two lists are equal.""" + assert len(l1) == len(l2), "Unequal number of outputs." + failed = False + failed_tensors = "" + for i, (t1, t2) in enumerate(zip(l1, l2)): + with torch.no_grad(): + t1.masked_fill_(t1.isnan(), 1.0) + t2.masked_fill_(t2.isnan(), 1.0) + if not torch.equal(t1, t2): + failed = True + failed_tensors += f" {names[i]}\n" if names is not None else f" tensor at idx={i}\n" + assert not failed, "Output mismatches in:\n" + failed_tensors + + +def generate_data( + s: int, b: int, h: int, nheads: int, kv: int, dtype: torch.dtype, + dpa: bool = False, warmup: bool = False, gen_labels: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Generate synthetic data.""" + gen_func = torch.ones if warmup else torch.randn + if dpa: + inputs = [gen_func(s, b, nheads, kv, device="cuda", requires_grad=True, dtype=dtype) for _ in range(3)] + else: + inputs = [gen_func(s, b, h, device="cuda", requires_grad=True, dtype=dtype)] + + if not gen_labels: + return inputs + + target = torch.randn(s, b, h, device="cuda", dtype=dtype) + return inputs, target + + +def get_outputs(model, output): + """Return grads and params for comparsion.""" + values = [] + for param in model.parameters(): + values.append(param) + if param.grad is not None: + values.append(param.grad) + values.append(output) + return values + + +def _test_cuda_graphs(config, bs, num_layers, dtype, fp8, fp8_params, graph, module, optimizer, graph_mode=""): + """Helper function for test.""" + reset_rng_states() + FP8GlobalStateManager.reset() + dpa = module == "dpa" + + with fp8_model_init(enabled=fp8_params): + # Create modules. + if module == "transformer": + modules = [TransformerLayer( + config.h, + config.h, + config.nheads, + hidden_dropout=0.0, + attention_dropout=0.0, + fuse_qkv_params=True, + params_dtype=dtype, + ) for _ in range(num_layers)] + elif module == "layernorm_mlp": + modules = [LayerNormMLP( + config.h, config.h, params_dtype=dtype + ) for _ in range(num_layers)] + elif module == "layernorm_linear": + modules = [LayerNormLinear( + config.h, config.h, params_dtype=dtype + ) for _ in range(num_layers)] + elif module == "mha": + modules = [MultiheadAttention( + config.h, + config.nheads, + attention_dropout=0.0, + params_dtype=dtype, + fuse_qkv_params=True, + ) for _ in range(num_layers)] + elif dpa: + assert config.h % config.nheads == 0, "Err." + assert num_layers == 1, "Err." + modules = [DotProductAttention( + config.nheads, config.kv, attention_dropout=0.0 + ) for _ in range(num_layers)] + else: + modules = [Linear( + config.h, config.h, device="cuda", params_dtype=dtype + ) for _ in range(num_layers)] + + # Generate model and wrap API to return graphed version. + if graph: + # Graph entire module at once. + if graph_mode == "full": + model = modules[0] if dpa else torch.nn.Sequential(*modules) + model = make_graphed_callables( + model, + generate_data(config.s, bs, config.h, config.nheads, config.kv, dtype, dpa=dpa, warmup=True), + num_warmup_iters=10, + fp8_enabled=fp8) + else: + modules = [make_graphed_callables( + module, + generate_data(config.s, bs, config.h, config.nheads, config.kv, dtype, dpa=dpa, warmup=True), + num_warmup_iters=10, + fp8_enabled=fp8) for module in modules] + model = modules[0] if dpa else torch.nn.Sequential(*modules) + else: + model = modules[0] if dpa else torch.nn.Sequential(*modules) + + # Loss function and optimizer. + loss_fn = torch.nn.MSELoss() + if not dpa: + optimizer = optimizer(model.parameters(), lr=0.001) + + # Launch. + for _ in range(10): + inputs, target = generate_data(config.s, bs, config.h, config.nheads, config.kv, dtype, dpa=dpa, gen_labels=True) + with fp8_autocast(enabled=fp8): + output = model(*inputs) + loss = loss_fn(output, target) + loss.backward() + if not dpa: + optimizer.step() + optimizer.zero_grad() + + return get_outputs(model, output) + + +@pytest.mark.parametrize("dtype", dtypes) +@pytest.mark.parametrize("bs", [1, 2]) +@pytest.mark.parametrize("model", model_configs.keys()) +@pytest.mark.parametrize("num_layers", [1, 10]) +@pytest.mark.parametrize("fp8", all_boolean) +@pytest.mark.parametrize("fp8_params", all_boolean) +@pytest.mark.parametrize("module", modules) +@pytest.mark.parametrize("optimizer", optimizers) +def test_gpt_make_graphed_callables(dtype, bs, model, num_layers, fp8, fp8_params, module, optimizer): + if fp8 and not fp8_available: + pytest.skip(reason_for_no_fp8) + if fp8_params and not fp8: + pytest.skip("FP8 needed for FP8 parameters.") + if module == "dpa" and num_layers > 1: + pytest.skip("Max 1 layer for DPA.") + + config = model_configs[model] + + outputs = _test_cuda_graphs(config, bs, num_layers, dtype, fp8, fp8_params, False, module, optimizer) + graph_outputs_mode1 = _test_cuda_graphs(config, bs, num_layers, dtype, fp8, fp8_params, True, module, optimizer, graph_mode="full") + graph_outputs_mode2 = _test_cuda_graphs(config, bs, num_layers, dtype, fp8, fp8_params, True, module, optimizer, graph_mode="individual") + + # Check that results match + assert_all_equal(outputs, graph_outputs_mode1) + assert_all_equal(outputs, graph_outputs_mode2) diff --git a/tests/pytorch/test_float8tensor.py b/tests/pytorch/test_float8tensor.py index 935519ca84..c4c39f9309 100644 --- a/tests/pytorch/test_float8tensor.py +++ b/tests/pytorch/test_float8tensor.py @@ -257,12 +257,10 @@ def test_inplace_ops( with pytest.raises(AssertionError): torch.testing.assert_close(x_fp8, x_ref, **tols) - @pytest.mark.parametrize("dims", [[33, 41], [5, 7, 11]]) - @pytest.mark.parametrize("transpose_dims", [(0, 1), (-2, -1), (0, 0)]) + @pytest.mark.parametrize("dims", [[33, 41], [7, 11]]) def test_transpose( self, dims: DimsType, - transpose_dims: Tuple[int, int], fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, scale: float = 0.5, dtype: torch.dtype = torch.float32, @@ -271,74 +269,44 @@ def test_transpose( # Initialize random data dims = _to_list(dims) - x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 + x = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 x_fp8 = Float8Tensor.to_float8( - x_ref, + x, fp8_dtype=fp8_dtype, scale=torch.full([1], scale), ) - x_ref = x_fp8.from_float8() + x = x_fp8.from_float8() # Perform transpose - y_fp8 = x_fp8.transpose(*transpose_dims) - y_ref = x_ref.transpose(*transpose_dims) + x_fp8_t = x_fp8.transpose_2d() + x_t = x.transpose(0, 1) + x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8_t) # Check results tols = dict(rtol=0, atol=0) - torch.testing.assert_close(y_fp8, y_ref, **tols) + torch.testing.assert_close(x_fp8_t, x_t, **tols) # Make sure we are not trivially passing the test - if transpose_dims[0] != transpose_dims[1]: - with pytest.raises(AssertionError): - torch.testing.assert_close( - y_fp8, - x_ref, - **tols, - ) - - # Check transpose caching - if x_fp8.dim() == 2 and transpose_dims[0] != transpose_dims[1]: - - # Check that cached transpose is returned when expected - # Note: Sneakily destroy data so that recalculating - # transpose would give wrong answer. - x_fp8 += 0.5 - x_ref = x_fp8.from_float8() - torch.testing.assert_close( - x_fp8.transpose(*transpose_dims, update_cache="lazy"), - x_ref.transpose(*transpose_dims), - **tols, - ) - x_fp8_data = x_fp8._data.clone() - x_fp8._data.zero_() - torch.testing.assert_close( - x_fp8.transpose(*transpose_dims), - x_ref.transpose(*transpose_dims), - **tols, - ) - torch.testing.assert_close( - x_fp8.transpose(*transpose_dims, update_cache="lazy"), - x_ref.transpose(*transpose_dims), - **tols, - ) - torch.testing.assert_close( - x_fp8.transpose(*transpose_dims, update_cache="force"), - torch.zeros_like(x_ref.transpose(*transpose_dims)), - rtol=0, - atol=0, - ) - x_fp8._data.copy_(x_fp8_data) - x_fp8._reset_caches() - - # Make sure cache is reset after in-place operation - x_fp8.transpose(*transpose_dims, update_cache="force") - x_fp8 += 0.5 - x_ref = x_fp8.from_float8() - torch.testing.assert_close( - x_fp8.transpose(*transpose_dims), - x_ref.transpose(*transpose_dims), - **tols, - ) + with pytest.raises(AssertionError): + torch.testing.assert_close(x_fp8_t, x, **tols) + + # Caching test. + assert x_fp8._transpose_invalid, "Transpose cache must be invalid when not caching." + x_fp8 += 0.5 + x = x_fp8.from_float8() + x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8.transpose_2d(cache=True)) + x_t = x.transpose(0, 1) + torch.testing.assert_close(x_fp8_t, x_t, **tols) + assert not x_fp8._transpose_invalid, "Transpose cache reset incorrectly." + + # Inplace update test. + x_fp8 += 0.5 + assert x_fp8._transpose_invalid, "Transpose cache not invalidated properly." + x = x_fp8.from_float8() + x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8.transpose_2d(cache=True)) + x_t = x.transpose(0, 1) + torch.testing.assert_close(x_fp8_t, x_t, **tols) + assert not x_fp8._transpose_invalid, "Transpose cache reset incorrectly." def test_serialization( self, diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index c2eb2c01a5..ddb3ecf49f 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -4,7 +4,6 @@ import math import os -import sys from typing import List, Optional import pytest import copy @@ -25,7 +24,6 @@ MultiheadAttention, RMSNorm, TransformerLayer, LayerNorm, InferenceParams ) from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint -from transformer_engine.pytorch.distributed import _set_cuda_rng_state, CudaRNGStatesTracker # Only run FP8 tests on H100. @@ -54,6 +52,14 @@ def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq "126m": ModelConfig(768, 1e-5, 12, 64, 12, 2048), } +model_configs_inference = { + # hidden_size, eps, num_attention_heads, embed, num_layers, seq_len + "126m": ModelConfig(768, 1e-5, 12, 64, 12, 16), +} +backends_inference = ["FlashAttention", "UnfusedAttention"] +module_inference = ["TransformerLayer", "MultiheadAttention"] +input_formats_inference = ["sbhd", "bshd"] + param_types = [torch.float32, torch.float16] if is_bf16_compatible(): # bf16 requires sm_80 or higher param_types.append(torch.bfloat16) @@ -104,7 +110,13 @@ def assert_allclose(l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float) def reset_rng_states() -> None: """revert back to initial RNG state.""" torch.set_rng_state(_cpu_rng_state) - _set_cuda_rng_state(_cuda_rng_state) + torch.cuda.set_rng_state(_cuda_rng_state) + + +@pytest.fixture(autouse=True) +def reset_global_fp8_state(): + yield + FP8GlobalStateManager.reset() class TorchScaledMaskedSoftmax(nn.Module): @@ -373,10 +385,10 @@ def __init__(self, hidden_size: int, eps: float, num_attention_heads: int, paral def forward( self, x: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: a = self.ln(x) - b = self.causal_attn(a, attn_mask) + b = self.causal_attn(a, attention_mask) if self.parallel_attention_mlp: n = self.ln_mlp(x) x = x + nn.functional.dropout(b + n, p=0.1, training=self.training) @@ -396,13 +408,6 @@ def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() - _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) - - def get_dummy_cuda_rng_tracker(): - """Get cuda rng tracker.""" - return _DUMMY_CUDA_RNG_STATE_TRACKER - with fp8_model_init(enabled=fp8 and fp8_model_params): block = ( TransformerLayer( @@ -417,7 +422,6 @@ def get_dummy_cuda_rng_tracker(): kv_channels=config.embed, apply_residual_connection_post_layernorm=False, output_layernorm=False, - get_rng_state_tracker=get_dummy_cuda_rng_tracker, params_dtype=dtype, fuse_qkv_params=True, ) @@ -476,13 +480,6 @@ def _test_e2e_full_recompute( init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() - _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) - - def get_dummy_cuda_rng_tracker(): - """Get cuda rng tracker.""" - return _DUMMY_CUDA_RNG_STATE_TRACKER - with fp8_model_init(enabled=fp8 and fp8_model_params): block = ( TransformerLayer( @@ -497,7 +494,6 @@ def get_dummy_cuda_rng_tracker(): kv_channels=config.embed, apply_residual_connection_post_layernorm=False, output_layernorm=False, - get_rng_state_tracker=get_dummy_cuda_rng_tracker, params_dtype=dtype, fuse_qkv_params=True, ) @@ -520,7 +516,6 @@ def get_dummy_cuda_rng_tracker(): checkpoint_core_attention=False, distribute_saved_activations=False, tp_group=None, - get_rng_state_tracker=get_dummy_cuda_rng_tracker, use_reentrant=use_reentrant, ) else: @@ -683,7 +678,7 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config): inp_hidden_states.retain_grad() inp_attn_mask = get_causal_attn_mask(config.seq_len) - out = block(inp_hidden_states, inp_attn_mask) + out = block(inp_hidden_states, attention_mask=inp_attn_mask) loss = out.sum() loss.backward() @@ -1261,13 +1256,6 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params): init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() - _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) - - def get_dummy_cuda_rng_tracker(): - """Get cuda rng tracker.""" - return _DUMMY_CUDA_RNG_STATE_TRACKER - with fp8_model_init(enabled=fp8_model_params): block = ( TransformerLayer( @@ -1282,7 +1270,6 @@ def get_dummy_cuda_rng_tracker(): kv_channels=config.embed, apply_residual_connection_post_layernorm=False, output_layernorm=False, - get_rng_state_tracker=get_dummy_cuda_rng_tracker, params_dtype=dtype, fuse_qkv_params=True, ) @@ -1321,6 +1308,7 @@ def test_gpt_fp8_parameters(dtype, bs, model): outputs_fp8_params = _test_gpt_fp8_parameters(bs, dtype, config, True) assert_all_equal(outputs, outputs_fp8_params) + @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", model_configs.keys()) @@ -1399,14 +1387,6 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): assert_all_equal([y_bshd], [y_sbhd.transpose(0,1).contiguous()]) -model_configs_inference = { - # hidden_size, eps, num_attention_heads, embed, num_layers, seq_len - "126m": ModelConfig(768, 1e-5, 12, 64, 12, 16), -} -backends_inference = ["FlashAttention", "UnfusedAttention"] -module_inference = ["TransformerLayer", "MultiheadAttention"] -input_formats_inference = ["sbhd", "bshd"] - @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model_key", model_configs_inference.keys()) diff --git a/tests/pytorch/test_onnx_export.py b/tests/pytorch/test_onnx_export.py index 822b1450ec..7707264c7f 100644 --- a/tests/pytorch/test_onnx_export.py +++ b/tests/pytorch/test_onnx_export.py @@ -86,6 +86,12 @@ def set_max_seq_len(max_seq_len=128): os.environ["NVTE_ONNX_KVCACHE_MAX_SEQ_LEN"] = f"{max_seq_len}" +@pytest.fixture(autouse=True) +def reset_global_fp8_state(): + yield + FP8GlobalStateManager.reset() + + def create_fp8_recipe(): return recipe.DelayedScaling(margin=0, interval=1, fp8_format=recipe.Format.E4M3) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 217eacc9b3..e91e464fa4 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -48,6 +48,7 @@ def custom_amax_compute(amax_history: torch.Tensor) -> torch.Tensor: """Custom func to test recipe.""" return torch.min(amax_history, dim=0).values + @dataclass class ModelConfig: """Transformer model configuration""" @@ -115,6 +116,12 @@ def _disable_wgrads(block): p.requires_grad = False +@pytest.fixture(autouse=True) +def reset_global_fp8_state(): + yield + FP8GlobalStateManager.reset() + + def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad): # Initialize loss function and optimizer. loss_fn = torch.nn.MSELoss() @@ -137,7 +144,7 @@ def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad): with torch.cuda.stream(s): for _ in range(3): optimizer.zero_grad(set_to_none=True) - with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): + with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, _graph=True): out = block(static_input) loss = loss_fn(out, static_target) loss.backward() @@ -148,7 +155,7 @@ def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad): g = torch.cuda.CUDAGraph() optimizer.zero_grad(set_to_none=True) with torch.cuda.graph(g): - with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): + with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, _graph=True): static_output = block(static_input) static_loss = loss_fn(static_output, static_target) static_loss.backward() diff --git a/transformer_engine/common/include/transformer_engine/cast_transpose_noop.h b/transformer_engine/common/include/transformer_engine/cast_transpose_noop.h new file mode 100644 index 0000000000..f9097679a6 --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/cast_transpose_noop.h @@ -0,0 +1,35 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file transpose_with_noop.h + * \brief Functions handling transposes with no-op. + */ + +#ifndef TRANSFORMER_ENGINE_CAST_TRANSPOSE_WITH_NOOP_H_ +#define TRANSFORMER_ENGINE_CAST_TRANSPOSE_WITH_NOOP_H_ + +#include "transformer_engine.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void nvte_transpose_with_noop(const NVTETensor input, + const NVTETensor noop, + NVTETensor output, + cudaStream_t stream); + +void nvte_cast_transpose_with_noop(const NVTETensor input, + const NVTETensor noop, + NVTETensor cast_output, + NVTETensor transposed_output, + cudaStream_t stream); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TRANSFORMER_ENGINE_CAST_TRANSPOSE_WITH_NOOP_H_ diff --git a/transformer_engine/common/include/transformer_engine/recipe.h b/transformer_engine/common/include/transformer_engine/recipe.h index ddb64be5e7..49cc9af914 100644 --- a/transformer_engine/common/include/transformer_engine/recipe.h +++ b/transformer_engine/common/include/transformer_engine/recipe.h @@ -56,6 +56,45 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update(const NVTETensor amax_his float margin, cudaStream_t stream); + +/*! \brief Bulk-update FP8 scaling factors with delayed scaling recipe after amax reduction. + * + * Operations performed include, updating the most recent amax history + * with the relevant segment of global reduction buffer if it's not 0, + * rotating the amax history based on the rule below, and updating the + * scales and scale_invs. + * + * The amax history is rotated by -1 (e.g. the first entry shifts to + * the last, the last entry shifts to the second to last) and the + * first entry is set to zero. The scaling factor is estimated so the + * FP8 tensor's maximum absolute value is + * @f$ 2^{-\text{margin}} \text{max}_\text{fp8\_dtype} @f$. + * + * \param[in] amax_reduction_buffer The contiguous buffer used for amax reduction. + * Shape: [num_scales * num_tensors] + * \param[in,out] amax_histories List of amax histories of maximum absolute values. + * Shape: num_tensors x [history_length, num_scales] + * \param[in,out] scales List of scaling factors for casting to FP8. + * Shape: num_tensors x [num_scales] + * \param[in,out] scale_invs List of scaling factors for casting from FP8. + * Shape: num_tensors x [num_scales] + * \param[in] amax_compute_algo Method to reduce amax history. Options are "max" and + * "most_recent". + * \param[in] fp8_dtype FP8 datatype. + * \param[in] margin Scaling factor margin. + * \param[in] stream CUDA stream. + */ +void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( + const NVTETensor amax_reduction_buffer, + std::vector amax_histories, + std::vector scales, + std::vector scale_invs, + const char *amax_compute_algo, + NVTEDType fp8_dtype, + float margin, + cudaStream_t stream); + + #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/common/layer_norm/ln_api.cpp b/transformer_engine/common/layer_norm/ln_api.cpp index f5eb1896c4..7a01cf0345 100644 --- a/transformer_engine/common/layer_norm/ln_api.cpp +++ b/transformer_engine/common/layer_norm/ln_api.cpp @@ -229,19 +229,29 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size // Query the kernel-specific launch parameters. launcher(launch_params, true); + if (launch_params.workspace_bytes == 0) { + launch_params.workspace_bytes = 1; + } + if (workspace->data.dptr == nullptr) { NVTE_CHECK(barrier->data.dptr == nullptr); workspace->data.dtype = layer_norm::DType::kByte; - if (launch_params.workspace_bytes == 0) { - launch_params.workspace_bytes = 1; - } workspace->data.shape = { launch_params.workspace_bytes }; barrier->data.dtype = layer_norm::DType::kInt32; barrier->data.shape = { launch_params.barrier_size }; return; + } else { + NVTE_CHECK(workspace->data.dtype == layer_norm::DType::kByte); + NVTE_CHECK(workspace->data.shape == std::vector{ launch_params.workspace_bytes }); + } + + if (launch_params.barrier_size > 0) { + NVTE_CHECK(barrier->data.dptr != nullptr); + NVTE_CHECK(barrier->data.dtype == layer_norm::DType::kInt32); + NVTE_CHECK(barrier->data.shape == std::vector{ launch_params.barrier_size }); } // Tensor checks are delayed here in order to recover workspace sizes with null data @@ -368,6 +378,27 @@ void layernorm_bwd(const Tensor& dz, barrier->data.shape = { launch_params.barrier_size }; return; + } else { + NVTE_CHECK(dbeta_part->data.dptr != nullptr); + auto pdw_shape = std::vector{ + static_cast(launch_params.params.ctas_per_col), hidden_size}; + + NVTE_CHECK(dgamma_part->data.dtype == ctype); + NVTE_CHECK(dgamma_part->data.shape == pdw_shape); + NVTE_CHECK(dbeta_part->data.dtype == ctype); + NVTE_CHECK(dbeta_part->data.shape == pdw_shape); + } + + if (launch_params.barrier_size > 0) { + NVTE_CHECK(barrier->data.dptr != nullptr); + NVTE_CHECK(barrier->data.dtype == layer_norm::DType::kInt32); + NVTE_CHECK(barrier->data.shape == std::vector{ launch_params.barrier_size }); + } + + if (launch_params.workspace_bytes > 0) { + NVTE_CHECK(workspace->data.dptr != nullptr); + NVTE_CHECK(workspace->data.dtype == layer_norm::DType::kByte); + NVTE_CHECK(workspace->data.shape == std::vector{ launch_params.workspace_bytes }); } // Tensor checks are delayed here in order to recover workspace sizes with null data diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 55a706492f..9abbb69cbe 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -133,3 +133,13 @@ def __post_init__(self) -> None: (False, False, False), (False, False, True), ), "Only wgrad GEMM override is currently supported." + + def __repr__(self) -> str: + return ( + f"margin={self.margin}, " + f"interval={self.interval}, " + f"format={str(self.fp8_format).split('.')[1]}, " + f"amax_history_len={self.amax_history_len}, " + f"wgrad_override={self.override_linear_precision.wgrad}, " + f"reduce_amax={self.reduce_amax}" + ) diff --git a/transformer_engine/common/recipe/delayed_scaling.cu b/transformer_engine/common/recipe/delayed_scaling.cu index 3fa64920df..6e07b1ce9f 100644 --- a/transformer_engine/common/recipe/delayed_scaling.cu +++ b/transformer_engine/common/recipe/delayed_scaling.cu @@ -11,6 +11,7 @@ #include "../common.h" #include "../util/logging.h" +#include "../util/cuda_runtime.h" namespace transformer_engine { namespace delayed_scaling_recipe { @@ -38,6 +39,36 @@ inline float fp8_dtype_max(DType dtype) { return 0; } +// struct for amax parameters +struct AmaxParam { + int num_scale = 0; + float* amax_history = nullptr; + float* scale = nullptr; + float* scale_inv = nullptr; +}; + +// dummy struct for kernel_bulk's other params +struct OtherParams { + float* a; + size_t b; + AmaxComputeAlgo c; + float d; +}; + +#if CUDART_VERSION >= 12010 +constexpr size_t max_constant_memory_per_kernel = 32000; +constexpr size_t AMAX_PARAMS_LIMIT = ( + max_constant_memory_per_kernel - sizeof(OtherParams)) / sizeof(AmaxParam); +#else +constexpr size_t max_constant_memory_per_kernel = 4000; +constexpr size_t AMAX_PARAMS_LIMIT = ( + max_constant_memory_per_kernel - sizeof(OtherParams)) / sizeof(AmaxParam); +#endif + +struct AmaxParams { + AmaxParam param[AMAX_PARAMS_LIMIT]; +}; + namespace amax_and_scale_update_impl { // CUDA block size @@ -133,11 +164,96 @@ kernel(const float* amax_history_ptr, } } -} // namespace amax_and_scale_update_impl +/* CUDA kernel to bulk-update amax history and FP8 scaling factors + * + * Block dims: bsize x 1 x 1 + * + * Grid dims: num_tensors x 1 x 1 + */ +__global__ void __launch_bounds__(bsize) +kernel_bulk( + float* amax_reduction_buffer, + AmaxParams p, + size_t amax_history_length, + AmaxComputeAlgo amax_compute_algo, + float scaled_max) { + const size_t bid = blockIdx.x; + const size_t tid = threadIdx.x; + const int num_scale = p.param[bid].num_scale; + + int offset_in_buffer = 0; + for (int j = 0; j < bid; j++) { + offset_in_buffer += p.param[j].num_scale; + } + for (int count = 0; count < num_scale; count++) { + // Update amax + float amax = 0; + { + // Roll amax history + const auto& length = amax_history_length; + const auto& stride = p.param[bid].num_scale; + auto* amax_history = p.param[bid].amax_history+count; + const auto last_amax = ((amax_reduction_buffer != nullptr) + && (amax_reduction_buffer[offset_in_buffer+count] != 0.0f)) ? + amax_reduction_buffer[offset_in_buffer+count] : amax_history[0]; + for (size_t off = 0; off < length; off += bsize) { + const size_t i = off + tid; + float a = 0; + if (i < length) { + a = (i < length - 1) ? amax_history[(i+1)*stride] : last_amax; + amax = fmaxf(amax, a); + } + __syncthreads(); // Inplace roll + if (i < length) { + amax_history[i*stride] = (i > 0) ? a : 0; + } + } + + // Compute amax to use for scaling factor + switch (amax_compute_algo) { + case AmaxComputeAlgo::MOST_RECENT: + amax = last_amax; + break; + case AmaxComputeAlgo::MAX: + { + __shared__ float shared_amax[bsize]; + shared_amax[tid] = amax; + __syncthreads(); +#pragma unroll + for (size_t off = bsize / 2; off > 0; off /= 2) { + if (tid < off) { + shared_amax[tid] = fmaxf(shared_amax[tid], shared_amax[tid + off]); + } + __syncthreads(); + } + amax = shared_amax[tid]; + } + break; + default: + amax = 0; + } + } + + // Update scale and scale inverse + if (tid == 0) { + float scale; + if (isfinite(amax) && amax > 0) { + scale = scaled_max / amax; + } else { + scale = p.param[bid].scale[count]; + } + p.param[bid].scale[count] = scale; + p.param[bid].scale_inv[count] = 1 / scale; + } + } +} + +} // namespace amax_and_scale_update_impl } // namespace + void amax_and_scale_update(const Tensor &amax_history, const Tensor &scale, const Tensor &scale_inv, @@ -238,9 +354,105 @@ void amax_and_scale_update(const Tensor &amax_history, NVTE_CHECK_CUDA(cudaGetLastError()); } + +void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer, + std::vector amax_histories, + std::vector scales, + std::vector scale_invs, + const std::string &amax_compute_algo, + DType fp8_dtype, + float margin, + cudaStream_t stream) { + using namespace transformer_engine; + + // amax value to use for updating scaling factor + AmaxComputeAlgo amax_compute_algo_ = AmaxComputeAlgo::INVALID; + if (amax_compute_algo == "max") { + amax_compute_algo_ = AmaxComputeAlgo::MAX; + } else if (amax_compute_algo == "most_recent") { + amax_compute_algo_ = AmaxComputeAlgo::MOST_RECENT; + } else { + NVTE_ERROR("Unsupported amax compute algorithm (", amax_compute_algo, ")"); + } + + // Expected maximum value after scale is applied + const float scaled_max = fp8_dtype_max(fp8_dtype) * std::pow(2.f, -margin); + + // Number of elements in tensor + auto numel = [] (const Tensor *tensor) -> size_t { + size_t acc = 1; + for (const auto& dim : tensor->data.shape) { + acc *= dim; + } + return acc; + }; + + // Number of tensors in the bulk + const size_t num_tensors = amax_histories.size(); + const int num_kernels = (num_tensors+AMAX_PARAMS_LIMIT-1)/AMAX_PARAMS_LIMIT; + size_t amax_history_length = 0; + if (num_tensors > 0) { + amax_history_length = amax_histories[0]->data.shape[0]; + } + + // amax parameters + float* amax_buffer = static_cast(amax_reduction_buffer.data.dptr); + AmaxParams p; + for (int iter = 0; iter < num_kernels; iter++) { + size_t kernel_num_scales = 0; + size_t kernel_num_tensors = (iter == (num_kernels -1)) + ? num_tensors % AMAX_PARAMS_LIMIT: AMAX_PARAMS_LIMIT; + for (size_t pi = 0; pi < kernel_num_tensors; pi++) { + size_t i = iter * AMAX_PARAMS_LIMIT + pi; + + // Check tensors + int num_scale = amax_histories[i]->data.shape[1]; + NVTE_CHECK(amax_histories[i]->data.dtype == DType::kFloat32, + "Found ", dtype_name(amax_histories[i]->data.dtype), "."); + NVTE_CHECK(amax_histories[i]->data.shape.size() == 2, + "Found ", amax_histories[i]->data.shape.size(), " dims"); + NVTE_CHECK(numel(amax_histories[i]) == amax_history_length * num_scale, + "Expected ", amax_history_length * num_scale, " elements, ", + "but found ", numel(amax_histories[i]), "."); + NVTE_CHECK(scales[i]->data.dtype == DType::kFloat32, + "Found ", dtype_name(scales[i]->data.dtype), "."); + NVTE_CHECK(scales[i]->data.shape.size() == 1, + "Found ", scales[i]->data.shape.size(), " dims"); + NVTE_CHECK(numel(scales[i]) == num_scale, + "Expected ", num_scale, " elements, ", + "Found ", numel(scales[i]), "."); + + // amax parameters + kernel_num_scales += num_scale; + p.param[pi].num_scale = num_scale; + p.param[pi].amax_history = static_cast(amax_histories[i]->data.dptr); + p.param[pi].scale = static_cast(scales[i]->data.dptr); + p.param[pi].scale_inv = static_cast(scale_invs[i]->data.dptr); + } + + // Launch CUDA kernel + size_t grid_size = kernel_num_tensors; + const size_t block_size = amax_and_scale_update_impl::bsize; + amax_and_scale_update_impl::kernel_bulk + <<>>( + amax_buffer, + p, + amax_history_length, + amax_compute_algo_, + scaled_max); + NVTE_CHECK_CUDA(cudaGetLastError()); + + // shift amax buffer pointer + if (amax_buffer != nullptr) { + amax_buffer += kernel_num_scales; + } + } +} + } // namespace delayed_scaling_recipe } // namespace transformer_engine + void nvte_delayed_scaling_recipe_amax_and_scale_update(const NVTETensor amax_history, const NVTETensor scale, const NVTETensor scale_inv, @@ -267,3 +479,33 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update(const NVTETensor amax_his margin, stream); } + + +void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( + const NVTETensor amax_reduction_buffer, + std::vector amax_histories, + std::vector scales, + std::vector scale_invs, + const char *amax_compute_algo, + NVTEDType fp8_dtype, + float margin, + cudaStream_t stream) { + NVTE_API_CALL(nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction); + using namespace transformer_engine; + size_t num_tensors = amax_histories.size(); + std::vector t_amax_histories, t_scales, t_scale_invs; + for (size_t i = 0; i < num_tensors; i++) { + t_amax_histories.push_back(reinterpret_cast(amax_histories[i])); + t_scales.push_back(reinterpret_cast(scales[i])); + t_scale_invs.push_back(reinterpret_cast(scale_invs[i])); + } + delayed_scaling_recipe::amax_and_scale_update_after_reduction( + *reinterpret_cast(amax_reduction_buffer), + t_amax_histories, + t_scales, + t_scale_invs, + amax_compute_algo, + static_cast(fp8_dtype), + margin, + stream); +} diff --git a/transformer_engine/common/rmsnorm/rmsnorm_api.cpp b/transformer_engine/common/rmsnorm/rmsnorm_api.cpp index 86ffc64c25..5ccfae1922 100644 --- a/transformer_engine/common/rmsnorm/rmsnorm_api.cpp +++ b/transformer_engine/common/rmsnorm/rmsnorm_api.cpp @@ -153,21 +153,32 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens // Query the kernel-specific launch parameters. launcher(launch_params, true); + if (launch_params.workspace_bytes == 0) { + launch_params.workspace_bytes = 1; + } + if (workspace->data.dptr == nullptr) { NVTE_CHECK(barrier->data.dptr == nullptr); workspace->data.dtype = DType::kByte; - if (launch_params.workspace_bytes == 0) { - launch_params.workspace_bytes = 1; - } workspace->data.shape = {launch_params.workspace_bytes}; barrier->data.dtype = DType::kInt32; barrier->data.shape = {launch_params.barrier_size}; return; + } else { + NVTE_CHECK(workspace->data.dtype == DType::kByte); + NVTE_CHECK(workspace->data.shape == std::vector{ launch_params.workspace_bytes }); } + if (launch_params.barrier_size > 0) { + NVTE_CHECK(barrier->data.dptr != nullptr); + NVTE_CHECK(barrier->data.dtype == DType::kInt32); + NVTE_CHECK(barrier->data.shape == std::vector{ launch_params.barrier_size }); + } + + // Tensor checks are delayed here in order to recover workspace sizes with null data CheckInputTensor(x, "x"); CheckInputTensor(gamma, "gamma"); @@ -265,6 +276,23 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const barrier->data.shape = {launch_params.barrier_size}; return; + } else { + auto pdw_shape = std::vector{ + static_cast(launch_params.params.ctas_per_col), hidden_size}; + NVTE_CHECK(dgamma_part->data.dtype == ctype); + NVTE_CHECK(dgamma_part->data.shape == pdw_shape); + } + + if (launch_params.barrier_size > 0) { + NVTE_CHECK(barrier->data.dptr != nullptr); + NVTE_CHECK(barrier->data.dtype == DType::kInt32); + NVTE_CHECK(barrier->data.shape == std::vector{ launch_params.barrier_size }); + } + + if (launch_params.workspace_bytes > 0) { + NVTE_CHECK(workspace->data.dptr != nullptr); + NVTE_CHECK(workspace->data.dtype == DType::kByte); + NVTE_CHECK(workspace->data.shape == std::vector{ launch_params.workspace_bytes }); } // Tensor checks are delayed here in order to recover workspace sizes with null data diff --git a/transformer_engine/common/transpose/cast_transpose.cu b/transformer_engine/common/transpose/cast_transpose.cu index 9f1a18de7a..347aeb9b15 100644 --- a/transformer_engine/common/transpose/cast_transpose.cu +++ b/transformer_engine/common/transpose/cast_transpose.cu @@ -4,6 +4,7 @@ * See LICENSE for license information. ************************************************************************/ +#include #include #include #include @@ -56,6 +57,7 @@ template ; using OVec = Vec; @@ -163,6 +167,7 @@ template ; using OVec = Vec; @@ -294,6 +301,7 @@ cast_transpose_kernel_notaligned(const IType * const input, } void cast_transpose(const Tensor &input, + const Tensor &noop, Tensor *cast_output, Tensor *transposed_output, cudaStream_t stream) { @@ -301,6 +309,22 @@ void cast_transpose(const Tensor &input, CheckOutputTensor(*cast_output, "cast_output"); CheckOutputTensor(*transposed_output, "transposed_output"); + // Number of elements in tensor + auto numel = [] (const Tensor &tensor) -> size_t { + size_t acc = 1; + for (const auto& dim : tensor.data.shape) { + acc *= dim; + } + return acc; + }; + + if (noop.data.dptr != nullptr) { + NVTE_CHECK(numel(noop) == 1, + "Expected 1 element, ", + "but found ", numel(noop), "."); + NVTE_CHECK(noop.data.dtype == DType::kFloat32); + NVTE_CHECK(noop.data.dptr != nullptr); + } NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions."); NVTE_CHECK(transposed_output->data.shape.size() == 2, "T output must have 2 dimensions."); @@ -332,6 +356,7 @@ void cast_transpose(const Tensor &input, (THREADS_PER_WARP + 1) * sizeof(Vec), \ stream>>>( \ reinterpret_cast(input.data.dptr), \ + reinterpret_cast(noop.data.dptr), \ reinterpret_cast(cast_output->data.dptr), \ reinterpret_cast(transposed_output->data.dptr), \ reinterpret_cast(cast_output->scale.dptr), \ @@ -417,7 +442,23 @@ void nvte_cast_transpose(const NVTETensor input, cudaStream_t stream) { NVTE_API_CALL(nvte_cast_transpose); using namespace transformer_engine; + auto noop = Tensor(); + cast_transpose(*reinterpret_cast(input), + noop, + reinterpret_cast(cast_output), + reinterpret_cast(transposed_output), + stream); +} + +void nvte_cast_transpose_with_noop(const NVTETensor input, + const NVTETensor noop, + NVTETensor cast_output, + NVTETensor transposed_output, + cudaStream_t stream) { + NVTE_API_CALL(nvte_cast_transpose_with_noop); + using namespace transformer_engine; cast_transpose(*reinterpret_cast(input), + *reinterpret_cast(noop), reinterpret_cast(cast_output), reinterpret_cast(transposed_output), stream); diff --git a/transformer_engine/common/transpose/rtc/transpose.cu b/transformer_engine/common/transpose/rtc/transpose.cu index 72a1621763..f21014866b 100644 --- a/transformer_engine/common/transpose/rtc/transpose.cu +++ b/transformer_engine/common/transpose/rtc/transpose.cu @@ -22,9 +22,12 @@ constexpr size_t block_size = __BLOCK_SIZE__; __global__ void __launch_bounds__(block_size) transpose_optimized_kernel(const Type * __restrict__ const input, + const float * const noop, Type * __restrict__ const output, const size_t row_length, const size_t num_rows) { + if (noop != nullptr && noop[0] == 1.0f) return; + // Vectorized load/store sizes constexpr size_t nvec_in = load_size / sizeof(Type); constexpr size_t nvec_out = store_size / sizeof(Type); diff --git a/transformer_engine/common/transpose/transpose.cu b/transformer_engine/common/transpose/transpose.cu index f1b8d7a228..3ab83b944b 100644 --- a/transformer_engine/common/transpose/transpose.cu +++ b/transformer_engine/common/transpose/transpose.cu @@ -4,6 +4,7 @@ * See LICENSE for license information. ************************************************************************/ +#include #include #include #include @@ -30,9 +31,12 @@ template __global__ void __launch_bounds__(block_size) transpose_general_kernel(const Type * __restrict__ const input, + const fp32 * const noop, Type * __restrict__ const output, const size_t row_length, const size_t num_rows) { + if (noop != nullptr && noop[0] == 1.0f) return; + // Vectorized load/store sizes constexpr size_t nvec_in = load_size / sizeof(Type); constexpr size_t nvec_out = store_size / sizeof(Type); @@ -124,6 +128,7 @@ transpose_general_kernel(const Type * __restrict__ const input, } void transpose(const Tensor &input, + const Tensor &noop, Tensor *output_, cudaStream_t stream) { Tensor &output = *output_; @@ -140,6 +145,23 @@ void transpose(const Tensor &input, NVTE_CHECK(input.data.dtype == output.data.dtype, "Input and output type must match."); + // Number of elements in tensor + auto numel = [] (const Tensor &tensor) -> size_t { + size_t acc = 1; + for (const auto& dim : tensor.data.shape) { + acc *= dim; + } + return acc; + }; + + if (noop.data.dptr != nullptr) { + NVTE_CHECK(numel(noop) == 1, + "Expected 1 element, ", + "but found ", numel(noop), "."); + NVTE_CHECK(noop.data.dtype == DType::kFloat32); + NVTE_CHECK(noop.data.dptr != nullptr); + } + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(input.data.dtype, Type, constexpr const char *type_name = TypeInfo::name; constexpr size_t type_size = sizeof(Type); @@ -239,6 +261,7 @@ void transpose(const Tensor &input, rtc_manager.launch(kernel_label, num_blocks(load_size, store_size), block_size, 0, stream, static_cast(input.data.dptr), + static_cast(noop.data.dptr), static_cast(output.data.dptr), row_length, num_rows); } else { // Statically-compiled general kernel @@ -250,6 +273,7 @@ void transpose(const Tensor &input, * DIVUP(num_rows, col_tile_size)); transpose_general_kernel<<>>( static_cast(input.data.dptr), + static_cast(noop.data.dptr), static_cast(output.data.dptr), row_length, num_rows); } @@ -263,7 +287,22 @@ void nvte_transpose(const NVTETensor input, cudaStream_t stream) { NVTE_API_CALL(nvte_transpose); using namespace transformer_engine; + auto noop = Tensor(); + transpose(*reinterpret_cast(input), + noop, + reinterpret_cast(output), + stream); +} + + +void nvte_transpose_with_noop(const NVTETensor input, + const NVTETensor noop, + NVTETensor output, + cudaStream_t stream) { + NVTE_API_CALL(nvte_transpose_with_noop); + using namespace transformer_engine; transpose(*reinterpret_cast(input), + *reinterpret_cast(noop), reinterpret_cast(output), stream); } diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index e3abfa00fc..4c513339a0 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -14,6 +14,7 @@ from .transformer import TransformerLayer from .fp8 import fp8_autocast from .fp8 import fp8_model_init +from .graph import make_graphed_callables from .export import onnx_export from .distributed import checkpoint from .distributed import CudaRNGStatesTracker diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index f03350eb4e..f57b58d736 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -52,9 +52,14 @@ get_distributed_world_size, get_distributed_rank, checkpoint, + set_all_rng_states, + CudaRNGStatesTracker, + graph_safe_rng_available, ) from transformer_engine.pytorch.export import is_in_onnx_export_mode from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo +from transformer_engine.pytorch.graph import is_graph_capturing + _flash_attn_version = packaging.version.Version(version("flash-attn")) _flash_attn_version_required = packaging.version.Version("2.0.6") @@ -2401,10 +2406,13 @@ def __init__( assert (num_attention_heads % self.num_gqa_groups == 0 ), "The number of attention heads must be divisible by the number of GQA groups!" + self.rng_states_tracker = None if sequence_parallel or get_rng_state_tracker is None: attention_dropout_ctx = nullcontext else: - attention_dropout_ctx = get_rng_state_tracker().fork + self.rng_states_tracker = get_rng_state_tracker() + set_all_rng_states(self.rng_states_tracker.get_states()) + attention_dropout_ctx = self.rng_states_tracker.fork norm_factor = math.sqrt(self.hidden_size_per_attention_head) @@ -2648,6 +2656,14 @@ def forward( assert (attn_mask_type in AttnMaskTypes ), f"Attention mask type {attn_mask_type} is not supported!" + if self.rng_states_tracker is not None and is_graph_capturing(): + assert ( + isinstance(self.rng_states_tracker, CudaRNGStatesTracker) + ), "Unsupported RNG states tracker." + assert ( + graph_safe_rng_available() + ), "Upgrade PyTorch version to get RNG manipulation support for cuda graph capture." + if window_size is None: window_size = self.window_size @@ -3695,7 +3711,8 @@ def forward( # =================== projection_output = self.proj( - context_layer, is_first_microbatch=is_first_microbatch + context_layer, + is_first_microbatch=is_first_microbatch, ) if self.return_bias: diff --git a/transformer_engine/pytorch/cpp_extensions/transpose.py b/transformer_engine/pytorch/cpp_extensions/transpose.py index ce18dffca0..3671f2e064 100644 --- a/transformer_engine/pytorch/cpp_extensions/transpose.py +++ b/transformer_engine/pytorch/cpp_extensions/transpose.py @@ -22,19 +22,26 @@ def fp8_cast_transpose_fused( otype: tex.DType, cast_out: Optional[torch.Tensor] = None, transpose_out: Optional[torch.Tensor] = None, + noop_flag: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor, torch.Tensor], None]: """Cast + Transpose with FP8 output""" return_outputs = False - if cast_out is None or transpose_out is None: - cast_out = torch.empty_like(inp, dtype=torch.uint8) + if transpose_out is None: transpose_out = torch.empty( inp.shape[1], inp.shape[0], device="cuda", dtype=torch.uint8 ) return_outputs = True + if cast_out is None: + cast_out = torch.empty_like(inp, dtype=torch.uint8) + return_outputs = True + + if noop_flag is None: + noop_flag = torch.Tensor() - tex.fused_cast_transpose( + tex.fused_cast_transpose_noop( inp, + noop_flag, fp8_meta_tensor.scale[fp8_tensor], fp8_meta_tensor.amax_history[0][fp8_tensor], fp8_meta_tensor.scale_inv[fp8_tensor], diff --git a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h index 4e3daf7512..3c039b9a88 100644 --- a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h +++ b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h @@ -157,7 +157,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { } // Catch up the default torch stream - at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); + at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); CHECK_CUDA(cudaEventRecord(_start_comm, (cudaStream_t)stream_main)); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0)); @@ -238,13 +238,10 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { int ori_sms = _ub_comm->sms; // Catch up the default torch stream - at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); + at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); - CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm)); - for (int i = 0; i < _stream_compute.size(); i++) { - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0)); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _stop_comm, 0)); - } + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _start_compute, 0)); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_compute, 0)); if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; @@ -350,11 +347,12 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { int ori_sms = _ub_comm->sms; // Catch up the default torch stream - at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); + at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); for (int i = 0; i < _stream_compute.size(); i++) { CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0)); } + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_compute, 0)); if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; @@ -469,13 +467,13 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size(); } } + for (int i = 0; i < _stream_compute.size(); i++) { + CHECK_CUDA( + cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i])); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); + } _ub_comm->sms = ori_sms; - int last_compute_stream_id = - (_num_splits + _stream_compute.size() - 1) % _stream_compute.size(); - CHECK_CUDA( - cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[last_compute_stream_id])); CHECK_CUDA(cudaEventRecord(_stop_comm, (cudaStream_t)_stream_comm)); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_comm, 0)); at::cuda::setCurrentCUDAStream(stream_main); @@ -506,7 +504,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { } } - at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); + at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); CHECK_CUDA(cudaEventRecord(_start_d2dcopy, (cudaStream_t)stream_main)); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_d2dcopy, 0)); CHECK_CUDA(cudaMemcpyAsync(ubuf_ptr, input.data_ptr(), input.numel() * input.element_size(), @@ -805,14 +803,15 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor]; - at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); + at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main)); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0)); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0)); + for (int i = 0; i < _stream_compute.size(); i++) { + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0)); + } if (_aggregate2) { - // Catch up the default torch stream - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0)); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0)); - const int num_steps = _tp_size / 2; char *input_b_ptr = reinterpret_cast(_ubuf.data_ptr()); @@ -877,21 +876,9 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_tp_id].data_ptr(), _ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(), cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send)); - CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send)); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0)); } } - at::cuda::setCurrentCUDAStream(stream_main); - int last_compute_stream_id = - (num_steps + _stream_compute.size() - 1) % _stream_compute.size(); - CHECK_CUDA( - cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[last_compute_stream_id])); } else { - // Catch up the default torch stream - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0)); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0)); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _start_compute, 0)); - for (int i = 0; i < _tp_size; i++) { // Set the userbuffer id. Buffer under send is the input for the current // GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to @@ -936,16 +923,19 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_tp_id].data_ptr(), _ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(), cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send)); - CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send)); - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0)); } } - at::cuda::setCurrentCUDAStream(stream_main); - int last_compute_stream_id = (_tp_size + _stream_compute.size() - 1) % _stream_compute.size(); + } + for (int i = 0; i < _stream_compute.size(); i++) { CHECK_CUDA( - cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[last_compute_stream_id])); + cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i])); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); } - CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); + CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send)); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0)); + CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0)); + at::cuda::setCurrentCUDAStream(stream_main); return D; } // split_overlap_ag diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 4096280d17..f6d6bad57f 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -43,6 +43,7 @@ #include #include #include +#include namespace transformer_engine { diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index d3872c5b75..0887054665 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -223,6 +223,17 @@ void fused_cast_transpose(at::Tensor input, ); +void fused_cast_transpose_noop(at::Tensor input, + at::Tensor noop, + at::Tensor scale, + at::Tensor amax, + at::Tensor scale_inv, + at::Tensor input_cast, + at::Tensor input_transpose, + transformer_engine::DType otype +); + + std::vector fused_cast_transpose_bgrad(at::Tensor grad_output, at::Tensor scale, at::Tensor amax, @@ -263,6 +274,17 @@ at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype ); +void fp8_transpose_noalloc(at::Tensor input, + at::Tensor output, + transformer_engine::DType otype +); + +void fp8_transpose_noalloc_noop(at::Tensor input, + at::Tensor output, + at::Tensor noop, + transformer_engine::DType otype +); + /*************************************************************************************************** * Activations **************************************************************************************************/ @@ -559,16 +581,13 @@ at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grads * FP8 recipe **************************************************************************************************/ -void fused_amax_and_scale_update(const at::Tensor &amax_history, - const at::Tensor &scale, - const at::Tensor &scale_inv, - const at::Tensor &scale_inv_mask, - at::Tensor updated_amax_history, - at::Tensor updated_scale, - at::Tensor updated_scale_inv, - const std::string& amax_compute_algo, - transformer_engine::DType fp8_dtype, - float margin); +void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reduction_buffer, + std::vector amax_histories, + std::vector scales, + std::vector scale_invs, + const std::string &amax_compute_algo, + transformer_engine::DType fp8_dtype, + float margin); /*************************************************************************************************** * Rotary positional embedding diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 328bf1dcb4..4a7d51cada 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -42,6 +42,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("rmsnorm_fwd", &rmsnorm_fwd, "RMSNorm FWD"); m.def("rmsnorm_fwd_noalloc", &rmsnorm_fwd_noalloc, "RMSNorm FWD"); m.def("fused_cast_transpose", &fused_cast_transpose, "Fused Cast + Transpose"); + m.def("fused_cast_transpose_noop", &fused_cast_transpose_noop, + "Fused Cast + Transpose with noop option"); m.def("fused_cast_transpose_bgrad", &fused_cast_transpose_bgrad, "Fused Cast + Transpose + BGRAD"); m.def("fused_fp8_transpose_bgrad", &fused_fp8_transpose_bgrad, @@ -67,6 +69,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("fused_attn_bwd", &fused_attn_bwd, "Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V"); m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O"); + m.def("fp8_transpose_noalloc", &fp8_transpose_noalloc, "Transpose with FP8 I/O"); + m.def("fp8_transpose_noalloc_noop", &fp8_transpose_noalloc_noop, + "Transpose with FP8 I/O with noop option."); m.def("gelu", &gelu, "GeLU with FP8 output"); m.def("relu", &relu, "ReLU with FP8 output"); m.def("geglu", &geglu, "GeGLU with FP8 output"); @@ -82,9 +87,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("fa_prepare_fwd", &fa_prepare_fwd, "Prepare QKV for Flash Attention"); m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention"); m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend"); - m.def("fused_amax_and_scale_update", - &fused_amax_and_scale_update, - "Update amax history and FP8 scale"); + m.def("fused_amax_and_scale_update_after_reduction", + &fused_amax_and_scale_update_after_reduction, + "Update amax history and FP8 scale/scale_inv after reduction"); // fused apply rope m.def("fused_rope_forward", &fused_rope_forward, "Fused Apply RoPE FWD"); diff --git a/transformer_engine/pytorch/csrc/extensions/recipe.cu b/transformer_engine/pytorch/csrc/extensions/recipe.cu index f97d24a011..d5d8e2f7c8 100644 --- a/transformer_engine/pytorch/csrc/extensions/recipe.cu +++ b/transformer_engine/pytorch/csrc/extensions/recipe.cu @@ -11,24 +11,50 @@ #include #include -void fused_amax_and_scale_update(const at::Tensor &amax_history, - const at::Tensor &scale, - const at::Tensor &scale_inv, - const at::Tensor &scale_inv_mask, - at::Tensor updated_amax_history, - at::Tensor updated_scale, - at::Tensor updated_scale_inv, - const std::string& amax_compute_algo, - transformer_engine::DType fp8_dtype, - float margin) { - nvte_delayed_scaling_recipe_amax_and_scale_update( - makeTransformerEngineTensor(amax_history).data(), - makeTransformerEngineTensor(scale).data(), - makeTransformerEngineTensor(scale_inv).data(), - makeTransformerEngineTensor(scale_inv_mask).data(), - makeTransformerEngineTensor(updated_amax_history).data(), - makeTransformerEngineTensor(updated_scale).data(), - makeTransformerEngineTensor(updated_scale_inv).data(), + +void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reduction_buffer, + std::vector amax_histories, + std::vector scales, + std::vector scale_invs, + const std::string &amax_compute_algo, + transformer_engine::DType fp8_dtype, + float margin) { + using namespace transformer_engine; + size_t num_tensors = amax_histories.size(); + std::vector t_amax_histories(num_tensors); + std::vector t_scales(num_tensors); + std::vector t_scale_invs(num_tensors); + std::vector te_amax_histories(num_tensors); + std::vector te_scales(num_tensors); + std::vector te_scale_invs(num_tensors); + for (size_t i = 0; i < num_tensors; i++) { + t_amax_histories[i].data.dptr = amax_histories[i].data_ptr(); + auto amax_sizes = amax_histories[i].sizes().vec(); + std::vector amax_shape{amax_sizes.begin(), amax_sizes.end()}; + t_amax_histories[i].data.shape = amax_shape; + t_amax_histories[i].data.dtype = DType::kFloat32; + + t_scales[i].data.dptr = scales[i].data_ptr(); + auto scale_sizes = scales[i].sizes().vec(); + std::vector scale_shape{scale_sizes.begin(), scale_sizes.end()}; + t_scales[i].data.shape = scale_shape; + t_scales[i].data.dtype = DType::kFloat32; + + t_scale_invs[i].data.dptr = scale_invs[i].data_ptr(); + auto scale_inv_sizes = scale_invs[i].sizes().vec(); + std::vector scale_inv_shape{scale_inv_sizes.begin(), scale_inv_sizes.end()}; + t_scale_invs[i].data.shape = scale_inv_shape; + t_scale_invs[i].data.dtype = DType::kFloat32; + + te_amax_histories[i] = reinterpret_cast(&t_amax_histories[i]); + te_scales[i] = reinterpret_cast(&t_scales[i]); + te_scale_invs[i] = reinterpret_cast(&t_scale_invs[i]); + } + nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( + makeTransformerEngineTensor(amax_reduction_buffer).data(), + te_amax_histories, + te_scales, + te_scale_invs, amax_compute_algo.c_str(), static_cast(fp8_dtype), margin, diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cu b/transformer_engine/pytorch/csrc/extensions/transpose.cu index 038e82d955..fc178adeb4 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cu +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cu @@ -32,6 +32,35 @@ void fused_cast_transpose(at::Tensor input, } +void fused_cast_transpose_noop(at::Tensor input, + at::Tensor noop, + at::Tensor scale, + at::Tensor amax, + at::Tensor scale_inv, + at::Tensor input_cast, + at::Tensor input_transpose, + transformer_engine::DType otype +) { + using namespace transformer_engine; + + size_t M = static_cast(input.size(0)); + size_t N = static_cast(input.size(1)); + + auto input_cu = makeTransformerEngineTensor(input); + auto noop_cu = makeTransformerEngineTensor(noop); + auto output_cast_cu = makeTransformerEngineTensor(input_cast.data_ptr(), {M, N}, otype, + amax.data_ptr(), scale.data_ptr(), + scale_inv.data_ptr()); + auto output_transpose_cu = makeTransformerEngineTensor(input_transpose.data_ptr(), {N, M}, otype, + amax.data_ptr(), scale.data_ptr(), + scale_inv.data_ptr()); + + nvte_cast_transpose_with_noop(input_cu.data(), noop_cu.data(), output_cast_cu.data(), + output_transpose_cu.data(), + at::cuda::getCurrentCUDAStream()); +} + + std::vector fused_cast_transpose_bgrad(at::Tensor grad_output, at::Tensor scale, at::Tensor amax, @@ -319,3 +348,39 @@ at::Tensor fp8_transpose(at::Tensor input, return output; } + + +void fp8_transpose_noalloc(at::Tensor input, + at::Tensor output, + transformer_engine::DType otype +) { + using namespace transformer_engine; + + size_t M = static_cast(input.size(0)); + size_t N = static_cast(input.size(1)); + + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, otype); + auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, M}, otype); + + nvte_transpose(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); +} + + +void fp8_transpose_noalloc_noop(at::Tensor input, + at::Tensor output, + at::Tensor noop, + transformer_engine::DType otype +) { + using namespace transformer_engine; + + size_t M = static_cast(input.size(0)); + size_t N = static_cast(input.size(1)); + + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, otype); + auto noop_cu = makeTransformerEngineTensor(noop); + auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, M}, otype); + + nvte_transpose_with_noop( + input_cu.data(), noop_cu.data(), output_cu.data(), + at::cuda::getCurrentCUDAStream()); +} diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 239cecf39b..8d499d88d6 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -5,10 +5,10 @@ """Methods needed for distributed training (DP/TP).""" import warnings from contextlib import contextmanager, AbstractContextManager, ContextDecorator -from typing import Any, Dict, Union, Optional, Callable, Tuple +from typing import Any, Dict, List, Union, Optional, Callable, Tuple import torch -from torch.cuda import _lazy_call +from torch.cuda import _lazy_call, _lazy_init from torch.utils.checkpoint import detach_variable, noop_context_fn from .utils import safely_set_viewless_tensor_data @@ -31,15 +31,60 @@ _FP8_ACTIVATION_RECOMPUTE_PHASE = False -def _set_cuda_rng_state(new_state: torch.Tensor, device: Union[int, str] = -1) -> None: - """Sets the random number generator state of the current GPU. +_ALL_ACTIVE_RNG_STATES = {} + + +def get_all_rng_states() -> bool: + """Returns all generator states used by `CudaRNGStatesTracker`.""" + return _ALL_ACTIVE_RNG_STATES + + +def set_all_rng_states(states: List) -> None: + """Updates all generator states used by `CudaRNGStatesTracker`.""" + global _ALL_ACTIVE_RNG_STATES + _ALL_ACTIVE_RNG_STATES = states + + +def graph_safe_rng_available() -> bool: + """Returns whether cuda graph safe RNG state manipulation is supported.""" + return (hasattr(torch.cuda.CUDAGraph, "register_generator_state") + and hasattr(torch.Generator, "graphsafe_set_state") + and hasattr(torch.Generator, "graphsafe_get_state") + and hasattr(torch.Generator, "clone_state")) + + +def _get_cuda_rng_state( + device: Union[int, str, torch.device] = "cuda", + clone: bool = False, + graph_safe: bool = True, +) -> torch.Tensor: + """Return the random number generator state of the specified GPU.""" + + _lazy_init() + if isinstance(device, str): + device = torch.device(device) + elif isinstance(device, int): + device = torch.device("cuda", device) + idx = device.index + if idx is None: + idx = torch.cuda.current_device() + default_generator = torch.cuda.default_generators[idx] + if graph_safe_rng_available() and graph_safe: + if clone: + # Reference to the cloned generator state + return default_generator.clone_state() + # Reference to the current generator state + return default_generator.graphsafe_get_state() + return default_generator.get_state() + + +def _set_cuda_rng_state( + new_state: torch.Tensor, + device: Union[int, str] = -1, + graph_safe = True, +) -> None: + """Sets the random number generator state of the current GPU.""" - Arguments: - new_state (torch.ByteTensor): The desired state - This function is adapted from PyTorch repo (torch.cuda.set_rng_state) - with a single change: the input state is not cloned. Cloning caused - major performance issues for +4 GPU cases. - """ if device == -1: device = torch.device("cuda") elif isinstance(device, str): @@ -52,6 +97,9 @@ def cb() -> None: if idx is None: idx = torch.cuda.current_device() default_generator = torch.cuda.default_generators[idx] + if graph_safe_rng_available() and graph_safe: + default_generator.graphsafe_set_state(new_state) + return default_generator.set_state(new_state) _lazy_call(cb) @@ -206,7 +254,7 @@ def forward( # Copy the rng states. ctx.fwd_cpu_rng_state = torch.get_rng_state() - ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state() + ctx.fwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=False) if get_rng_state_tracker is not None: ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states() @@ -271,13 +319,13 @@ def backward( # Store the current states. bwd_cpu_rng_state = torch.get_rng_state() - bwd_cuda_rng_state = torch.cuda.get_rng_state() + bwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=False) if get_rng_state_tracker is not None: bwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states() # Set the states to what it used to be before the forward pass. torch.set_rng_state(ctx.fwd_cpu_rng_state) - _set_cuda_rng_state(ctx.fwd_cuda_rng_state) + _set_cuda_rng_state(ctx.fwd_cuda_rng_state, graph_safe=False) if get_rng_state_tracker is not None: get_rng_state_tracker().set_states(ctx.fwd_cuda_rng_state_tracker) @@ -291,7 +339,7 @@ def backward( # Set the states back to what it was at the start of this function. torch.set_rng_state(bwd_cpu_rng_state) - _set_cuda_rng_state(bwd_cuda_rng_state) + _set_cuda_rng_state(bwd_cuda_rng_state, graph_safe=False) if get_rng_state_tracker is not None: get_rng_state_tracker().set_states(bwd_cuda_rng_state_tracker) @@ -317,6 +365,7 @@ def backward( ) return (None, None, None, None, None, None) + grads + class _CheckpointFrame: """ Storage frame for forward RNG states and detached activations from the forward recompute. @@ -338,7 +387,7 @@ def cache_rng_states(self, forward=True): """Cache fwd/bwd RNG states in the frame to restore later.""" rng_states = ( torch.get_rng_state(), - torch.cuda.get_rng_state(), + _get_cuda_rng_state(graph_safe=False), ) if self.get_rng_state_tracker is not None: rng_states += (self.get_rng_state_tracker().get_states(), ) @@ -356,7 +405,7 @@ def restore_rng_states(self, forward=True): rng_states = self.bwd_rng_states torch.set_rng_state(rng_states[0]) - _set_cuda_rng_state(rng_states[1]) + _set_cuda_rng_state(rng_states[1], graph_safe=False) if self.get_rng_state_tracker is not None: self.get_rng_state_tracker().set_states(rng_states[2]) @@ -604,6 +653,7 @@ def recompute_fn(*args, **kwargs): return out + class CudaRNGStatesTracker: """ For model parallelism, multiple RNG states need to simultaneously exist in order @@ -664,13 +714,23 @@ def add(self, name: str, seed: int) -> None: # Check that state is not already defined. if name in self.states_: raise Exception(f"cuda rng state {name} already exists") - # Get the current rng state. - orig_rng_state = torch.cuda.get_rng_state() - # Set the new state and store it. - torch.cuda.manual_seed(seed) - self.states_[name] = torch.cuda.get_rng_state() - # Reset rng state to what it was. - _set_cuda_rng_state(orig_rng_state) + + if graph_safe_rng_available(): + new_state = _get_cuda_rng_state(clone=True) + new_state.manual_seed(seed) + self.states_[name] = new_state + # Update global states. + set_all_rng_states(self.states_) + else: + # Get the current rng state. + orig_rng_state = _get_cuda_rng_state() + # Set the new state and store it. + torch.cuda.manual_seed(seed) + self.states_[name] = _get_cuda_rng_state(clone=True) + # Reset rng state to what it was. + _set_cuda_rng_state(orig_rng_state) + # Update global states. + set_all_rng_states(self.states_) @contextmanager def fork(self, name: str = "model-parallel-rng"): @@ -684,16 +744,17 @@ def fork(self, name: str = "model-parallel-rng"): # Check if we have added the state if name not in self.states_: raise Exception(f"cuda rng state {name} is not added") - # Store current rng state. - orig_cuda_rng_state = torch.cuda.get_rng_state() + # Get the reference to current rng state. + orig_cuda_rng_state = _get_cuda_rng_state() # Set rng state to the desired one _set_cuda_rng_state(self.states_[name]) # Do the stuff we wanted to do. try: yield finally: - # Update the current rng state for later use. - self.states_[name] = torch.cuda.get_rng_state() + # this is redundant with graph-safe API + if not graph_safe_rng_available(): + self.states_[name] = _get_cuda_rng_state() # And set the state to the original state we started with. _set_cuda_rng_state(orig_cuda_rng_state) diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index 8092d2fccd..9923d24a42 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -16,6 +16,7 @@ aten = torch.ops.aten c10d = torch.ops.c10d +updated_fp8_params = {} def _make_fp8_attr_property_funcs(name: str) -> Any: @@ -67,6 +68,31 @@ def backward(ctx, grad): return grad, None +def post_optimizer_step_fwd_amax_reduction(param: Float8Tensor) -> None: + """Amax scale and update when there is at least 1 trainable FP8 parameter.""" + param_id = id(param._data) + + if param_id not in FP8GlobalStateManager.fp8_param_to_autocast: + return + + autocast_key = FP8GlobalStateManager.fp8_param_to_autocast[param_id] + + if autocast_key not in FP8GlobalStateManager.autocast_to_fp8_params: + return + + if autocast_key in updated_fp8_params: + updated_fp8_params[autocast_key].add(param_id) + else: + updated_fp8_params[autocast_key] = {param_id} + + current_fp8_params_set = FP8GlobalStateManager.autocast_to_fp8_params[autocast_key] + # All FP8 trainable parameters have been updated. + if updated_fp8_params[autocast_key] == current_fp8_params_set: + FP8GlobalStateManager.reduce_and_update_fp8_tensors( + forward=True, fp8_weights=True) + del updated_fp8_params[autocast_key] + + class _ToFloat8Func(torch.autograd.Function): """Cast to FP8 from other dtype""" @staticmethod @@ -167,6 +193,7 @@ def backward(ctx, grad): # Assume that we want gradients in full precision return grad, None, None, None, None, None, None, None + class _IdentityFunc(torch.autograd.Function): """Identity function @@ -307,8 +334,9 @@ def __new__( ), f"Unsupported fp8_dtype {fp8_dtype}." self._fp8_dtype: tex.DType = fp8_dtype - # Cached transpose + # Transposed version of `_data`. self._transpose: Optional[Float8Tensor] = None + self._transpose_invalid: bool = True # FP8 scale-inverse self._scale_inv: Optional[torch.Tensor] = fp8_scale_inv @@ -435,80 +463,51 @@ def expand_as(self, other: torch.Tensor): return _IdentityFunc.apply(self) return super().expand_as(other) - def transpose( + def transpose_2d( self, - dim0: int = 0, - dim1: int = 1, *, - update_cache: str | bool = "reuse_only", + cache: bool = False, + noop_flag: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ - Swap tensor dimensions - - For basic 2D matrix transposes, an optimized transpose kernel - is applied and a Float8Tensor is returned. + 2D transpose with caching support. Parameters ---------- - dim0: int, default = 0 - The first dimension to be transposed - dim1: int, default = 1 - The second dimension to be transposed - update_cache: str or bool, default = "reuse_only" - Memoization behavior. Options are - "reuse_only"/`False` (reuse cached value if - available, otherwise calculate transpose without - caching), "force"/`True` (calculate transpose - and cache), "lazy" (reuse cached value if - available, otherwise calculate transpose and - cache if possible). Caching is only supported - for basic 2D transposes and the cache is reset - after any in-place operations. - + cache: bool, default = `False` + Whether or not to cache the transpose. + noop_flag: Optional[torch.Tensor], default = `None` + Only used if argument `cache` is `True`, ignored otherwise. + A single element fp32 tensor with a value of 1.0 or 0.0 + which is treated as a boolean. `1.0` forces recompute + and `0.0` executes a noop using the same kernel. """ + assert self.dim() == 2, f"{self.dim()}-D transpose not supported." - # Check caching mode - if not isinstance(update_cache, str): - update_cache = "force" if update_cache else "reuse_only" - if update_cache not in ("force", "reuse_only", "lazy"): - raise ValueError( - "Supported values for update_cache are " - '"force" (True), "reuse_only" (False), "lazy" ' - f"(got {update_cache})" - ) + # Case: no caching. + if not cache: + return tex.fp8_transpose(self._data, self._fp8_dtype) - # Handle non-2D transposes - if -self.dim() <= dim0 < 0: - dim0 += self.dim() - if -self.dim() <= dim1 < 0: - dim1 += self.dim() - if self.dim() != 2 or dim0 == dim1: - if update_cache == "force": - raise ValueError( - "Transpose caching is only supported for basic 2D transposes " - f"(ndims={self.dim()}, dim0={dim0}, dim1={dim1})" - ) - return super().transpose(dim0, dim1) - - # Clear cache if needed - if update_cache == "force": - self._transpose = None - - # Compute transpose if needed - out = self._transpose - if out is None: - out = Float8Tensor.make_like( - self, - data=tex.fp8_transpose( - self._data.contiguous(), - self._fp8_dtype, - ), - ) + # Case: reuse cache without calling a kernel. + if not self._transpose_invalid and noop_flag is None: + assert self._transpose is not None, "Tranpose cache is empty." + return self._transpose - # Update cache if needed - if update_cache in ("force", "lazy"): - self._transpose = out - return out + # Allocate transpose if needed. + data_2d = self._data.reshape(-1, self._data.shape[-1]) + if self._transpose is None: + shape = (data_2d.shape[1], data_2d.shape[0]) + self._transpose = torch.empty(shape, dtype=torch.uint8, device=self._data.device) + + # Case: recompute transpose and store cache. + if noop_flag is None: + tex.fp8_transpose_noalloc(data_2d, self._transpose, self._fp8_dtype) + else: + # Case: cuda graph capture. + tex.fp8_transpose_noalloc_noop(data_2d, self._transpose, noop_flag, self._fp8_dtype) + + self._transpose_invalid = False + return self._transpose @torch.no_grad() def reset_fp8_meta_scale_inv(self) -> None: @@ -519,13 +518,11 @@ def reset_fp8_meta_scale_inv(self) -> None: the tensor. """ - if self._fp8_meta is None: - return + assert self._fp8_meta is not None, "FP8 meta tensors not found." fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( forward=self._fp8_meta_forward, ) - scale_inv = self._fp8_meta[fp8_meta_key].scale_inv[self._fp8_meta_index] - scale_inv.view(1).copy_(self._scale_inv.view(1)) + self._fp8_meta[fp8_meta_key].scale_inv[self._fp8_meta_index].copy_(self._scale_inv[0]) def to_dtype(self, dtype: torch.dtype) -> Float8Tensor: """Create `Float8Tensor` with given nominal dtype @@ -541,12 +538,11 @@ def to_dtype(self, dtype: torch.dtype) -> Float8Tensor: ) def _reset_caches(self) -> None: - """Reset cached values - + """ + Set transpose cache as invalid. Should be called after any in-place operation. - """ - self._transpose = None + self._transpose_invalid = True @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): @@ -574,7 +570,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): # Directly copy FP8 data if possible if dst._fp8_dtype == src._fp8_dtype: dst._data.copy_(src._data) - dst._scale_inv = src._scale_inv.clone() + dst._scale_inv.copy_(src._scale_inv.detach()) if dst._fp8_meta is not None: if src._fp8_meta is None: src_min, src_max = src.from_float8().aminmax() @@ -600,7 +596,6 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): dst.copy_(src.from_float8()) elif dst_is_fp8 and not src_is_fp8: - # Make sure input is in expected format src = src.expand(dst.size()) src = src.to( @@ -619,7 +614,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): fp8_meta_index = dst._fp8_meta_index scale = dst._fp8_meta[fp8_meta_key].scale[fp8_meta_index] amax = dst._fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index] - dst._scale_inv = scale.detach().view(1).reciprocal() + dst._scale_inv.copy_(scale.detach().reciprocal()) # Cast to FP8 if not dst._data.is_contiguous(): @@ -633,6 +628,9 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): dst._fp8_dtype, ) + # This branch is where the FP8 parameters are updated in-place during optimization. + # Handle forward amax reduction. + post_optimizer_step_fwd_amax_reduction(dst) else: # Invalid case @@ -641,6 +639,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): # Nothing to return for in-place ops if dst_is_fp8: dst._reset_caches() + return None # Slice op @@ -764,6 +763,7 @@ def _set_data(self, tensor: torch.Tensor) -> None: _fp8_meta_index = property(**_make_fp8_attr_property_funcs("fp8_meta_index")) _fp8_dtype = property(**_make_fp8_attr_property_funcs("dtype")) _transpose = property(**_make_fp8_attr_property_funcs("transpose")) + _transpose_invalid = property(**_make_fp8_attr_property_funcs("transpose_invalid")) _scale_inv = property(**_make_fp8_attr_property_funcs("scale_inv")) # Do not force the Float8Tensor type on the returned tensor diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index bbeea13af3..e821bfe11d 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -51,6 +51,17 @@ def get_fp8_te_dtype( return tex.DType.kFloat8E5M2 +def get_fp8_max( + fp8_recipe: DelayedScaling, fprop_tensor: bool = True +) -> tex.DType: + """Get max representible FP8 value.""" + if fp8_recipe.fp8_format == Format.E4M3 or ( + fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor + ): + return Format.E4M3.value.max_fwd + return Format.E5M2.value.max_fwd + + class FP8GlobalStateManager: """Class to keep track of and manipulate the global FP8 state at different stages of execution. @@ -61,20 +72,21 @@ class FP8GlobalStateManager: FP8_DISTRIBUTED_GROUP = None FP8_PARAMETERS = False IS_FIRST_FP8_MODULE = False - FP8_AUTOCAST_COUNTER = 0 - FP8_CURRENT_CONTEXT_ID = 0 + FP8_GRAPH_CAPTURING = False FP8_AUTOCAST_DEPTH = 0 - global_fp8_buffer = {} + global_amax_buffer = {} + global_amax_history_buffer = {} + global_scale_buffer = {} + global_scale_inv_buffer = {} fp8_tensors_recompute_buffer = [] - amax_forward_global_reduce_func = None - buffer_delete_key_fwd = None - buffer_delete_key_bwd = None - amax_reduce_handle_fwd = None fp8_available = None reason_for_no_fp8 = "" - dp_amax_reduce_interval = None - dp_amax_reduce_forward_idx = 0 - dp_amax_reduce_backward_idx = 0 + multi_grad_hook_tensors = [] + bwd_amax_update_hook_registered = False + autocast_arguments = {} + autocast_to_fp8_params = {} + fp8_param_to_autocast = {} + skip_fp8_weight_update_tensor = None @classmethod def reset(cls) -> None: @@ -83,21 +95,35 @@ def reset(cls) -> None: cls.FP8_CALIBRATION = False cls.FP8_RECIPE = None cls.FP8_DISTRIBUTED_GROUP = None + cls.FP8_PARAMETERS = False cls.IS_FIRST_FP8_MODULE = False - cls.FP8_AUTOCAST_COUNTER = 0 - cls.FP8_CURRENT_CONTEXT_ID = 0 + cls.FP8_GRAPH_CAPTURING = False cls.FP8_AUTOCAST_DEPTH = 0 - cls.global_fp8_buffer = {} + cls.global_amax_buffer = {} + cls.global_amax_history_buffer = {} + cls.global_scale_buffer = {} + cls.global_scale_inv_buffer = {} cls.fp8_tensors_recompute_buffer = [] - cls.amax_forward_global_reduce_func = None - cls.buffer_delete_key_fwd = None - cls.buffer_delete_key_bwd = None - cls.amax_reduce_handle_fwd = None cls.fp8_available = None cls.reason_for_no_fp8 = "" - cls.dp_amax_reduce_interval = None - cls.dp_amax_reduce_forward_idx = 0 - cls.dp_amax_reduce_backward_idx = 0 + cls.multi_grad_hook_tensors = [] + cls.bwd_amax_update_hook_registered = False + cls.autocast_arguments = {} + cls.autocast_to_fp8_params = {} + cls.fp8_param_to_autocast = {} + cls.skip_fp8_weight_update_tensor = None + + @classmethod + def set_skip_fp8_weight_update_tensor(cls, skip: bool) -> None: + """`skip_fp8_weight_update_tensor` inplace setter.""" + if cls.skip_fp8_weight_update_tensor is None: + cls.skip_fp8_weight_update_tensor = torch.empty(1, dtype=torch.float32, device="cuda") + cls.skip_fp8_weight_update_tensor.fill_(skip) + + @classmethod + def get_skip_fp8_weight_update_tensor(cls) -> None: + """`skip_fp8_weight_update_tensor` getter.""" + return cls.skip_fp8_weight_update_tensor @classmethod def is_fp8_available(cls) -> Tuple[bool, str]: @@ -106,44 +132,6 @@ def is_fp8_available(cls) -> Tuple[bool, str]: cls.fp8_available, cls.reason_for_no_fp8 = check_fp8_support() return cls.fp8_available, cls.reason_for_no_fp8 - @classmethod - def get_global_fp8_state_checkpoint(cls) -> Dict[str, Union[int, str]]: - """Returns global fp8 state variables.""" - # Convert attributes to dictionary to make future proof against - # changes in global state variables in order to make setting the - # checkpoint backwards compatible. - global_fp8_state = {} - global_fp8_state["FP8_AUTOCAST_COUNTER"] = cls.FP8_AUTOCAST_COUNTER - global_fp8_state["FP8_CURRENT_CONTEXT_ID"] = cls.FP8_CURRENT_CONTEXT_ID - global_fp8_state["FP8_AUTOCAST_DEPTH"] = cls.FP8_AUTOCAST_DEPTH - global_fp8_state["buffer_delete_key_fwd"] = cls.buffer_delete_key_fwd - global_fp8_state["buffer_delete_key_bwd"] = cls.buffer_delete_key_bwd - global_fp8_state["dp_amax_reduce_interval"] = cls.dp_amax_reduce_interval - global_fp8_state["dp_amax_reduce_forward_idx"] = cls.dp_amax_reduce_forward_idx - global_fp8_state["dp_amax_reduce_backward_idx"] = cls.dp_amax_reduce_backward_idx - return global_fp8_state - - @classmethod - def set_global_fp8_state_checkpoint(cls, state: Dict[str, Union[int, str]]) -> None: - """Sets global fp8 state variables.""" - for k, v in state.items(): - if hasattr(cls, k): - setattr(cls, k, v) - - @classmethod - def get_global_fp8_buffer_checkpoint(cls) -> Dict[str, List[torch.Tensor]]: - """Returns global fp8 amax buffer.""" - return cls.global_fp8_buffer - - @classmethod - def set_global_fp8_buffer_checkpoint(cls, buffer: Dict[str, List[torch.Tensor]]) -> None: - """Sets global fp8 amax buffer.""" - # Map all tensors back to GPU. - for k, v in buffer.items(): - buffer[k] = [tensor.cuda() for tensor in v] - - cls.global_fp8_buffer = buffer - @staticmethod def get_meta_tensor_key(forward: bool = True) -> str: """Returns scaling key in `fp8_meta`.""" @@ -152,121 +140,102 @@ def get_meta_tensor_key(forward: bool = True) -> str: return "scaling_bwd" @staticmethod - def get_buffer_position_key(forward: bool = True) -> str: - """Returns module position key in `fp8_meta`.""" - if forward: - return "global_fp8_buffer_pos_fwd" - return "global_fp8_buffer_pos_bwd" - - @staticmethod - def get_autocast_key(forward: bool = True) -> str: - """Returns module position key in `fp8_meta`.""" - if forward: - return "autocast_id_fwd" - return "autocast_id_bwd" - - @staticmethod - def get_amax_buffer_key(fp8_meta: Dict[str, Any], forward: bool = True) -> str: - """Return a key in `_global_fp8_buffer` for the AMAX storage.""" - if forward: - return f"FWD_AMAX_{fp8_meta['autocast_id_fwd']}" - return f"BWD_AMAX_{fp8_meta['autocast_id_bwd']}" + def get_fwd_bwd_key(forward: bool = True) -> str: + """Convert bool `forward` to string.""" + return "forward" if forward else "backward" @classmethod - def get_amax_reduce_handle_fwd(cls) -> Union[bool, None]: - """Return AMAX reduction wait handle of forward prop.""" - return cls.amax_reduce_handle_fwd + def get_buffer_info(cls) -> str: + """ + Returns a key for `fp8_meta` that stores the module's index + in the global buffers along with autocast information. + """ + return "buffer_index_and_autocast_key" @classmethod - def setup_amax_forward_global_reduce_func(cls, f: Callable) -> None: - """Sets up the function to call during autocast exit.""" - cls.amax_forward_global_reduce_func = f + def get_key_in_buffer( + cls, + forward: bool, + fp8_weights: bool, + fp8_recipe: DelayedScaling, + fp8_group: dist_group_type, + ) -> str: + """Returns a key into the global FP8 buffers.""" + autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group) + fwd_bwd_key = cls.get_fwd_bwd_key(forward) + return f"{fwd_bwd_key}_{fp8_weights}_{autocast_key}" @classmethod - def add_amax_to_global_buffer(cls, fp8_meta: Dict[str, Any], forward: bool = True) -> None: - """Append 1D tensor `amax` to global buffer.""" - buffer_key = cls.get_amax_buffer_key(fp8_meta, forward=forward) - fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward) - buffer_position_key = cls.get_buffer_position_key(forward=forward) - - if buffer_key not in cls.global_fp8_buffer: - cls.global_fp8_buffer[buffer_key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]] - else: - cls.global_fp8_buffer[buffer_key].append( - fp8_meta[fp8_meta_tensor_key].amax_history[0] - ) - - if buffer_position_key not in fp8_meta: - fp8_meta[buffer_position_key] = len(cls.global_fp8_buffer[buffer_key]) - 1 - - # Catch incorrect fp8_autocast usage. - assert fp8_meta[buffer_position_key] == len(cls.global_fp8_buffer[buffer_key]) - 1, \ - "Same module is being invoked more than once inside an `fp8_autocast` " \ - "region when using FP8 with amax reduction. This behavior is currently" \ - " unsupported. For more details and correct usage, please see " \ - "https://github.com/NVIDIA/TransformerEngine/pull/93." + def split_key_in_buffer(cls, key: str) -> Tuple[bool, bool, str]: + """Splits buffer key into relevant parts.""" + forward, fp8_weights, autocast_key = key.split("_", 2) + forward = forward == "forward" + fp8_weights = fp8_weights == "True" + return forward, fp8_weights, autocast_key @classmethod - def copy_amax_from_global_buffer( - cls, fp8_meta: Dict[str, Any], forward: bool = True + def add_fp8_tensors_to_global_buffer( + cls, + fp8_meta: Dict[str, Any], + fp8_weights: Optional[List[torch.Tensor]] = None, ) -> None: - """Populate current amax with the correct location from buffer.""" - fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward) - buffer_position_key = cls.get_buffer_position_key(forward=forward) - if buffer_position_key not in fp8_meta: - return - - amax_buffer_key = cls.get_amax_buffer_key(fp8_meta, forward=forward) - assert amax_buffer_key in cls.global_fp8_buffer, "TE internal error." - - fp8_meta[fp8_meta_tensor_key].amax_history[0] = cls.global_fp8_buffer[amax_buffer_key][ - fp8_meta[buffer_position_key] - ] + """ + The amax reduction process happens completely outside the FP8 modules. + To participate in the reduction, the only role played by a module is + to call this function in order to append it's FP8 tensor into a global + buffer. There are 5 global buffers maintained, one each for amax, amax + history, scale, scale-inverse, and non-weight-mask. Each buffer has + keys that hold FP8 tensors. Keys have a `forward_` or `backward_` prefix + to indicate the type of FP8 tensor, since the forward and backward + reductions happen separately. + + Note: For CG capture, this method is called from the graphed + wrapper. For non CG case, it's called from within the module. + """ - @classmethod - def set_amax_buffer_key_deletion( - cls, fp8_meta: Dict[str, Any], forward: bool = True - ) -> None: - """Delete this amax key from global buffer during autocast end.""" - if cls.get_autocast_key(forward=forward) not in fp8_meta: + # Every module must call this function exactly once since + # the amax tensors are static. Ensures that compatibility + # with non-graphed modules is maintained. + index_in_buffer = cls.get_buffer_info() # Same index for fwd/bwd fp8 tensors. + if index_in_buffer in fp8_meta: return - if forward: - cls.buffer_delete_key_fwd = cls.get_amax_buffer_key(fp8_meta, forward=forward) - else: - cls.buffer_delete_key_bwd = cls.get_amax_buffer_key(fp8_meta, forward=forward) - - @classmethod - def delete_key_from_amax_buffer(cls, forward: bool = True) -> None: - """Delete the key from global amax buffer.""" - if forward: - if ( - cls.buffer_delete_key_fwd is not None - and cls.buffer_delete_key_fwd in cls.global_fp8_buffer - ): - del cls.global_fp8_buffer[cls.buffer_delete_key_fwd] - else: - if ( - cls.buffer_delete_key_bwd is not None - and cls.buffer_delete_key_bwd in cls.global_fp8_buffer - ): - del cls.global_fp8_buffer[cls.buffer_delete_key_bwd] - @classmethod - def get_fp8_context_id(cls) -> int: - """Returns an ID for the current FP8 context.""" - return cls.FP8_CURRENT_CONTEXT_ID - - @classmethod - def set_fp8_context_id(cls, ctx_id: int) -> None: - """Sets the current FP8 context.""" - cls.FP8_CURRENT_CONTEXT_ID = ctx_id - - @classmethod - def new_fp8_context_id(cls) -> int: - """Returns global autocast counter as a proxy to be used - as the autocast ID for FP8 modules. - """ - return cls.FP8_AUTOCAST_COUNTER + fp8_meta[index_in_buffer] = [] + for forward in (True, False): + # This algorithm creates a two-way map with `autocast_to_fp8_params` and + # `fp8_param_to_autocast`. This is used for keeping track of FP8 weights + # in an autocasted region and cross reference them in `float8_tensor.py` + # to perform the forward amax reduction. + if forward and fp8_weights is not None: + autocast_key = cls.get_unique_autocast_key( + fp8_meta["recipe"], fp8_meta["fp8_group"]) + fp8_weight_set = {id(w._data) for w in fp8_weights} + if autocast_key not in cls.autocast_to_fp8_params: + cls.autocast_to_fp8_params[autocast_key] = fp8_weight_set + else: + cls.autocast_to_fp8_params[autocast_key] = ( + cls.autocast_to_fp8_params[autocast_key].union(fp8_weight_set)) + # Identify correct autocast key for a given param. + for w in fp8_weight_set: + cls.fp8_param_to_autocast[w] = autocast_key + + key = cls.get_key_in_buffer( + forward, fp8_weights is not None, fp8_meta["recipe"], fp8_meta["fp8_group"]) + fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward) + + if key not in cls.global_amax_buffer: + cls.global_amax_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]] + cls.global_amax_history_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history] + cls.global_scale_buffer[key] = [fp8_meta[fp8_meta_tensor_key].scale] + cls.global_scale_inv_buffer[key] = [fp8_meta[fp8_meta_tensor_key].scale_inv] + else: + cls.global_amax_buffer[key].append(fp8_meta[fp8_meta_tensor_key].amax_history[0]) + cls.global_amax_history_buffer[key].append( + fp8_meta[fp8_meta_tensor_key].amax_history) + cls.global_scale_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale) + cls.global_scale_inv_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale_inv) + fp8_meta[index_in_buffer].append(len(cls.global_amax_buffer[key]) - 1) + fp8_meta[index_in_buffer].append(key) @classmethod def is_fp8_enabled(cls) -> bool: @@ -283,6 +252,11 @@ def with_fp8_parameters(cls) -> bool: """Should the parameters be stored as FP8""" return cls.FP8_PARAMETERS + @classmethod + def fp8_graph_capturing(cls) -> bool: + """Is CUDA graph capture under way?""" + return cls.FP8_GRAPH_CAPTURING or torch.cuda.is_current_stream_capturing() + @classmethod def is_first_fp8_module(cls): """Returns `True` only the first time when called multiple @@ -310,7 +284,8 @@ def get_fp8_autocast_state(cls) -> Tuple[bool, bool, DelayedScaling, dist_group_ cls.FP8_CALIBRATION, cls.FP8_RECIPE, cls.FP8_DISTRIBUTED_GROUP, - cls.IS_FIRST_FP8_MODULE) + cls.IS_FIRST_FP8_MODULE, + cls.FP8_GRAPH_CAPTURING) @classmethod def set_fp8_autocast_state( @@ -322,80 +297,100 @@ def set_fp8_autocast_state( cls.FP8_CALIBRATION, cls.FP8_RECIPE, cls.FP8_DISTRIBUTED_GROUP, - cls.IS_FIRST_FP8_MODULE) = fp8_state + cls.IS_FIRST_FP8_MODULE, + cls.FP8_GRAPH_CAPTURING) = fp8_state @staticmethod def reduce_tensor_across_group_op_max( - tensor: torch.Tensor, group: dist_group_type, async_op: bool + tensor: torch.Tensor, group: dist_group_type ) -> None: """Reduce tensor across given group.""" if torch.distributed.is_initialized(): - wait_handle = torch.distributed.all_reduce( + torch.distributed.all_reduce( tensor, op=torch.distributed.ReduceOp.MAX, group=group, - async_op=async_op, + async_op=False, ) - return wait_handle - return None @classmethod - def global_amax_reduction( + def reduce_and_update_fp8_tensors( cls, - fp8_meta: Dict[str, Any], - tp_group: dist_group_type, - tp_size: int, forward: bool = True, + fp8_weights: bool = False, ) -> None: """Concatenate, reduce, and split amaxes in the global buffer.""" - amax_buffer_key = cls.get_amax_buffer_key(fp8_meta, forward=forward) - - # Key already deleted. - if amax_buffer_key not in cls.global_fp8_buffer: - return None - - # Reduce AMAX in DP-domain at an interval. - # `NVTE_DP_AMAX_REDUCE_INTERVAL` should be set as an integer value larger than 0. If - # `NVTE_DP_AMAX_REDUCE_INTERVAL` is set to 0, AMAX is reduced only in TP domain. - if cls.dp_amax_reduce_interval is None: - cls.dp_amax_reduce_interval = int(os.getenv("NVTE_DP_AMAX_REDUCE_INTERVAL", "1")) - - if cls.dp_amax_reduce_interval == 0: - tp_amax_reduce = True - else: - tp_amax_reduce = False - if forward: - if cls.dp_amax_reduce_forward_idx == 0: - reduce_group = fp8_meta["fp8_group"] - else: - tp_amax_reduce = True - cls.dp_amax_reduce_forward_idx = ( - (cls.dp_amax_reduce_forward_idx + 1) % cls.dp_amax_reduce_interval) + for buffer_key, amax_buffer in cls.global_amax_buffer.items(): + # Check for forward or backward reduction. + fwd_update, fp8_weights_update, autocast_key = cls.split_key_in_buffer(buffer_key) + if fwd_update != forward: + continue + # Only skip a forward update when `fp8_weights` is explicitly set to `True` + # (inside optimizer) and the current key is not an `fp8_weight_update` key. + # For other cases, we need to reduce because of activation tensors. + # TODO(ksivaman) consider separate weight and activation fp8_tensors. + if fwd_update and fp8_weights and not fp8_weights_update: + continue + if len(amax_buffer) == 0: + continue + + # Retrieve autocast specific args and concat amaxes. + recipe, group = cls.autocast_arguments[autocast_key] + contiguous_amax = torch.cat(amax_buffer) + + # Reduction. + if (recipe.reduce_amax + and torch.distributed.is_initialized() + and torch.distributed.get_world_size(group=group) > 1): + cls.reduce_tensor_across_group_op_max(contiguous_amax, group) + + # Amax and scale update. + unfused_update = (bool(int(os.getenv("NVTE_UNFUSED_FP8_UPDATE", "0"))) + or callable(recipe.amax_compute_algo) + or callable(recipe.scaling_factor_compute_algo)) + + if not unfused_update: + tex.fused_amax_and_scale_update_after_reduction( + contiguous_amax, + cls.global_amax_history_buffer[buffer_key], + cls.global_scale_buffer[buffer_key], + cls.global_scale_inv_buffer[buffer_key], + recipe.amax_compute_algo, + get_fp8_te_dtype(recipe, forward), + recipe.margin, + ) else: - if cls.dp_amax_reduce_backward_idx == 0: - reduce_group = fp8_meta["fp8_group"] - else: - tp_amax_reduce = True - cls.dp_amax_reduce_backward_idx = ( - (cls.dp_amax_reduce_backward_idx + 1) % cls.dp_amax_reduce_interval) + split_and_copy(contiguous_amax, amax_buffer, [x.numel() for x in amax_buffer]) - if tp_amax_reduce: - if tp_size > 1: - reduce_group = tp_group - else: - return None + for amax_history, scale, scale_inv in zip( + cls.global_amax_history_buffer[buffer_key], + cls.global_scale_buffer[buffer_key], + cls.global_scale_inv_buffer[buffer_key], + ): + _amax_and_scale_update( + amax_history, scale, scale_inv, get_fp8_max(recipe, forward), recipe) - chunk_sizes = [x.numel() for x in cls.global_fp8_buffer[amax_buffer_key]] - contiguous_amax = torch.cat(cls.global_fp8_buffer[amax_buffer_key]) + @classmethod + def add_tensor_for_bwd_reduction_multi_grad_hook(cls, tensor): + """Add tensor to list for multi grad hook.""" + cls.multi_grad_hook_tensors.append(tensor) - wait_handle = cls.reduce_tensor_across_group_op_max( - contiguous_amax, - reduce_group, - fp8_meta["async_amax_reduction"], - ) + @classmethod + def hook_for_bwd_amax_reduction(cls, grads: Tuple[torch.Tensor]) -> None: # pylint: disable=unused-argument + """Executes at the end of backward pass.""" + cls.reduce_and_update_fp8_tensors(forward=False) - cls.global_fp8_buffer[amax_buffer_key] = list(contiguous_amax.split(chunk_sizes)) - return wait_handle + @classmethod + def get_unique_autocast_key( + cls, + recipe: Optional[DelayedScaling] = None, + group: Optional[dist_group_type] = None, + ): + """ + For FP8, each autocast can be uniquely identified by the recipe and fp8 group. + Safely using `hash` as we never cross checkpoint boundaries. + """ + return f"{str(recipe)}:{hash(group)}" @classmethod def fp8_autocast_enter( @@ -404,21 +399,29 @@ def fp8_autocast_enter( calibrating: bool = False, fp8_recipe: Optional[DelayedScaling] = None, fp8_group: Optional[dist_group_type] = None, + _graph: bool = False, ) -> None: """Set state and tracking variables for entry into FP8 region.""" - if cls.FP8_AUTOCAST_DEPTH == 0: - if callable(cls.amax_forward_global_reduce_func): - cls.amax_reduce_handle_fwd = cls.amax_forward_global_reduce_func() # pylint: disable=not-callable - cls.delete_key_from_amax_buffer(forward=True) + + fp8_recipe = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe + autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group) + cls.autocast_arguments[autocast_key] = (fp8_recipe, fp8_group) + + if enabled and cls.FP8_AUTOCAST_DEPTH == 0 and not _graph and torch.is_grad_enabled(): + if not cls.bwd_amax_update_hook_registered and len(cls.multi_grad_hook_tensors) > 0: + # This hook does not fire for graphed modules. + torch.autograd.graph.register_multi_grad_hook( + tuple(cls.multi_grad_hook_tensors), cls.hook_for_bwd_amax_reduction) + cls.bwd_amax_update_hook_registered = True cls.FP8_ENABLED = enabled cls.FP8_CALIBRATION = calibrating - cls.FP8_RECIPE = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe + cls.FP8_RECIPE = fp8_recipe cls.FP8_DISTRIBUTED_GROUP = fp8_group + cls.FP8_GRAPH_CAPTURING = _graph if cls.FP8_AUTOCAST_DEPTH == 0: cls.IS_FIRST_FP8_MODULE = True - cls.FP8_AUTOCAST_COUNTER += 1 cls.FP8_AUTOCAST_DEPTH += 1 if enabled: @@ -426,9 +429,14 @@ def fp8_autocast_enter( assert fp8_available, reason_for_no_fp8 @classmethod - def fp8_autocast_exit(cls): + def fp8_autocast_exit(cls, enabled: bool, _graph: bool) -> None: """Set state and tracking variables for exit from FP8 region.""" cls.FP8_AUTOCAST_DEPTH -= 1 + # Reduce only the non-FP8 weight modules here. + # FP8 weight modules are reduced at the end of the optimizer + # step after the weight amax is populated. + if enabled and cls.FP8_AUTOCAST_DEPTH == 0 and not _graph and torch.is_grad_enabled(): + cls.reduce_and_update_fp8_tensors(forward=True, fp8_weights=False) @classmethod def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None: @@ -525,6 +533,7 @@ def fp8_autocast( calibrating: bool = False, fp8_recipe: Optional[DelayedScaling] = None, fp8_group: Optional[dist_group_type] = None, + _graph: bool = False, ) -> None: """ Context manager for FP8 usage. @@ -568,23 +577,25 @@ def fp8_autocast( FP8GlobalStateManager.fp8_autocast_enter(enabled=enabled, calibrating=calibrating, fp8_recipe=fp8_recipe, - fp8_group=fp8_group) + fp8_group=fp8_group, + _graph=_graph) yield finally: FP8GlobalStateManager.set_fp8_autocast_state(fp8_state) # pylint: disable=used-before-assignment - FP8GlobalStateManager.fp8_autocast_exit() + FP8GlobalStateManager.fp8_autocast_exit(enabled, _graph=_graph) def _update_amax_history(amax_history: torch.Tensor) -> torch.Tensor: """Update amax history and set next amax to zero.""" if amax_history.shape[0] > 1: - amax_history = torch.roll(amax_history, -1, 0) + new_amax_history = torch.roll(amax_history, -1, 0) + amax_history.copy_(new_amax_history) amax_history[0].fill_(0.0) return amax_history @torch.jit.script -def _default_get_amax( +def _default_get_amax_and_update_history( amax_history: torch.Tensor, amax_compute_algo: str, ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -609,63 +620,23 @@ def _default_sf_compute( sf = (fp8_max / amax) / (2 ** margin) sf = torch.where(amax > 0.0, sf, scale) sf = torch.where(torch.isfinite(amax), sf, scale) - return sf - - -@jit_fuser -def _compute_scaling_factor_inverse( - scale: torch.Tensor, - scale_inv: torch.Tensor, - non_weight_mask: torch.Tensor, - update_weight_scale_inv: bool, -) -> torch.Tensor: - """Compute inverse of scaling factor.""" - if update_weight_scale_inv: - return 1.0 / scale - return torch.where(non_weight_mask, 1.0 / scale, scale_inv) - - -def _fused_amax_and_scale_update( - amax_history: torch.Tensor, - scale: torch.Tensor, - scale_inv: torch.Tensor, - fp8_dtype: tex.DType, - margin: int, - amax_compute_algo: str, - non_weight_mask: torch.Tensor, - update_weight_scale_inv: bool, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Update amax history and FP8 scaling factors""" - if update_weight_scale_inv: - non_weight_mask = torch.Tensor() - tex.fused_amax_and_scale_update( - amax_history, - scale, - scale_inv, - non_weight_mask, - amax_history, - scale, - scale_inv, - amax_compute_algo, - fp8_dtype, - margin, - ) - return amax_history, scale, scale_inv + scale.copy_(sf) + return scale -def _compute_amax( +def _compute_amax_and_update_history( amax_history: torch.Tensor, - recipe: DelayedScaling, + amax_compute_algo: Union[Callable, str], ) -> Tuple[torch.Tensor, torch.Tensor]: """Obtain the amax from the history.""" - if callable(recipe.amax_compute_algo): - amax = recipe.amax_compute_algo(amax_history) + if callable(amax_compute_algo): + amax = amax_compute_algo(amax_history) amax_history = _update_amax_history(amax_history) return amax_history, amax - return _default_get_amax( + return _default_get_amax_and_update_history( amax_history, - recipe.amax_compute_algo, + amax_compute_algo, ) @@ -687,46 +658,29 @@ def _compute_scaling_factor( return recipe.scaling_factor_compute_algo(amax, scale, fp8_max, recipe) -def amax_and_scale_update( - fp8_meta: Dict[str, Any], - fwd_update: bool, - update_weight_scale_inv: bool = True, +def _amax_and_scale_update( + amax_history: torch.Tensor, + scale: torch.Tensor, + scale_inv: torch.Tensor, + fp8_max: float, + recipe: DelayedScaling, ) -> None: - """Updates fp8 amaxes/scales for fwd | bwd.""" - amax_compute = fp8_meta["recipe"].amax_compute_algo - sf_compute = fp8_meta["recipe"].scaling_factor_compute_algo - fp8_meta_tensor_key = "scaling_fwd" if fwd_update else "scaling_bwd" - fp8_max_key = "fp8_max_fwd" if fwd_update else "fp8_max_bwd" - - if not callable(amax_compute) and sf_compute is None: - ( - fp8_meta[fp8_meta_tensor_key].amax_history, - fp8_meta[fp8_meta_tensor_key].scale, - fp8_meta[fp8_meta_tensor_key].scale_inv, - ) = _fused_amax_and_scale_update( - fp8_meta[fp8_meta_tensor_key].amax_history, - fp8_meta[fp8_meta_tensor_key].scale, - fp8_meta[fp8_meta_tensor_key].scale_inv, - get_fp8_te_dtype(fp8_meta["recipe"], fwd_update), - fp8_meta["recipe"].margin, - fp8_meta["recipe"].amax_compute_algo, - fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"], - update_weight_scale_inv, - ) - else: - fp8_meta[fp8_meta_tensor_key].amax_history, amax = _compute_amax( - fp8_meta[fp8_meta_tensor_key].amax_history, - fp8_meta["recipe"], - ) - fp8_meta[fp8_meta_tensor_key].scale = _compute_scaling_factor( - amax, - fp8_meta[fp8_meta_tensor_key].scale, - fp8_meta[fp8_max_key], - fp8_meta["recipe"], - ) - fp8_meta[fp8_meta_tensor_key].scale_inv = _compute_scaling_factor_inverse( - fp8_meta[fp8_meta_tensor_key].scale, - fp8_meta[fp8_meta_tensor_key].scale_inv, - fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"], - update_weight_scale_inv, - ) + """Updates FP8 meta tensors.""" + new_amax_history, amax = _compute_amax_and_update_history( + amax_history, + recipe.amax_compute_algo, + ) + new_scale = _compute_scaling_factor(amax, scale, fp8_max, recipe) + scale.copy_(new_scale) + scale_inv.copy_(1.0 / new_scale) + amax_history.copy_(new_amax_history) + + +def split_and_copy( + buffer: torch.Tensor, + outputs: List[torch.Tensor], + chunk_sizes: List[int], +) -> None: + """Split `buffer` by `chunk_sizes` and copy into `outputs`.""" + splits = buffer.split(chunk_sizes) + torch._foreach_copy_(outputs, splits) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py new file mode 100644 index 0000000000..5de3b7a342 --- /dev/null +++ b/transformer_engine/pytorch/graph.py @@ -0,0 +1,548 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Functions for CUDA Graphs support in FP8""" +import torch +from torch.utils._pytree import tree_flatten as _tree_flatten +from torch.utils._pytree import tree_unflatten as _tree_unflatten +from torch._C import _graph_pool_handle + +from .fp8 import ( + fp8_autocast, + FP8GlobalStateManager, + get_default_fp8_recipe, +) +from .distributed import get_all_rng_states, graph_safe_rng_available +from .module.base import TransformerEngineBaseModule + + +__all__ = ["make_graphed_callables"] + + +_IS_GRAPH_CAPTURING = False + + +def set_capture_start() -> None: + """Record beginning of `make_graphed_callables`.""" + global _IS_GRAPH_CAPTURING + _IS_GRAPH_CAPTURING = True + + +def set_capture_end() -> None: + """Record end of `make_graphed_callables`.""" + global _IS_GRAPH_CAPTURING + _IS_GRAPH_CAPTURING = False + + +def is_graph_capturing() -> None: + """Return whether within `make_graphed_callables`.""" + return _IS_GRAPH_CAPTURING + + +def graph_pool_handle(): + """ + Returns an opaque token representing the id of a graph memory pool. + """ + return _graph_pool_handle() + + +def _make_graphed_callables( + callables, + sample_args, + num_warmup_iters=3, + allow_unused_input=False, + fp8_weight_caching=False, + _order=None, +): + """ + Helper method for `make_graphed_callables` + """ + + if torch.is_autocast_enabled() and torch.is_autocast_cache_enabled(): + raise RuntimeError( + "make_graphed_callables does not support the autocast " + "caching. Please set `cache_enabled=False`." + ) + + just_one_callable = False + + if not isinstance(callables, tuple): + just_one_callable = True + callables = (callables,) + sample_args = (sample_args,) + + flatten_sample_args = [] + if _order is not None: + # order is a list containing 1..model_chunk values in the order of microbatch schedule + num_model_chunks = max(_order) + num_microbatches = len(_order) // num_model_chunks // 2 + assert num_model_chunks * num_microbatches * 2 == len(_order) + assert ( + len(sample_args)*2 >= len(_order) + and (len(sample_args)*2 % len(_order) == 0) + ), f'{len(sample_args)} >= {len(_order)} and {len(sample_args)} % {len(_order)} == 0' + num_layers = len(sample_args) // num_model_chunks // num_microbatches + assert ( + len(callables) == num_model_chunks*num_layers + ), (f"Callables should have ({num_model_chunks * num_layers}) " + + f"entries when order input is provided but got {len(callables)}." + ) + assert ( + len(sample_args) == num_model_chunks * num_microbatches * num_layers + ), (f"Expected {num_model_chunks * num_microbatches}" + + f"args tuple, but got {len(sample_args)}." + ) + + if fp8_weight_caching: + FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(False) + + for c in callables: + if isinstance(c, torch.nn.Module): + assert ( + len(c._backward_hooks) == 0 + and len(c._forward_hooks) == 0 + and len(c._forward_pre_hooks) == 0 + ), ( + "Modules must not have hooks registered at the time they are passed. " + + "However, registering hooks on modules after passing them " + + "through make_graphed_callables is allowed." + ) + assert all(b.requires_grad is False for b in c.buffers()), ( + "In any :class:`~torch.nn.Module` passed to " + + ":func:`~make_graphed_callables`, only parameters may be trainable. " + + "All buffers must have ``requires_grad=False``." + ) + for args in sample_args: + flatten_arg, _ = _tree_flatten(args) + flatten_sample_args.append(tuple(flatten_arg)) + assert all(isinstance(arg, torch.Tensor) for arg in flatten_arg), ( + "In the beta API, sample_args " + + "for each callable must contain only Tensors. Other types are not allowed." + ) + + # If a callable is an nn.Module, its graph's full input surface is the args the user explicitly + # passes to forward (ie, its sample_args) AND the module's parameter attributes. + per_callable_len_user_args = [len(args) for args in flatten_sample_args] + if _order is None: + per_callable_module_params = [ + tuple(c.parameters()) if isinstance(c, torch.nn.Module) else () + for c in callables + ] + per_callable_static_input_surfaces = [ + flatten_sample_args[i] + per_callable_module_params[i] + for i in range(len(callables)) + ] + else: + per_callable_module_params = [] + for c in callables: + for i in range(num_microbatches): + per_callable_module_params.append( + tuple(c.parameters()) if isinstance(c, torch.nn.Module) else () + ) + assert len(per_callable_module_params) == len(flatten_sample_args) + per_callable_static_input_surfaces = [ + flatten_sample_args[i] + per_callable_module_params[i] + for i in range(len(flatten_sample_args)) + ] + + fwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(flatten_sample_args))] + bwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(flatten_sample_args))] + graph_callables = [None for _ in range(len(flatten_sample_args))] + # For cases with multiple active RNG states, e.g. TP. + if graph_safe_rng_available(): + for _, state in get_all_rng_states().items(): + for fwd_graph, bwd_graph in zip(fwd_graphs, bwd_graphs): + fwd_graph.register_generator_state(state) + bwd_graph.register_generator_state(state) + + mempool = graph_pool_handle() + + # Warmup + # Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work + # from ending up in any captures. + torch.cuda.synchronize() + with torch.cuda.stream(torch.cuda.Stream()): + for c_i, func in enumerate(callables): + args = sample_args[c_i] + static_input_surface = per_callable_static_input_surfaces[c_i] + for _ in range(num_warmup_iters): + outputs, _ = _tree_flatten(func(*args)) + grad_inputs = torch.autograd.grad( + outputs=tuple(o for o in outputs if o.requires_grad), + inputs=tuple(i for i in static_input_surface if i.requires_grad), + grad_outputs=tuple( + torch.empty_like(o) for o in outputs if o.requires_grad + ), + only_inputs=True, + allow_unused=allow_unused_input, + ) + del outputs, grad_inputs + torch.cuda.synchronize() + + # All captures here share a mempool. To avoid replays corrupting each other's memory, + # the safest approach is to capture all passes in the same order they'll run: + # fwd 1, fwd 2, ... fwd N, then bwd N, bwd N-1, ... bwd 1. + + if _order is not None: # pylint: disable=too-many-nested-blocks + per_callable_static_outputs = [None] * len(flatten_sample_args) + per_callable_output_unflatten_spec = [None] * len(flatten_sample_args) + per_callable_static_grad_outputs = [None] * len(flatten_sample_args) + per_callable_static_grad_inputs = [None] * len(flatten_sample_args) + fwd_idx = [0] * num_model_chunks + bwd_idx = [0] * num_model_chunks + for c_id in _order: + if c_id > 0: + # Capture forward graph for model chunk c_id, microbatch fwd_idx[c_id-1] + m_chunk = c_id-1 + for l_no in range(num_layers): + func = callables[m_chunk*num_layers + l_no] + per_callable_fwd_idx = (m_chunk * num_microbatches * num_layers) \ + + (fwd_idx[m_chunk] * num_layers + l_no) + args = sample_args[per_callable_fwd_idx] + fwd_graph = fwd_graphs[per_callable_fwd_idx] + with torch.cuda.graph(fwd_graph, pool=mempool): + outputs = func(*args) + flatten_outputs, spec = _tree_flatten(outputs) + per_callable_static_outputs[per_callable_fwd_idx] = tuple(flatten_outputs) + per_callable_output_unflatten_spec[per_callable_fwd_idx] = spec + graph_callables[per_callable_fwd_idx] = func + fwd_idx[m_chunk] += 1 + else: + # Capture backward graph for model chunk c_id, microbatch bwd_idx[-c_id-1] + m_chunk = -c_id-1 + for l_no in list(reversed(range(num_layers))): + per_callable_bwd_idx = (m_chunk * num_microbatches * num_layers) \ + + (bwd_idx[m_chunk] * num_layers + l_no) + static_input_surface = per_callable_static_input_surfaces[per_callable_bwd_idx] + static_outputs = per_callable_static_outputs[per_callable_bwd_idx] + bwd_graph = bwd_graphs[per_callable_bwd_idx] + # For now, assumes all static_outputs require grad + static_grad_outputs = tuple( + torch.empty_like(o) if o.requires_grad else None for o in static_outputs + ) + with torch.cuda.graph(bwd_graph, pool=mempool): + grad_inputs = torch.autograd.grad( + outputs=tuple(o for o in static_outputs if o.requires_grad), + inputs=tuple(i for i in static_input_surface if i.requires_grad), + grad_outputs=tuple(o for o in static_grad_outputs if o is not None), + only_inputs=True, + allow_unused=allow_unused_input, + ) + # Constructs a tuple suitable for returning from Graphed.backward: + # Pads out the actually-needed grads with Nones in gradient slots for inputs + # that don't require grad. I couldn't think of a one-liner for this pattern. + static_grad_inputs = [] + grad_idx = 0 + for arg in static_input_surface: + if arg.requires_grad: + static_grad_inputs.append(grad_inputs[grad_idx]) + grad_idx += 1 + else: + static_grad_inputs.append(None) # type: ignore[arg-type] + static_grad_inputs = tuple(static_grad_inputs) # type: ignore[assignment] + + per_callable_static_grad_outputs[per_callable_bwd_idx] = static_grad_outputs + per_callable_static_grad_inputs[per_callable_bwd_idx] = static_grad_inputs + bwd_idx[m_chunk] += 1 + else: + # Capture forward graphs + per_callable_static_outputs = [] + per_callable_output_unflatten_spec = [] + graph_id = 0 + for func, args, fwd_graph in zip(callables, sample_args, fwd_graphs): + with torch.cuda.graph(fwd_graph, pool=mempool): + outputs = func(*args) + graph_callables[graph_id] = func + graph_id += 1 + + flatten_outputs, spec = _tree_flatten(outputs) + per_callable_static_outputs.append(tuple(flatten_outputs)) + per_callable_output_unflatten_spec.append(spec) + + # Capture backward graphs in reverse order + per_callable_static_grad_outputs = [] + per_callable_static_grad_inputs = [] + for static_input_surface, static_outputs, bwd_graph in zip( + reversed(per_callable_static_input_surfaces), + reversed(per_callable_static_outputs), + reversed(bwd_graphs), + ): + # For now, assumes all static_outputs require grad + static_grad_outputs = tuple( + torch.empty_like(o) if o.requires_grad else None for o in static_outputs + ) + with torch.cuda.graph(bwd_graph, pool=mempool): + grad_inputs = torch.autograd.grad( + outputs=tuple(o for o in static_outputs if o.requires_grad), + inputs=tuple(i for i in static_input_surface if i.requires_grad), + grad_outputs=tuple(o for o in static_grad_outputs if o is not None), + only_inputs=True, + allow_unused=allow_unused_input, + ) + # Constructs a tuple suitable for returning from Graphed.backward: + # Pads out the actually-needed grads with Nones in gradient slots for inputs that + # don't require grad. I couldn't think of a slick one-liner for this pattern. + static_grad_inputs = [] + grad_idx = 0 + for arg in static_input_surface: + if arg.requires_grad: + static_grad_inputs.append(grad_inputs[grad_idx]) + grad_idx += 1 + else: + static_grad_inputs.append(None) # type: ignore[arg-type] + static_grad_inputs = tuple(static_grad_inputs) # type: ignore[assignment] + + per_callable_static_grad_outputs.append(static_grad_outputs) + per_callable_static_grad_inputs.append(static_grad_inputs) + + # Reverses the most recent two lists + per_callable_static_grad_outputs = list(reversed(per_callable_static_grad_outputs)) + per_callable_static_grad_inputs = list(reversed(per_callable_static_grad_inputs)) + # Now for every per_callable list, per_callable_*[i] holds the stuff for the ith callable. + + def make_graphed_autograd_function( + fwd_graph, + bwd_graph, + module_params, + len_user_args, + output_unflatten_spec, + static_input_surface, + static_outputs, + static_grad_outputs, + static_grad_inputs, + ): + class Graphed(torch.autograd.Function): + """Autograd function for graph replay.""" + @staticmethod + def forward(ctx, skip_fp8_weight_update, *inputs): + # At this stage, only the user args may (potentially) be new tensors. + ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() + if ctx.is_first_module and skip_fp8_weight_update is not None: + FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(skip_fp8_weight_update) + + for i in range(len_user_args): + if static_input_surface[i].data_ptr() != inputs[i].data_ptr(): + static_input_surface[i].copy_(inputs[i]) + fwd_graph.replay() + assert isinstance(static_outputs, tuple) + return tuple(o.detach() for o in static_outputs) + + @staticmethod + @torch.autograd.function.once_differentiable + def backward(ctx, *grads): + assert len(grads) == len(static_grad_outputs) + for g, grad in zip(static_grad_outputs, grads): + if g is not None: + # don't copy if autograd gods have been kind and the + # incoming grad is already in the right place + if g.data_ptr() != grad.data_ptr(): + g.copy_(grad) + bwd_graph.replay() + + if ctx.is_first_module: + FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) + + # Input args that didn't require grad expect a None gradient. + assert isinstance(static_grad_inputs, tuple) + return (None,) + tuple( + b.detach() if b is not None else b for b in static_grad_inputs + ) + + def functionalized(*user_args, **user_kwargs): + # Runs the autograd function with inputs == all + # inputs to the graph that might require grad + # (explicit user args + module parameters) + # Assumes module params didn't change since capture. + skip_fp8_weight_update = None + if fp8_weight_caching: + assert ( + ("is_first_microbatch" in user_kwargs + and isinstance(user_kwargs["is_first_microbatch"], bool)) + ), "`is_first_microbatch` boolean kwarg must be provided for FP8 weight caching." + + skip_fp8_weight_update = not user_kwargs["is_first_microbatch"] + + flatten_user_args, _ = _tree_flatten(user_args) + out = Graphed.apply(skip_fp8_weight_update, *(tuple(flatten_user_args) + module_params)) + return _tree_unflatten(out, output_unflatten_spec) + + return functionalized + + # Put together the final graphed callables + ret = [] + for i in range(len(sample_args)): + graphed = make_graphed_autograd_function( + fwd_graphs[i], + bwd_graphs[i], + per_callable_module_params[i], + per_callable_len_user_args[i], + per_callable_output_unflatten_spec[i], + per_callable_static_input_surfaces[i], + per_callable_static_outputs[i], + per_callable_static_grad_outputs[i], + per_callable_static_grad_inputs[i], + ) + + func = graph_callables[i] + if isinstance(func, torch.nn.Module): + + def make_graphed_forward(func, graph_training_state, graphed, orig_fwd): + def new_fwd(*user_args, **user_kwargs): + # If the module's training-or-eval state matches what we graphed, + # run the graph, otherwise run the original forward method + if func.training == graph_training_state: + # Set the FP8 group from global amax reduction. + for m in func.modules(): + if (isinstance(m, TransformerEngineBaseModule) + and FP8GlobalStateManager.is_fp8_enabled()): + m.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() + m.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() + FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( + m.fp8_meta, fp8_weights=m._get_fp8_params()) + return graphed(*user_args, **user_kwargs) + return orig_fwd(*user_args, **user_kwargs) + return new_fwd + + forward = make_graphed_forward(func, func.training, graphed, func.forward) + if _order is None: + func.forward = forward + ret.append(func) + else: + ret.append(forward) + else: + ret.append(graphed) + + if just_one_callable: + return ret[0] + + return tuple(ret) + + +def save_fp8_tensors(modules, amax_history_len): + """ + Returns the FP8 tensors for all modules + with adjusted amax history sizes. + """ + saved_fp8_meta_tensors = [] + for module in modules: + for m in module.modules(): + if isinstance(m, TransformerEngineBaseModule): + if m.primary_weights_in_fp8: + m.adjust_amax_history_length(amax_history_len) + saved_fp8_meta_tensors.append(m.get_fp8_meta_tensors()) + return saved_fp8_meta_tensors + + +def restore_fp8_tensors(modules, fp8_tensors): + """Restore FP8 tensors.""" + for module in modules: + for m in module.modules(): + if isinstance(m, TransformerEngineBaseModule): + m.reset_fp8_meta_tensors(fp8_tensors.pop(0)) + assert len(fp8_tensors) == 0, "TE internal error." + + +def make_graphed_callables( + modules, + sample_args, + num_warmup_iters=3, + allow_unused_input=False, + fp8_enabled=False, + fp8_calibrating=False, + fp8_recipe=None, + fp8_weight_caching=False, + _order=None, +): + """ + A version of PyTorch's `make_graphed_callables` utility function with support for + TransformerEngine modules and FP8. Please see the original version in upstream PyTorch + `here `_ + for extensive documentation. The documentation for additional parameters which are + specific to FP8 are given below. + + FP8 specific parameters + ----------------------- + fp8_enabled: bool, default = `True` + whether or not to enable fp8 + fp8_calibrating: bool, default = `False` + calibration mode allows collecting statistics such as amax and scale + data of fp8 tensors even when executing without fp8 enabled. This is + useful for saving an inference ready fp8 checkpoint while training + using a higher precision. + fp8_recipe: recipe.DelayedScaling, default = `None` + recipe used for FP8 training. + fp8_weight_caching: bool, default = `False` + Whether or not to cache FP8 weights across microbatches. if set to `True`, + the `is_first_microbatch` boolean argument must be passed into the forward + method for TransformerEngine modules. When storing primary weights in FP8 + using TE's `fp8_model_init` API and using an FP8 aware optimizer, this arg + must be set to `False` if calculating weight transposes' outside TE, e.g., + in the optimizer step. + """ + set_capture_start() + + fp8_recipe = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe + + # Handle single module. + just_one_callable = False + if not isinstance(modules, tuple): + just_one_callable = True + modules = (modules,) + + # Store FP8 tensors to reset later. + saved_fp8_tensors = save_fp8_tensors(modules, fp8_recipe.amax_history_len) + + # FP8 wrapper. + def wrap_autocast(block): + old_forward = block.forward + def forward_func(*args, **kwargs): + with fp8_autocast(enabled=fp8_enabled, + calibrating=fp8_calibrating, + fp8_recipe=fp8_recipe, + _graph=True): + outputs = old_forward(*args, **kwargs) + return outputs + block.forward = forward_func + + forward_funcs = [] + for module in modules: + assert isinstance(module, torch.nn.Module), f"Graphing for {type(module)} is not supported." + wrap_autocast(module) + forward_funcs.append(module) + + if just_one_callable: + forward_funcs = forward_funcs[0] + else: + forward_funcs = tuple(forward_funcs) + + # Save RNG state. + if graph_safe_rng_available(): + generators = [torch.cuda.default_generators[torch.cuda.current_device()], + *get_all_rng_states().values()] + original_rng_states = [state.get_state() for state in generators] + else: + original_rng_states = torch.cuda.get_rng_state() + + graphed_callables = _make_graphed_callables( + forward_funcs, sample_args, num_warmup_iters=num_warmup_iters, + allow_unused_input=allow_unused_input, + fp8_weight_caching=fp8_weight_caching, _order=_order) + + # Ensures warmup does not affect numerics for ops such as dropout. + if graph_safe_rng_available(): + for gen, state in zip(generators, original_rng_states): + gen.set_state(state) + else: + torch.cuda.set_rng_state(original_rng_states) + + # Reset FP8 gradients. + for module in modules: + for p in module.parameters(): + p.grad = None + + # Restore FP8 state. + restore_fp8_tensors(modules, saved_fp8_tensors) + + set_capture_end() + return graphed_callables diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 56dd3c8fc4..7e0cf5c106 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -8,8 +8,7 @@ import pickle import warnings from abc import ABC, abstractmethod -from typing import Generator, Union, Optional, Tuple, Dict, Any, List -from functools import partial +from typing import Generator, Union, Optional, Tuple, List from contextlib import contextmanager import torch @@ -22,13 +21,11 @@ get_default_fp8_recipe, get_fp8_te_dtype, FP8GlobalStateManager, - amax_and_scale_update, ) from ..distributed import ( gather_along_first_dim, is_fp8_activation_recompute_enabled, in_fp8_activation_recompute_phase, - get_distributed_world_size, ) from ..cpp_extensions import ( fp8_cast_transpose_fused, @@ -44,7 +41,6 @@ _cublas_workspace = None _ub_communicators = None _NUM_MAX_UB_STREAMS = 3 -_amax_reduce_handle_bwd = None layers_atomic_ring_exchange = [] @@ -64,49 +60,6 @@ def get_workspace() -> torch.Tensor: ) return _cublas_workspace -@contextmanager -def _prepare_backward( - fp8: bool, - fp8_meta: Dict[str, Any], - tp_group: dist_group_type, - tp_size: int, - name: str = "" -) -> Generator[None, None, None]: - """Checks and prep for BWD.""" - if fp8: - global _amax_reduce_handle_bwd - if _amax_reduce_handle_bwd is not None: - _amax_reduce_handle_bwd.wait() - _amax_reduce_handle_bwd = None - - # Update amax and scale; Skip all setup for global amax reduction - if fp8_meta["recipe"].reduce_amax and get_distributed_world_size(fp8_meta["fp8_group"]) > 1: - # From previous iteration - FP8GlobalStateManager.copy_amax_from_global_buffer(fp8_meta, forward=False) - amax_and_scale_update(fp8_meta, False) - FP8GlobalStateManager.set_amax_buffer_key_deletion(fp8_meta, forward=False) - - # Get new backward key. - fp8_meta["autocast_id_bwd"] = fp8_meta["autocast_id_fwd_stack"].pop(0) - - FP8GlobalStateManager.add_amax_to_global_buffer(fp8_meta, forward=False) - else: - amax_and_scale_update(fp8_meta, False) - - with torch.cuda.nvtx.range(name + " backward"): - yield - - if (fp8 and fp8_meta["recipe"].reduce_amax - and get_distributed_world_size(fp8_meta["fp8_group"]) > 1): - if fp8_meta["first_module"]: - _amax_reduce_handle_bwd = FP8GlobalStateManager.global_amax_reduction( - fp8_meta, - tp_group, - tp_size, - forward=False - ) - FP8GlobalStateManager.delete_key_from_amax_buffer(forward=False) - def initialize_ub( shape: list, @@ -300,31 +253,54 @@ def __init__(self) -> None: self.tp_size = 1 self.sequence_parallel = False self.fp8_weight_shapes = [] - self.fp8_meta["autocast_id_fwd_stack"] = [] - self.fp8_meta["async_amax_reduction"] = bool( - int(os.getenv("NVTE_ASYNC_AMAX_REDUCTION", "0")) - ) self.param_init_meta = {} self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters() + def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None: + """Increase or decrease size of amax history based on given `length`. + + .. warning:: + This changes the underlying amax memory location. + """ + if fwd is None: + fp8_meta_tensor_keys = ("scaling_fwd", "scaling_bwd") + else: + fp8_meta_tensor_keys = ("scaling_fwd" if fwd else "scaling_bwd",) + + for meta_key in fp8_meta_tensor_keys: + curr_len = self.fp8_meta[meta_key].amax_history.shape[0] + if length == curr_len: + continue + if length < curr_len: + self.fp8_meta[meta_key].amax_history = ( + self.fp8_meta[meta_key].amax_history[: length].clone()) + elif length > curr_len: + extra_rows = length - curr_len + self.fp8_meta[meta_key].amax_history = F.pad( + self.fp8_meta[meta_key].amax_history, pad=(0, 0, 0, extra_rows) + ) + + # Update the global buffers with new amax and history pointers. + if FP8GlobalStateManager.get_buffer_info() in self.fp8_meta: + fwd_pos, fwd_key, bwd_pos, bwd_key = ( + self.fp8_meta[FP8GlobalStateManager.get_buffer_info()]) + for pos, buffer_key in zip((fwd_pos, bwd_pos), (fwd_key, bwd_key)): + if buffer_key in FP8GlobalStateManager.global_amax_buffer: + assert ( + buffer_key in FP8GlobalStateManager.global_amax_history_buffer + ), "TE internal error during amax history change." + FP8GlobalStateManager.global_amax_buffer[buffer_key][pos] = ( + self.fp8_meta[meta_key].amax_history[0]) + FP8GlobalStateManager.global_amax_history_buffer[buffer_key][pos] = ( + self.fp8_meta[meta_key].amax_history) + def set_meta_tensor(self, fwd: bool) -> None: """Init scales and amaxes for fwd | bwd.""" fp8_meta_tensor_key = "scaling_fwd" if fwd else "scaling_bwd" if self.fp8_meta_tensors_initialized: # Handle changed amax history size. - curr_len = self.fp8_meta[fp8_meta_tensor_key].amax_history.shape[0] - need_len = self.fp8_meta["recipe"].amax_history_len - if need_len < curr_len: - self.fp8_meta[fp8_meta_tensor_key].amax_history = ( - self.fp8_meta[fp8_meta_tensor_key] - .amax_history[: self.fp8_meta["recipe"].amax_history_len].clone() - ) - elif need_len > curr_len: - extra_rows = need_len - curr_len - self.fp8_meta[fp8_meta_tensor_key].amax_history = F.pad( - self.fp8_meta[fp8_meta_tensor_key].amax_history, pad=(0, 0, 0, extra_rows) - ) + self.adjust_amax_history_length(self.fp8_meta["recipe"].amax_history_len, fwd=fwd) return # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and @@ -347,25 +323,45 @@ def set_meta_tensor(self, fwd: bool) -> None: device="cuda", ) - # Needed for calculation of scale inverses to - # preserve scale_inv when caching FP8 weights - if fwd: - # [True, False, True]: -> [input, weight, output] - self.fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"] = torch.BoolTensor( - [True, False, True] * self.fp8_meta["num_gemms"] - ).cuda() - else: - # [True, True]: -> [grad_output, grad_input] - self.fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"] = torch.BoolTensor( - [True, True] * self.fp8_meta["num_gemms"] - ).cuda() - def init_fp8_meta_tensors(self) -> None: """Init scales and amaxes.""" self.set_meta_tensor(True) self.set_meta_tensor(False) self.fp8_meta_tensors_initialized = True + def get_fp8_meta_tensors(self) -> None: + """Get scales and amaxes.""" + fwd_key, bwd_key = "scaling_fwd", "scaling_bwd" + if fwd_key not in self.fp8_meta or bwd_key not in self.fp8_meta: + return None + + fp8_meta_tensors = {fwd_key: [], bwd_key: []} + with torch.no_grad(): + for key in (fwd_key, bwd_key): + fp8_meta_tensors[key].append(self.fp8_meta[key].scale.clone()) + fp8_meta_tensors[key].append(self.fp8_meta[key].scale_inv.clone()) + fp8_meta_tensors[key].append(self.fp8_meta[key].amax_history.clone()) + return fp8_meta_tensors + + def reset_fp8_meta_tensors(self, fp8_meta_tensors=None) -> None: + """Reset scales and amaxes.""" + def reset(key): + if key in self.fp8_meta: + if fp8_meta_tensors is None: + self.fp8_meta[key].scale.copy_(torch.ones_like(self.fp8_meta[key].scale)) + self.fp8_meta[key].scale_inv.copy_( + torch.ones_like(self.fp8_meta[key].scale_inv)) + self.fp8_meta[key].amax_history.copy_( + torch.zeros_like(self.fp8_meta[key].amax_history)) + else: + assert key in fp8_meta_tensors, "Cannot reset fp8 tensors." + self.fp8_meta[key].scale.copy_(fp8_meta_tensors[key][0]) + self.fp8_meta[key].scale_inv.copy_(fp8_meta_tensors[key][1]) + self.fp8_meta[key].amax_history.copy_(fp8_meta_tensors[key][2]) + with torch.no_grad(): + reset("scaling_fwd") + reset("scaling_bwd") + def get_extra_state(self) -> torch.Tensor: """Save before checkpointing.""" state = None @@ -380,13 +376,11 @@ def get_extra_state(self) -> torch.Tensor: state["scale_bwd"] = self.fp8_meta["scaling_bwd"].scale state["scale_inv_bwd"] = self.fp8_meta["scaling_bwd"].scale_inv state["amax_history_bwd"] = self.fp8_meta["scaling_bwd"].amax_history - state["global_fp8_buffer"] = FP8GlobalStateManager.get_global_fp8_buffer_checkpoint() - state["global_fp8_state"] = FP8GlobalStateManager.get_global_fp8_state_checkpoint() # Store other pickelable values. extra = {} for k, v in self.fp8_meta.items(): - if isinstance(v, (bool, int, float, str, list)): + if isinstance(v, (bool, int, float, str, tuple, list)): extra[k] = v state["extra_fp8_variables"] = extra @@ -414,11 +408,6 @@ def set_extra_state(self, state: torch.Tensor) -> None: if state is None: return - # Restore global FP8 amax buffer. - FP8GlobalStateManager.set_global_fp8_buffer_checkpoint(state["global_fp8_buffer"]) - # Restore global FP8 state. - FP8GlobalStateManager.set_global_fp8_state_checkpoint(state["global_fp8_state"]) - # Load extra items. self.fp8_meta.update(state["extra_fp8_variables"]) self.fp8_meta["recipe"].amax_history_len = state["amax_history_fwd"].shape[0] @@ -527,6 +516,16 @@ def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> N self.tp_group = tp_group self.tp_group_initialized = True + def _get_fp8_params(self) -> Union[List[torch.Tensor], None]: + """returns the FP8 weights.""" + fp8_params = [] + for param in self.parameters(): + if isinstance(param, Float8Tensor) and param.requires_grad: + fp8_params.append(param) + if len(fp8_params) == 0: + return None + return fp8_params + # This routine is shared across FP8 and FP8_calibration paths so should not actually # assume FP8 execution. def init_fp8_metadata(self, num_gemms: int = 1) -> None: @@ -576,7 +575,6 @@ def prepare_forward( to setup the forward aggregated amax reduction for every module just in case. The autocast exit will pick up the most recent one. """ - # Activation recomputation is used and this is the second forward phase. if self.fp8 and in_fp8_activation_recompute_phase(): FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta) @@ -594,49 +592,14 @@ def prepare_forward( if is_first_microbatch is not None and not self.primary_weights_in_fp8: self.set_fp8_weights() - update_weight_scale_inv = is_first_microbatch is None or is_first_microbatch if self.fp8 and self.sequence_parallel: assert self.fp8_meta["recipe"].reduce_amax, \ "Amax reduction across tensor parallel group is " \ "necessary when using sequence parallelism with FP8." - # Previous iteration was grad_enabled - if self.fp8_meta.get("update_amax_and_scale_fwd", False): - if (self.fp8_meta["recipe"].reduce_amax - and get_distributed_world_size(self.fp8_meta["fp8_group"]) > 1): - FP8GlobalStateManager.copy_amax_from_global_buffer(self.fp8_meta, forward=True) - amax_and_scale_update( - self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv - ) - FP8GlobalStateManager.set_amax_buffer_key_deletion(self.fp8_meta, forward=True) - else: - amax_and_scale_update( - self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv - ) - - if self.fp8 and self.training: - # Setup for amax reduction - if (self.fp8_meta["recipe"].reduce_amax - and get_distributed_world_size(self.fp8_meta["fp8_group"]) > 1): - self.fp8_meta["first_module"] = FP8GlobalStateManager.is_first_fp8_module() - if self.fp8_meta["first_module"]: - # Wait for the prior AMAX reduction to finish - amax_reduce_handle_fwd = FP8GlobalStateManager.get_amax_reduce_handle_fwd() - if amax_reduce_handle_fwd is not None: - amax_reduce_handle_fwd.wait() - self.fp8_meta["autocast_id_fwd"] = ( - FP8GlobalStateManager.new_fp8_context_id()) - FP8GlobalStateManager.set_fp8_context_id(self.fp8_meta["autocast_id_fwd"]) - else: - self.fp8_meta["autocast_id_fwd"] = ( - FP8GlobalStateManager.get_fp8_context_id()) - self.fp8_meta["autocast_id_fwd_stack"].append( - self.fp8_meta["autocast_id_fwd"] - ) - FP8GlobalStateManager.add_amax_to_global_buffer(self.fp8_meta, forward=True) - self.fp8_meta["update_amax_and_scale_fwd"] = True - else: - self.fp8_meta["update_amax_and_scale_fwd"] = False + if self.fp8 and not FP8GlobalStateManager.fp8_graph_capturing(): + FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( + self.fp8_meta, fp8_weights=self._get_fp8_params()) # Activation recomputation is used and this is the first forward phase. if ( @@ -653,18 +616,6 @@ def prepare_forward( FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta) return - if (self.fp8 and self.training and self.fp8_meta["recipe"].reduce_amax - and get_distributed_world_size(self.fp8_meta["fp8_group"]) > 1): - FP8GlobalStateManager.set_fp8_context_id(self.fp8_meta["autocast_id_fwd"]) - reduce_func = partial( - FP8GlobalStateManager.global_amax_reduction, - self.fp8_meta, - self.tp_group, - self.tp_size, - forward=True - ) - FP8GlobalStateManager.setup_amax_forward_global_reduce_func(reduce_func) - def set_nccl_overlap_warning_if_tp(self) -> None: """When using TP, the NCCL communication needs to be scheduled before the GEMM for there to be a guaranteed overlap. From the diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 985d587e54..8fdd5d1356 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -14,7 +14,6 @@ from .base import ( get_workspace, - _prepare_backward, get_ub, TransformerEngineBaseModule, _2X_ACC_FPROP, @@ -65,6 +64,7 @@ def forward( use_bias: bool, eps: float, is_first_microbatch: Union[bool, None], + skip_fp8_weight_update: Union[torch.Tensor, None], fp8: bool, fp8_calibration: bool, fp8_meta: Dict[str, Any], @@ -89,6 +89,7 @@ def forward( ub_overlap_rs_dgrad: bool, ub_overlap_ag: bool, ub_name: str, + dummy_tensor: torch.Tensor, # pylint: disable=unused-argument ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # Make sure input dimensions are compatible in_features = ln_weight.numel() @@ -98,7 +99,11 @@ def forward( assert_dim_for_fp8_exec(inputmat) assert_dim_for_fp8_exec(weight) - update_fp8_weights = is_first_microbatch is None or is_first_microbatch + update_fp8_weights = ( + is_first_microbatch is None + or is_first_microbatch + or skip_fp8_weight_update is not None + ) # Cast for native AMP inputmat = cast_if_needed(inputmat, activation_dtype) @@ -196,7 +201,6 @@ def forward( # Weight is already in FP8 weight.reset_fp8_meta_scale_inv() weight_fp8 = weight - weight_t_fp8 = None elif update_fp8_weights: # Need to cast weights to FP8 weight_fp8 = Float8Tensor( @@ -214,6 +218,7 @@ def forward( fp8_dtype_forward, cast_out=weight_fp8._data, transpose_out=weight_t_fp8._data, + noop_flag=skip_fp8_weight_update, ) else: tex.cast_to_fp8( @@ -295,6 +300,7 @@ def forward( weight_t_fp8, ln_out if weight.requires_grad else None, fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, + skip_fp8_weight_update.clone() if skip_fp8_weight_update is not None else None, ) ctx.activation_dtype = activation_dtype @@ -321,6 +327,7 @@ def forward( ctx.ub_name = ub_name ctx.requires_dgrad = inp.requires_grad ctx.normalization = normalization + ctx.primary_weights_in_fp8 = primary_weights_in_fp8 # Row Parallel Linear if parallel_mode == "row" and sequence_parallel: @@ -344,9 +351,7 @@ def forward( def backward( ctx, *grad_outputs: Tuple[torch.Tensor, ...] ) -> Tuple[Union[torch.Tensor, None], ...]: - with _prepare_backward( - ctx.fp8, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_LayerNormLinear" - ): + with torch.cuda.nvtx.range("_LayerNormLinear_backward"): ( inputmat, ln_weight, @@ -357,6 +362,7 @@ def backward( weight_t_fp8, ln_out, fwd_scale_inverses, + skip_fp8_weight_update, ) = ctx.saved_tensors if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: @@ -364,10 +370,13 @@ def backward( weight.main_grad = main_grad # Primary weights are in FP8. - if ctx.fp8 and weight_t_fp8 is None: - weight_t_fp8 = weight.transpose( - update_cache="reuse_only" if ctx.is_first_microbatch is None else "lazy", + if ctx.primary_weights_in_fp8: + weight_t_fp8 = weight.transpose_2d( + cache=ctx.is_first_microbatch is not None, + noop_flag=skip_fp8_weight_update, ) + elif ctx.fp8: + weight_t_fp8 = weight_t_fp8._data if ctx.ub_overlap_rs_dgrad: ctx.ub_bulk_dgrad = False @@ -472,7 +481,7 @@ def backward( # DGRAD: Evaluated unconditionally to feed into Linear backward _ = tex.fp8_gemm( - weight_t_fp8._data, + weight_t_fp8, fwd_scale_inverses, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, @@ -686,6 +695,8 @@ def backward( None, None, None, + None, + None, ) @@ -970,7 +981,6 @@ def __init__( if self.primary_weights_in_fp8: self.init_fp8_metadata() - self.fp8_meta["update_amax_and_scale_fwd"] = True self.reset_parameters(defer_init=(device == 'meta')) @@ -990,6 +1000,10 @@ def __init__( self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) + # Initialize a dummy tensor to be used as gradient hook for bwd amax reduction. + self.dummy_tensor = torch.zeros(1, device=device, requires_grad=True) + FP8GlobalStateManager.add_tensor_for_bwd_reduction_multi_grad_hook(self.dummy_tensor) + def reset_layer_norm_parameters(self) -> None: """Init LN params""" warnings.warn( @@ -1084,6 +1098,10 @@ def forward( produced) """ + skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() + if skip_fp8_weight_update is not None: + is_first_microbatch = False + with self.prepare_forward(inp, is_first_microbatch) as inp: assert self.fp8 or not self.primary_weights_in_fp8, \ "Need to run inside fp8_autocast region when weights are stored in FP8." @@ -1132,6 +1150,7 @@ def forward( self.apply_bias and not self.gemm_bias_unfused_add, self.eps, is_first_microbatch, + skip_fp8_weight_update, self.fp8, self.fp8_calibration, self.fp8_meta, @@ -1156,6 +1175,7 @@ def forward( self.ub_overlap_rs_dgrad, self.ub_overlap_ag, self.ub_name, + self.dummy_tensor, ) out = fwd_fn(*args) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index ad66e01e07..43103f06e1 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -13,7 +13,6 @@ from .base import ( get_workspace, - _prepare_backward, get_ub, TransformerEngineBaseModule, _2X_ACC_FPROP, @@ -94,6 +93,7 @@ def forward( use_fc2_bias: bool, eps: float, is_first_microbatch: Union[bool, None], + skip_fp8_weight_update: Union[torch.Tensor, None], fp8: bool, fp8_calibration: bool, fp8_meta: Dict[str, Any], @@ -121,6 +121,7 @@ def forward( ub_overlap_rs: bool, ub_overlap_ag: bool, gemm_gelu_fusion: bool, + dummy_tensor: torch.Tensor, # pylint: disable=unused-argument, ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # Make sure input dimensions are compatible in_features = ln_weight.numel() @@ -131,7 +132,11 @@ def forward( assert_dim_for_fp8_exec(fc1_weight) assert_dim_for_fp8_exec(fc2_weight) - update_fp8_weights = is_first_microbatch is None or is_first_microbatch + update_fp8_weights = ( + is_first_microbatch is None + or is_first_microbatch + or skip_fp8_weight_update is not None + ) activation_func = _act_func(activation)[0] @@ -225,8 +230,6 @@ def forward( fc2_weight.reset_fp8_meta_scale_inv() fc1_weight_fp8 = fc1_weight fc2_weight_fp8 = fc2_weight - fc1_weight_t_fp8 = None - fc2_weight_t_fp8 = None elif update_fp8_weights: # Need to cast weights to FP8 fc1_weight_fp8 = Float8Tensor( @@ -250,6 +253,7 @@ def forward( fp8_dtype_forward, cast_out=fc1_weight_fp8._data, transpose_out=fc1_weight_t_fp8._data, + noop_flag=skip_fp8_weight_update, ) tex.fp8_cast_transpose_fused( fc2_weight, @@ -258,6 +262,7 @@ def forward( fp8_dtype_forward, cast_out=fc2_weight_fp8._data, transpose_out=fc2_weight_t_fp8._data, + noop_flag=skip_fp8_weight_update, ) else: tex.cast_to_fp8( @@ -510,6 +515,7 @@ def forward( fc2_weight_t_fp8, fc1_bias, fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, + skip_fp8_weight_update.clone() if skip_fp8_weight_update is not None else None, ) ctx.activation_dtype = activation_dtype ctx.activation = activation @@ -538,6 +544,7 @@ def forward( ctx.ub_overlap_ag = ub_overlap_ag ctx.requires_dgrad = inp.requires_grad ctx.normalization = normalization + ctx.primary_weights_in_fp8 = primary_weights_in_fp8 # Row Parallel Linear if ub_overlap_rs: @@ -563,9 +570,7 @@ def forward( def backward( ctx, *grad_outputs: Tuple[torch.Tensor, ...] ) -> Tuple[Union[torch.Tensor, None], ...]: - with _prepare_backward( - ctx.fp8, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_LayerNormMLP" - ): + with torch.cuda.nvtx.range("_LayerNormMLP_backward"): ( inputmat, ln_weight, @@ -582,6 +587,7 @@ def backward( fc2_weight_t_fp8, fc1_bias, fwd_scale_inverses, + skip_fp8_weight_update, ) = ctx.saved_tensors if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: @@ -592,11 +598,18 @@ def backward( fc2_weight.main_grad = fc2_weight_main_grad # Primary weights are in FP8. - update_transpose_cache = "reuse_only" if ctx.is_first_microbatch is None else "lazy" - if ctx.fp8 and fc1_weight_t_fp8 is None: - fc1_weight_t_fp8 = fc1_weight.transpose(update_cache=update_transpose_cache) - if ctx.fp8 and fc2_weight_t_fp8 is None: - fc2_weight_t_fp8 = fc2_weight.transpose(update_cache=update_transpose_cache) + if ctx.primary_weights_in_fp8: + fc1_weight_t_fp8 = fc1_weight.transpose_2d( + cache=ctx.is_first_microbatch is not None, + noop_flag=skip_fp8_weight_update, + ) + fc2_weight_t_fp8 = fc2_weight.transpose_2d( + cache=ctx.is_first_microbatch is not None, + noop_flag=skip_fp8_weight_update, + ) + elif ctx.fp8: + fc1_weight_t_fp8 = fc1_weight_t_fp8._data + fc2_weight_t_fp8 = fc2_weight_t_fp8._data activation_func = _act_func(ctx.activation)[1] @@ -673,7 +686,7 @@ def backward( # FC2 DGRAD; Unconditional fc2_dgrad, _ = tex.fp8_gemm( - fc2_weight_t_fp8._data, + fc2_weight_t_fp8, fwd_scale_inverses, tex.FP8FwdTensors.GEMM2_WEIGHT, fp8_dtype_forward, @@ -826,7 +839,7 @@ def backward( ub_obj = None # FC1 DGRAD: Unconditional _ = tex.fp8_gemm( - fc1_weight_t_fp8._data, + fc1_weight_t_fp8, fwd_scale_inverses, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, @@ -1151,6 +1164,8 @@ def backward( None, None, None, + None, + None, ) @@ -1389,7 +1404,6 @@ def __init__( if self.primary_weights_in_fp8: self.init_fp8_metadata(num_gemms=2) - self.fp8_meta["update_amax_and_scale_fwd"] = True self.reset_parameters(defer_init=(device == 'meta')) @@ -1414,6 +1428,10 @@ def __init__( self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) + # Initialize a dummy tensor to be used as gradient hook for bwd amax reduction. + self.dummy_tensor = torch.zeros(1, device=device, requires_grad=True) + FP8GlobalStateManager.add_tensor_for_bwd_reduction_multi_grad_hook(self.dummy_tensor) + def reset_layer_norm_parameters(self) -> None: """Init LN params""" warnings.warn( @@ -1473,7 +1491,9 @@ def get_fp8_weights_scratchpad( @no_torch_dynamo() def forward( - self, inp: torch.Tensor, is_first_microbatch: Optional[bool] = None + self, + inp: torch.Tensor, + is_first_microbatch: Optional[bool] = None ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: """ Apply layer normalization to the input followed by a feedforward network (MLP Block). @@ -1497,6 +1517,10 @@ def forward( produced) """ + skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() + if skip_fp8_weight_update is not None: + is_first_microbatch = False + with self.prepare_forward(inp, is_first_microbatch, num_gemms=2) as inp: assert self.fp8 or not self.primary_weights_in_fp8, \ "Need to run inside fp8_autocast region when weights are stored in FP8." @@ -1535,6 +1559,7 @@ def forward( self.apply_bias and not self.gemm_bias_unfused_add, self.eps, is_first_microbatch, + skip_fp8_weight_update, self.fp8, self.fp8_calibration, self.fp8_meta, @@ -1562,6 +1587,7 @@ def forward( self.ub_overlap_rs, self.ub_overlap_ag, self.gemm_gelu_fusion, + self.dummy_tensor, ) out = fwd_fn(*args) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 1f7898a592..4baf2d5965 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -11,7 +11,6 @@ from .base import ( get_workspace, - _prepare_backward, get_ub, TransformerEngineBaseModule, _2X_ACC_FPROP, @@ -65,6 +64,7 @@ def forward( bias: torch.Tensor, use_bias: bool, is_first_microbatch: Union[bool, None], + skip_fp8_weight_update: Union[torch.Tensor, None], fp8: bool, fp8_calibration: bool, fp8_meta: Dict[str, Any], @@ -80,7 +80,8 @@ def forward( primary_weights_in_fp8: bool, ub_overlap_rs: bool, ub_overlap_ag: bool, - ub_name: str + ub_name: str, + dummy_tensor: torch.Tensor, # pylint: disable=unused-argument ) -> torch.Tensor: # Make sure input dimensions are compatible in_features = weight.shape[-1] @@ -90,7 +91,12 @@ def forward( assert_dim_for_fp8_exec(inputmat) assert_dim_for_fp8_exec(weight) - update_fp8_weights = is_first_microbatch is None or is_first_microbatch + update_fp8_weights = ( + is_first_microbatch is None + or is_first_microbatch + or skip_fp8_weight_update is not None + ) + tp_world_size = get_distributed_world_size(tp_group) ub_overlap_rs = False if tp_world_size == 1 else ub_overlap_rs @@ -140,7 +146,6 @@ def forward( # Weight is already in FP8 weight.reset_fp8_meta_scale_inv() weight_fp8 = weight - weight_t_fp8 = None elif update_fp8_weights: # Need to cast weights to FP8 weight_fp8 = Float8Tensor( @@ -158,6 +163,7 @@ def forward( fp8_dtype_forward, cast_out=weight_fp8._data, transpose_out=weight_t_fp8._data, + noop_flag=skip_fp8_weight_update, ) else: cast_to_fp8( @@ -296,6 +302,7 @@ def forward( weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None, weight_t_fp8 if fp8 else None, fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, + skip_fp8_weight_update.clone() if skip_fp8_weight_update is not None else None, ) ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 @@ -313,6 +320,7 @@ def forward( ctx.ub_name = ub_name ctx.tp_size = tp_size ctx.requires_dgrad = inp.requires_grad + ctx.primary_weights_in_fp8 = primary_weights_in_fp8 # Row Parallel Linear if ub_overlap_rs: @@ -330,9 +338,7 @@ def forward( def backward( ctx, grad_output: torch.Tensor ) -> Tuple[Union[torch.Tensor, None], ...]: - with _prepare_backward( - ctx.fp8, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_Linear" - ): + with torch.cuda.nvtx.range("_Linear_backward"): ( inputmat, inputmat_t, @@ -340,6 +346,7 @@ def backward( main_grad, weight_t_fp8, fwd_scale_inverses, + skip_fp8_weight_update, ) = ctx.saved_tensors if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: @@ -347,10 +354,14 @@ def backward( weight.main_grad = main_grad # Primary weights are in FP8. - if ctx.fp8 and weight_t_fp8 is None: - weight_t_fp8 = weight.transpose( - update_cache="reuse_only" if ctx.is_first_microbatch is None else "lazy", + if ctx.primary_weights_in_fp8: + weight_t_fp8 = weight.transpose_2d( + cache=ctx.is_first_microbatch is not None, + noop_flag=skip_fp8_weight_update, ) + elif ctx.fp8: + weight_t_fp8 = weight_t_fp8._data + tp_world_size = get_distributed_world_size(ctx.tp_group) ctx.ub_overlap_ag = False if tp_world_size == 1 else ctx.ub_overlap_ag if ctx.ub_overlap_ag: @@ -361,6 +372,7 @@ def backward( ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P else: ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P + ( grad_output, grad_output_c, @@ -401,7 +413,7 @@ def backward( if ctx.requires_dgrad: if ctx.fp8: dgrad, _ = fp8_gemm( - weight_t_fp8._data, + weight_t_fp8, fwd_scale_inverses, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, @@ -542,6 +554,8 @@ def backward( None, None, None, + None, + None, ) @@ -772,7 +786,6 @@ def __init__( if self.primary_weights_in_fp8: self.init_fp8_metadata() - self.fp8_meta["update_amax_and_scale_fwd"] = True self.reset_parameters(defer_init=(device == 'meta')) @@ -785,6 +798,10 @@ def __init__( else: self.gemm_bias_unfused_add = False + # Initialize a dummy tensor to be used as gradient hook for bwd amax reduction. + self.dummy_tensor = torch.zeros(1, device=device, requires_grad=True) + FP8GlobalStateManager.add_tensor_for_bwd_reduction_multi_grad_hook(self.dummy_tensor) + def reset_parameters(self, defer_init=False): super().reset_parameters(defer_init=defer_init) @@ -858,6 +875,10 @@ def forward( produced) """ + skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() + if skip_fp8_weight_update is not None: + is_first_microbatch = False + with self.prepare_forward(inp, is_first_microbatch) as inp: assert self.fp8 or not self.primary_weights_in_fp8, \ "Need to run inside fp8_autocast region when weights are stored in FP8." @@ -903,6 +924,7 @@ def forward( bias_tensor, self.apply_bias and not self.gemm_bias_unfused_add, is_first_microbatch, + skip_fp8_weight_update, self.fp8, self.fp8_calibration, self.fp8_meta, @@ -919,6 +941,7 @@ def forward( self.ub_overlap_rs, self.ub_overlap_ag, self.ub_name, + self.dummy_tensor, ) out = linear_fn(*args) diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 2e00333fa0..5b6fc1e5c3 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -473,6 +473,15 @@ def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> N if hasattr(child, "set_tensor_parallel_group"): child.set_tensor_parallel_group(tp_group) + def reset_fp8_meta_tensors(self) -> None: + """Set TP group""" + # Deep iterate but skip self to avoid infinite recursion. + for index, child in enumerate(self.modules()): + if index == 0: + continue + if hasattr(child, "reset_fp8_meta_tensors"): + child.reset_fp8_meta_tensors() + def set_context_parallel_group( self, cp_group: Union[dist_group_type, None], @@ -665,7 +674,8 @@ def forward( # MLP. mlp_outputs = self.layernorm_mlp( - hidden_states, is_first_microbatch=is_first_microbatch + hidden_states, + is_first_microbatch=is_first_microbatch, ) if self.apply_residual_connection_post_layernorm: mlp_output, mlp_bias, residual = mlp_outputs From a2f7c72dc1c7c0db5e5f676aabdcddd4d38fe576 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Fri, 12 Apr 2024 16:35:40 -0700 Subject: [PATCH 015/244] [PyTorch] Fix kernel_bulk launch config (#775) Fix 0 grid size Signed-off-by: Kirthi Shankar Sivamani Signed-off-by: Pawel Gadzinski --- transformer_engine/common/recipe/delayed_scaling.cu | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/transformer_engine/common/recipe/delayed_scaling.cu b/transformer_engine/common/recipe/delayed_scaling.cu index 6e07b1ce9f..38e71b74de 100644 --- a/transformer_engine/common/recipe/delayed_scaling.cu +++ b/transformer_engine/common/recipe/delayed_scaling.cu @@ -56,11 +56,11 @@ struct OtherParams { }; #if CUDART_VERSION >= 12010 -constexpr size_t max_constant_memory_per_kernel = 32000; +constexpr size_t max_constant_memory_per_kernel = 32768; constexpr size_t AMAX_PARAMS_LIMIT = ( max_constant_memory_per_kernel - sizeof(OtherParams)) / sizeof(AmaxParam); #else -constexpr size_t max_constant_memory_per_kernel = 4000; +constexpr size_t max_constant_memory_per_kernel = 4096; constexpr size_t AMAX_PARAMS_LIMIT = ( max_constant_memory_per_kernel - sizeof(OtherParams)) / sizeof(AmaxParam); #endif @@ -389,6 +389,7 @@ void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer, // Number of tensors in the bulk const size_t num_tensors = amax_histories.size(); + size_t num_remaining_tensors = num_tensors; const int num_kernels = (num_tensors+AMAX_PARAMS_LIMIT-1)/AMAX_PARAMS_LIMIT; size_t amax_history_length = 0; if (num_tensors > 0) { @@ -400,8 +401,8 @@ void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer, AmaxParams p; for (int iter = 0; iter < num_kernels; iter++) { size_t kernel_num_scales = 0; - size_t kernel_num_tensors = (iter == (num_kernels -1)) - ? num_tensors % AMAX_PARAMS_LIMIT: AMAX_PARAMS_LIMIT; + size_t kernel_num_tensors = (iter == (num_kernels - 1)) + ? num_remaining_tensors: AMAX_PARAMS_LIMIT; for (size_t pi = 0; pi < kernel_num_tensors; pi++) { size_t i = iter * AMAX_PARAMS_LIMIT + pi; @@ -446,6 +447,7 @@ void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer, if (amax_buffer != nullptr) { amax_buffer += kernel_num_scales; } + num_remaining_tensors -= AMAX_PARAMS_LIMIT; } } From 87cc8037c6aaad4e12902c3e54d3227304a3df5c Mon Sep 17 00:00:00 2001 From: Sangkug Lym Date: Fri, 12 Apr 2024 16:36:01 -0700 Subject: [PATCH 016/244] Add SM margin to LayerNorm in inference (#772) * Add LN margin to inference Signed-off-by: Sangkug Lym * cleanup Signed-off-by: Sangkug Lym * Fix symbolic func registration Signed-off-by: Kirthi Shankar Sivamani * Fix grads Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Sangkug Lym Signed-off-by: Kirthi Shankar Sivamani Co-authored-by: Kirthi Shankar Sivamani Signed-off-by: Pawel Gadzinski --- tests/pytorch/test_onnx_export.py | 3 +++ .../paddle/layer/layernorm_linear.py | 3 ++- .../paddle/layer/layernorm_mlp.py | 3 ++- .../pytorch/cpp_extensions/normalization.py | 8 ++++++++ transformer_engine/pytorch/csrc/extensions.h | 4 ++++ .../pytorch/csrc/extensions/normalization.cu | 13 ++++++++---- transformer_engine/pytorch/csrc/ts_fp8_op.cpp | 8 ++++++++ transformer_engine/pytorch/module/_common.py | 3 ++- .../pytorch/module/layernorm.py | 7 +++++-- .../pytorch/module/layernorm_linear.py | 3 ++- .../pytorch/module/layernorm_mlp.py | 3 ++- transformer_engine/pytorch/module/rmsnorm.py | 6 +++++- .../pytorch/te_onnx_extensions.py | 20 +++++++++---------- 13 files changed, 62 insertions(+), 22 deletions(-) diff --git a/tests/pytorch/test_onnx_export.py b/tests/pytorch/test_onnx_export.py index 7707264c7f..2c34867f2b 100644 --- a/tests/pytorch/test_onnx_export.py +++ b/tests/pytorch/test_onnx_export.py @@ -660,6 +660,7 @@ def forward(self, inp): self.meta, self.fp8_tensor, self.fp8_type, + 0, zero_centered_gamma) ret = cast_from_fp8( @@ -748,6 +749,7 @@ def forward(self, inp): self.meta, self.fp8_tensor, self.fp8_type, + 0, zero_centered_gamma) ret = cast_from_fp8( @@ -1279,6 +1281,7 @@ def forward(self, inp, weight): self.meta, self.fp8_tensor, self.fp8_type, + 0, zero_centered_gamma) x = cast_from_fp8( diff --git a/transformer_engine/paddle/layer/layernorm_linear.py b/transformer_engine/paddle/layer/layernorm_linear.py index 838b62188a..5645e5ee0e 100644 --- a/transformer_engine/paddle/layer/layernorm_linear.py +++ b/transformer_engine/paddle/layer/layernorm_linear.py @@ -565,6 +565,7 @@ def __init__( # communication overlap with LN. self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) + self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0")) def _te_forward( self, @@ -600,7 +601,7 @@ def _te_forward( self.activation_dtype, self.return_layernorm_output, paddle.is_grad_enabled(), - self.fwd_ln_sm_margin, + self.fwd_ln_sm_margin if paddle.is_grad_enabled() else self.inf_ln_sm_margin, self.bwd_ln_sm_margin, self.zero_centered_gamma, self.normalization, diff --git a/transformer_engine/paddle/layer/layernorm_mlp.py b/transformer_engine/paddle/layer/layernorm_mlp.py index 5242280d55..81e77fb1c1 100644 --- a/transformer_engine/paddle/layer/layernorm_mlp.py +++ b/transformer_engine/paddle/layer/layernorm_mlp.py @@ -824,6 +824,7 @@ def __init__( # communication overlap with LN. self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) + self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0")) def _te_forward( self, @@ -865,7 +866,7 @@ def _te_forward( self.activation_dtype, self.return_layernorm_output, paddle.is_grad_enabled(), - self.fwd_ln_sm_margin, + self.fwd_ln_sm_margin if paddle.is_grad_enabled() else self.inf_ln_sm_margin, self.bwd_ln_sm_margin, self.zero_centered_gamma, self.normalization, diff --git a/transformer_engine/pytorch/cpp_extensions/normalization.py b/transformer_engine/pytorch/cpp_extensions/normalization.py index 1f80f2b604..1d15fe618b 100644 --- a/transformer_engine/pytorch/cpp_extensions/normalization.py +++ b/transformer_engine/pytorch/cpp_extensions/normalization.py @@ -66,6 +66,7 @@ def layernorm_fwd_fp8_inf( fp8_meta_tensor: tex.FP8TensorMeta, fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], otype: tex.DType, + sm_margin: int, zero_centered_gamma, ) -> torch.Tensor: """LayerNorm with FP8 output. @@ -83,6 +84,7 @@ def layernorm_fwd_fp8_inf( fp8_meta_tensor.scale_inv, fp8_tensor, otype, + sm_margin, zero_centered_gamma) return ret @@ -92,6 +94,7 @@ def layernorm_fwd_inf( weight: torch.Tensor, bias: torch.Tensor, eps: float, + sm_margin: int, zero_centered_gamma: bool, ) -> torch.Tensor: """LayerNorm with FP8 output""" @@ -100,6 +103,7 @@ def layernorm_fwd_inf( weight, bias, eps, + sm_margin, zero_centered_gamma, ) @@ -149,6 +153,7 @@ def rmsnorm_fwd_fp8_inf( fp8_meta_tensor: tex.FP8TensorMeta, fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], otype: tex.DType, + sm_margin: int, zero_centered_gamma, ) -> torch.Tensor: """RMSNorm with FP8 output. @@ -165,6 +170,7 @@ def rmsnorm_fwd_fp8_inf( fp8_meta_tensor.scale_inv, fp8_tensor, otype, + sm_margin, zero_centered_gamma) return ret @@ -173,6 +179,7 @@ def rmsnorm_fwd_inf( inp: torch.Tensor, weight: torch.Tensor, eps: float, + sm_margin: int, zero_centered_gamma: bool, ) -> torch.Tensor: """RMSNorm with FP8 output""" @@ -180,5 +187,6 @@ def rmsnorm_fwd_inf( inp, weight, eps, + sm_margin, zero_centered_gamma, ) diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 0887054665..bf0bb576ec 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -408,6 +408,7 @@ at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input, at::Tensor amax, at::Tensor scale_inv, transformer_engine::DType otype, + const int sm_margin, const bool zero_centered_gamma ); @@ -432,6 +433,7 @@ at::Tensor layernorm_fwd_inf(const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, float eps, + const int sm_margin, const bool zero_centered_gamma ); @@ -478,6 +480,7 @@ at::Tensor rmsnorm_fwd_fp8_inf(const at::Tensor &input, at::Tensor amax, at::Tensor scale_inv, transformer_engine::DType otype, + const int sm_margin, const bool zero_centered_gamma ); @@ -499,6 +502,7 @@ std::vector rmsnorm_fwd_noalloc(const at::Tensor &input, at::Tensor rmsnorm_fwd_inf(const at::Tensor &input, const at::Tensor &weight, float eps, + const int sm_margin, const bool zero_centered_gamma ); diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cu b/transformer_engine/pytorch/csrc/extensions/normalization.cu index c7cc37198e..ef0facee28 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cu +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cu @@ -154,12 +154,13 @@ at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input, at::Tensor amax, at::Tensor scale_inv, transformer_engine::DType otype, + const int sm_margin, const bool zero_centered_gamma ) { // This is a specialized version of layernorm_fwd_fp8, optimized for inference, // which only returns the normalized output. std::vector out = layernorm_fwd_fp8( - input, weight, bias, eps, scale, amax, scale_inv, otype, 0, zero_centered_gamma); + input, weight, bias, eps, scale, amax, scale_inv, otype, sm_margin, zero_centered_gamma); return out[0]; } @@ -203,11 +204,13 @@ at::Tensor layernorm_fwd_inf(const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, float eps, + const int sm_margin, const bool zero_centered_gamma ) { // This is a specialized version of layernorm_fwd, optimized for inference, // which only returns the normalized output. - std::vector out = layernorm_fwd(input, weight, bias, eps, 0, zero_centered_gamma); + std::vector out = layernorm_fwd(input, weight, bias, eps, sm_margin, + zero_centered_gamma); return out[0]; } @@ -345,12 +348,13 @@ at::Tensor rmsnorm_fwd_fp8_inf(const at::Tensor &input, at::Tensor amax, at::Tensor scale_inv, transformer_engine::DType otype, + const int sm_margin, const bool zero_centered_gamma ) { // This is a specialized version of rmsnorm_fwd_fp8, optimized for inference, // which only returns the normalized output. std::vector out = rmsnorm_fwd_fp8( - input, weight, eps, scale, amax, scale_inv, otype, 0, zero_centered_gamma); + input, weight, eps, scale, amax, scale_inv, otype, sm_margin, zero_centered_gamma); return out[0]; } @@ -391,10 +395,11 @@ std::vector rmsnorm_fwd_noalloc(const at::Tensor &input, at::Tensor rmsnorm_fwd_inf(const at::Tensor &input, const at::Tensor &weight, float eps, + const int sm_margin, const bool zero_centered_gamma ) { // This is a specialized version of rmsnorm_fwd, optimized for inference, // which only returns the normalized output. - std::vector out = rmsnorm_fwd(input, weight, eps, 0, zero_centered_gamma); + std::vector out = rmsnorm_fwd(input, weight, eps, sm_margin, zero_centered_gamma); return out[0]; } diff --git a/transformer_engine/pytorch/csrc/ts_fp8_op.cpp b/transformer_engine/pytorch/csrc/ts_fp8_op.cpp index a7217d4570..ac9c7351a8 100755 --- a/transformer_engine/pytorch/csrc/ts_fp8_op.cpp +++ b/transformer_engine/pytorch/csrc/ts_fp8_op.cpp @@ -365,6 +365,7 @@ at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input, at::Tensor scale_inv, int64_t fp8_tensor, int64_t otype, + const int8_t sm_margin, const bool zero_centered_gamma) { transformer_engine::DType otype_arg = reverse_map_dtype(otype); float eps_float = static_cast(eps); @@ -377,6 +378,7 @@ at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input, amax, scale_inv, otype_arg, + sm_margin, zero_centered_gamma); return output; @@ -387,6 +389,7 @@ at::Tensor layernorm_fwd_inf_ts(const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, double eps, + const int8_t sm_margin, const bool zero_centered_gamma) { float eps_float = static_cast(eps); @@ -394,6 +397,7 @@ at::Tensor layernorm_fwd_inf_ts(const at::Tensor &input, weight, bias, eps_float, + sm_margin, zero_centered_gamma); return output; @@ -408,6 +412,7 @@ at::Tensor rmsnorm_fwd_fp8_inf_ts(const at::Tensor &input, at::Tensor scale_inv, int64_t fp8_tensor, int64_t otype, + const int8_t sm_margin, const bool zero_centered_gamma) { transformer_engine::DType otype_arg = reverse_map_dtype(otype); float eps_float = static_cast(eps); @@ -419,6 +424,7 @@ at::Tensor rmsnorm_fwd_fp8_inf_ts(const at::Tensor &input, amax, scale_inv, otype_arg, + sm_margin, zero_centered_gamma); return output; @@ -428,12 +434,14 @@ at::Tensor rmsnorm_fwd_fp8_inf_ts(const at::Tensor &input, at::Tensor rmsnorm_fwd_inf_ts(const at::Tensor &input, const at::Tensor &weight, double eps, + const int8_t sm_margin, const bool zero_centered_gamma) { float eps_float = static_cast(eps); at::Tensor output = rmsnorm_fwd_inf(input, weight, eps_float, + sm_margin, zero_centered_gamma); return output; diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index d2ab776288..79798d2ff0 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -78,6 +78,7 @@ def _apply_normalization(inputmat:torch.Tensor, fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype_forward, + fwd_ln_sm_margin, zero_centered_gamma, ), None, None else: @@ -88,7 +89,7 @@ def _apply_normalization(inputmat:torch.Tensor, ) else: return normalization_func( - *inputs, eps, zero_centered_gamma + *inputs, eps, fwd_ln_sm_margin, zero_centered_gamma ), None, None if normalization == "RMSNorm": output = (ln_out, None, output[1]) diff --git a/transformer_engine/pytorch/module/layernorm.py b/transformer_engine/pytorch/module/layernorm.py index 6178199be6..ef441888dc 100644 --- a/transformer_engine/pytorch/module/layernorm.py +++ b/transformer_engine/pytorch/module/layernorm.py @@ -34,6 +34,7 @@ def forward( eps: float, fwd_ln_sm_margin: int, bwd_ln_sm_margin: int, + inf_ln_sm_margin: int, zero_centered_gamma: bool, is_grad_enabled: bool, activation_dtype: torch.dtype, @@ -58,7 +59,7 @@ def forward( ctx.zero_centered_gamma = zero_centered_gamma else: ln_out, mu, rsigma = layernorm_fwd_inf(inputmat, ln_weight, - ln_bias, eps, zero_centered_gamma), None, None + ln_bias, eps, inf_ln_sm_margin, zero_centered_gamma), None, None return ln_out.view_as(inp) @staticmethod @@ -72,7 +73,7 @@ def backward( d_ln_out, inputmat, mu, rsigma, ln_weight, ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma ) - return dxmat.view(ctx.inp_shape), dgamma, dbeta, None, None, None, None, None, None + return dxmat.view(ctx.inp_shape), dgamma, dbeta, None, None, None, None, None, None, None class LayerNorm(torch.nn.Module): @@ -148,6 +149,7 @@ def __init__( # communication overlap with LN. self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) + self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0")) def reset_layer_norm_parameters(self) -> None: """Init LN params""" @@ -198,6 +200,7 @@ def forward(self, inp: torch.Tensor) -> torch.Tensor: self.eps, self.fwd_ln_sm_margin, self.bwd_ln_sm_margin, + self.inf_ln_sm_margin, self.zero_centered_gamma, torch.is_grad_enabled(), self.activation_dtype, diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 8fdd5d1356..ffa14bc157 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -999,6 +999,7 @@ def __init__( # communication overlap with LN. self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) + self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0")) # Initialize a dummy tensor to be used as gradient hook for bwd amax reduction. self.dummy_tensor = torch.zeros(1, device=device, requires_grad=True) @@ -1165,7 +1166,7 @@ def forward( self.return_layernorm_output, self.return_layernorm_output_gathered, torch.is_grad_enabled(), - self.fwd_ln_sm_margin, + self.fwd_ln_sm_margin if torch.is_grad_enabled() else self.inf_ln_sm_margin, self.bwd_ln_sm_margin, self.zero_centered_gamma, self.normalization, diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 43103f06e1..e143cf6659 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1427,6 +1427,7 @@ def __init__( # communication overlap with LN. self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) + self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0")) # Initialize a dummy tensor to be used as gradient hook for bwd amax reduction. self.dummy_tensor = torch.zeros(1, device=device, requires_grad=True) @@ -1575,7 +1576,7 @@ def forward( self.bias_gelu_nvfusion, self.set_parallel_mode, torch.is_grad_enabled(), - self.fwd_ln_sm_margin, + self.fwd_ln_sm_margin if torch.is_grad_enabled() else self.inf_ln_sm_margin, self.bwd_ln_sm_margin, self.zero_centered_gamma, self.activation, diff --git a/transformer_engine/pytorch/module/rmsnorm.py b/transformer_engine/pytorch/module/rmsnorm.py index c32012d8e0..e1d2ac2551 100644 --- a/transformer_engine/pytorch/module/rmsnorm.py +++ b/transformer_engine/pytorch/module/rmsnorm.py @@ -31,6 +31,7 @@ def forward( eps: float, fwd_rmsnorm_sm_margin: int, bwd_rmsnorm_sm_margin: int, + inf_rmsnorm_sm_margin: int, zero_centered_gamma: bool, is_grad_enabled: bool, activation_dtype: torch.dtype, @@ -55,7 +56,7 @@ def forward( ctx.zero_centered_gamma = zero_centered_gamma else: rmsnorm_out = tex.rmsnorm_fwd_inf(inputmat, rmsnorm_weight, - eps, + eps, inf_rmsnorm_sm_margin, zero_centered_gamma) return rmsnorm_out.view_as(inp) @@ -79,6 +80,7 @@ def backward( None, None, None, + None, ) @@ -151,6 +153,7 @@ def __init__( # communication overlap with RMSNorm. self.fwd_rmsnorm_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) self.bwd_rmsnorm_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) + self.inf_rmsnorm_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0")) def reset_rms_norm_parameters(self) -> None: """Init RMSNorm params""" @@ -195,6 +198,7 @@ def forward(self, inp: torch.Tensor) -> torch.Tensor: self.eps, self.fwd_rmsnorm_sm_margin, self.bwd_rmsnorm_sm_margin, + self.inf_rmsnorm_sm_margin, self.zero_centered_gamma, torch.is_grad_enabled(), self.activation_dtype, diff --git a/transformer_engine/pytorch/te_onnx_extensions.py b/transformer_engine/pytorch/te_onnx_extensions.py index 67ff4ce161..33ca1ed594 100755 --- a/transformer_engine/pytorch/te_onnx_extensions.py +++ b/transformer_engine/pytorch/te_onnx_extensions.py @@ -304,9 +304,9 @@ def _ones_like(g, inp, dtype): return one -@symbolic_helper.parse_args("v", "v", "v", "f", "v", "v", "fs", "i", "i", "b") +@symbolic_helper.parse_args("v", "v", "v", "f", "v", "v", "fs", "i", "i", "i", "b") def onnx_layernorm_fwd_fp8(g, inputs, weight, bias, eps, scale, amax, - scale_inv, fp8_tensor, otype, zero_centered_gamma): + scale_inv, fp8_tensor, otype, sm_margin, zero_centered_gamma): """ONNX graph for layernorm_fwd_fp8""" # pylint: disable=unused-argument inp_dtype = get_TensorProtoDataType(inputs) @@ -316,13 +316,13 @@ def onnx_layernorm_fwd_fp8(g, inputs, weight, bias, eps, scale, amax, if inp_dtype != get_TensorProtoDataType(bias): bias = g.op("Cast", bias, to_i=inp_dtype) - ln = onnx_layernorm_fwd(g, inputs, weight, bias, eps, zero_centered_gamma) + ln = onnx_layernorm_fwd(g, inputs, weight, bias, eps, sm_margin, zero_centered_gamma) fp8_ln = quantize(g, ln, scale_inv, fp8_tensor) return fp8_ln -@symbolic_helper.parse_args("v", "v", "v", "f", "b") -def onnx_layernorm_fwd(g, inputs, weight, bias, eps, zero_centered_gamma): +@symbolic_helper.parse_args("v", "v", "v", "f", "i", "b") +def onnx_layernorm_fwd(g, inputs, weight, bias, eps, sm_margin, zero_centered_gamma): """ONNX graph for layernorm_fwd""" # pylint: disable=unused-argument @@ -352,9 +352,9 @@ def onnx_layernorm_fwd(g, inputs, weight, bias, eps, zero_centered_gamma): ) return ln -@symbolic_helper.parse_args("v", "v", "f", "v", "v", "fs", "i", "i", "b") +@symbolic_helper.parse_args("v", "v", "f", "v", "v", "fs", "i", "i", "i", "b") def onnx_rmsnorm_fwd_fp8(g, inputs, weight, eps, scale, amax, - scale_inv, fp8_tensor, otype, zero_centered_gamma): + scale_inv, fp8_tensor, otype, sm_margin, zero_centered_gamma): """ONNX graph for rmsnorm_fwd_fp8""" # pylint: disable=unused-argument inp_dtype = get_TensorProtoDataType(inputs) @@ -362,13 +362,13 @@ def onnx_rmsnorm_fwd_fp8(g, inputs, weight, eps, scale, amax, if inp_dtype != get_TensorProtoDataType(weight): weight = g.op("Cast", weight, to_i=inp_dtype) - ln = onnx_rmsnorm_fwd(g, inputs, weight, eps, zero_centered_gamma) + ln = onnx_rmsnorm_fwd(g, inputs, weight, eps, sm_margin, zero_centered_gamma) fp8_ln = quantize(g, ln, scale_inv, fp8_tensor) return fp8_ln -@symbolic_helper.parse_args("v", "v", "f", "b") -def onnx_rmsnorm_fwd(g, inputs, weight, eps, zero_centered_gamma): +@symbolic_helper.parse_args("v", "v", "f", "i", "b") +def onnx_rmsnorm_fwd(g, inputs, weight, eps, sm_margin, zero_centered_gamma): """ONNX graph for rmsnorm_fwd""" # pylint: disable=unused-argument From d0e02cfdba637bb36815dedf3dab0e4400a2223d Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Mon, 15 Apr 2024 15:08:02 -0700 Subject: [PATCH 017/244] [PyTorch] Don't use autograd hook for bwd reduction (#781) Don't use autograd hook for bwd reduction Signed-off-by: Kirthi Shankar Sivamani Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/fp8.py | 21 ------------------- .../pytorch/module/layernorm_linear.py | 12 +++++------ .../pytorch/module/layernorm_mlp.py | 13 +++++------- transformer_engine/pytorch/module/linear.py | 13 +++++------- 4 files changed, 15 insertions(+), 44 deletions(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index e821bfe11d..d06443efb6 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -81,8 +81,6 @@ class FP8GlobalStateManager: fp8_tensors_recompute_buffer = [] fp8_available = None reason_for_no_fp8 = "" - multi_grad_hook_tensors = [] - bwd_amax_update_hook_registered = False autocast_arguments = {} autocast_to_fp8_params = {} fp8_param_to_autocast = {} @@ -106,8 +104,6 @@ def reset(cls) -> None: cls.fp8_tensors_recompute_buffer = [] cls.fp8_available = None cls.reason_for_no_fp8 = "" - cls.multi_grad_hook_tensors = [] - cls.bwd_amax_update_hook_registered = False cls.autocast_arguments = {} cls.autocast_to_fp8_params = {} cls.fp8_param_to_autocast = {} @@ -370,16 +366,6 @@ def reduce_and_update_fp8_tensors( _amax_and_scale_update( amax_history, scale, scale_inv, get_fp8_max(recipe, forward), recipe) - @classmethod - def add_tensor_for_bwd_reduction_multi_grad_hook(cls, tensor): - """Add tensor to list for multi grad hook.""" - cls.multi_grad_hook_tensors.append(tensor) - - @classmethod - def hook_for_bwd_amax_reduction(cls, grads: Tuple[torch.Tensor]) -> None: # pylint: disable=unused-argument - """Executes at the end of backward pass.""" - cls.reduce_and_update_fp8_tensors(forward=False) - @classmethod def get_unique_autocast_key( cls, @@ -407,13 +393,6 @@ def fp8_autocast_enter( autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group) cls.autocast_arguments[autocast_key] = (fp8_recipe, fp8_group) - if enabled and cls.FP8_AUTOCAST_DEPTH == 0 and not _graph and torch.is_grad_enabled(): - if not cls.bwd_amax_update_hook_registered and len(cls.multi_grad_hook_tensors) > 0: - # This hook does not fire for graphed modules. - torch.autograd.graph.register_multi_grad_hook( - tuple(cls.multi_grad_hook_tensors), cls.hook_for_bwd_amax_reduction) - cls.bwd_amax_update_hook_registered = True - cls.FP8_ENABLED = enabled cls.FP8_CALIBRATION = calibrating cls.FP8_RECIPE = fp8_recipe diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index ffa14bc157..5df4950276 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -40,6 +40,7 @@ ) from ..constants import GemmParallelModes, dist_group_type, TE_DType from ..jit import no_torch_dynamo +from ..graph import is_graph_capturing from ._common import _apply_normalization, _noop_cat from ..float8_tensor import Float8Tensor @@ -89,7 +90,6 @@ def forward( ub_overlap_rs_dgrad: bool, ub_overlap_ag: bool, ub_name: str, - dummy_tensor: torch.Tensor, # pylint: disable=unused-argument ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # Make sure input dimensions are compatible in_features = ln_weight.numel() @@ -328,6 +328,7 @@ def forward( ctx.requires_dgrad = inp.requires_grad ctx.normalization = normalization ctx.primary_weights_in_fp8 = primary_weights_in_fp8 + ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() # Row Parallel Linear if parallel_mode == "row" and sequence_parallel: @@ -660,6 +661,9 @@ def backward( else: wgrad = None + if ctx.is_first_module and not is_graph_capturing(): + FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) + return ( dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None, dgamma, @@ -696,7 +700,6 @@ def backward( None, None, None, - None, ) @@ -1001,10 +1004,6 @@ def __init__( self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0")) - # Initialize a dummy tensor to be used as gradient hook for bwd amax reduction. - self.dummy_tensor = torch.zeros(1, device=device, requires_grad=True) - FP8GlobalStateManager.add_tensor_for_bwd_reduction_multi_grad_hook(self.dummy_tensor) - def reset_layer_norm_parameters(self) -> None: """Init LN params""" warnings.warn( @@ -1176,7 +1175,6 @@ def forward( self.ub_overlap_rs_dgrad, self.ub_overlap_ag, self.ub_name, - self.dummy_tensor, ) out = fwd_fn(*args) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index e143cf6659..6efb72b8db 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -49,7 +49,7 @@ from ..constants import dist_group_type, TE_DType from ..jit import no_torch_dynamo - +from ..graph import is_graph_capturing from ..float8_tensor import Float8Tensor from ._common import _apply_normalization @@ -121,7 +121,6 @@ def forward( ub_overlap_rs: bool, ub_overlap_ag: bool, gemm_gelu_fusion: bool, - dummy_tensor: torch.Tensor, # pylint: disable=unused-argument, ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # Make sure input dimensions are compatible in_features = ln_weight.numel() @@ -545,6 +544,7 @@ def forward( ctx.requires_dgrad = inp.requires_grad ctx.normalization = normalization ctx.primary_weights_in_fp8 = primary_weights_in_fp8 + ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() # Row Parallel Linear if ub_overlap_rs: @@ -1121,6 +1121,9 @@ def backward( else: fc2_wgrad = None + if ctx.is_first_module and not is_graph_capturing(): + FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) + return ( dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None, dgamma, @@ -1165,7 +1168,6 @@ def backward( None, None, None, - None, ) @@ -1429,10 +1431,6 @@ def __init__( self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0")) - # Initialize a dummy tensor to be used as gradient hook for bwd amax reduction. - self.dummy_tensor = torch.zeros(1, device=device, requires_grad=True) - FP8GlobalStateManager.add_tensor_for_bwd_reduction_multi_grad_hook(self.dummy_tensor) - def reset_layer_norm_parameters(self) -> None: """Init LN params""" warnings.warn( @@ -1588,7 +1586,6 @@ def forward( self.ub_overlap_rs, self.ub_overlap_ag, self.gemm_gelu_fusion, - self.dummy_tensor, ) out = fwd_fn(*args) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 4baf2d5965..3c055270b0 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -43,7 +43,7 @@ ) from ..constants import GemmParallelModes, dist_group_type from ..jit import no_torch_dynamo - +from ..graph import is_graph_capturing from ..float8_tensor import Float8Tensor __all__ = ["Linear"] @@ -81,7 +81,6 @@ def forward( ub_overlap_rs: bool, ub_overlap_ag: bool, ub_name: str, - dummy_tensor: torch.Tensor, # pylint: disable=unused-argument ) -> torch.Tensor: # Make sure input dimensions are compatible in_features = weight.shape[-1] @@ -321,6 +320,7 @@ def forward( ctx.tp_size = tp_size ctx.requires_dgrad = inp.requires_grad ctx.primary_weights_in_fp8 = primary_weights_in_fp8 + ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() # Row Parallel Linear if ub_overlap_rs: @@ -530,6 +530,9 @@ def backward( else: wgrad = None + if ctx.is_first_module and not is_graph_capturing(): + FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) + return ( wgrad, None, @@ -555,7 +558,6 @@ def backward( None, None, None, - None, ) @@ -798,10 +800,6 @@ def __init__( else: self.gemm_bias_unfused_add = False - # Initialize a dummy tensor to be used as gradient hook for bwd amax reduction. - self.dummy_tensor = torch.zeros(1, device=device, requires_grad=True) - FP8GlobalStateManager.add_tensor_for_bwd_reduction_multi_grad_hook(self.dummy_tensor) - def reset_parameters(self, defer_init=False): super().reset_parameters(defer_init=defer_init) @@ -941,7 +939,6 @@ def forward( self.ub_overlap_rs, self.ub_overlap_ag, self.ub_name, - self.dummy_tensor, ) out = linear_fn(*args) From a25a2fe351c262842f9e8a6e837384e6b031dd7a Mon Sep 17 00:00:00 2001 From: cyanguwa <8636796+cyanguwa@users.noreply.github.com> Date: Mon, 15 Apr 2024 21:46:56 -0700 Subject: [PATCH 018/244] [C/PyTorch] Add FP8 DPA and MHA (#768) * WIP: fp8 v1 fprop integration Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * WIP: minor fixes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add debug info Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add more debug info Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fprop working for h1; w/ debug info Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * WIP: add bprop Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * cleanup; bprop running but has mismatches Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add gitlab frontend as submodule Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * clean up and add back v0.9.2 FE support; fprop/bprop passing with 5e-2 tols Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix after merge; add bias_b/h to caching descriptor Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * distinguish fwd/bwd tensor types for bprop Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor fix for F16 cases; include added dqkv_type and d_scale_dp Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * adjust out shape for bwd in test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add casting from/to FP8 to DPA module Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * WIP: bshd_bshd_bshd layout Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * WIP: support all sbhd/bshd layouts Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * clean up Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add qkvpacked and kvpacked support in both FusedAttnFunc and C levels Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove qkvpacked/kvpacked calls in DPA module (used for testing) Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove tp setup; add allow_non_contiguous; update FE; revert to sbh3d in tests; clean up Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add NVTE_FP8_DPA_BWD to control whether to use FP8 bwd or F16 bwd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix MQA Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix MQA/GQA in FP8 v1 API Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update FE to 705d8e3, with API change Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * test causal mask Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * restrict mha_fill for THD format Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix fused attn with CP and comment out is_alibi code Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * clean up FE0.9 vs FE1.0 FP8 implementations, and related unit tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * change NVTE_FP8_DPA_BWD default to 1, and fix its use in qkvpacked/kvpacked APIs Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix lint and self.tp_size/group in FusedAttention() Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update FE to 6902c94 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add FP8 MHA support Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update to FE v1.3.0 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor fixes for FP8 MHA with different configs Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * emit stats regardless of is_training Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix linear when input is not Float8Tensor Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix d_out type when f16 bprop Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix user buffer for layernorm_linear/linear and revert two FP8 casts in MHA Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add docstring for fp8_dpa/mha in recipe Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fixes Signed-off-by: Kirthi Shankar Sivamani * fix backend selection to avoid FA Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * replace transpose with transpose_2d Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * use RMSE for FP8 unit tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * replace two more transpose with transpose_2d Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add FP8 initialization to FusedAttention Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * rm docs Signed-off-by: Kirthi Shankar Sivamani * Revert "add FP8 initialization to FusedAttention" This reverts commit 15fffd825d6f23f31ea709b16ba01dfd61efabf8. Signed-off-by: Kirthi Shankar Sivamani * Change order of ctxs Signed-off-by: Kirthi Shankar Sivamani * Fixes Signed-off-by: Kirthi Shankar Sivamani * minor fixes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add back docs and mark as beta Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor fixes for tests and docs Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Signed-off-by: Kirthi Shankar Sivamani Co-authored-by: Kirthi Shankar Sivamani Signed-off-by: Pawel Gadzinski --- 3rdparty/cudnn-frontend | 2 +- qa/L0_pytorch_unittest/test.sh | 2 +- tests/pytorch/fused_attn/test_fused_attn.py | 594 ++++++-- tests/pytorch/test_numerics.py | 2 +- .../common/fused_attn/fused_attn.cpp | 71 +- .../fused_attn_f16_arbitrary_seqlen.cu | 23 +- .../fused_attn_f16_arbitrary_seqlen.h | 2 +- .../common/fused_attn/fused_attn_fp8.cu | 1205 ++++++++++++++++- .../common/fused_attn/fused_attn_fp8.h | 56 +- transformer_engine/common/fused_attn/utils.h | 7 +- transformer_engine/common/recipe/__init__.py | 21 +- transformer_engine/pytorch/attention.py | 853 ++++++++++-- .../pytorch/cpp_extensions/fused_attn.py | 78 +- .../pytorch/csrc/comm_gemm_overlap.h | 4 +- transformer_engine/pytorch/csrc/extensions.h | 9 + .../pytorch/csrc/extensions/attention.cu | 171 ++- transformer_engine/pytorch/float8_tensor.py | 89 +- transformer_engine/pytorch/fp8.py | 6 +- transformer_engine/pytorch/module/base.py | 74 +- .../pytorch/module/layernorm_linear.py | 46 +- transformer_engine/pytorch/module/linear.py | 150 +- transformer_engine/pytorch/utils.py | 9 +- 22 files changed, 3003 insertions(+), 471 deletions(-) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index a86ad708db..1b0b5eac54 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit a86ad708db725e4d29919bb6fadf8e6cdfa5dc06 +Subproject commit 1b0b5eac540b7f8fd19b18f1e6b8427c95503348 diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 50f54cd714..ded45dd377 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -6,7 +6,7 @@ set -e : ${TE_PATH:=/opt/transformerengine} -pip install pytest==6.2.5 onnxruntime==1.13.1 +pip install pytest==7.2 onnxruntime==1.13.1 pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index b2c8f69ef3..40cfdd34b7 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -2,6 +2,7 @@ # # See LICENSE for license information. +import math import functools from importlib.metadata import version import os @@ -12,9 +13,10 @@ import torch from transformer_engine.common import recipe -from transformer_engine.pytorch import TransformerLayer, fp8_autocast +from transformer_engine.pytorch import TransformerLayer, fp8_autocast, fp8_model_init from transformer_engine.pytorch.attention import ( DotProductAttention, + MultiheadAttention, RotaryPositionEmbedding, ) from transformer_engine.pytorch.constants import TE_DType @@ -939,52 +941,415 @@ def _run_transformer_layer( return out, inp.grad -model_configs_fp8 = { +model_configs_fp8_vs_f16 = { # test: b, h, hg, d, sq, skv, p, mask, bias - "fp8_1": ModelConfig(1, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), - "fp8_2": ModelConfig(4, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), + "fp8_9 ": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), + "fp8_10": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"), + "fp8_11": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), + "fp8_12": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "causal", "no_bias"), + "fp8_13": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "no_mask", "no_bias"), + "fp8_14": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), } -param_types_fp8 = [torch.float16] +param_types_fp8_vs_f16 = [torch.float16, torch.bfloat16] +qkv_layout_fp8_vs_f16 = ['sbh3d', 'bshd_bshd_bshd', 'sbhd_sbhd_sbhd'] +qkv_format_fp8_vs_f16 = ['bshd', 'sbhd'] + +def _rmse(a, b): + return math.sqrt((torch.pow((a-b), 2)/a.numel()).sum()) @pytest.mark.skipif(_cudnn_version() < (8,9,3), reason="cuDNN 8.9.3+ is required.") @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(get_device_compute_capability() != (9, 0), reason="FP8 tests require Hopper.") -@pytest.mark.parametrize("dtype", param_types_fp8) -@pytest.mark.parametrize("model", model_configs_fp8.keys()) -def test_dpa_fp8(dtype, model): - """Test FP8 dot product attention +@pytest.mark.parametrize("dtype", param_types_fp8_vs_f16) +@pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys()) +@pytest.mark.parametrize("qkv_format", qkv_format_fp8_vs_f16) +@pytest.mark.parametrize("input_layernorm", [True, False]) +@pytest.mark.parametrize("fp8_dpa_bwd", [True, False]) +def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd): + os.environ["NVTE_FLASH_ATTN"] = "0" + os.environ["NVTE_FUSED_ATTN"] = "1" + config = model_configs_fp8_vs_f16[model] + + os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" + if _NVTE_DEBUG: + print() + print("[test_mha_fp8_vs_f16]: run with fp8_mha = True") + fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16( + dtype, config, True, qkv_format, input_layernorm) + if _NVTE_DEBUG: + print() + print("[test_mha_fp8_vs_f16]: run with fp8_mha = False") + fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16( + dtype, config, False, qkv_format, input_layernorm) + + tols = dict(atol=5e-1, rtol=5e-1) + rmse_tol = 0.1 + fwd_rmse = _rmse(fused_attn_fwd_fp8, fused_attn_fwd_f16) + fwd_range = max(fused_attn_fwd_fp8.max().item(), + fused_attn_fwd_f16.max().item()) - min(fused_attn_fwd_fp8.min().item(), + fused_attn_fwd_f16.min().item()) + if _NVTE_DEBUG: + print() + print('========== {:^25s} =========='.format('forward output')) + print('fused_attn_fwd_fp8 min {:.6f} max {:.6f}'.format( + fused_attn_fwd_fp8.min().item(),fused_attn_fwd_fp8.max().item())) + print('fused_attn_fwd_f16 min {:.6f} max {:.6f}'.format( + fused_attn_fwd_f16.min().item(), fused_attn_fwd_f16.max().item())) + print('fused_attn_fwd RMSE: {:.6f}'.format(fwd_rmse)) + try: + torch.testing.assert_close(fused_attn_fwd_fp8, fused_attn_fwd_f16, **tols) + except Exception as e: + print(e) + print() + assert(fwd_rmse < rmse_tol * fwd_range + ), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( + fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range) + for i in range(len(param_names[:1])): + bwd_rmse = _rmse(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i]) + bwd_range = max(fused_attn_bwd_fp8[i].max().item(), + fused_attn_bwd_f16[i].max().item()) - min(fused_attn_bwd_fp8[i].min().item(), + fused_attn_bwd_f16[i].min().item()) + if _NVTE_DEBUG: + print() + print('========== {:^25s} =========='.format(param_names[i])) + print('fused_attn_bwd_fp8[{}] min {:.6f} max {:.6f}'.format(i, + fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_fp8[i].max().item())) + print('fused_attn_bwd_f16[{}] min {:.6f} max {:.6f}'.format(i, + fused_attn_bwd_f16[i].min().item(), fused_attn_bwd_f16[i].max().item())) + print('fused_attn_bwd RMSE[{}]: {:.6f}'.format(i, bwd_rmse)) + try: + torch.testing.assert_close(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i], **tols) + except Exception as e: + print(e) + print() + assert(bwd_rmse < rmse_tol * bwd_range + ), "BWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( + bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range) + +def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm): + reset_rng_states() + _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() + _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) + def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: + """Get cuda rng tracker.""" + return _DUMMY_CUDA_RNG_STATE_TRACKER - FusedAttention uses fused_attn_fwd/bwd_qkvpacked from cpp_extensions, - and UnfusedDotProductAttention uses plain PyTorch operations in FP16 - and converts inputs/outputs from/to FP8. + fp8_recipe = recipe.DelayedScaling( + margin=0, + interval=1, + fp8_format=recipe.Format.HYBRID, + amax_history_len=1, + amax_compute_algo="most_recent", + fp8_dpa=fp8_mha, + fp8_mha=fp8_mha, + ) - """ + with fp8_model_init(enabled=fp8_mha): + mha = (MultiheadAttention( + hidden_size=config.hidden_size, + num_attention_heads=config.num_heads, + kv_channels=config.head_dim, + num_gqa_groups=config.num_gqa_groups, + attention_dropout=config.dropout_p, + layer_number=1, + bias=True, + get_rng_state_tracker=get_dummy_cuda_rng_tracker, + params_dtype=dtype, + input_layernorm=input_layernorm, + fuse_qkv_params=True, + attention_type="self", + qkv_weight_interleaved=True, + qkv_format=qkv_format, + ).to(dtype=dtype, device="cuda") + ) - config = model_configs_fp8[model] + seqlens_q = torch.full([config.batch_size], config.max_seqlen_q, + dtype=torch.int32, device="cuda") + seqlens_kv = torch.full([config.batch_size], config.max_seqlen_kv, + dtype=torch.int32, device="cuda") + cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda") + cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda") + cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0) + cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0) - # Skip if not supported - fused_attn_supported, fused_attn_backend = _is_fused_attention_supported( - config, dtype) - if not fused_attn_supported: - pytest.skip("FusedAttention does not support this model config") + dim_to_num = { + 'b' : config.batch_size, + 'sq' : config.max_seqlen_q, + 'skv': config.max_seqlen_kv, + 'h' : config.num_heads, + 'hg' : config.num_gqa_groups, + 'd' : config.head_dim, + 't' : cu_seqlens_q[-1], + 'tg' : cu_seqlens_kv[-1], + '3' : 3, + '2' : 2, + '1' : 1, + } + layout = '_'.join(qkv_format) + layout = layout.replace('s', 'sq') + tensor_shape = [dim_to_num[j] for j in layout.split('_')] + tensor = 0.01 * torch.randint(-100, 100, tensor_shape, dtype=dtype, device="cuda") + hidden_states = tensor.view(*tensor.shape[:-2], -1) + hidden_states.requires_grad = True + tensor = 0.01 * torch.randn(tensor_shape, dtype=dtype, device="cuda") + out_grad = tensor.view(*tensor.shape[:-2], -1) + + with fp8_autocast(enabled=fp8_mha, fp8_recipe=fp8_recipe): + out = mha(hidden_states, + attn_mask_type=config.attn_mask_type, + checkpoint_core_attention=False, + core_attention_bias_type=config.attn_bias_type, + is_first_microbatch=None, + ) + out.backward(out_grad) - # Run dot-product attention with different backends - fused_attn_fwd, fused_attn_bwd = _run_dpa_fp8( - dtype, config, "FusedAttention") - unfused_attn_fwd, unfused_attn_bwd = _run_dpa_fp8_ref( - dtype, config, "UnfusedDotProductAttention") + param_names = [] + param_names.append('hidden_states.grad') + params = [] + params.append(hidden_states) + for name, param in mha.named_parameters(): + if param.requires_grad: + param_names.append(name+'.grad') + params.append(param) - tols = dict(atol=2.5e-2, rtol=2.5e-2) - torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols) - torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, **tols) + return out, param_names, tuple(x.grad for x in params) -def _run_dpa_fp8(dtype, config, backend): - """Run FusedAttention FP8 backend, i.e. - fused_attn_fwd/bwd_qkvpacked from cpp_extensions""" +@pytest.mark.skipif(_cudnn_version() < (8,9,3), reason="cuDNN 8.9.3+ is required.") +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +@pytest.mark.skipif(get_device_compute_capability() != (9, 0), reason="FP8 tests require Hopper.") +@pytest.mark.parametrize("dtype", param_types_fp8_vs_f16) +@pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys()) +@pytest.mark.parametrize("qkv_layout", qkv_layout_fp8_vs_f16) +@pytest.mark.parametrize("fp8_dpa_bwd", [True, False]) +def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd): + config = model_configs_fp8_vs_f16[model] + + if (config.num_heads != config.num_gqa_groups and '3' in qkv_layout): + pytest.skip("qkv_layout not applicable for MQA/GQA"); + + os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" + if _NVTE_DEBUG: + print() + print("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True") + fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( + dtype, config, True, qkv_layout) + if _NVTE_DEBUG: + print("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False") + fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16( + dtype, config, False, qkv_layout) + + tols = dict(atol=5e-1, rtol=5e-2) + if _NVTE_DEBUG: + print('[test_dpa_fp8_vs_f16]: ', tols) + print('fused_attn_fwd_fp8 min {:.6f} max {:.6f}'.format( + fused_attn_fwd_fp8.min().item(),fused_attn_fwd_fp8.max().item())) + print('fused_attn_fwd_f16 min {:.6f} max {:.6f}'.format( + fused_attn_fwd_f16.min().item(), fused_attn_fwd_f16.max().item())) + print('fused_attn_fwd RMSE: {:.6f}'.format( + _rmse(fused_attn_fwd_fp8, fused_attn_fwd_f16))) + torch.testing.assert_close(fused_attn_fwd_fp8, fused_attn_fwd_f16, **tols) + for i,_ in enumerate(fused_attn_bwd_f16): + if _NVTE_DEBUG: + print('fused_attn_bwd_fp8 min {:.6f} max {:.6f}'.format( + fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_fp8[i].max().item())) + print('fused_attn_bwd_f16 min {:.6f} max {:.6f}'.format( + fused_attn_bwd_f16[i].min().item(), fused_attn_bwd_f16[i].max().item())) + print('fused_attn_bwd RMSE: {:.6f}'.format( + _rmse(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i]))) + torch.testing.assert_close(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i], **tols) + + +def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout): + reset_rng_states() + _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() + _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) + def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: + """Get cuda rng tracker.""" + return _DUMMY_CUDA_RNG_STATE_TRACKER + + fp8_recipe = recipe.DelayedScaling( + margin=0, + interval=1, + fp8_format=recipe.Format.HYBRID, + amax_history_len=1, + amax_compute_algo="most_recent", + fp8_dpa=fp8_dpa, + ) + + qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()]) + with fp8_model_init(enabled=fp8_dpa): + dpa = ( + DotProductAttention( + config.num_heads, + config.head_dim, + num_gqa_groups=config.num_gqa_groups, + attention_dropout=config.dropout_p, + sequence_parallel=False, + tp_size=1, + get_rng_state_tracker=get_dummy_cuda_rng_tracker, + tp_group=None, + layer_number=1, + attention_type="self", + qkv_format=qkv_format, + ).to(dtype=dtype, device="cuda") + ) + + seqlens_q = torch.full([config.batch_size], config.max_seqlen_q, + dtype=torch.int32, device="cuda") + seqlens_kv = torch.full([config.batch_size], config.max_seqlen_kv, + dtype=torch.int32, device="cuda") + cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda") + cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda") + cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0) + cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0) + + dim_to_num = { + 'b' : config.batch_size, + 'sq' : config.max_seqlen_q, + 'skv': config.max_seqlen_kv, + 'h' : config.num_heads, + 'hg' : config.num_gqa_groups, + 'd' : config.head_dim, + 't' : cu_seqlens_q[-1], + 'tg' : cu_seqlens_kv[-1], + '3' : 3, + '2' : 2, + '1' : 1, + } + inp = [] + for i,layout in enumerate(qkv_layout.split('_')): + layout = '_'.join(layout) + if i == 0: + layout = layout.replace('s', 'sq') + else: + layout = layout.replace('s', 'skv') + layout = layout.replace('h', 'hg') + layout = layout.replace('t', 'tg') + tensor_shape = [dim_to_num[j] for j in layout.split('_')] + tensor = 0.1 * torch.randn(tensor_shape, dtype=dtype, device="cuda") + tensor_count = 1 + split_dim = 0 + for dim, l in enumerate(layout.split('_')): + if l.isdigit(): + tensor_count = int(l) + split_dim = dim + break + tensors = torch.split(tensor, 1, dim=split_dim) if split_dim != 0 else [tensor] + for j in range(tensor_count): + if split_dim != 0: + inp.append(tensors[j].squeeze(split_dim)) + else: + inp.append(tensors[j]) + for i in range(3): + inp[i].requires_grad = True + + qkv_format_kv = '_'.join(qkv_format) + qkv_format_kv = qkv_format_kv.replace('s', 'sq') + out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split('_')] + out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]] + out_grad = 0.1 * torch.randn(out_grad_shape_new, dtype=dtype, device="cuda") + + with fp8_autocast(enabled=fp8_dpa, fp8_recipe=fp8_recipe): + out = dpa(inp[0], inp[1], inp[2], + qkv_format=qkv_format, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_q=config.max_seqlen_q, + max_seqlen_kv=config.max_seqlen_kv, + attn_mask_type=config.attn_mask_type, + checkpoint_core_attention=False, + core_attention_bias_type=config.attn_bias_type, + is_first_microbatch=True, + ) + out.backward(out_grad) + + return out, (inp[0].grad, inp[1].grad, inp[2].grad) + + +model_configs_fp8 = { + # test: b, h, hg, d, sq, skv, p, mask, bias + "fp8_1": ModelConfig(1, 1, 1, 64, 512, 512, 0.0, "no_mask", "no_bias"), + "fp8_2": ModelConfig(4, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), + "fp8_3": ModelConfig(1, 1, 1, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), + "fp8_4": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), + "fp8_5": ModelConfig(1, 1, 1, 64, 512, 512, 0.0, "causal", "no_bias"), + "fp8_6": ModelConfig(4, 16, 16, 64, 512, 512, 0.0, "causal", "no_bias"), + "fp8_7": ModelConfig(1, 1, 1, 128, 2048, 2048, 0.0, "causal", "no_bias"), + "fp8_8": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"), +} +param_types_fp8 = [torch.float16, torch.bfloat16] +cudnn_frontend_version = int(os.getenv('NVTE_FUSED_ATTN_FE_VER','1')) +models_v0 = ['fp8_1', 'fp8_2', 'fp8_5', 'fp8_6'] +models_v1 = ['fp8_3', 'fp8_4', 'fp8_7', 'fp8_8'] + + +@pytest.mark.skipif(_cudnn_version() < (8,9,3), reason="cuDNN 8.9.3+ is required.") +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +@pytest.mark.skipif(get_device_compute_capability() != (9, 0), reason="FP8 tests require Hopper.") +@pytest.mark.parametrize("dtype", param_types_fp8) +@pytest.mark.parametrize("model", models_v1 if cudnn_frontend_version == 1 else models_v0) +def test_custom_mha_fp8_vs_f16(dtype, model): + """Test FP8 dot product attention implementations based on cuDNN frontend + v0.9 and v1.0+. Each test compares results from a custom implementation of + an FP8 MHA module, i.e. Custom_MHA_FP8(), to results from an F16 MHA + implementation, i.e. transformer_engine.pytorch.attention.MultiHeadAttention. + Both paths take F16 input and output. QKV layout is t3hd or bs3hd""" + + config = model_configs_fp8[model] + + fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_custom_mha_fp8( + dtype, config, "FusedAttention") + unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16( + dtype, config, "UnfusedAttention") + + tols = dict(atol=5e-1, rtol=5e-1) + rmse_tol = 0.1 + fwd_rmse = _rmse(fused_attn_fwd_fp8, unfused_attn_fwd_f16) + fwd_range = max(fused_attn_fwd_fp8.max().item(), + unfused_attn_fwd_f16.max().item()) - min(fused_attn_fwd_fp8.min().item(), + unfused_attn_fwd_f16.min().item()) + bwd_rmse = _rmse(fused_attn_bwd_fp8, unfused_attn_bwd_f16) + bwd_range = max(fused_attn_bwd_fp8.max().item(), + unfused_attn_bwd_f16.max().item()) - min(fused_attn_bwd_fp8.min().item(), + unfused_attn_bwd_f16.min().item()) + if _NVTE_DEBUG: + print('fused_attn_fwd_fp8 min {:.6f} max {:.6f}'.format( + fused_attn_fwd_fp8.min().item(),fused_attn_fwd_fp8.max().item())) + print('unfused_attn_fwd_f16 min {:.6f} max {:.6f}'.format( + unfused_attn_fwd_f16.min().item(), unfused_attn_fwd_f16.max().item())) + print('fused_attn_fwd_fp8 vs unfused_attn_fwd_f16 RMSE: {:.6f}'.format( + fwd_rmse)) + try: + torch.testing.assert_close(fused_attn_fwd_fp8, unfused_attn_fwd_f16, **tols) + except Exception as e: + print(e) + print() + print('fused_attn_bwd_fp8 min {:.6f} max {:.6f}'.format( + fused_attn_bwd_fp8.min().item(), fused_attn_bwd_fp8.max().item())) + print('unfused_attn_bwd_f16 min {:.6f} max {:.6f}'.format( + unfused_attn_bwd_f16.min().item(), unfused_attn_bwd_f16.max().item())) + print('fused_attn_bwd_fp8 vs unfused_attn_bwd_f16 RMSE: {:.6f}'.format( + bwd_rmse)) + try: + torch.testing.assert_close(fused_attn_bwd_fp8, unfused_attn_bwd_f16, **tols) + except Exception as e: + print(e) + print() + assert(fwd_rmse < rmse_tol * fwd_range + ), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( + fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range) + assert(bwd_rmse < rmse_tol * bwd_range + ), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( + bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range) + + +def _run_custom_mha_fp8(dtype, config, backend): + """Run Custom_MHA_FP8 with FP8 FusedAttention backend. Both input and output + are in F16. QKV GEMM, DPA, and projection GEMM are calculated in FP8.""" reset_rng_states() os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0" @@ -993,13 +1358,14 @@ def _run_dpa_fp8(dtype, config, backend): if backend == "FusedAttention": os.environ["NVTE_FUSED_ATTN"] = "1" - inp = 0.01 * torch.randn( - config.batch_size * config.max_seqlen_q, config.num_heads * config.head_dim, + inp = 0.0001 * torch.randint(0, 100, + (config.batch_size * config.max_seqlen_q, config.num_heads * config.head_dim), dtype=dtype, device="cuda", requires_grad=True) seqlens = torch.full([config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda") cu_seqlens = torch.zeros(config.batch_size + 1, device="cuda", dtype=torch.int32) cu_seqlens[1:] = torch.cumsum(seqlens, dim=0) + out_grad = 0.01 * torch.randn( config.batch_size * config.max_seqlen_q, config.num_heads * config.head_dim, dtype=dtype, device="cuda") @@ -1013,22 +1379,21 @@ def _run_dpa_fp8(dtype, config, backend): amax_compute_algo="most_recent", ) - dpa = DPA_FP8(config).to(dtype=torch.float16, device="cuda") + mha = Custom_MHA_FP8(config).to(dtype=dtype, device="cuda") with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): - out = dpa(inp, cu_seqlens, config.max_seqlen_q) + out = mha(inp, cu_seqlens, config.max_seqlen_q) out.backward(out_grad) - context = torch.load("ctx.pt") + out = torch.load("out.pt") dqkv = torch.load('dqkv.pt') - return (context.view(config.batch_size, config.max_seqlen_q, -1).transpose(0,1), + return (out.view(config.batch_size, config.max_seqlen_q, -1), dqkv.view(config.batch_size, config.max_seqlen_q, 3, - config.num_heads, config.head_dim).transpose(0,1).contiguous()) + config.num_heads, config.head_dim).contiguous()) -def _run_dpa_fp8_ref(dtype, config, backend): - """Run UnfusedDotProductAttention as a reference, i.e. - plain PyTorch implementation in FP16 and inputs/outputs - are converted from/to FP8""" +def _run_ref_mha_f16(dtype, config, backend): + """Run reference F16 FusedAttention. Both input and output + are in F16. QKV GEMM, DPA, and projection GEMM are also in F16.""" os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0" @@ -1043,7 +1408,7 @@ def _run_dpa_fp8_ref(dtype, config, backend): cu_seqlens = torch.zeros(config.batch_size + 1, device="cuda", dtype=torch.int32) cu_seqlens[1:] = torch.cumsum(seqlens, dim=0) out_grad = torch.load('out_grad.pt').to(device="cuda").view( - config.batch_size, config.max_seqlen_q, -1).transpose(0,1) + config.batch_size, config.max_seqlen_q, -1) _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) @@ -1069,13 +1434,14 @@ def get_dummy_cuda_rng_tracker(): get_rng_state_tracker=get_dummy_cuda_rng_tracker, tp_group=None, layer_number=1, - attention_type="self" + attention_type="self", + qkv_format="bshd", ).to(dtype=dtype, device="cuda") ) - q = inp[:, :,0,:,:] - k = inp[:, :,1,:,:] - v = inp[:, :,2,:,:] + q = inp[:,:,0,:,:] + k = inp[:,:,1,:,:] + v = inp[:,:,2,:,:] out = block(q, k, v, attn_mask_type=config.attn_mask_type) out.backward(out_grad) @@ -1088,14 +1454,14 @@ def get_dummy_cuda_rng_tracker(): _2X_ACC_WGRAD = False META_QKV = tex.FP8FwdTensors.GEMM1_OUTPUT +META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1 META_O = tex.FP8FwdTensors.GEMM2_INPUT META_DO = tex.FP8BwdTensors.GRAD_INPUT2 -META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1 +META_S = tex.FP8FwdTensors.GEMM3_OUTPUT +META_DP = tex.FP8BwdTensors.GRAD_INPUT3 -META_S = tex.FP8FwdTensors.GEMM3_WEIGHT -META_DS = tex.FP8BwdTensors.GRAD_INPUT3 -class _dpa_fp8(torch.autograd.Function): +class _custom_mha_fp8(torch.autograd.Function): @staticmethod def forward( ctx, @@ -1110,6 +1476,7 @@ def forward( fp8_meta: Dict[str, Any], workspace: torch.Tensor, is_training: bool, + mask_type: str, ) -> torch.Tensor: assert inp.dim() == 2 @@ -1117,14 +1484,10 @@ def forward( h = num_heads d = in_features // h b = cu_seqlens.numel() - 1 - is_nl = False - if b < 4 and b > 1: - max_s = 512 - is_nl = True fp8_dtype_forward = fp8.get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - inputmat, inputmat_t = ext.fp8_cast_transpose_fused( + inp_fp8, inp_t_fp8 = ext.fp8_cast_transpose_fused( inp, fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_INPUT, @@ -1142,12 +1505,12 @@ def forward( ZInv = None philox_unpacked = None - qkv_out, _ = ext.fp8_gemm( + qkv, _ = ext.fp8_gemm( qkv_weight_fp8, fp8_meta["scaling_fwd"].scale_inv, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, - inputmat, + inp_fp8, fp8_meta["scaling_fwd"].scale_inv, tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype_forward, @@ -1160,26 +1523,29 @@ def forward( use_split_accumulator=_2X_ACC_FPROP, D_dtype=fp8_dtype_forward, ) - qkv_out = qkv_out.view(-1, 3, h, d) - qkv_out_fp16 = ext.cast_from_fp8(qkv_out, fp8_meta["scaling_fwd"], + qkv = qkv.view(-1, 3, h, d) + qkv_fp16 = ext.cast_from_fp8(qkv, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward, - tex.DType.kFloat16).view(b, max_s, 3, h, d).transpose(0,1).contiguous() - torch.save(qkv_out_fp16, 'qkv.pt') + tex.DType.kFloat16).view(b, max_s, 3, h, d).contiguous() + torch.save(qkv_fp16, 'qkv.pt') + if cudnn_frontend_version == 1: + qkv = qkv.view(b, max_s, 3, h, d) # bs3hd # FMHA - context_, aux_ctx_tensors, *rest = fused_attn_fwd( + out, aux_ctx_tensors, *rest = fused_attn_fwd( is_training, max_s, max_s, cu_seqlens, cu_seqlens, - qkv_out[:,0,:,:], - qkv_out[:,1,:,:], - qkv_out[:,2,:,:], + qkv[:,:,0,:,:] if cudnn_frontend_version == 1 else qkv[:,0,:,:], + qkv[:,:,1,:,:] if cudnn_frontend_version == 1 else qkv[:,1,:,:], + qkv[:,:,2,:,:] if cudnn_frontend_version == 1 else qkv[:,2,:,:], fp8_dtype_forward, FusedAttnBackend["FP8"], None, fp8_meta["scaling_fwd"].scale_inv[META_QKV], + fp8_meta["scaling_fwd"].scale_inv[META_S], fp8_meta["scaling_fwd"].scale[META_S], fp8_meta["scaling_fwd"].scale[META_O], fp8_meta["scaling_fwd"].amax_history[0][META_S], @@ -1187,20 +1553,17 @@ def forward( attn_scale=None, dropout=p_dropout, fast_zero_fill=fast_zero_fill, - qkv_layout="t3hd", + qkv_layout="bs3hd" if cudnn_frontend_version == 1 else "t3hd", attn_bias_type="no_bias", - attn_mask_type="padding", + attn_mask_type=mask_type if cudnn_frontend_version == 1 else "padding", rng_gen=None, ) - M, ZInv, philox_unpacked = aux_ctx_tensors - context = context_.view(-1, in_features) - context_t = tex.fp8_transpose(context, fp8_dtype_forward) + M, ZInv, philox_unpacked = aux_ctx_tensors ctx.save_for_backward( - inputmat_t, qkv_weight_t_fp8, workspace, - qkv_out, - context_, context_t, + inp_t_fp8, qkv_weight_t_fp8, workspace, + qkv, out, fp8_meta["scaling_fwd"].scale, fp8_meta["scaling_fwd"].scale_inv, ) @@ -1210,14 +1573,16 @@ def forward( ctx.p_dropout = p_dropout ctx.max_s = max_s ctx.fast_zero_fill = fast_zero_fill - ctx.is_nl = is_nl ctx.hidden_size = in_features ctx.num_heads = num_heads + ctx.mask_type = mask_type + ctx.dtype = inp.dtype - context_fp16 = ext.cast_from_fp8(context, fp8_meta["scaling_fwd"], + out = out.view(-1, in_features) # (bs)(hd) + out_fp16 = ext.cast_from_fp8(out, fp8_meta["scaling_fwd"], META_O, fp8_dtype_forward, tex.DType.kFloat16) - torch.save(context_fp16, 'ctx.pt') - return context_fp16 + torch.save(out_fp16, 'out.pt') # (bs)(hd) + return out_fp16 @staticmethod @@ -1226,11 +1591,10 @@ def backward( ) -> Tuple[Union[torch.Tensor, None], ...]: with torch.cuda.nvtx.range("_DPA"): ( - inputmat_t, + inp_t_fp8, qkv_weight_t_fp8, workspace, - qkv_out, - context, context_t, + qkv, out, fwd_scales, fwd_scale_inverses, ) = ctx.saved_tensors @@ -1243,51 +1607,59 @@ def backward( proj_dgrad = ext.cast_to_fp8( grad_output, ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward - ) + ) # (bs)(hd) dq, dk, dv, *rest = fused_attn_bwd( ctx.max_s, ctx.max_s, ctx.cu_seqlens, ctx.cu_seqlens, - qkv_out[:,0,:,:], - qkv_out[:,1,:,:], - qkv_out[:,2,:,:], - context, - proj_dgrad.view_as(context), + qkv[:,:,0,:,:] if cudnn_frontend_version == 1 else qkv[:,0,:,:], + qkv[:,:,1,:,:] if cudnn_frontend_version == 1 else qkv[:,1,:,:], + qkv[:,:,2,:,:] if cudnn_frontend_version == 1 else qkv[:,2,:,:], + out, + proj_dgrad.view_as(out), fp8_dtype_forward, + fp8_dtype_backward, ctx.aux_ctx_tensors, FusedAttnBackend["FP8"], fwd_scale_inverses[META_QKV], # d_scale_qkv, fwd_scale_inverses[META_S], # d_scale_s, fwd_scale_inverses[META_O], # d_scale_o, ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO], # d_scale_do + ctx.fp8_meta['scaling_bwd'].scale_inv[META_DP], # d_scale_dp fwd_scales[META_S], # q_scale_s - ctx.fp8_meta['scaling_bwd'].scale[META_DS], # q_scale_ds + ctx.fp8_meta['scaling_bwd'].scale[META_DP], # q_scale_dp ctx.fp8_meta['scaling_bwd'].scale[META_DQKV], # q_scale_dqkv - ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DS], # amax_ds + ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DP], # amax_dp ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DQKV], # amax_dqkv - None, - ctx.p_dropout, - ctx.fast_zero_fill, - "t3hd", - "no_bias", - "padding", + attn_scale=None, + dropout=ctx.p_dropout, + fast_zero_fill=ctx.fast_zero_fill, + qkv_layout="bs3hd" if cudnn_frontend_version == 1 else "t3hd", + attn_bias_type="no_bias", + attn_mask_type=ctx.mask_type if cudnn_frontend_version == 1 else "padding", ) - dqkv = torch.cat([dq.unsqueeze(1), dk.unsqueeze(1), dv.unsqueeze(1)], dim=1) - - dqkv_grad_output_c = dqkv.view(-1, 3*ctx.hidden_size) - dqkv_grad_output_c_fp16 = ext.cast_from_fp8(dqkv_grad_output_c, + dim = 2 if cudnn_frontend_version == 1 else 1 + dqkv = torch.Tensor().to(device=dq.device, dtype=dq.dtype) + dqkv_shape = list(dq.shape) + dqkv_shape.insert(dim, 3) + dqkv_stride = list(dq.stride()) + dqkv_stride.insert(dim, int(dqkv_stride[-3]/3)) + dqkv.set_(dq.untyped_storage(), dq.storage_offset(), dqkv_shape, dqkv_stride) # bs3hd + + dqkv_c = dqkv.view(-1, 3*ctx.hidden_size) + dqkv_c_fp16 = ext.cast_from_fp8(dqkv_c, ctx.fp8_meta["scaling_bwd"], META_DQKV, fp8_dtype_backward, tex.DType.kFloat16) - torch.save(dqkv_grad_output_c_fp16, 'dqkv.pt') + torch.save(dqkv_c_fp16, 'dqkv.pt') - qkv_bgrad, dqkv_grad_output_t = ext.fp8_transpose_bgrad_fused( - dqkv_grad_output_c, + qkv_bgrad, dqkv_t = ext.fp8_transpose_bgrad_fused( + dqkv_c, ctx.fp8_meta["scaling_bwd"], META_DQKV, fp8_dtype_backward, - torch.float16, + ctx.dtype, ) # QKV DGRAD @@ -1296,25 +1668,25 @@ def backward( fwd_scale_inverses, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, - dqkv_grad_output_c, + dqkv_c, ctx.fp8_meta["scaling_bwd"].scale_inv, META_DQKV, fp8_dtype_backward, - torch.float16, + ctx.dtype, workspace, use_split_accumulator=_2X_ACC_DGRAD, ) # QKV WGRAD qkv_wgrad, _ = ext.fp8_gemm( - inputmat_t, + inp_t_fp8, fwd_scale_inverses, tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype_forward, - dqkv_grad_output_t, + dqkv_t, ctx.fp8_meta["scaling_bwd"].scale_inv, META_DQKV, fp8_dtype_backward, - torch.float16, + ctx.dtype, workspace, use_split_accumulator=_2X_ACC_WGRAD, ) @@ -1334,7 +1706,7 @@ def backward( None) -class DPA_FP8(TransformerEngineBaseModule): +class Custom_MHA_FP8(TransformerEngineBaseModule): def __init__( self, config, @@ -1345,6 +1717,7 @@ def __init__( self.hidden_size = config.hidden_size self.head_dim = config.head_dim self.fast_zero_fill = True + self.mask_type = config.attn_mask_type self.qkv_weight = torch.nn.Parameter( torch.empty( @@ -1374,7 +1747,7 @@ def forward( cu_seqlens, max_s, ) -> torch.Tensor: with self.prepare_forward(inp, None, num_gemms=3) as inp: - out = _dpa_fp8.apply( + out = _custom_mha_fp8.apply( inp, self.qkv_weight, self.qkv_bias, @@ -1385,7 +1758,8 @@ def forward( self.fast_zero_fill, self.fp8_meta, self.workspace, - self.training) + self.training, + self.mask_type) return out def get_fp8_weights_scratchpad( diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index ddb3ecf49f..0cda82e0c4 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -1091,7 +1091,7 @@ def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centere torch_outputs = _test_granular_accuracy(torch_ln_linear, bs, dtype, config) # Check output. - atol = {torch.float32 : 2e-4, + atol = {torch.float32 : 2.5e-4, torch.half : 2e-3, torch.bfloat16: 2e-2, } diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 43e7d17350..2d9759898f 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -85,15 +85,25 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( NVTE_CHECK(q_dtype == kv_dtype, "Q and KV must have the same data type."); NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); auto cudnn_runtime_version = cudnnGetVersion(); - if ((q_dtype == NVTEDType::kNVTEFloat8E4M3) || (q_dtype == NVTEDType::kNVTEFloat8E5M2) - && (sm_arch_ >= 90) - && (max_seqlen_q == max_seqlen_kv) - && (num_attn_heads == num_gqa_groups) - && (max_seqlen_q <= 512) - && (head_dim == 64) - && (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) - && (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) - && (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD)) { + if (((q_dtype == NVTEDType::kNVTEFloat8E4M3) + || (q_dtype == NVTEDType::kNVTEFloat8E5M2)) + && (sm_arch_ >= 90) + && (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) + && ( + ((cudnn_runtime_version >= 8900) + && (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) + && (max_seqlen_q == max_seqlen_kv) + && (max_seqlen_q <= 512) + && (head_dim == 64) + && (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)) + || ((cudnn_runtime_version >= 90100) + && (max_seqlen_q % 128 == 0) + && (max_seqlen_kv % 128 == 0) + && (head_dim == 128) + && ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) + || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) + && ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) + || (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK))))) { if (cudnn_runtime_version >= 8900) { backend = NVTE_Fused_Attn_Backend::NVTE_FP8; } else { @@ -269,7 +279,7 @@ void nvte_fused_attn_fwd_qkvpacked( #if (CUDNN_VERSION >= 8900) fused_attn_fp8_fwd_qkvpacked( b, h, max_seqlen, d, - is_training, attn_scale, dropout, qkv_layout, + is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_QKV, input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens, @@ -379,7 +389,7 @@ void nvte_fused_attn_bwd_qkvpacked( const Tensor *input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); fused_attn_fp8_bwd_qkvpacked( b, h, max_seqlen, d, - attn_scale, dropout, qkv_layout, + attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_QKV, input_O, input_dO, input_M, input_ZInv, input_S, input_output_dP, @@ -476,7 +486,18 @@ void nvte_fused_attn_fwd_kvpacked( "cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { - NVTE_ERROR("The FP8 fused attention API only supports packed QKV input. \n"); +#if (CUDNN_VERSION >= 8900) + fused_attn_fp8_fwd_kvpacked( + b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, + is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, + input_Q, input_KV, input_output_S, output_O, + Aux_CTX_Tensors, + input_cu_seqlens_q, input_cu_seqlens_kv, + input_rng_state, + wkspace, stream, handle); +#else + NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); +#endif } else { NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); } @@ -580,7 +601,23 @@ void nvte_fused_attn_bwd_kvpacked( NVTE_ERROR(err_msg); #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { - NVTE_ERROR("The FP8 fused attention API only supports packed QKV input. \n"); +#if (CUDNN_VERSION >= 8900) + const Tensor *input_M = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + const Tensor *input_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + const Tensor *input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + fused_attn_fp8_bwd_kvpacked( + b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, + attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, + input_Q, input_KV, input_O, input_dO, + input_M, input_ZInv, + input_S, input_output_dP, + output_dQ, output_dKV, + input_cu_seqlens_q, input_cu_seqlens_kv, + input_rng_state, + wkspace, stream, handle); +#else + NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); +#endif } else { NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); } @@ -662,8 +699,8 @@ void nvte_fused_attn_fwd( } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) fused_attn_fp8_fwd( - b, h_q, max_seqlen_q, max_seqlen_kv, d, - is_training, attn_scale, dropout, qkv_layout, + b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, + is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, @@ -775,8 +812,8 @@ void nvte_fused_attn_bwd( const Tensor *input_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); const Tensor *input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); fused_attn_fp8_bwd( - b, h_q, max_seqlen_q, max_seqlen_kv, d, - attn_scale, dropout, qkv_layout, + b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, + attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, input_O, input_dO, input_M, input_ZInv, input_S, input_output_dP, diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 8ffd8608b6..180759f327 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -76,7 +76,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( scaling_factor, is_training, dropout_probability, layout, bias_type, mask_type, - tensorType}; + tensorType, tensorType}; namespace fe = cudnn_frontend; using graph_and_tensors = std::tuple, @@ -147,7 +147,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( fe::graph::SDPA_attributes sdpa_options; sdpa_options = fe::graph::SDPA_attributes() .set_name("flash_attention") - .set_is_inference(!is_training) + .set_is_inference(false) .set_causal_mask(is_causal) .set_attn_scale(attn_scale); @@ -199,11 +199,9 @@ void fused_attn_arbitrary_seqlen_fwd_impl( layout, NVTE_QKV_Matrix::NVTE_O_Matrix); O->set_output(true).set_dim({b, h, s_q, d}).set_stride(o_stride); - if (is_training) { - Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT) - .set_dim({b, h, s_q, 1}) - .set_stride({h * s_q, s_q, 1, 1}); - } + Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT) + .set_dim({b, h, s_q, 1}) + .set_stride({h * s_q, s_q, 1, 1}); std::tuple, // Q std::shared_ptr, // K @@ -211,7 +209,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( std::shared_ptr, // attn_scale std::shared_ptr > // O key_tensors_tuple = std::make_tuple(Q, K, V, attn_scale, O); - auto Stats_tuple = is_training ? std::make_tuple(Stats) : std::make_tuple(nullptr); + auto Stats_tuple = std::make_tuple(Stats); auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr); auto padding_tuple = is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); @@ -258,11 +256,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( {K, devPtrK}, {V, devPtrV}, {attn_scale, &scaling_factor}, - {O, devPtrO}}; - - if (is_training) { - variant_pack[Stats] = devPtrSoftmaxStats; - } + {O, devPtrO}, + {Stats, devPtrSoftmaxStats}}; if (is_bias) { variant_pack[bias] = devPtrBias; @@ -321,7 +316,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( scaling_factor, true, dropout_probability, layout, bias_type, mask_type, - tensorType}; + tensorType, tensorType}; namespace fe = cudnn_frontend; using graph_and_tensors = std::tuple, diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index 55a5638b26..a8866908ce 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -19,7 +19,7 @@ namespace transformer_engine { #if (CUDNN_VERSION >= 8900) void fused_attn_arbitrary_seqlen_fwd_qkvpacked( size_t batch, size_t num_attn_heads, size_t max_seqlen, - size_t head_size, bool is_training, float attn_scale, + size_t head_dim, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_Bias, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 76c1a44b0d..66185c0c41 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -8,6 +8,7 @@ #include "../common.h" #include "utils.h" +#include "../util/system.h" #include "fused_attn_fp8.h" namespace transformer_engine { @@ -984,7 +985,7 @@ static cudnn_frontend::Tensor createdSQBMM( return After_dSTranspose_Q; } -// fused attention FWD FP8 +// fused attention FWD FP8 with FE 0.9 void fused_attn_fp8_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, bool isTraining, float attnScale, float dropoutProbability, NVTE_QKV_Layout layout, @@ -1295,7 +1296,7 @@ void fused_attn_fp8_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in } } -// fused attention BWD FP8 +// fused attention BWD FP8 with FE 0.9 void fused_attn_fp8_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, float attnScale, float dropoutProbability, NVTE_QKV_Layout layout, void* devPtrQ, void* devPtrK, void* devPtrV, @@ -1846,6 +1847,707 @@ void fused_attn_fp8_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in } } +// fused attention FWD FP8 with FE 1.0+ +void fused_attn_fp8_fwd_impl_v1(int64_t b, int64_t h, int64_t hg, + int64_t s_q, int64_t s_kv, int64_t d, + bool is_training, float scaling_factor, + float dropout_probability, NVTE_QKV_Layout layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + void* devPtrQ, void* devPtrK, void* devPtrV, + void* devPtrM, void* devPtrZInv, + void* devPtrO, + void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, + void* devPtrDescaleS, void* devPtrScaleS, void* devPtrScaleO, + void* devPtrAmaxO, void* devPtrAmaxS, + void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, + void* devPtrDropoutSeed, void* devPtrDropoutOffset, + cudnn_frontend::DataType_t fwd_tensor_type, + void* workspace, + size_t* workspace_size, + cudaStream_t stream, + cudnnHandle_t handle) { + using namespace transformer_engine; + bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); + bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI); + bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) + || (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); + bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) + || (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); + bool is_dropout = (is_training && dropout_probability != 0.0f); + auto bias_b = b; + auto bias_h = h; + NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!"); + NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); + NVTE_CHECK(~is_padding, + "FP8 fused attention does not support padding/padding_causal mask yet!"); + NVTE_CHECK(~is_dropout, "FP8 fused attention does not support dropout yet!"); + + try { + FADescriptor_v1 descriptor{b, h, + hg, s_q, + s_kv, d, + bias_b, bias_h, + scaling_factor, is_training, + dropout_probability, layout, + bias_type, mask_type, + fwd_tensor_type, fwd_tensor_type}; + + namespace fe = cudnn_frontend; + using graph_and_tensors = std::tuple, + std::shared_ptr, // Q + std::shared_ptr, // K + std::shared_ptr, // V + std::shared_ptr, // descale_q + std::shared_ptr, // descale_k + std::shared_ptr, // descale_v + std::shared_ptr, // descale_s + std::shared_ptr, // scale_s + std::shared_ptr, // scale_o + std::shared_ptr, // attn_scale + std::shared_ptr, // O + std::shared_ptr, // amax_s + std::shared_ptr, // amax_o + std::shared_ptr, // Stats + std::shared_ptr, // bias + std::shared_ptr, // seq_q + std::shared_ptr, // seq_kv + std::shared_ptr, // dropout_seed + std::shared_ptr >; // dropout_offset + + using CacheType = std::map; + static thread_local CacheType sdpa_fp8_fprop_cache; + + // Get plan from cache if cache is available, otherwise create one + auto get_graph = [&](CacheType &cache, const FADescriptor_v1 &descriptor) + -> graph_and_tensors { + // if hit, return + auto it = cache.find(descriptor); + if (it != cache.end()) { + auto graph = it->second; + return graph; + } + + // otherwise, build the op_graph and the plan. Then update cache + auto mha_graph = std::make_shared(); + mha_graph->set_io_data_type(fwd_tensor_type) + .set_intermediate_data_type(fe::DataType_t::FLOAT) + .set_compute_data_type(fe::DataType_t::FLOAT); + + std::shared_ptr Q, K, V, attn_scale; + std::shared_ptr descale_q, descale_k, descale_v; + std::shared_ptr descale_s, scale_s, scale_o; + std::shared_ptr bias, seq_q, seq_kv; + std::shared_ptr dropout_seed, dropout_offset; + + std::vector q_stride(4); + std::vector k_stride(4); + std::vector v_stride(4); + generateMatrixStrides(b, h, s_q, s_kv, d, q_stride.data(), + layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); + generateMatrixStrides(b, hg, s_q, s_kv, d, k_stride.data(), + layout, NVTE_QKV_Matrix::NVTE_K_Matrix); + generateMatrixStrides(b, hg, s_q, s_kv, d, v_stride.data(), + layout, NVTE_QKV_Matrix::NVTE_V_Matrix); + Q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Q") + .set_dim({b, h, s_q, d}) + .set_stride(q_stride)); + K = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("K") + .set_dim({b, hg, s_kv, d}) + .set_stride(k_stride)); + V = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("V") + .set_dim({b, hg, s_kv, d}) + .set_stride(v_stride)); + + attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("attn_scale") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_is_pass_by_value(true) + .set_data_type(fe::DataType_t::FLOAT)); + + descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_q") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + descale_k = mha_graph->tensor_like(descale_q, "Descale_q"); + descale_v = mha_graph->tensor_like(descale_q, "Descale_V"); + descale_s = mha_graph->tensor_like(descale_q, "Descale_S"); + scale_s = mha_graph->tensor_like(descale_q, "Scale_S"); + scale_o = mha_graph->tensor_like(descale_q, "Scale_O"); + + fe::graph::SDPA_fp8_attributes sdpa_options; + sdpa_options = fe::graph::SDPA_fp8_attributes() + .set_name("sdpa_fp8") + .set_is_inference(false) + .set_causal_mask(is_causal) + .set_attn_scale(attn_scale); + + // sdpa_options.set_alibi_mask(is_alibi); + // if (is_bias) { + // bias = mha_graph->tensor(fe::graph::Tensor_attributes() + // .set_name("bias") + // .set_dim({bias_b, bias_h, s_q, s_kv}) + // .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1})); + // sdpa_options.set_bias(bias); + // } + + // if (is_padding) { + // seq_q = mha_graph->tensor(fe::graph::Tensor_attributes() + // .set_name("seq_q") + // .set_dim({b, 1, 1, 1}) + // .set_stride({1, 1, 1, 1}) + // .set_data_type(fe::DataType_t::INT32)); + // seq_kv = mha_graph->tensor(fe::graph::Tensor_attributes() + // .set_name("seq_kv") + // .set_dim({b, 1, 1, 1}) + // .set_stride({1, 1, 1, 1}) + // .set_data_type(fe::DataType_t::INT32)); + // sdpa_options.set_padding_mask(is_padding) + // .set_seq_len_q(seq_q) + // .set_seq_len_kv(seq_kv); + // } + + // if (is_dropout) { + // dropout_seed = mha_graph->tensor(fe::graph::Tensor_attributes() + // .set_name("Seed") + // .set_dim({1, 1, 1, 1}) + // .set_stride({1, 1, 1, 1}) + // .set_data_type(fe::DataType_t::INT64)); + // dropout_offset = mha_graph->tensor(fe::graph::Tensor_attributes() + // .set_name("Offset") + // .set_dim({1, 1, 1, 1}) + // .set_stride({1, 1, 1, 1}) + // .set_data_type(fe::DataType_t::INT64)); + // sdpa_options.set_dropout( + // dropout_probability, dropout_seed, dropout_offset); + // } + + auto [O, Stats, amax_s, amax_o] = mha_graph->sdpa_fp8( + Q, K, V, descale_q, descale_k, descale_v, descale_s, + scale_s, scale_o, sdpa_options); + + std::vector o_stride(4); + generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), + layout, NVTE_QKV_Matrix::NVTE_O_Matrix); + O->set_output(true).set_dim({b, h, s_q, d}).set_stride(o_stride); + amax_o->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); + amax_s->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); + + Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT) + .set_dim({b, h, s_q, 1}) + .set_stride({h * s_q, s_q, 1, 1}); + + std::tuple, // Q + std::shared_ptr, // K + std::shared_ptr, // V + std::shared_ptr, // descale_q + std::shared_ptr, // descale_k + std::shared_ptr, // descale_v + std::shared_ptr, // descale_s + std::shared_ptr, // scale_s + std::shared_ptr, // scale_o + std::shared_ptr, // attn_scale + std::shared_ptr, // O + std::shared_ptr, // amax_s + std::shared_ptr > // amax_o + key_tensors_tuple = std::make_tuple(Q, K, V, descale_q, descale_k, descale_v, + descale_s, scale_s, scale_o, attn_scale, O, amax_s, amax_o); + auto Stats_tuple = std::make_tuple(Stats); + auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr); + auto padding_tuple = is_padding ? + std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); + auto dropout_tuple = is_dropout ? + std::make_tuple(dropout_seed, dropout_offset) : std::make_tuple(nullptr, nullptr); + + NVTE_CHECK_CUDNN_FE(mha_graph->validate()); + NVTE_CHECK_CUDNN_FE(mha_graph->build_operation_graph(handle)); + NVTE_CHECK_CUDNN_FE(mha_graph->create_execution_plans({fe::HeurMode_t::A})); + NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle)); + NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); + + auto return_tuple = std::tuple_cat( + std::make_tuple(mha_graph), key_tensors_tuple, + Stats_tuple, bias_tuple, padding_tuple, dropout_tuple); + cache.insert({descriptor, return_tuple}); + + return return_tuple; + }; + + auto [mha_graph, Q, K, V, descale_q, descale_k, descale_v, descale_s, + scale_s, scale_o, attn_scale, O, amax_s, amax_o, Stats, + bias, seq_q, seq_kv, dropout_seed, dropout_offset] = get_graph( + sdpa_fp8_fprop_cache, descriptor); + + auto plan_workspace_size = mha_graph->get_workspace_size(); + + // Exit to request upper level API to allocate memory if needed + size_t actual_seqlen_workspace_size = 2 * b * sizeof(int32_t); + if (workspace == nullptr) { + *workspace_size = plan_workspace_size + actual_seqlen_workspace_size; + return; + } + + // cuDNN stream check needs to be moved here to support dummy kernel calls with + // null streams for sizing the cuDNN workspace. + NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream)); + + // Build variant pack + std::unordered_map, void*> variant_pack = { + {Q, devPtrQ}, + {K, devPtrK}, + {V, devPtrV}, + {descale_q, devPtrDescaleQ}, + {descale_k, devPtrDescaleK}, + {descale_v, devPtrDescaleV}, + {descale_s, devPtrDescaleS}, + {scale_s, devPtrScaleS}, + {scale_o, devPtrScaleO}, + {attn_scale, &scaling_factor}, + {O, devPtrO}, + {amax_s, devPtrAmaxS}, + {amax_o, devPtrAmaxO}, + {Stats, devPtrM}}; + + // if (is_bias) { + // variant_pack[bias] = devPtrBias; + // } + + // if (is_padding) { + // constexpr size_t nthreads_per_block = 128; + // const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block; + // void *devActualSeqlenQ = static_cast(workspace) + plan_workspace_size; + // void *devActualSeqlenKV = static_cast(devActualSeqlenQ) + // + b * sizeof(int32_t); + // cu_seqlens_to_actual_seqlens<<>>( + // b, static_cast(devPtrCuSeqlensQ), + // static_cast(devPtrCuSeqlensKV), + // static_cast(devActualSeqlenQ), + // static_cast(devActualSeqlenKV)); + // variant_pack[seq_q] = devActualSeqlenQ; + // variant_pack[seq_kv] = devActualSeqlenKV; + // } + + // if (is_dropout) { + // variant_pack[dropout_seed] = devPtrDropoutSeed; + // variant_pack[dropout_offset] = devPtrDropoutOffset; + // } + NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace)); + } catch (cudnn_frontend::cudnnException &e) { + NVTE_ERROR(e.what()); + } +} + +// fused attention BWD FP8 with FE 1.0+ +void fused_attn_fp8_bwd_impl_v1(int64_t b, int64_t h, int64_t hg, + int64_t s_q, int64_t s_kv, int64_t d, + float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + void* devPtrQ, void* devPtrK, void* devPtrV, + void* devPtrM, void* devPtrZInv, + void* devPtrO, void* devPtrdO, + void* devPtrdQ, void* devPtrdK, void* devPtrdV, + void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, + void* devPtrDescaleO, void* devPtrDescaledO, + void* devPtrDescaleS, void* devPtrDescaledP, + void* devPtrScaleS, void* devPtrScaledP, + void* devPtrScaledQ, void* devPtrScaledK, void* devPtrScaledV, + void* devPtrAmaxdP, + void* devPtrAmaxdQ, void* devPtrAmaxdK, void* devPtrAmaxdV, + void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, + void* devPtrDropoutSeed, void* devPtrDropoutOffset, + cudnn_frontend::DataType_t fwd_tensor_type, + cudnn_frontend::DataType_t bwd_tensor_type, + void* workspace, + size_t* workspace_size, + cudaStream_t stream, + cudnnHandle_t handle) { + using namespace transformer_engine; + bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); + bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI); + bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) + || (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); + bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) + || (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); + bool is_dropout = (dropout_probability != 0.0f); + auto bias_b = b; + auto bias_h = h; + NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!"); + NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); + NVTE_CHECK(~is_padding, + "FP8 fused attention does not support padding/padding_causal mask yet!"); + NVTE_CHECK(~is_dropout, "FP8 fused attention does not support dropout yet!"); + + try { + FADescriptor_v1 descriptor{b, h, + hg, s_q, + s_kv, d, + bias_b, bias_h, + scaling_factor, true, + dropout_probability, layout, + bias_type, mask_type, + fwd_tensor_type, bwd_tensor_type}; + + namespace fe = cudnn_frontend; + using graph_and_tensors = std::tuple, + std::shared_ptr, // q + std::shared_ptr, // k + std::shared_ptr, // v + std::shared_ptr, // o + std::shared_ptr, // stats + std::shared_ptr, // dO + std::shared_ptr, // attn_scale + std::shared_ptr, // descale_q + std::shared_ptr, // descale_k + std::shared_ptr, // descale_v + std::shared_ptr, // descale_o + std::shared_ptr, // descale_dO + std::shared_ptr, // descale_s + std::shared_ptr, // descale_dP + std::shared_ptr, // scale_dQ + std::shared_ptr, // scale_dK + std::shared_ptr, // scale_dV + std::shared_ptr, // scale_s + std::shared_ptr, // scale_dP + std::shared_ptr, // dQ + std::shared_ptr, // dK + std::shared_ptr, // dV + std::shared_ptr, // amax_dQ + std::shared_ptr, // amax_dK + std::shared_ptr, // amax_dV + std::shared_ptr, // amax_dP + std::shared_ptr, // bias + std::shared_ptr, // dBias + std::shared_ptr, // seq_q + std::shared_ptr, // seq_kv + std::shared_ptr, // dropout_seed + std::shared_ptr >; // dropout_offset + + using CacheType = std::map; + static thread_local CacheType sdpa_fp8_bprop_cache; + + // Get plan from cache if cache is available, otherwise create one + auto get_graph = [&](CacheType &cache, const FADescriptor_v1 &descriptor) + -> graph_and_tensors { + // if hit, return + auto it = cache.find(descriptor); + if (it != cache.end()) { + auto graph = it->second; + return graph; + } + + // otherwise, build the op_graph and the plan. Then update cache + auto mha_graph = std::make_shared(); + + mha_graph->set_io_data_type(fwd_tensor_type) + .set_intermediate_data_type(fe::DataType_t::FLOAT) + .set_compute_data_type(fe::DataType_t::FLOAT); + + std::shared_ptr q, k, v, o, dO, stats, attn_scale; + std::shared_ptr descale_q, descale_k, descale_v; + std::shared_ptr descale_s, descale_o; + std::shared_ptr descale_dP, descale_dO; + std::shared_ptr scale_s, scale_dP; + std::shared_ptr scale_dQ, scale_dK, scale_dV; + std::shared_ptr bias, dBias, seq_q, seq_kv; + std::shared_ptr dropout_seed, dropout_offset; + + std::vector q_stride(4); + std::vector k_stride(4); + std::vector v_stride(4); + std::vector o_stride(4); + generateMatrixStrides(b, h, s_q, s_kv, d, q_stride.data(), + layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); + generateMatrixStrides(b, hg, s_q, s_kv, d, k_stride.data(), + layout, NVTE_QKV_Matrix::NVTE_K_Matrix); + generateMatrixStrides(b, hg, s_q, s_kv, d, v_stride.data(), + layout, NVTE_QKV_Matrix::NVTE_V_Matrix); + generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), + layout, NVTE_QKV_Matrix::NVTE_O_Matrix); + q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Q") + .set_dim({b, h, s_q, d}) + .set_stride(q_stride)); + k = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("K") + .set_dim({b, hg, s_kv, d}) + .set_stride(k_stride)); + v = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("V") + .set_dim({b, hg, s_kv, d}) + .set_stride(v_stride)); + o = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("O") + .set_dim({b, h, s_q, d}) + .set_stride(o_stride)); + dO = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("dO") + .set_dim({b, h, s_q, d}) + .set_stride(o_stride)); + stats = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("stats") + .set_dim({b, h, s_q, 1}) + .set_stride({h * s_q, s_q, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + + attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("attn_scale") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_is_pass_by_value(true) + .set_data_type(fe::DataType_t::FLOAT)); + + descale_q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Descale_q") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + descale_k = mha_graph->tensor_like(descale_q, "Descale_q"); + descale_v = mha_graph->tensor_like(descale_q, "Descale_V"); + descale_s = mha_graph->tensor_like(descale_q, "Descale_S"); + descale_o = mha_graph->tensor_like(descale_q, "Descale_O"); + descale_dP = mha_graph->tensor_like(descale_q, "Descale_dP"); + descale_dO = mha_graph->tensor_like(descale_q, "Descale_dO"); + scale_s = mha_graph->tensor_like(descale_q, "Scale_S"); + scale_dP = mha_graph->tensor_like(descale_q, "Scale_dP"); + scale_dQ = mha_graph->tensor_like(descale_q, "Scale_dQ"); + scale_dK = mha_graph->tensor_like(descale_q, "Scale_dK"); + scale_dV = mha_graph->tensor_like(descale_q, "Scale_dV"); + + fe::graph::SDPA_fp8_backward_attributes sdpa_backward_options; + sdpa_backward_options = fe::graph::SDPA_fp8_backward_attributes() + .set_name("sdpa_fp8_backward") + .set_causal_mask(is_causal) + .set_attn_scale(attn_scale); + + // sdpa_backward_options.set_alibi_mask(is_alibi); + + // if (is_bias) { + // bias = mha_graph->tensor(fe::graph::Tensor_attributes() + // .set_name("bias") + // .set_dim({bias_b, bias_h, s_q, s_kv}) + // .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1})); + // dBias = mha_graph->tensor(fe::graph::Tensor_attributes() + // .set_name("dBias") + // .set_dim({bias_b, bias_h, s_q, s_kv}) + // .set_stride({bias_h * s_q * s_kv, s_q * s_kv, s_kv, 1})); + // sdpa_backward_options.set_bias(bias); + // // shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s] + // // are not supported for dbias calculation but they are + // // supported for forward bias calculation + // if ((bias_b == 1) && (bias_h == h)) { + // sdpa_backward_options.set_dbias(dBias); + // } + // } + + // if (is_padding) { + // seq_q = mha_graph->tensor(fe::graph::Tensor_attributes() + // .set_name("seq_q") + // .set_dim({b, 1, 1, 1}) + // .set_stride({1, 1, 1, 1}) + // .set_data_type(fe::DataType_t::INT32)); + // seq_kv = mha_graph->tensor(fe::graph::Tensor_attributes() + // .set_name("seq_kv") + // .set_dim({b, 1, 1, 1}) + // .set_stride({1, 1, 1, 1}) + // .set_data_type(fe::DataType_t::INT32)); + // sdpa_backward_options.set_padding_mask(is_padding) + // .set_seq_len_q(seq_q) + // .set_seq_len_kv(seq_kv); + // } + + // if (is_dropout) { + // dropout_seed = mha_graph->tensor(fe::graph::Tensor_attributes() + // .set_name("Seed") + // .set_dim({1, 1, 1, 1}) + // .set_stride({1, 1, 1, 1}) + // .set_data_type(fe::DataType_t::INT64)); + // dropout_offset = mha_graph->tensor(fe::graph::Tensor_attributes() + // .set_name("Offset") + // .set_dim({1, 1, 1, 1}) + // .set_stride({1, 1, 1, 1}) + // .set_data_type(fe::DataType_t::INT64)); + // sdpa_backward_options.set_dropout( + // dropout_probability, dropout_seed, dropout_offset); + // } + + auto [dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP] = mha_graph->sdpa_fp8_backward( + q, k, v, o, dO, stats, + descale_q, descale_k, descale_v, + descale_o, descale_dO, descale_s, descale_dP, + scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, + sdpa_backward_options); + + dQ->set_output(true) + .set_dim({b, h, s_q, d}) + .set_stride(q_stride); + dK->set_output(true) + .set_dim({b, hg, s_kv, d}) + .set_stride(k_stride); + dV->set_output(true) + .set_dim({b, hg, s_kv, d}) + .set_stride(v_stride); + amax_dQ->set_output(true) + .set_dim({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT); + amax_dK->set_output(true) + .set_dim({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT); + amax_dV->set_output(true) + .set_dim({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT); + amax_dP->set_output(true) + .set_dim({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT); + + dO->set_data_type(bwd_tensor_type); + dQ->set_data_type(bwd_tensor_type); + dK->set_data_type(bwd_tensor_type); + dV->set_data_type(bwd_tensor_type); + + std::tuple, // q + std::shared_ptr, // k + std::shared_ptr, // v + std::shared_ptr, // o + std::shared_ptr, // stats + std::shared_ptr, // dO + std::shared_ptr, // attn_scale + std::shared_ptr, // descale_q + std::shared_ptr, // descale_k + std::shared_ptr, // descale_v + std::shared_ptr, // descale_o + std::shared_ptr, // descale_dO + std::shared_ptr, // descale_s + std::shared_ptr, // descale_dP + std::shared_ptr, // scale_dQ + std::shared_ptr, // scale_dK + std::shared_ptr, // scale_dV + std::shared_ptr, // scale_s + std::shared_ptr, // scale_dP + std::shared_ptr, // dQ + std::shared_ptr, // dK + std::shared_ptr, // dV + std::shared_ptr, // amax_dQ + std::shared_ptr, // amax_dK + std::shared_ptr, // amax_dV + std::shared_ptr > // amax_dP + key_tensors_tuple = std::make_tuple( + q, k, v, o, stats, dO, attn_scale, + descale_q, descale_k, descale_v, + descale_o, descale_dO, descale_s, descale_dP, + scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, + dQ, dK, dV, + amax_dQ, amax_dK, amax_dV, amax_dP); + auto bias_tuple = is_bias ? + std::make_tuple(bias, dBias) : std::make_tuple(nullptr, nullptr); + auto padding_tuple = is_padding ? + std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); + auto dropout_tuple = is_dropout ? + std::make_tuple(dropout_seed, dropout_offset) : std::make_tuple(nullptr, nullptr); + + NVTE_CHECK_CUDNN_FE(mha_graph->validate()); + NVTE_CHECK_CUDNN_FE(mha_graph->build_operation_graph(handle)); + NVTE_CHECK_CUDNN_FE(mha_graph->create_execution_plans({fe::HeurMode_t::A})); + NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle)); + NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); + + auto return_tuple = std::tuple_cat( + std::make_tuple(mha_graph), key_tensors_tuple, + bias_tuple, padding_tuple, dropout_tuple); + cache.insert({descriptor, return_tuple}); + + return return_tuple; + }; + + auto [mha_graph, q, k, v, o, stats, dO, attn_scale, + descale_q, descale_k, descale_v, + descale_o, descale_dO, descale_s, descale_dP, + scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, + dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP, + bias, dBias, seq_q, seq_kv, dropout_seed, dropout_offset] = get_graph( + sdpa_fp8_bprop_cache, descriptor); + + auto plan_workspace_size = mha_graph->get_workspace_size(); + + // Exit to request upper level API to allocate memory if needed + size_t actual_seqlen_workspace_size = 2 * b * sizeof(int32_t); + if (workspace == nullptr) { + *workspace_size = plan_workspace_size + actual_seqlen_workspace_size; + return; + } + + // cuDNN stream check needs to be moved here to support dummy kernel calls with + // null streams for sizing the cuDNN workspace. + NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream)); + + // build variant pack + std::unordered_map, void*> variant_pack = { + {q, devPtrQ}, + {k, devPtrK}, + {v, devPtrV}, + {o, devPtrO}, + {stats, devPtrM}, + {dO, devPtrdO}, + {attn_scale, &scaling_factor}, + {descale_q, devPtrDescaleQ}, + {descale_k, devPtrDescaleK}, + {descale_v, devPtrDescaleV}, + {descale_o, devPtrDescaleO}, + {descale_dO, devPtrDescaledO}, + {descale_s, devPtrDescaleS}, + {descale_dP, devPtrDescaledP}, + {scale_s, devPtrScaleS}, + {scale_dQ, devPtrScaledQ}, + {scale_dK, devPtrScaledK}, + {scale_dV, devPtrScaledV}, + {scale_dP, devPtrScaledP}, + {dQ, devPtrdQ}, + {dK, devPtrdK}, + {dV, devPtrdV}, + {amax_dQ, devPtrAmaxdQ}, + {amax_dK, devPtrAmaxdK}, + {amax_dV, devPtrAmaxdV}, + {amax_dP, devPtrAmaxdP}, + }; + + // if (is_bias) { + // variant_pack[bias] = devPtrBias; + // if ((bias_b == 1) && (bias_h == h)) { + // variant_pack[dBias] = devPtrdBias; + // } else { + // variant_pack[dBias] = nullptr; + // } + // } + + // if (is_padding) { + // constexpr size_t nthreads_per_block = 128; + // const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block; + // void *devActualSeqlenQ = static_cast(workspace) + plan_workspace_size; + // void *devActualSeqlenKV = static_cast(devActualSeqlenQ) + // + b * sizeof(int32_t); + // cu_seqlens_to_actual_seqlens<<>>( + // b, static_cast(devPtrCuSeqlensQ), + // static_cast(devPtrCuSeqlensKV), + // static_cast(devActualSeqlenQ), + // static_cast(devActualSeqlenKV)); + // variant_pack[seq_q] = devActualSeqlenQ; + // variant_pack[seq_kv] = devActualSeqlenKV; + // } + + // if (is_dropout) { + // variant_pack[dropout_seed] = devPtrDropoutSeed; + // variant_pack[dropout_offset] = devPtrDropoutOffset; + // } + + NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace)); + } catch (cudnn_frontend::cudnnException &e) { + NVTE_ERROR(e.what()); + } +} + #endif } // namespace fused_attn @@ -1853,9 +2555,10 @@ void fused_attn_fp8_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in #if (CUDNN_VERSION >= 8900) // fused attention FWD FP8 with packed QKV void fused_attn_fp8_fwd_qkvpacked( - size_t b, size_t h, size_t max_seqlen, size_t d, + size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, const Tensor *input_QKV, Tensor *input_output_S, Tensor *output_O, @@ -1866,11 +2569,18 @@ void fused_attn_fp8_fwd_qkvpacked( cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; - // QKV shape is [total_seqs, 3, h, d] + const DType QKV_type = input_QKV->data.dtype; void* devPtrQKV = input_QKV->data.dptr; - void* devPtrQ = reinterpret_cast(devPtrQKV); - void* devPtrK = reinterpret_cast(reinterpret_cast(devPtrQKV) + h * d); - void* devPtrV = reinterpret_cast(reinterpret_cast(devPtrQKV) + 2 * h * d); + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + size_t stride = 0; + if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { + stride = typeToSize(QKV_type) * num_attn_heads * head_dim; + } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { + stride = typeToSize(QKV_type) * head_dim; + } + void *devPtrQ = static_cast(devPtrQKV); + void *devPtrK = static_cast(static_cast(devPtrQKV) + stride); + void *devPtrV = static_cast(static_cast(devPtrQKV) + 2 * stride); void* devPtrDescaleQ = input_QKV->scale_inv.dptr; void* devPtrDescaleK = input_QKV->scale_inv.dptr; void* devPtrDescaleV = input_QKV->scale_inv.dptr; @@ -1882,21 +2592,19 @@ void fused_attn_fp8_fwd_qkvpacked( void* devPtrM = nullptr; void* devPtrZInv = nullptr; if (Aux_CTX_Tensors->size == 0) { - if (is_training) { - Aux_CTX_Tensors->size = 3; - Tensor *output_M = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); - Tensor *output_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); - Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); - output_M->data.dptr = nullptr; - output_M->data.shape = {b, h, max_seqlen, 1}; - output_M->data.dtype = DType::kFloat32; - output_ZInv->data.dptr = nullptr; - output_ZInv->data.shape = {b, h, max_seqlen, 1}; - output_ZInv->data.dtype = DType::kFloat32; - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; - } + Aux_CTX_Tensors->size = 3; + Tensor *output_M = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *output_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + output_M->data.dptr = nullptr; + output_M->data.shape = {batch, num_attn_heads, max_seqlen, 1}; + output_M->data.dtype = DType::kFloat32; + output_ZInv->data.dptr = nullptr; + output_ZInv->data.shape = {batch, num_attn_heads, max_seqlen, 1}; + output_ZInv->data.dtype = DType::kFloat32; + output_rng_state->data.dptr = nullptr; + output_rng_state->data.shape = {2}; + output_rng_state->data.dtype = DType::kInt64; } else if (Aux_CTX_Tensors->size == 3) { Tensor *output_M = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); Tensor *output_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); @@ -1919,11 +2627,27 @@ void fused_attn_fp8_fwd_qkvpacked( void* devPtrDropoutOffset = reinterpret_cast( reinterpret_cast(rng_state->data.dptr) + 1); - const DType QKV_type = input_QKV->data.dtype; size_t workspace_size = 0; + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) + || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { + fused_attn::fused_attn_fp8_fwd_impl_v1( + batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, + is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, + devPtrQ, devPtrK, devPtrV, + devPtrM, devPtrZInv, + devPtrO, + devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, + devPtrDescaleS, devPtrScaleS, devPtrScaleO, + devPtrAmaxO, devPtrAmaxS, + devPtrcuSeqlens, devPtrcuSeqlens, + devPtrDropoutSeed, devPtrDropoutOffset, + get_cudnn_fe_dtype(QKV_type), + workspace->data.dptr, &workspace_size, stream, handle); + } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { fused_attn::fused_attn_fp8_fwd_impl( - b, h, max_seqlen, max_seqlen, d, + batch, num_attn_heads, max_seqlen, max_seqlen, head_dim, is_training, attn_scale, p_dropout, qkv_layout, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, @@ -1935,6 +2659,9 @@ void fused_attn_fp8_fwd_qkvpacked( devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); + } else { + NVTE_ERROR("FP8 fused attention only supports qkv_layout=t3hd or qkv_format=bshd/sbhd. \n"); + } if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1950,8 +2677,9 @@ void fused_attn_fp8_fwd_qkvpacked( } // fused attention BWD FP8 with packed QKV void fused_attn_fp8_bwd_qkvpacked( - size_t b, size_t h, size_t max_seqlen, size_t d, + size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO, @@ -1966,11 +2694,19 @@ void fused_attn_fp8_bwd_qkvpacked( cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; - // QKV shape is [total_seqs, 3, h, d] + const DType QKV_type = input_QKV->data.dtype; + const DType dQKV_type = output_dQKV->data.dtype; void* devPtrQKV = input_QKV->data.dptr; - void* devPtrQ = reinterpret_cast(devPtrQKV); - void* devPtrK = reinterpret_cast(reinterpret_cast(devPtrQKV) + h * d); - void* devPtrV = reinterpret_cast(reinterpret_cast(devPtrQKV) + 2 * h * d); + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + size_t stride = 0; + if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { + stride = typeToSize(QKV_type) * num_attn_heads * head_dim; + } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { + stride = typeToSize(QKV_type) * head_dim; + } + void *devPtrQ = devPtrQKV; + void *devPtrK = static_cast(static_cast(devPtrQKV) + stride); + void *devPtrV = static_cast(static_cast(devPtrQKV) + 2 * stride); void* devPtrDescaleQ = input_QKV->scale_inv.dptr; void* devPtrDescaleK = input_QKV->scale_inv.dptr; void* devPtrDescaleV = input_QKV->scale_inv.dptr; @@ -1985,15 +2721,14 @@ void fused_attn_fp8_bwd_qkvpacked( void* devPtrScaleS = input_S->scale.dptr; void* devPtrDescaleS = input_S->scale_inv.dptr; - void* devPtrAmaxdS = input_output_dP->amax.dptr; - void* devPtrScaledS = input_output_dP->scale.dptr; - void* devPtrDescaledS = input_output_dP->scale_inv.dptr; - - // dQKV shape is [total_seqs, 3, h, d] - void* devPtrdQKV = output_dQKV->data.dptr; - void* devPtrdQ = reinterpret_cast(devPtrdQKV); - void* devPtrdK = reinterpret_cast(reinterpret_cast(devPtrdQKV) + h * d); - void* devPtrdV = reinterpret_cast(reinterpret_cast(devPtrdQKV) + 2 * h * d); + void* devPtrAmaxdP = input_output_dP->amax.dptr; + void* devPtrScaledP = input_output_dP->scale.dptr; + void* devPtrDescaledP = input_output_dP->scale_inv.dptr; + + void *devPtrdQKV = output_dQKV->data.dptr; + void *devPtrdQ = devPtrdQKV; + void *devPtrdK = static_cast(static_cast(devPtrdQKV) + stride); + void *devPtrdV = static_cast(static_cast(devPtrdQKV) + 2 * stride); void* devPtrAmaxdQ = output_dQKV->amax.dptr; void* devPtrAmaxdK = output_dQKV->amax.dptr; void* devPtrAmaxdV = output_dQKV->amax.dptr; @@ -2008,11 +2743,33 @@ void fused_attn_fp8_bwd_qkvpacked( void* devPtrDropoutOffset = reinterpret_cast( reinterpret_cast(rng_state->data.dptr) + 1); - const DType QKV_type = input_QKV->data.dtype; size_t workspace_size = 0; + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) + || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { + fused_attn::fused_attn_fp8_bwd_impl_v1( + batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, + attn_scale, p_dropout, qkv_layout, bias_type, mask_type, + devPtrQ, devPtrK, devPtrV, + devPtrM, devPtrZInv, + devPtrO, devPtrdO, + devPtrdQ, devPtrdK, devPtrdV, + devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, + devPtrDescaleO, devPtrDescaledO, + devPtrDescaleS, devPtrDescaledP, + devPtrScaleS, devPtrScaledP, + devPtrScaledQ, devPtrScaledK, devPtrScaledV, + devPtrAmaxdP, + devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, + devPtrcuSeqlens, devPtrcuSeqlens, + devPtrDropoutSeed, devPtrDropoutOffset, + get_cudnn_fe_dtype(QKV_type), + get_cudnn_fe_dtype(dQKV_type), + workspace->data.dptr, &workspace_size, stream, handle); + } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { fused_attn::fused_attn_fp8_bwd_impl( - b, h, max_seqlen, max_seqlen, d, + batch, num_attn_heads, max_seqlen, max_seqlen, head_dim, attn_scale, p_dropout, qkv_layout, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, @@ -2020,15 +2777,278 @@ void fused_attn_fp8_bwd_qkvpacked( devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, - devPtrDescaleS, devPtrDescaledS, - devPtrScaleS, devPtrScaledS, + devPtrDescaleS, devPtrDescaledP, + devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, - devPtrAmaxdS, + devPtrAmaxdP, devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlens, devPtrcuSeqlens, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); + } else { + NVTE_ERROR("FP8 fused attention only supports qkv_layout=t3hd or qkv_format=bshd/sbhd. \n"); + } + + if (workspace_size > 0) { + if (workspace->data.dptr == nullptr) { + workspace->data.shape = { workspace_size }; + workspace->data.dtype = DType::kByte; + return; + } + } else if (workspace_size == 0) { + workspace->data.shape = { 1 }; + workspace->data.dtype = DType::kByte; + return; + } +} +// fused attention FWD FP8 with packed KV +void fused_attn_fp8_fwd_kvpacked( + size_t batch, size_t num_attn_heads, size_t num_gqa_groups, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, + bool is_training, float attn_scale, + float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + const Tensor *input_Q, + const Tensor *input_KV, + Tensor *input_output_S, + Tensor *output_O, + NVTETensorPack* Aux_CTX_Tensors, + const Tensor *cu_seqlens_q, + const Tensor *cu_seqlens_kv, + const Tensor *rng_state, + Tensor *workspace, + cudaStream_t stream, + cudnnHandle_t handle) { + using namespace transformer_engine; + const DType QKV_type = input_Q->data.dtype; + void* devPtrQ = input_Q->data.dptr; + void *devPtrKV = input_KV->data.dptr; + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + size_t stride = 0; + if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { + stride = typeToSize(QKV_type) * num_gqa_groups * head_dim; + } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { + stride = typeToSize(QKV_type) * head_dim; + } + void *devPtrK = devPtrKV; + void *devPtrV = static_cast(static_cast(devPtrKV) + stride); + void* devPtrDescaleQ = input_Q->scale_inv.dptr; + void* devPtrDescaleK = input_KV->scale_inv.dptr; + void* devPtrDescaleV = input_KV->scale_inv.dptr; + + void* devPtrO = output_O->data.dptr; + void* devPtrAmaxO = output_O->amax.dptr; + void* devPtrScaleO = output_O->scale.dptr; + + void* devPtrM = nullptr; + void* devPtrZInv = nullptr; + if (Aux_CTX_Tensors->size == 0) { + Aux_CTX_Tensors->size = 3; + Tensor *output_M = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *output_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + output_M->data.dptr = nullptr; + output_M->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + output_M->data.dtype = DType::kFloat32; + output_ZInv->data.dptr = nullptr; + output_ZInv->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + output_ZInv->data.dtype = DType::kFloat32; + output_rng_state->data.dptr = nullptr; + output_rng_state->data.shape = {2}; + output_rng_state->data.dtype = DType::kInt64; + } else if (Aux_CTX_Tensors->size == 3) { + Tensor *output_M = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *output_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + devPtrM = output_M->data.dptr; + devPtrZInv = output_ZInv->data.dptr; + output_rng_state->data.dptr = rng_state->data.dptr; + } else { + NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); + } + + void* devPtrAmaxS = input_output_S->amax.dptr; + void* devPtrScaleS = input_output_S->scale.dptr; + void* devPtrDescaleS = input_output_S->scale_inv.dptr; + + void* devPtrcuSeqlensQ = reinterpret_cast( + reinterpret_cast(cu_seqlens_q->data.dptr)); + void* devPtrcuSeqlensKV = reinterpret_cast( + reinterpret_cast(cu_seqlens_kv->data.dptr)); + void* devPtrDropoutSeed = reinterpret_cast( + reinterpret_cast(rng_state->data.dptr)); + void* devPtrDropoutOffset = reinterpret_cast( + reinterpret_cast(rng_state->data.dptr) + 1); + + size_t workspace_size = 0; + + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) + || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { + fused_attn::fused_attn_fp8_fwd_impl_v1( + batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, + is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, + devPtrQ, devPtrK, devPtrV, + devPtrM, devPtrZInv, + devPtrO, + devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, + devPtrDescaleS, devPtrScaleS, devPtrScaleO, + devPtrAmaxO, devPtrAmaxS, + devPtrcuSeqlensQ, devPtrcuSeqlensKV, + devPtrDropoutSeed, devPtrDropoutOffset, + get_cudnn_fe_dtype(QKV_type), + workspace->data.dptr, &workspace_size, stream, handle); + } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { + fused_attn::fused_attn_fp8_fwd_impl( + batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, + is_training, attn_scale, p_dropout, qkv_layout, + devPtrQ, devPtrK, devPtrV, + devPtrM, devPtrZInv, + devPtrO, + devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, + devPtrDescaleS, devPtrScaleS, devPtrScaleO, + devPtrAmaxO, devPtrAmaxS, + devPtrcuSeqlensQ, devPtrcuSeqlensKV, + devPtrDropoutSeed, devPtrDropoutOffset, + get_cudnn_dtype(QKV_type), + workspace->data.dptr, &workspace_size, stream, handle); + } else { + NVTE_ERROR("FP8 fused attention only supports qkv_layout=t3hd or qkv_format=bshd/sbhd. \n"); + } + + if (workspace_size > 0) { + if (workspace->data.dptr == nullptr) { + workspace->data.shape = { workspace_size }; + workspace->data.dtype = DType::kByte; + return; + } + } else if (workspace_size == 0) { + workspace->data.shape = { 1 }; + workspace->data.dtype = DType::kByte; + return; + } +} +// fused attention BWD FP8 with packed KV +void fused_attn_fp8_bwd_kvpacked( + size_t batch, size_t num_attn_heads, size_t num_gqa_groups, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, + float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + const Tensor *input_Q, + const Tensor *input_KV, + const Tensor *input_O, + const Tensor *input_dO, + const Tensor *input_M, + const Tensor *input_ZInv, + const Tensor *input_S, + Tensor *input_output_dP, + const Tensor *output_dQ, + const Tensor *output_dKV, + const Tensor *cu_seqlens_q, + const Tensor *cu_seqlens_kv, + const Tensor *rng_state, + Tensor *workspace, + cudaStream_t stream, + cudnnHandle_t handle) { + using namespace transformer_engine; + const DType QKV_type = input_Q->data.dtype; + const DType dQKV_type = output_dQ->data.dtype; + void *devPtrQ = input_Q->data.dptr; + void *devPtrKV = input_KV->data.dptr; + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + size_t stride = 0; + if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { + stride = typeToSize(QKV_type) * num_gqa_groups * head_dim; + } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { + stride = typeToSize(QKV_type) * head_dim; + } + void *devPtrK = devPtrKV; + void *devPtrV = static_cast(static_cast(devPtrKV) + stride); + void* devPtrDescaleQ = input_Q->scale_inv.dptr; + void* devPtrDescaleK = input_KV->scale_inv.dptr; + void* devPtrDescaleV = input_KV->scale_inv.dptr; + + void* devPtrO = input_O->data.dptr; + void* devPtrDescaleO = input_O->scale_inv.dptr; + void* devPtrdO = input_dO->data.dptr; + void* devPtrDescaledO = input_dO->scale_inv.dptr; + + void* devPtrM = input_M->data.dptr; + void* devPtrZInv = input_ZInv->data.dptr; + + void* devPtrScaleS = input_S->scale.dptr; + void* devPtrDescaleS = input_S->scale_inv.dptr; + void* devPtrAmaxdP = input_output_dP->amax.dptr; + void* devPtrScaledP = input_output_dP->scale.dptr; + void* devPtrDescaledP = input_output_dP->scale_inv.dptr; + + void *devPtrdQ = output_dQ->data.dptr; + void *devPtrdKV = output_dKV->data.dptr; + void *devPtrdK = devPtrdKV; + void *devPtrdV = static_cast(static_cast(devPtrdKV) + stride); + void* devPtrAmaxdQ = output_dQ->amax.dptr; + void* devPtrAmaxdK = output_dKV->amax.dptr; + void* devPtrAmaxdV = output_dKV->amax.dptr; + void* devPtrScaledQ = output_dQ->scale.dptr; + void* devPtrScaledK = output_dKV->scale.dptr; + void* devPtrScaledV = output_dKV->scale.dptr; + + void* devPtrcuSeqlensQ = reinterpret_cast( + reinterpret_cast(cu_seqlens_q->data.dptr)); + void* devPtrcuSeqlensKV = reinterpret_cast( + reinterpret_cast(cu_seqlens_kv->data.dptr)); + void* devPtrDropoutSeed = reinterpret_cast( + reinterpret_cast(rng_state->data.dptr)); + void* devPtrDropoutOffset = reinterpret_cast( + reinterpret_cast(rng_state->data.dptr) + 1); + + size_t workspace_size = 0; + + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) + || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { + fused_attn::fused_attn_fp8_bwd_impl_v1( + batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, + attn_scale, p_dropout, qkv_layout, bias_type, mask_type, + devPtrQ, devPtrK, devPtrV, + devPtrM, devPtrZInv, + devPtrO, devPtrdO, + devPtrdQ, devPtrdK, devPtrdV, + devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, + devPtrDescaleO, devPtrDescaledO, + devPtrDescaleS, devPtrDescaledP, + devPtrScaleS, devPtrScaledP, + devPtrScaledQ, devPtrScaledK, devPtrScaledV, + devPtrAmaxdP, + devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, + devPtrcuSeqlensQ, devPtrcuSeqlensKV, + devPtrDropoutSeed, devPtrDropoutOffset, + get_cudnn_fe_dtype(QKV_type), + get_cudnn_fe_dtype(dQKV_type), + workspace->data.dptr, &workspace_size, stream, handle); + } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { + fused_attn::fused_attn_fp8_bwd_impl( + batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, + attn_scale, p_dropout, qkv_layout, + devPtrQ, devPtrK, devPtrV, + devPtrM, devPtrZInv, + devPtrO, devPtrdO, + devPtrdQ, devPtrdK, devPtrdV, + devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, + devPtrDescaleO, devPtrDescaledO, + devPtrDescaleS, devPtrDescaledP, + devPtrScaleS, devPtrScaledP, + devPtrScaledQ, devPtrScaledK, devPtrScaledV, + devPtrAmaxdP, + devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, + devPtrcuSeqlensQ, devPtrcuSeqlensKV, + devPtrDropoutSeed, devPtrDropoutOffset, + get_cudnn_dtype(QKV_type), + workspace->data.dptr, &workspace_size, stream, handle); + } else { + NVTE_ERROR("FP8 fused attention only supports qkv_layout=t3hd or qkv_format=bshd/sbhd. \n"); + } if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -2044,9 +3064,11 @@ void fused_attn_fp8_bwd_qkvpacked( } // fused attention FWD FP8 with separate Q, K, V void fused_attn_fp8_fwd( - size_t b, size_t h, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d, + size_t batch, size_t num_attn_heads, size_t num_gqa_groups, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, @@ -2074,21 +3096,19 @@ void fused_attn_fp8_fwd( void* devPtrM = nullptr; void* devPtrZInv = nullptr; if (Aux_CTX_Tensors->size == 0) { - if (is_training) { - Aux_CTX_Tensors->size = 3; - Tensor *output_M = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); - Tensor *output_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); - Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); - output_M->data.dptr = nullptr; - output_M->data.shape = {b, h, max_seqlen_q, 1}; - output_M->data.dtype = DType::kFloat32; - output_ZInv->data.dptr = nullptr; - output_ZInv->data.shape = {b, h, max_seqlen_q, 1}; - output_ZInv->data.dtype = DType::kFloat32; - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; - } + Aux_CTX_Tensors->size = 3; + Tensor *output_M = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); + Tensor *output_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); + Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); + output_M->data.dptr = nullptr; + output_M->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + output_M->data.dtype = DType::kFloat32; + output_ZInv->data.dptr = nullptr; + output_ZInv->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + output_ZInv->data.dtype = DType::kFloat32; + output_rng_state->data.dptr = nullptr; + output_rng_state->data.shape = {2}; + output_rng_state->data.dtype = DType::kInt64; } else if (Aux_CTX_Tensors->size == 3) { Tensor *output_M = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); Tensor *output_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); @@ -2116,8 +3136,25 @@ void fused_attn_fp8_fwd( const DType QKV_type = input_Q->data.dtype; size_t workspace_size = 0; + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) + || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { + fused_attn::fused_attn_fp8_fwd_impl_v1( + batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, + is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, + devPtrQ, devPtrK, devPtrV, + devPtrM, devPtrZInv, + devPtrO, + devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, + devPtrDescaleS, devPtrScaleS, devPtrScaleO, + devPtrAmaxO, devPtrAmaxS, + devPtrcuSeqlensQ, devPtrcuSeqlensKV, + devPtrDropoutSeed, devPtrDropoutOffset, + get_cudnn_fe_dtype(QKV_type), + workspace->data.dptr, &workspace_size, stream, handle); + } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { fused_attn::fused_attn_fp8_fwd_impl( - b, h, max_seqlen_q, max_seqlen_kv, d, + batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, is_training, attn_scale, p_dropout, qkv_layout, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, @@ -2129,6 +3166,9 @@ void fused_attn_fp8_fwd( devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); + } else { + NVTE_ERROR("FP8 fused attention only supports qkv_layout=t3hd or qkv_format=bshd/sbhd. \n"); + } if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -2144,8 +3184,10 @@ void fused_attn_fp8_fwd( } // fused attention BWD FP8 with separate Q, K, V void fused_attn_fp8_bwd( - size_t b, size_t h, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d, + size_t batch, size_t num_attn_heads, size_t num_gqa_groups, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, @@ -2182,9 +3224,9 @@ void fused_attn_fp8_bwd( void* devPtrScaleS = input_S->scale.dptr; void* devPtrDescaleS = input_S->scale_inv.dptr; - void* devPtrAmaxdS = input_output_dP->amax.dptr; - void* devPtrScaledS = input_output_dP->scale.dptr; - void* devPtrDescaledS = input_output_dP->scale_inv.dptr; + void* devPtrAmaxdP = input_output_dP->amax.dptr; + void* devPtrScaledP = input_output_dP->scale.dptr; + void* devPtrDescaledP = input_output_dP->scale_inv.dptr; void* devPtrdQ = output_dQ->data.dptr; void* devPtrdK = output_dK->data.dptr; @@ -2206,10 +3248,34 @@ void fused_attn_fp8_bwd( reinterpret_cast(rng_state->data.dptr) + 1); const DType QKV_type = input_Q->data.dtype; + const DType dQKV_type = output_dQ->data.dtype; size_t workspace_size = 0; + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) + || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { + fused_attn::fused_attn_fp8_bwd_impl_v1( + batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, + attn_scale, p_dropout, qkv_layout, bias_type, mask_type, + devPtrQ, devPtrK, devPtrV, + devPtrM, devPtrZInv, + devPtrO, devPtrdO, + devPtrdQ, devPtrdK, devPtrdV, + devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, + devPtrDescaleO, devPtrDescaledO, + devPtrDescaleS, devPtrDescaledP, + devPtrScaleS, devPtrScaledP, + devPtrScaledQ, devPtrScaledK, devPtrScaledV, + devPtrAmaxdP, + devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, + devPtrcuSeqlensQ, devPtrcuSeqlensKV, + devPtrDropoutSeed, devPtrDropoutOffset, + get_cudnn_fe_dtype(QKV_type), + get_cudnn_fe_dtype(dQKV_type), + workspace->data.dptr, &workspace_size, stream, handle); + } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { fused_attn::fused_attn_fp8_bwd_impl( - b, h, max_seqlen_q, max_seqlen_kv, d, + batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale, p_dropout, qkv_layout, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, @@ -2217,15 +3283,18 @@ void fused_attn_fp8_bwd( devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, - devPtrDescaleS, devPtrDescaledS, - devPtrScaleS, devPtrScaledS, + devPtrDescaleS, devPtrDescaledP, + devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, - devPtrAmaxdS, + devPtrAmaxdP, devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); + } else { + NVTE_ERROR("FP8 fused attention only supports qkv_layout=t3hd or qkv_format=bshd/sbhd. \n"); + } if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.h b/transformer_engine/common/fused_attn/fused_attn_fp8.h index 3373e0cb3b..3b0ea6c2c2 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.h +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.h @@ -14,9 +14,10 @@ namespace transformer_engine { #if (CUDNN_VERSION >= 8900) // fused attention FWD FP8 with packed QKV void fused_attn_fp8_fwd_qkvpacked( - size_t b, size_t h, size_t max_seqlen, size_t d, + size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, const Tensor *input_QKV, Tensor *input_output_S, Tensor *output_O, @@ -29,8 +30,9 @@ void fused_attn_fp8_fwd_qkvpacked( // fused attention BWD FP8 with packed QKV void fused_attn_fp8_bwd_qkvpacked( - size_t b, size_t h, size_t max_seqlen, size_t d, + size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO, @@ -45,11 +47,55 @@ void fused_attn_fp8_bwd_qkvpacked( cudaStream_t stream, cudnnHandle_t handle); +// fused attention FWD FP8 with packed KV +void fused_attn_fp8_fwd_kvpacked( + size_t batch, size_t num_attn_heads, size_t num_gqa_groups, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, + bool is_training, float attn_scale, + float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + const Tensor *input_Q, + const Tensor *input_KV, + Tensor *input_output_S, + Tensor *output_O, + NVTETensorPack* Aux_CTX_Tensors, + const Tensor *cu_seqlens_q, + const Tensor *cu_seqlens_kv, + const Tensor *rng_state, + Tensor *workspace, + cudaStream_t stream, + cudnnHandle_t handle); + +// fused attention BWD FP8 with packed KV +void fused_attn_fp8_bwd_kvpacked( + size_t batch, size_t num_attn_heads, size_t num_gqa_groups, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, + float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + const Tensor *input_Q, + const Tensor *input_KV, + const Tensor *input_O, + const Tensor *input_dO, + const Tensor *input_M, + const Tensor *input_ZInv, + const Tensor *input_S, + Tensor *input_output_dP, + const Tensor *output_dQ, + const Tensor *output_dKV, + const Tensor *cu_seqlens_q, + const Tensor *cu_seqlens_kv, + const Tensor *rng_state, + Tensor *workspace, + cudaStream_t stream, + cudnnHandle_t handle); + // fused attention FWD FP8 with separate Q, K, V void fused_attn_fp8_fwd( - size_t b, size_t h, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d, + size_t batch, size_t num_attn_heads, size_t num_gqa_groups, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, Tensor *input_output_S, Tensor *output_O, @@ -63,8 +109,10 @@ void fused_attn_fp8_fwd( // fused attention BWD FP8 with separate Q, K, V void fused_attn_fp8_bwd( - size_t b, size_t h, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d, + size_t batch, size_t num_attn_heads, size_t num_gqa_groups, + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index 49d056ff1c..11da5cf56c 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -111,19 +111,20 @@ struct FADescriptor_v1 { NVTE_QKV_Layout layout; NVTE_Bias_Type bias_type; NVTE_Mask_Type mask_type; - cudnn_frontend::DataType_t tensor_type; + cudnn_frontend::DataType_t fwd_tensor_type; + cudnn_frontend::DataType_t bwd_tensor_type; bool operator<(const FADescriptor_v1 &rhs) const { return std::tie(b, h, hg, s_q, s_kv, d, bias_b, bias_h, attnScale, isTraining, dropoutProbability, - layout, mask_type, bias_type, tensor_type) + layout, mask_type, bias_type, fwd_tensor_type, bwd_tensor_type) < std::tie( rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.bias_type, - rhs.tensor_type); + rhs.fwd_tensor_type, rhs.bwd_tensor_type); } }; diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 9abbb69cbe..989dd03d62 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -96,7 +96,7 @@ def scaling_factor_compute(amax: Tensor, where `Tensor` is a framework tensor type. override_linear_precision: Tuple(bool, bool, bool), default=(False, False, False) - Whether or not the execute the `fprop`, `dgrad`, and `wgrad` + Whether or not to execute the `fprop`, `dgrad`, and `wgrad` GEMMs (respectively) in higher precision when using FP8. reduce_amax: bool, default = `True` By default, if `torch.distributed` is initialized, the `amax` value for FP8 @@ -106,6 +106,20 @@ def scaling_factor_compute(amax: Tensor, GPU maintains local amaxes and scaling factors. To ensure results are numerically identical across checkpointing boundaries in this case, all ranks must checkpoint in order to store the local tensors. + fp8_dpa: bool, default = `False` + Whether to enable FP8 dot product attention (DPA). When the model is placed in an + `fp8_autocast(enabled=True)` region and `fp8_dpa` is set to `True`, DPA casts the + inputs from higher precision to FP8, performs attention in FP8, and casts tensors + back to higher precision as outputs. FP8 DPA currently is only supported in the + `FusedAttention` backend. + fp8_mha: bool, default = `False` + Whether to enable FP8 multi-head attention (MHA). When `True`, it removes the casting + operations mentioned above at the DPA boundaries. Currently only standard MHA modules + i.e. `LayerNormLinear/Linear + DPA + Linear`, are supported for this feature. When + `fp8_mha = False, fp8_dpa = True`, a typical MHA module works as + `LayerNormLinear (BF16 output) -> (cast to FP8 ) FP8 DPA (cast to BF16) -> Linear`. + When `fp8_mha = True, fp8_dpa = True`, it becomes + `LayerNormLinear (FP8 output) -> FP8 DPA -> Linear`. Notes ----- @@ -116,6 +130,9 @@ def scaling_factor_compute(amax: Tensor, FP8_MAX = maximum_representable_value(fp8_format) new_scaling_factor = (FP8_MAX / amax) / (2 ^ margin) + + * `fp8_dpa` and `fp8_mha` are Beta features, and their API and functionality are + subject to change in future Transformer Engine releases. """ margin: int = 0 @@ -126,6 +143,8 @@ def scaling_factor_compute(amax: Tensor, override_linear_precision: _OverrideLinearPrecision = _OverrideLinearPrecision() scaling_factor_compute_algo: Optional[Callable] = None reduce_amax: bool = True + fp8_dpa: bool = False + fp8_mha: bool = False def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index f57b58d736..90da9e06b6 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -19,6 +19,10 @@ import torch.nn.functional as F import transformer_engine_extensions as tex +from transformer_engine.pytorch.cpp_extensions import ( + cast_to_fp8, + cast_from_fp8, +) from transformer_engine.pytorch.cpp_extensions.fused_attn import ( fused_attn_fwd_qkvpacked, fused_attn_bwd_qkvpacked, @@ -31,7 +35,10 @@ AttnMaskType, FusedAttnBackend, ) +from transformer_engine.pytorch.fp8 import get_fp8_te_dtype +from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.module import LayerNormLinear, Linear +from transformer_engine.pytorch.module.base import TransformerEngineBaseModule from transformer_engine.pytorch.utils import ( divide, attention_mask_func, @@ -74,6 +81,12 @@ from flash_attn.flash_attn_interface import _flash_attn_varlen_forward as _flash_attn_forward # pylint: disable=no-name-in-module,ungrouped-imports from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward # pylint: disable=no-name-in-module +META_QKV = tex.FP8FwdTensors.GEMM1_OUTPUT +META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1 +META_O = tex.FP8FwdTensors.GEMM2_INPUT +META_DO = tex.FP8BwdTensors.GRAD_INPUT2 +META_S = tex.FP8FwdTensors.GEMM3_OUTPUT +META_DP = tex.FP8BwdTensors.GRAD_INPUT3 _NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) _alibi_cache = { @@ -810,7 +823,7 @@ def backward(ctx, dout): dq_, dk_, dv_, _ = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_k, cu_seqlens_q, cu_seqlens_k, - q_, kv_[0], kv_[1], out_, dout_, TE_DType[q.dtype], + q_, kv_[0], kv_[1], out_, dout_, TE_DType[q.dtype], TE_DType[kv.dtype], [softmax_lse, ctx.rng_states[cp_size-i-1]], tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, attn_scale=ctx.softmax_scale, @@ -850,7 +863,7 @@ def backward(ctx, dout): dq_, dk_, dv_, _ = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_k//2, cu_seqlens_q, cu_seqlens_k//2, - q_, kv_[0], kv_[1], out_, dout_, TE_DType[q.dtype], + q_, kv_[0], kv_[1], out_, dout_, TE_DType[q.dtype], TE_DType[kv.dtype], [softmax_lse, ctx.rng_states[cp_size-i-1]], tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, attn_scale=ctx.softmax_scale, @@ -890,7 +903,7 @@ def backward(ctx, dout): dq_, dk_, dv_, _ = fused_attn_bwd( ctx.max_seqlen_q//2, ctx.max_seqlen_k, cu_seqlens_q//2, cu_seqlens_k, - q_, kv_[0], kv_[1], out_, dout_, TE_DType[q.dtype], + q_, kv_[0], kv_[1], out_, dout_, TE_DType[q.dtype], TE_DType[kv.dtype], [softmax_lse_, ctx.rng_states[cp_size-i-1]], tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, attn_scale=ctx.softmax_scale, @@ -923,7 +936,7 @@ def backward(ctx, dout): dq_, dk_, dv_, _ = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_k, cu_seqlens_q, cu_seqlens_k, - q, kv[0], kv[1], out, dout, TE_DType[q.dtype], + q, kv[0], kv[1], out, dout, TE_DType[q.dtype], TE_DType[kv.dtype], [softmax_lse, ctx.rng_states[cp_size-i-1]], tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, attn_scale=ctx.softmax_scale, @@ -1246,6 +1259,14 @@ def forward(ctx, ) -> Tuple[torch.Tensor, ...]: ctx.split_dim = split_dim ctx.split_size_or_sections = split_size_or_sections + if isinstance(mixed_x_layer, Float8Tensor): + return tuple(Float8Tensor.make_like( + mixed_x_layer, + data=x, + ) for x in torch.split( + mixed_x_layer._data, + split_size_or_sections=split_size_or_sections, + dim=split_dim)) return torch.split(mixed_x_layer, split_size_or_sections, dim = split_dim) @staticmethod @@ -1262,6 +1283,37 @@ def backward(ctx, dims = len(grad_outputs[0].shape) split_dim = (ctx.split_dim + dims) % dims + if isinstance(grad_outputs[0], Float8Tensor): + noop_ok = True + strides = grad_outputs[0].stride() + data_ptr = grad_outputs[0]._data.untyped_storage().data_ptr() + shape = list(grad_outputs[0].shape) + for i, tensor in enumerate(grad_outputs): + shape_i = shape + shape_i[split_dim] = split_sizes[i] + offset_size = sum(split_sizes[:i]) * np.prod(shape[split_dim+1:]) + if (tensor.stride() != strides or + list(tensor.shape) != shape_i or + tensor._data.untyped_storage().data_ptr() != data_ptr or + tensor.storage_offset() != offset_size): + noop_ok = False + break + if noop_ok: + ret = torch.Tensor().to(device=grad_outputs[0].device, + dtype=grad_outputs[0]._data.dtype) + new_shape = list(shape) + new_shape[split_dim] = sum(split_sizes) + ret.set_(grad_outputs[0]._data.untyped_storage(), + grad_outputs[0]._data.storage_offset(), + new_shape, + strides + ) + return Float8Tensor.make_like(grad_outputs[0], data=ret), None, None + + grad_outputs_data = [x._data for x in grad_outputs] + return Float8Tensor.make_like( + grad_outputs[0], + data=torch.cat(grad_outputs_data, dim = split_dim)), None, None noop_ok = True strides = grad_outputs[0].stride() data_ptr = grad_outputs[0].untyped_storage().data_ptr() @@ -1276,7 +1328,6 @@ def backward(ctx, tensor.storage_offset() != offset_size): noop_ok = False break - if noop_ok: ret = torch.Tensor().to(device=grad_outputs[0].device, dtype=grad_outputs[0].dtype) @@ -1848,6 +1899,35 @@ def forward( return output +def _combine_tensors( + tensors: List[torch.Tensor], + dim: int, + ) -> torch.Tensor: + """Combine tensors along a particular dimension""" + + num_tensors = len(tensors) + new_shape = list(tensors[0].shape) + new_shape.insert(dim, num_tensors) + new_stride = list(tensors[0].stride()) + new_stride.insert(dim, int(new_stride[dim-1]/num_tensors)) + if isinstance(tensors[0], Float8Tensor): + combined_tensor = torch.Tensor().to( + device=tensors[0].device, dtype=tensors[0]._data.dtype) + combined_tensor.set_( + tensors[0]._data.untyped_storage(), + tensors[0]._data.storage_offset(), + new_shape, new_stride) + combined_tensor = Float8Tensor.make_like( + tensors[0], data=combined_tensor) + else: + combined_tensor = torch.Tensor().to( + device=tensors[0].device, dtype=tensors[0].dtype) + combined_tensor.set_( + tensors[0].untyped_storage(), + tensors[0].storage_offset(), + new_shape, new_stride) + + return combined_tensor class FusedAttnFunc_qkvpacked(torch.autograd.Function): """Function for FusedAttention with packed QKV input""" @@ -1855,15 +1935,83 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): @staticmethod def forward(ctx, is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, - rng_gen, fused_attention_backend, use_FAv2_bwd): - out, aux_ctx_tensors = fused_attn_fwd_qkvpacked( - is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype, - fused_attention_backend, attn_bias, - None, None, None, None, None, - attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, - rng_gen) - - ctx.save_for_backward(qkv, out, cu_seqlens) + rng_gen, fused_attention_backend, use_FAv2_bwd, + fp8, fp8_meta, tp_size, tp_group): + if fp8: + if _NVTE_DEBUG: + print('[DotProductAttention]: using FP8 forward') + if fp8_meta["recipe"].fp8_mha: + assert (isinstance(qkv, Float8Tensor)), "qkv must be Float8Tensors for FP8 MHA." + fp8_meta["scaling_fwd"].scale_inv[META_QKV] = qkv._scale_inv + fused_attention_backend = FusedAttnBackend["FP8"] + fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + # 1: qkv packed, 2: kv packed, 3: qkv separate + qkv_group = len(qkv_layout.split('_')) + assert (qkv_group == 1 + ), f"qkv layout should conform to 3hd or h3d, e.g. sb3hd, \ + but found {qkv_layout}." + if fp8_meta["recipe"].fp8_mha: + qkv_fp8 = qkv._data + else: + qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) + qkv_fp8 = cast_to_fp8(qkv_c, + fp8_meta["scaling_fwd"], + META_QKV, fp8_dtype_forward).view(qkv.shape) + out_fp8, aux_ctx_tensors = fused_attn_fwd_qkvpacked( + is_training, max_seqlen, cu_seqlens, + qkv_fp8, fp8_dtype_forward, fused_attention_backend, attn_bias, + fp8_meta["scaling_fwd"].scale_inv[META_QKV], + fp8_meta["scaling_fwd"].scale_inv[META_S], + fp8_meta["scaling_fwd"].scale[META_S], + fp8_meta["scaling_fwd"].scale[META_O], + fp8_meta["scaling_fwd"].amax_history[0][META_S], + fp8_meta["scaling_fwd"].amax_history[0][META_O], + attn_scale, dropout_p, fast_zero_fill, qkv_layout, + attn_bias_type, attn_mask_type, rng_gen) + if fp8_meta["recipe"].fp8_mha: + out_ret = Float8Tensor(data=out_fp8, + fp8_meta=fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=META_O, + fp8_dtype=fp8_dtype_forward, + dtype=qkv.dtype, + ) + else: + out_ret = cast_from_fp8( + out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), + fp8_meta["scaling_fwd"], META_O, + fp8_dtype_forward, qkv_dtype).view(out_fp8.shape) + out_save = out_ret + if fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) + qkv = cast_from_fp8(qkv_c._data, + fp8_meta["scaling_fwd"], + META_QKV, fp8_dtype_forward, TE_DType[qkv.dtype]).view(qkv.shape) + out_save = cast_from_fp8( + out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), + fp8_meta["scaling_fwd"], META_O, + fp8_dtype_forward, qkv_dtype).view(out_fp8.shape) + fp8_tensors = (qkv_fp8, out_fp8, + fp8_meta["scaling_fwd"].scale.clone(), + fp8_meta["scaling_fwd"].scale_inv.clone()) + else: + if _NVTE_DEBUG: + print('[DotProductAttention]: using non-FP8 forward') + out_ret, aux_ctx_tensors = fused_attn_fwd_qkvpacked( + is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype, + fused_attention_backend, attn_bias, + None, None, None, None, None, None, + attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, + rng_gen) + fp8_tensors = (None, None, None, None) + out_save = out_ret + + ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + qkvo_tensors = (qkv, out_save) if not ctx.fp8 else (None, None) + ctx.save_for_backward(*qkvo_tensors, cu_seqlens, *fp8_tensors) + ctx.fp8_meta = fp8_meta + ctx.tp_size = tp_size + ctx.tp_group = tp_group ctx.aux_ctx_tensors = aux_ctx_tensors ctx.max_seqlen = max_seqlen ctx.qkv_dtype = qkv_dtype @@ -1873,15 +2021,23 @@ def forward(ctx, is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype, attn_bias, ctx.qkv_layout = qkv_layout ctx.attn_bias_type = attn_bias_type ctx.attn_mask_type = attn_mask_type - ctx.fused_attention_backend = fused_attention_backend + ctx.fused_attention_backend = \ + fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"] ctx.use_FAv2_bwd = use_FAv2_bwd - return out + return out_ret @staticmethod def backward(ctx, d_out): + if ctx.fp8_meta["recipe"].fp8_mha: + assert (isinstance(d_out, Float8Tensor) + ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA." + d_out_f8tensor = d_out + d_out = d_out._data + d_out = d_out.contiguous() - qkv, out, cu_seqlens = ctx.saved_tensors + (qkv, out, cu_seqlens, + qkv_fp8, out_fp8, fwd_scales, fwd_scale_invs) = ctx.saved_tensors if not ctx.aux_ctx_tensors[0].is_contiguous(): ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous() if ctx.use_FAv2_bwd: @@ -1898,13 +2054,65 @@ def backward(ctx, d_out): ) dqkv = dqkv[..., :d_out.shape[-1]] else: - dqkv, *rest = fused_attn_bwd_qkvpacked( - ctx.max_seqlen, cu_seqlens, qkv, out, d_out, - ctx.qkv_dtype, ctx.aux_ctx_tensors, - ctx.fused_attention_backend, - None, None, None, None, None, None, None, None, None, - ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, - ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type) + with torch.cuda.nvtx.range("_FusedAttn_qkvpacked"): + if ctx.fp8: + if _NVTE_DEBUG: + print('[DotProductAttention]: using FP8 backward') + fp8_dtype_forward = get_fp8_te_dtype( + ctx.fp8_meta["recipe"], fprop_tensor=True) + fp8_dtype_backward = get_fp8_te_dtype( + ctx.fp8_meta["recipe"], fprop_tensor=False) + if ctx.fp8_meta["recipe"].fp8_mha: + d_out_fp8 = d_out + ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO] = d_out_f8tensor._scale_inv + else: + d_out_fp8 = cast_to_fp8( + d_out.view(-1, d_out.shape[-2] * d_out.shape[-1]), + ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward + ).view(d_out.shape) + dqkv_fp8, *rest = fused_attn_bwd_qkvpacked( + ctx.max_seqlen, cu_seqlens, + qkv_fp8, out_fp8, d_out_fp8, + fp8_dtype_forward, fp8_dtype_backward, ctx.aux_ctx_tensors, + ctx.fused_attention_backend, + fwd_scale_invs[META_QKV], # d_scale_qkv, + fwd_scale_invs[META_S], # d_scale_s, + fwd_scale_invs[META_O], # d_scale_o, + ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO], # d_scale_do + ctx.fp8_meta['scaling_bwd'].scale_inv[META_DP], # d_scale_dp + fwd_scales[META_S], # q_scale_s + ctx.fp8_meta['scaling_bwd'].scale[META_DP], # q_scale_dp + ctx.fp8_meta['scaling_bwd'].scale[META_DQKV], # q_scale_dqkv + ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DP], # amax_dp + ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DQKV], # amax_dqkv + ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, + ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type) + if ctx.fp8_meta["recipe"].fp8_mha: + dqkv = Float8Tensor(data=dqkv_fp8, + fp8_meta=ctx.fp8_meta, + fp8_meta_forward=False, + fp8_meta_index=META_DQKV, + fp8_dtype=fp8_dtype_backward, + dtype=d_out_f8tensor.dtype, + ) + else: + dqkv_c_fp8 = dqkv_fp8.view(-1, + dqkv_fp8.shape[-3] * dqkv_fp8.shape[-2] * dqkv_fp8.shape[-1]) + dqkv = cast_from_fp8(dqkv_c_fp8, + ctx.fp8_meta["scaling_bwd"], META_DQKV, + fp8_dtype_backward, ctx.qkv_dtype).view(dqkv_fp8.shape) + else: + if _NVTE_DEBUG: + print('[DotProductAttention]: using non-FP8 backward') + if d_out.dtype == torch.uint8: + d_out = d_out_f8tensor.from_float8(qkv.dtype) + dqkv, *rest = fused_attn_bwd_qkvpacked( + ctx.max_seqlen, cu_seqlens, qkv, out, d_out, + ctx.qkv_dtype, ctx.qkv_dtype, ctx.aux_ctx_tensors, + ctx.fused_attention_backend, + None, None, None, None, None, None, None, None, None, None, + ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, + ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type) # if no_bias or alibi, return dqkv if ctx.attn_bias_type in ["no_bias", "alibi"]: @@ -1923,16 +2131,90 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): @staticmethod def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, q, kv, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill, - qkv_layout, attn_bias_type, attn_mask_type, - rng_gen, fused_attention_backend, use_FAv2_bwd): - out, aux_ctx_tensors = fused_attn_fwd_kvpacked( - is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, - q, kv, qkv_dtype, fused_attention_backend, attn_bias, - None, None, None, None, None, - attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, - rng_gen) - - ctx.save_for_backward(q, kv, out, cu_seqlens_q, cu_seqlens_kv) + qkv_layout, attn_bias_type, attn_mask_type, rng_gen, fused_attention_backend, + use_FAv2_bwd, fp8, fp8_meta, tp_size, tp_group): + if fp8: + if _NVTE_DEBUG: + print('[DotProductAttention]: using FP8 forward') + if fp8_meta["recipe"].fp8_mha: + assert (isinstance(q, Float8Tensor) + and isinstance(kv, Float8Tensor)), "q/kv must be Float8Tensors for FP8 MHA." + fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv + fused_attention_backend = FusedAttnBackend["FP8"] + fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + if fp8_meta["recipe"].fp8_mha: + q_fp8, kv_fp8 = q._data, kv._data + else: + # 1: qkv packed, 2: kv packed, 3: qkv separate + qkv_group = len(qkv_layout.split('_')) + assert (qkv_group == 2 + ), f"qkv layout should conform to hd_2hd or hd_h2d, e.g. sbhd_sb2hd, \ + but found {qkv_layout}." + q_fp8 = cast_to_fp8(q, + fp8_meta["scaling_fwd"], + META_QKV, fp8_dtype_forward).view(q.shape) + kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) + kv_fp8 = cast_to_fp8(kv_c, + fp8_meta["scaling_fwd"], + META_QKV, fp8_dtype_forward).view(kv.shape) + out_fp8, aux_ctx_tensors = fused_attn_fwd_kvpacked( + is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, + q_fp8, kv_fp8, fp8_dtype_forward, fused_attention_backend, attn_bias, + fp8_meta["scaling_fwd"].scale_inv[META_QKV], + fp8_meta["scaling_fwd"].scale_inv[META_S], + fp8_meta["scaling_fwd"].scale[META_S], + fp8_meta["scaling_fwd"].scale[META_O], + fp8_meta["scaling_fwd"].amax_history[0][META_S], + fp8_meta["scaling_fwd"].amax_history[0][META_O], + attn_scale, dropout_p, fast_zero_fill, qkv_layout, + attn_bias_type, attn_mask_type, rng_gen) + if fp8_meta["recipe"].fp8_mha: + out_ret = Float8Tensor(data=out_fp8, + fp8_meta=fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=META_O, + fp8_dtype=fp8_dtype_forward, + dtype=q.dtype, + ) + else: + out_ret = cast_from_fp8( + out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), + fp8_meta["scaling_fwd"], META_O, + fp8_dtype_forward, qkv_dtype).view(out_fp8.shape) + out_save = out_ret + if fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + q = cast_from_fp8(q._data, + fp8_meta["scaling_fwd"], + META_QKV, fp8_dtype_forward, TE_DType[q.dtype]).view(q.shape) + kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) + kv = cast_from_fp8(kv_c._data, + fp8_meta["scaling_fwd"], + META_QKV, fp8_dtype_forward, TE_DType[kv.dtype]).view(kv.shape) + out_save = cast_from_fp8( + out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), + fp8_meta["scaling_fwd"], META_O, + fp8_dtype_forward, qkv_dtype).view(out_fp8.shape) + fp8_tensors = (q_fp8, kv_fp8, out_fp8, + fp8_meta["scaling_fwd"].scale.clone(), + fp8_meta["scaling_fwd"].scale_inv.clone()) + else: + if _NVTE_DEBUG: + print('[DotProductAttention]: using non-FP8 forward') + out_ret, aux_ctx_tensors = fused_attn_fwd_kvpacked( + is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, + q, kv, qkv_dtype, fused_attention_backend, attn_bias, + None, None, None, None, None, None, + attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, + rng_gen) + out_save = out_ret + fp8_tensors = (None, None, None, None, None) + + ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + qkvo_tensors = (q, kv, out_save) if not ctx.fp8 else (None, None, None) + ctx.save_for_backward(*qkvo_tensors, cu_seqlens_q, cu_seqlens_kv, *fp8_tensors) + ctx.fp8_meta = fp8_meta + ctx.tp_size = tp_size + ctx.tp_group = tp_group ctx.aux_ctx_tensors = aux_ctx_tensors ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_kv = max_seqlen_kv @@ -1943,15 +2225,23 @@ def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seql ctx.qkv_layout = qkv_layout ctx.attn_bias_type = attn_bias_type ctx.attn_mask_type = attn_mask_type - ctx.fused_attention_backend = fused_attention_backend + ctx.fused_attention_backend = \ + fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"] ctx.use_FAv2_bwd = use_FAv2_bwd - return out + return out_ret @staticmethod def backward(ctx, d_out): + if ctx.fp8_meta["recipe"].fp8_mha: + assert (isinstance(d_out, Float8Tensor) + ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA." + d_out_f8tensor = d_out + d_out = d_out._data + d_out = d_out.contiguous() - q, kv, out, cu_seqlens_q, cu_seqlens_kv = ctx.saved_tensors + (q, kv, out, cu_seqlens_q, cu_seqlens_kv, + q_fp8, kv_fp8, out_fp8, fwd_scales, fwd_scale_invs) = ctx.saved_tensors if not ctx.aux_ctx_tensors[0].is_contiguous(): ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous() if ctx.use_FAv2_bwd: @@ -1970,14 +2260,77 @@ def backward(ctx, d_out): dq = dq[..., :d_out.shape[-1]] dkv = dkv[..., :d_out.shape[-1]] else: - dq, dkv, *rest = fused_attn_bwd_kvpacked( - ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, - q, kv, out, d_out, - ctx.qkv_dtype, ctx.aux_ctx_tensors, - ctx.fused_attention_backend, - None, None, None, None, None, None, None, None, None, - ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, - ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type) + with torch.cuda.nvtx.range("_FusedAttn_kvpacked"): + if ctx.fp8: + if _NVTE_DEBUG: + print('[DotProductAttention]: using FP8 backward') + fp8_dtype_forward = get_fp8_te_dtype( + ctx.fp8_meta["recipe"], fprop_tensor=True) + fp8_dtype_backward = get_fp8_te_dtype( + ctx.fp8_meta["recipe"], fprop_tensor=False) + if ctx.fp8_meta["recipe"].fp8_mha: + d_out_fp8 = d_out + ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO] = d_out_f8tensor._scale_inv + else: + d_out_fp8 = cast_to_fp8( + d_out.view(-1, d_out.shape[-2] * d_out.shape[-1]), + ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward + ).view(d_out.shape) + dq_fp8, dkv_fp8, *rest = fused_attn_bwd_kvpacked( + ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, + q_fp8, kv_fp8, out_fp8, d_out_fp8, + fp8_dtype_forward, fp8_dtype_backward, ctx.aux_ctx_tensors, + ctx.fused_attention_backend, + fwd_scale_invs[META_QKV], # d_scale_qkv, + fwd_scale_invs[META_S], # d_scale_s, + fwd_scale_invs[META_O], # d_scale_o, + ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO], # d_scale_do + ctx.fp8_meta['scaling_bwd'].scale_inv[META_DP], # d_scale_dp + fwd_scales[META_S], # q_scale_s + ctx.fp8_meta['scaling_bwd'].scale[META_DP], # q_scale_dp + ctx.fp8_meta['scaling_bwd'].scale[META_DQKV], # q_scale_dqkv + ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DP], # amax_dp + ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DQKV], # amax_dqkv + ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, + ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type) + if ctx.fp8_meta["recipe"].fp8_mha: + dq = Float8Tensor(data=dq_fp8, + fp8_meta=ctx.fp8_meta, + fp8_meta_forward=False, + fp8_meta_index=META_DQKV, + fp8_dtype=fp8_dtype_backward, + dtype=d_out_f8tensor.dtype, + ) + dkv = Float8Tensor(data=dkv_fp8, + fp8_meta=ctx.fp8_meta, + fp8_meta_forward=False, + fp8_meta_index=META_DQKV, + fp8_dtype=fp8_dtype_backward, + dtype=d_out_f8tensor.dtype, + ) + else: + dq = cast_from_fp8( + dq_fp8.view(-1, dq_fp8.shape[-2] * dq_fp8.shape[-1]), + ctx.fp8_meta["scaling_bwd"], META_DQKV, + fp8_dtype_backward, ctx.qkv_dtype).view(dq_fp8.shape) + dkv_c_fp8 = dkv_fp8.view(-1, + dkv_fp8.shape[-3] * dkv_fp8.shape[-2] * dkv_fp8.shape[-1]) + dkv = cast_from_fp8(dkv_c_fp8, + ctx.fp8_meta["scaling_bwd"], META_DQKV, + fp8_dtype_backward, ctx.qkv_dtype).view(dkv_fp8.shape) + else: + if _NVTE_DEBUG: + print('[DotProductAttention]: using non-FP8 backward') + if d_out.dtype == torch.uint8: + d_out = d_out_f8tensor.from_float8(q.dtype) + dq, dkv, *rest = fused_attn_bwd_kvpacked( + ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, + q, kv, out, d_out, + ctx.qkv_dtype, ctx.qkv_dtype, ctx.aux_ctx_tensors, + ctx.fused_attention_backend, + None, None, None, None, None, None, None, None, None, None, + ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, + ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type) # if no_bias or alibi, return dqkv if ctx.attn_bias_type in ["no_bias", "alibi"]: @@ -1989,32 +2342,153 @@ def backward(ctx, d_out): None, None, None, None, None, None, None, None, None, None, None, None) - class FusedAttnFunc(torch.autograd.Function): """Function for FusedAttention with separate Q, K, V tensors""" @staticmethod def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, q, k, v, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill, - qkv_layout, attn_bias_type, attn_mask_type, - rng_gen, fused_attention_backend, use_FAv2_bwd): - out, aux_ctx_tensors = fused_attn_fwd( - is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, - q, k, v, qkv_dtype, fused_attention_backend, attn_bias, - None, None, None, None, None, - attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, - rng_gen) + qkv_layout, attn_bias_type, attn_mask_type, rng_gen, fused_attention_backend, + use_FAv2_bwd, fp8, fp8_meta, tp_size, tp_group): + if fp8: + if _NVTE_DEBUG: + print('[DotProductAttention]: using FP8 forward') + fused_attention_backend = FusedAttnBackend["FP8"] + fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + if fp8_meta["recipe"].fp8_mha: + assert (isinstance(q, Float8Tensor) + and isinstance(k, Float8Tensor) + and isinstance(v, Float8Tensor)), "q/k/v must be Float8Tensors for FP8 MHA." + fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv + q_fp8, k_fp8, v_fp8 = q._data, k._data, v._data + else: + # 1: qkv packed, 2: kv packed, 3: qkv separate + qkv_group = len(qkv_layout.split('_')) + if qkv_group == 1: + dim = qkv_layout.find('3') + qkv = _combine_tensors([q,k,v], dim) + qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) + qkv_fp8 = cast_to_fp8(qkv_c, + fp8_meta["scaling_fwd"], + META_QKV, fp8_dtype_forward).view(qkv.shape) + q_fp8, k_fp8, v_fp8 = _SplitAlongDim.apply(qkv_fp8, dim, [1,1,1]) + q_fp8, k_fp8, v_fp8 = [x.squeeze(dim) for x in [q_fp8, k_fp8, v_fp8]] + if qkv_group == 2: + q_fp8 = cast_to_fp8(q, + fp8_meta["scaling_fwd"], + META_QKV, fp8_dtype_forward).view(q.shape) + dim = qkv_layout.split('_')[1].find('2') + kv = _combine_tensors([k,v], dim) + kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) + kv_fp8 = cast_to_fp8(kv_c, + fp8_meta["scaling_fwd"], + META_QKV, fp8_dtype_forward).view(kv.shape) + k_fp8, v_fp8 = _SplitAlongDim.apply(kv_fp8, dim, [1,1]) + k_fp8, v_fp8 = [x.squeeze(dim) for x in [k_fp8, v_fp8]] + if qkv_group == 3: + q_fp8 = cast_to_fp8(q, + fp8_meta["scaling_fwd"], + META_QKV, fp8_dtype_forward).view(q.shape) + k_fp8 = cast_to_fp8(k, + fp8_meta["scaling_fwd"], + META_QKV, fp8_dtype_forward).view(k.shape) + v_fp8 = cast_to_fp8(v, + fp8_meta["scaling_fwd"], + META_QKV, fp8_dtype_forward).view(v.shape) + out_fp8, aux_ctx_tensors = fused_attn_fwd( + is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, + q_fp8, k_fp8, v_fp8, fp8_dtype_forward, fused_attention_backend, attn_bias, + fp8_meta["scaling_fwd"].scale_inv[META_QKV], + fp8_meta["scaling_fwd"].scale_inv[META_S], + fp8_meta["scaling_fwd"].scale[META_S], + fp8_meta["scaling_fwd"].scale[META_O], + fp8_meta["scaling_fwd"].amax_history[0][META_S], + fp8_meta["scaling_fwd"].amax_history[0][META_O], + attn_scale, dropout_p, fast_zero_fill, qkv_layout, + attn_bias_type, attn_mask_type, rng_gen) + if fp8_meta["recipe"].fp8_mha: + out_ret = Float8Tensor(data=out_fp8, + fp8_meta=fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=META_O, + fp8_dtype=fp8_dtype_forward, + dtype=q.dtype, + ) + else: + out_ret = cast_from_fp8( + out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), + fp8_meta["scaling_fwd"], META_O, + fp8_dtype_forward, qkv_dtype).view(out_fp8.shape) + out_save = out_ret + + if fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + # 1: qkv packed, 2: kv packed, 3: qkv separate + qkv_group = len(qkv_layout.split('_')) + if qkv_group == 1: + dim = qkv_layout.find('3') + qkv = _combine_tensors([q,k,v], dim) + qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) + qkv_no_fp8 = cast_from_fp8(qkv_c._data, + fp8_meta["scaling_fwd"], + META_QKV, fp8_dtype_forward, TE_DType[qkv.dtype]).view(qkv.shape) + q, k, v = _SplitAlongDim.apply(qkv_no_fp8, dim, [1,1,1]) + q, k, v = [x.squeeze(dim) for x in [q, k, v]] + if qkv_group == 2: + q = cast_from_fp8(q._data, + fp8_meta["scaling_fwd"], + META_QKV, fp8_dtype_forward, TE_DType[q.dtype]).view(q.shape) + dim = qkv_layout.split('_')[1].find('2') + kv = _combine_tensors([k,v], dim) + kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) + kv_no_fp8 = cast_from_fp8(kv_c._data, + fp8_meta["scaling_fwd"], + META_QKV, fp8_dtype_forward, TE_DType[kv.dtype]).view(kv.shape) + k, v = _SplitAlongDim.apply(kv_no_fp8, dim, [1,1]) + k, v = [x.squeeze(dim) for x in [k, v]] + if qkv_group == 3: + q = cast_from_fp8(q._data, + fp8_meta["scaling_fwd"], + META_QKV, fp8_dtype_forward, TE_DType[q.dtype]).view(q.shape) + k = cast_from_fp8(k._data, + fp8_meta["scaling_fwd"], + META_QKV, fp8_dtype_forward, TE_DType[k.dtype]).view(k.shape) + v = cast_from_fp8(v._data, + fp8_meta["scaling_fwd"], + META_QKV, fp8_dtype_forward, TE_DType[v.dtype]).view(v.shape) + out_save = cast_from_fp8( + out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), + fp8_meta["scaling_fwd"], META_O, + fp8_dtype_forward, qkv_dtype).view(out_fp8.shape) + + fp8_tensors = (q_fp8, k_fp8, v_fp8, out_fp8, + fp8_meta["scaling_fwd"].scale.clone(), + fp8_meta["scaling_fwd"].scale_inv.clone()) + else: + if _NVTE_DEBUG: + print('[DotProductAttention]: using non-FP8 forward') + out_ret, aux_ctx_tensors = fused_attn_fwd( + is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, + q, k, v, qkv_dtype, fused_attention_backend, attn_bias, + None, None, None, None, None, None, + attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, + rng_gen) + out_save = out_ret + fp8_tensors = (None, None, None, None, None, None) from .cpu_offload import CPUOffloadEnabled if CPUOffloadEnabled: - tensor_list = [q, k, v, out, cu_seqlens_q, cu_seqlens_kv] + tensor_list = [q, k, v, out_save, cu_seqlens_q, cu_seqlens_kv] qkv_layout = 'sbhd_sbhd_sbhd' for tensor in tensor_list: if tensor is not None: tensor.activation_offloading = True - - ctx.save_for_backward(q, k, v, out, cu_seqlens_q, cu_seqlens_kv) + ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None) + ctx.save_for_backward(*qkvo_tensors, cu_seqlens_q, cu_seqlens_kv, *fp8_tensors) + ctx.fp8_meta = fp8_meta + ctx.tp_size = tp_size + ctx.tp_group = tp_group ctx.aux_ctx_tensors = aux_ctx_tensors ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_kv = max_seqlen_kv @@ -2025,15 +2499,23 @@ def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seql ctx.qkv_layout = qkv_layout ctx.attn_bias_type = attn_bias_type ctx.attn_mask_type = attn_mask_type - ctx.fused_attention_backend = fused_attention_backend + ctx.fused_attention_backend = \ + fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"] ctx.use_FAv2_bwd = use_FAv2_bwd - return out + return out_ret @staticmethod def backward(ctx, d_out): + if ctx.fp8_meta["recipe"].fp8_mha: + assert (isinstance(d_out, Float8Tensor) + ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA." + d_out_f8tensor = d_out + d_out = d_out._data + d_out = d_out.contiguous() - q, k, v, out, cu_seqlens_q, cu_seqlens_kv = ctx.saved_tensors + (q, k, v, out, cu_seqlens_q, cu_seqlens_kv, + q_fp8, k_fp8, v_fp8, out_fp8, fwd_scales, fwd_scale_invs) = ctx.saved_tensors if not ctx.aux_ctx_tensors[0].is_contiguous(): ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous() if ctx.use_FAv2_bwd: @@ -2054,14 +2536,112 @@ def backward(ctx, d_out): dk = dk[..., :d_out.shape[-1]] dv = dv[..., :d_out.shape[-1]] else: - dq, dk, dv, *rest = fused_attn_bwd( - ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, - q, k, v, out, d_out, - ctx.qkv_dtype, ctx.aux_ctx_tensors, - ctx.fused_attention_backend, - None, None, None, None, None, None, None, None, None, - ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, - ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type) + with torch.cuda.nvtx.range("_FusedAttn"): + if ctx.fp8: + if _NVTE_DEBUG: + print('[DotProductAttention]: using FP8 backward') + fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) + fp8_dtype_backward = get_fp8_te_dtype( + ctx.fp8_meta["recipe"], fprop_tensor=False) + if ctx.fp8_meta["recipe"].fp8_mha: + d_out_fp8 = d_out + ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO] = d_out_f8tensor._scale_inv + else: + d_out_fp8 = cast_to_fp8( + d_out.view(-1, d_out.shape[-2] * d_out.shape[-1]), + ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward + ).view(d_out.shape) + dq_fp8, dk_fp8, dv_fp8, *rest = fused_attn_bwd( + ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, + q_fp8, k_fp8, v_fp8, out_fp8, d_out_fp8, + fp8_dtype_forward, fp8_dtype_backward, ctx.aux_ctx_tensors, + ctx.fused_attention_backend, + fwd_scale_invs[META_QKV], # d_scale_qkv, + fwd_scale_invs[META_S], # d_scale_s, + fwd_scale_invs[META_O], # d_scale_o, + ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO], # d_scale_do + ctx.fp8_meta['scaling_bwd'].scale_inv[META_DP], # d_scale_dp + fwd_scales[META_S], # q_scale_s + ctx.fp8_meta['scaling_bwd'].scale[META_DP], # q_scale_dp + ctx.fp8_meta['scaling_bwd'].scale[META_DQKV], # q_scale_dqkv + ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DP], # amax_dp + ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DQKV], # amax_dqkv + ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, + ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type) + if ctx.fp8_meta["recipe"].fp8_mha: + dq = Float8Tensor(data=dq_fp8, + fp8_meta=ctx.fp8_meta, + fp8_meta_forward=False, + fp8_meta_index=META_DQKV, + fp8_dtype=fp8_dtype_backward, + dtype=d_out_f8tensor.dtype, + ) + dk = Float8Tensor(data=dk_fp8, + fp8_meta=ctx.fp8_meta, + fp8_meta_forward=False, + fp8_meta_index=META_DQKV, + fp8_dtype=fp8_dtype_backward, + dtype=d_out_f8tensor.dtype, + ) + dv = Float8Tensor(data=dv_fp8, + fp8_meta=ctx.fp8_meta, + fp8_meta_forward=False, + fp8_meta_index=META_DQKV, + fp8_dtype=fp8_dtype_backward, + dtype=d_out_f8tensor.dtype, + ) + else: + qkv_group = len(ctx.qkv_layout.split('_')) + if qkv_group == 1: + dim = ctx.qkv_layout.find('3') + dqkv_fp8 = _combine_tensors([dq_fp8,dk_fp8,dv_fp8], dim) + dqkv_c_fp8 = dqkv_fp8.view(-1, + dqkv_fp8.shape[-3] * dqkv_fp8.shape[-2] * dqkv_fp8.shape[-1]) + dqkv = cast_from_fp8(dqkv_c_fp8, + ctx.fp8_meta["scaling_bwd"], META_DQKV, + fp8_dtype_backward, ctx.qkv_dtype).view(dqkv_fp8.shape) + dq, dk, dv = _SplitAlongDim.apply(dqkv, dim, [1,1,1]) + dq, dk, dv = [x.squeeze(dim) for x in [dq, dk, dv]] + if qkv_group == 2: + dq = cast_from_fp8( + dq_fp8.view(-1, dq_fp8.shape[-2] * dq_fp8.shape[-1]), + ctx.fp8_meta["scaling_bwd"], META_DQKV, + fp8_dtype_backward, ctx.qkv_dtype).view(dq_fp8.shape) + dim = ctx.qkv_layout.split('_')[1].find('2') + dkv_fp8 = _combine_tensors([dk_fp8,dv_fp8], dim) + dkv_c_fp8 = dkv_fp8.view(-1, + dkv_fp8.shape[-3] * dkv_fp8.shape[-2] * dkv_fp8.shape[-1]) + dkv = cast_from_fp8(dkv_c_fp8, + ctx.fp8_meta["scaling_bwd"], META_DQKV, + fp8_dtype_backward, ctx.qkv_dtype).view(dkv_fp8.shape) + dk, dv = _SplitAlongDim.apply(dkv, dim, [1,1]) + dk, dv = [x.squeeze(dim) for x in [dk, dv]] + if qkv_group == 3: + dq = cast_from_fp8( + dq_fp8.view(-1, dq_fp8.shape[-2] * dq_fp8.shape[-1]), + ctx.fp8_meta["scaling_bwd"], META_DQKV, + fp8_dtype_backward, ctx.qkv_dtype).view(dq_fp8.shape) + dk = cast_from_fp8( + dk_fp8.view(-1, dk_fp8.shape[-2] * dk_fp8.shape[-1]), + ctx.fp8_meta["scaling_bwd"], META_DQKV, + fp8_dtype_backward, ctx.qkv_dtype).view(dk_fp8.shape) + dv = cast_from_fp8( + dv_fp8.view(-1, dv_fp8.shape[-2] * dv_fp8.shape[-1]), + ctx.fp8_meta["scaling_bwd"], META_DQKV, + fp8_dtype_backward, ctx.qkv_dtype).view(dv_fp8.shape) + else: + if _NVTE_DEBUG: + print('[DotProductAttention]: using non-FP8 backward') + if d_out.dtype == torch.uint8: + d_out = d_out_f8tensor.from_float8(q.dtype) + dq, dk, dv, *rest = fused_attn_bwd( + ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, + q, k, v, out, d_out, + ctx.qkv_dtype, ctx.qkv_dtype, ctx.aux_ctx_tensors, + ctx.fused_attention_backend, + None, None, None, None, None, None, None, None, None, None, + ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, + ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type) # if no_bias or alibi, return dqkv if ctx.attn_bias_type in ["no_bias", "alibi"]: @@ -2074,7 +2654,7 @@ def backward(ctx, d_out): None, None, None, None, None, None) -class FusedAttention(torch.nn.Module): +class FusedAttention(TransformerEngineBaseModule): """Dot product attention, with multiple backends: 1. FusedAttnBackend["F16_max512_seqlen"] @@ -2110,6 +2690,8 @@ def __init__( attention_type: str = "self", layer_number: Optional[int] = None, deterministic: bool = False, + tp_size: int = 1, + tp_group: Optional[dist_group_type] = None, ) -> None: super().__init__() @@ -2136,6 +2718,15 @@ def __init__( if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "1": os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1" + self.tp_size = tp_size + self.tp_group = tp_group + + def get_fp8_weights_scratchpad( + self, + is_first_microbatch: Union[bool, None], + ) -> List[Float8Tensor]: + """Needs override.""" + @no_torch_dynamo() def forward( self, @@ -2157,6 +2748,7 @@ def forward( cp_group: Optional[dist_group_type] = None, cp_global_ranks: List[int] = None, cp_stream: torch.cuda.Stream = None, + is_first_microbatch: Optional[bool] = None, ) -> torch.Tensor: """fused attention fprop""" @@ -2164,9 +2756,9 @@ def forward( != tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend ), 'No fused attention backend supports this input combination!' assert ( - (query_layer.dtype in [torch.float16, torch.bfloat16]) - and (key_layer.dtype in [torch.float16, torch.bfloat16]) - and (value_layer.dtype in [torch.float16, torch.bfloat16]) + (query_layer.dtype in [torch.float16, torch.bfloat16, torch.uint8]) + and (key_layer.dtype in [torch.float16, torch.bfloat16, torch.uint8]) + and (value_layer.dtype in [torch.float16, torch.bfloat16, torch.uint8]) ), 'FusedAttention only supports FP16 and BF16 data types.' assert ( query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda @@ -2248,24 +2840,43 @@ def forward( if qkv_format == 'sbhd': output = output.transpose(0,1).contiguous() else: - with self.attention_dropout_ctx(): - output = FusedAttnFunc.apply( - self.training, - max_seqlen_q, max_seqlen_kv, - cu_seqlens_q, cu_seqlens_kv, - query_layer, key_layer, value_layer, - qkv_dtype, - core_attention_bias, - 1.0/self.norm_factor, - self.attention_dropout if self.training else 0.0, - fast_zero_fill, - qkv_layout, - core_attention_bias_type, - attn_mask_type, - None, # rng_gen - fused_attention_backend, - use_FAv2_bwd, - ) + with self.prepare_forward(query_layer, + is_first_microbatch, + num_gemms=3, + allow_non_contiguous=True) as query_layer: + with self.attention_dropout_ctx(): + forced_fp8_dpa = "" + if self.fp8_meta["recipe"].fp8_mha: + if not self.fp8_meta["recipe"].fp8_dpa: + self.fp8_meta["recipe"].fp8_dpa = True + forced_fp8_dpa = " (forced)" + if _NVTE_DEBUG: + print("[DotProductAttention]: " + f"""using fp8_recipe.fp8_mha={self.fp8_meta["recipe"].fp8_mha}, """ + f"""fp8_recipe.fp8_dpa={self.fp8_meta["recipe"].fp8_dpa}""" + f"""{forced_fp8_dpa} and """ + f"""NVTE_FP8_DPA_BWD={int(os.getenv("NVTE_FP8_DPA_BWD", "1"))}""") + output = FusedAttnFunc.apply( + self.training, + max_seqlen_q, max_seqlen_kv, + cu_seqlens_q, cu_seqlens_kv, + query_layer, key_layer, value_layer, + qkv_dtype, + core_attention_bias, + 1.0/self.norm_factor, + self.attention_dropout if self.training else 0.0, + fast_zero_fill, + qkv_layout, + core_attention_bias_type, + attn_mask_type, + None, # rng_gen + fused_attention_backend, + use_FAv2_bwd, + self.fp8 and self.fp8_meta["recipe"].fp8_dpa, + self.fp8_meta, + self.tp_size, + self.tp_group, + ) # ...hd -> ...(hd) return output.view(*output.shape[:-2], -1) @@ -2463,7 +3074,9 @@ def __init__( attention_type=attention_type, layer_number=layer_number, deterministic=self.deterministic, - **attn_kwargs) + **attn_kwargs, + tp_size=self.tp_size, + tp_group=self.tp_group) self.unfused_attention = UnfusedDotProductAttention( norm_factor, **attn_kwargs, layer_number=layer_number) @@ -2532,6 +3145,7 @@ def forward( alibi_slopes: Optional[torch.Tensor] = None, fast_zero_fill: bool = True, inference_params: Optional[InferenceParams] = None, + is_first_microbatch: Optional[bool] = None, ) -> torch.Tensor: """ Dot Product Attention Layer. @@ -2635,6 +3249,19 @@ def forward( Adjustments of the sequence_len_offset should be done after a complete forward pass. If rotary positional embeddings (RoPE) are utilized, they must be prepared beforehand. Supports "sbhd" and "bshd" layouts, with the "sbhd" layout being more efficient. + is_first_microbatch : {True, False, None}, default = None + During training using either gradient accumulation or + pipeline parallelism a minibatch of data is further split + into microbatches. Between the microbatches of the same minibatch + the model weights are not updated. Setting this parameter indicates + whether the current microbatch is the first in a minibatch or not. + When set, this parameter enables additional optimizations: + + * during FP8 training, it allows caching of the FP8 versions of + the weights + * it also allows skipping gradient accumulation during the + first microbatch (since it is the first gradient being + produced) """ assert ( @@ -2746,8 +3373,14 @@ def forward( ), """Sequence lengths indicated by cu_seqlens_kv must be no greater than the sequence dimention in 'key_layer' and 'value_layer'!""" - qkv_layout, query_layer, key_layer, value_layer = _get_qkv_layout( - query_layer, key_layer, value_layer, qkv_format = qkv_format) + if (isinstance(query_layer, Float8Tensor) + and isinstance(key_layer, Float8Tensor) + and isinstance(value_layer, Float8Tensor)): + qkv_layout, query_layer._data, key_layer._data, value_layer._data = _get_qkv_layout( + query_layer._data, key_layer._data, value_layer._data, qkv_format = qkv_format) + else: + qkv_layout, query_layer, key_layer, value_layer = _get_qkv_layout( + query_layer, key_layer, value_layer, qkv_format = qkv_format) # The priority for attention backends (subject to availability and clearing the filters) # is: FlashAttention > FusedAttention (cuDNN) > UnfusedDotProductAttention. @@ -2767,8 +3400,13 @@ def forward( if (query_layer.dtype not in [torch.bfloat16, torch.float16] or key_layer.dtype not in [torch.bfloat16, torch.float16] or value_layer.dtype not in [torch.bfloat16, torch.float16] + or any(isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]) ): use_flash_attention = False + if (query_layer.dtype not in [torch.bfloat16, torch.float16] + or key_layer.dtype not in [torch.bfloat16, torch.float16] + or value_layer.dtype not in [torch.bfloat16, torch.float16] + ): use_fused_attention = False # Filter: Device and dimensions. @@ -2865,8 +3503,10 @@ def forward( if use_fused_attention: fused_attention_backend = tex.get_fused_attn_backend( - TE_DType[query_layer.dtype], - TE_DType[key_layer.dtype], + TE_DType[query_layer.dtype] + if not isinstance(query_layer, Float8Tensor) else query_layer._fp8_dtype, + TE_DType[key_layer.dtype] + if not isinstance(key_layer, Float8Tensor) else key_layer._fp8_dtype, QKVLayout[qkv_layout], AttnBiasType[fu_core_attention_bias_type], AttnMaskType[attn_mask_type], @@ -2879,7 +3519,9 @@ def forward( ) # DPA does not support FP8; for FP8, use cpp_extensions modules directly is_backend_avail = (fused_attention_backend in - [FusedAttnBackend["F16_max512_seqlen"], FusedAttnBackend["F16_arbitrary_seqlen"]]) + [FusedAttnBackend["F16_max512_seqlen"], + FusedAttnBackend["F16_arbitrary_seqlen"], + FusedAttnBackend["FP8"]]) use_fused_attention = ( \ use_fused_attention and is_backend_avail and \ (not context_parallel or \ @@ -2950,6 +3592,8 @@ def forward( qkv_layout=qkv_layout, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, attn_mask_type=attn_mask_type, attention_mask=attention_mask, fused_attention_backend=fused_attention_backend, @@ -2959,8 +3603,7 @@ def forward( cp_group=self.cp_group, cp_global_ranks=self.cp_global_ranks, cp_stream=self.cp_stream, - max_seqlen_q=max_seqlen_q, - max_seqlen_kv=max_seqlen_kv) + is_first_microbatch=is_first_microbatch) return self.fused_attention( query_layer, key_layer, @@ -2968,6 +3611,8 @@ def forward( qkv_layout=qkv_layout, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, attn_mask_type=attn_mask_type, attention_mask=attention_mask, fused_attention_backend=fused_attention_backend, @@ -2977,8 +3622,7 @@ def forward( cp_group=self.cp_group, cp_global_ranks=self.cp_global_ranks, cp_stream=self.cp_stream, - max_seqlen_q=max_seqlen_q, - max_seqlen_kv=max_seqlen_kv) + is_first_microbatch=is_first_microbatch) assert (not context_parallel), \ "Context parallelism is only implemented with Flash Attention and Fused Attention!" @@ -3552,6 +4196,7 @@ def forward( mixed_x_layer = self.qkv( hidden_states, is_first_microbatch=is_first_microbatch, + is_first_module_in_mha=True, # specific to FP8 MHA ) num_queries_per_key_value = (self.num_attention_heads_per_partition // @@ -3603,6 +4248,7 @@ def forward( mixed_kv_layer = self.key_value( encoder_output, is_first_microbatch=is_first_microbatch, + is_first_module_in_mha=True, # specific to FP8 MHA ) if self.qkv_weight_interleaved: @@ -3633,6 +4279,9 @@ def forward( key_layer, value_layer = torch.split( mixed_kv_layer, mixed_kv_layer.shape[split_dim] // 2, dim = split_dim, ) + key_layer, value_layer = (x.reshape( + x.size(0), x.size(1), -1, self.hidden_size_per_attention_head, + ) for x in (key_layer, value_layer)) # Attention head [sq, b, h] --> [sq, b, hp] if self.input_layernorm: @@ -3648,6 +4297,7 @@ def forward( query_layer = self.query_layer( hidden_states, is_first_microbatch=is_first_microbatch, + is_first_module_in_mha=True, # specific to FP8 MHA ) # [sq, b, hp] --> [sq, b, np, hn] @@ -3662,6 +4312,9 @@ def forward( # ====================================================== if rotary_pos_emb is not None: + assert (not isinstance(query_layer, Float8Tensor) + and not isinstance(key_layer, Float8Tensor) + ), "RoPE is not supported for Float8Tensors!" # duplicate the pos_emb for self attention if not isinstance(rotary_pos_emb, tuple): rotary_pos_emb = ((rotary_pos_emb,) * 2) diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 0f9a88454f..574627ac5d 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -84,6 +84,7 @@ def fused_attn_fwd_qkvpacked( fused_attention_backend: tex.NVTE_Fused_Attn_Backend, attn_bias: torch.Tensor = None, d_scale_qkv: torch.Tensor = None, + d_scale_s: torch.Tensor = None, q_scale_s: torch.Tensor = None, q_scale_o: torch.Tensor = None, amax_s: torch.Tensor = None, @@ -119,6 +120,8 @@ def fused_attn_fwd_qkvpacked( shape [1, num_heads, max_seqlen, max_seqlen], same data type as qkv d_scale_qkv: torch.Tensor, default = None input tensor for the dequantization of QKV in FP8 computations + d_scale_s: torch.Tensor, default = None + input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T) q_scale_s: torch.Tensor, default = None input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T) q_scale_o: torch.Tensor, default = None @@ -206,6 +209,8 @@ def fused_attn_fwd_qkvpacked( assert (d_scale_qkv is not None ), "d_scale_qkv is required as an input for FP8 fused attention." + assert (d_scale_s is not None + ), "q_scale_s is required as an input for FP8 fused attention." assert (q_scale_s is not None ), "q_scale_s is required as an input for FP8 fused attention." assert (q_scale_o is not None @@ -220,7 +225,7 @@ def fused_attn_fwd_qkvpacked( max_seqlen, is_training, attn_scale, dropout, fast_zero_fill, QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], cu_seqlens, qkv, qkv_dtype, - d_scale_qkv, q_scale_s, q_scale_o, amax_s, amax_o, attn_bias, + d_scale_qkv, d_scale_s, q_scale_s, q_scale_o, amax_s, amax_o, attn_bias, rng_gen, rng_elts_per_thread, ) @@ -235,12 +240,14 @@ def fused_attn_bwd_qkvpacked( o: torch.Tensor, d_o: torch.Tensor, qkv_dtype: tex.DType, + dqkv_dtype: tex.DType, aux_ctx_tensors: List[torch.Tensor], fused_attention_backend: tex.NVTE_Fused_Attn_Backend, d_scale_qkv: torch.Tensor = None, d_scale_s: torch.Tensor = None, d_scale_o: torch.Tensor = None, d_scale_do: torch.Tensor = None, + d_scale_dp: torch.Tensor = None, q_scale_s: torch.Tensor = None, q_scale_dp: torch.Tensor = None, q_scale_dqkv: torch.Tensor = None, @@ -272,6 +279,8 @@ def fused_attn_bwd_qkvpacked( same shape as Q, i.e. thd, sbhd or bshd (see `qkv_layout` for details) qkv_dtype: tex.DType data type of QKV; in tex.DType, not torch.dtype + dqkv_dtype: tex.DType + data type of dQKV; in tex.DType, not torch.dtype aux_ctx_tensors: List[torch.Tensor] auxiliary output tensors of the forward pass when its is_training is True, e.g. aux_ctx_tensors = [M, ZInv, rng_state] @@ -285,6 +294,8 @@ def fused_attn_bwd_qkvpacked( input tensor for the dequantization of O in FP8 computations d_scale_do: torch.Tensor, default = None input tensor for the dequantization of dO in FP8 computations + d_scale_dp: torch.Tensor, default = None + input tensor for the dequantization of dP in FP8 computations q_scale_s: torch.Tensor, default = None input tensor for the quantization of S in FP8 computations q_scale_dp: torch.Tensor, default = None @@ -336,6 +347,7 @@ def fused_attn_bwd_qkvpacked( assert (d_scale_s is not None), "d_scale_s is required for FP8 fused attention." assert (d_scale_o is not None), "d_scale_o is required for FP8 fused attention." assert (d_scale_do is not None), "d_scale_do is required for FP8 fused attention." + assert (d_scale_dp is not None), "d_scale_dp is required for FP8 fused attention." assert (q_scale_s is not None), "q_scale_s is required for FP8 fused attention." assert (q_scale_dp is not None), "q_scale_dp is required for FP8 fused attention." assert (q_scale_dqkv is not None), "q_scale_dqkv is required for FP8 fused attention." @@ -348,8 +360,8 @@ def fused_attn_bwd_qkvpacked( output_tensors = tex.fused_attn_bwd_qkvpacked( max_seqlen, attn_scale, dropout, fast_zero_fill, QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], - cu_seqlens, qkv, o, d_o, qkv_dtype, aux_ctx_tensors, - d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, + cu_seqlens, qkv, o, d_o, qkv_dtype, dqkv_dtype, aux_ctx_tensors, + d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, d_scale_dp, q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv, ) @@ -368,6 +380,7 @@ def fused_attn_fwd_kvpacked( fused_attention_backend: tex.NVTE_Fused_Attn_Backend, attn_bias: torch.Tensor = None, d_scale_qkv: torch.Tensor = None, + d_scale_s: torch.Tensor = None, q_scale_s: torch.Tensor = None, q_scale_o: torch.Tensor = None, amax_s: torch.Tensor = None, @@ -410,6 +423,8 @@ def fused_attn_fwd_kvpacked( shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q and kv d_scale_qkv: torch.Tensor, default = None input tensor for the dequantization of QKV in FP8 computations + d_scale_s: torch.Tensor, default = None + input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T) q_scale_s: torch.Tensor, default = None input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T) q_scale_o: torch.Tensor, default = None @@ -496,12 +511,25 @@ def fused_attn_fwd_kvpacked( rng_elts_per_thread = (max_seqlen_q * max_seqlen_q + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1)//BACKEND_F16m512_FP8_THREADS_PER_CTA + assert (d_scale_qkv is not None + ), "d_scale_qkv is required as an input for FP8 fused attention." + assert (d_scale_s is not None + ), "q_scale_s is required as an input for FP8 fused attention." + assert (q_scale_s is not None + ), "q_scale_s is required as an input for FP8 fused attention." + assert (q_scale_o is not None + ), "q_scale_o is required as an input for FP8 fused attention." + assert (amax_s is not None + ), "amax_s is required as an input for FP8 fused attention." + assert (amax_o is not None + ), "amax_o is required as an input for FP8 fused attention." + # execute kernel output_tensors = tex.fused_attn_fwd_kvpacked( max_seqlen_q, max_seqlen_kv, is_training, attn_scale, dropout, fast_zero_fill, QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], cu_seqlens_q, cu_seqlens_kv, q, kv, qkv_dtype, - d_scale_qkv, q_scale_s, q_scale_o, amax_s, amax_o, + d_scale_qkv, d_scale_s, q_scale_s, q_scale_o, amax_s, amax_o, attn_bias, rng_gen, rng_elts_per_thread, ) @@ -519,12 +547,14 @@ def fused_attn_bwd_kvpacked( o: torch.Tensor, d_o: torch.Tensor, qkv_dtype: tex.DType, + dqkv_dtype: tex.DType, aux_ctx_tensors: List[torch.Tensor], fused_attention_backend: tex.NVTE_Fused_Attn_Backend, d_scale_qkv: torch.Tensor = None, d_scale_s: torch.Tensor = None, d_scale_o: torch.Tensor = None, d_scale_do: torch.Tensor = None, + d_scale_dp: torch.Tensor = None, q_scale_s: torch.Tensor = None, q_scale_dp: torch.Tensor = None, q_scale_dqkv: torch.Tensor = None, @@ -562,7 +592,9 @@ def fused_attn_bwd_kvpacked( input tensor dO (gradient of O); same shape as Q, i.e. thd, sbhd or bshd (see `qkv_layout` for details) qkv_dtype: tex.DType - data type of QKV; in tex.DType, not torch.dtype + data type of Q and KV; in tex.DType, not torch.dtype + dqkv_dtype: tex.DType + data type of dQ and dKV; in tex.DType, not torch.dtype aux_ctx_tensors: List[torch.Tensor] auxiliary output tensors of the forward pass when its is_training is True, e.g. aux_ctx_tensors = [M, ZInv, rng_state] @@ -576,6 +608,8 @@ def fused_attn_bwd_kvpacked( input tensor for the dequantization of O in FP8 computations d_scale_do: torch.Tensor, default = None input tensor for the dequantization of dO in FP8 computations + d_scale_dp: torch.Tensor, default = None + input tensor for the dequantization of dP in FP8 computations q_scale_s: torch.Tensor, default = None input tensor for the quantization of S in FP8 computations q_scale_dp: torch.Tensor, default = None @@ -631,6 +665,7 @@ def fused_attn_bwd_kvpacked( assert (d_scale_s is not None), "d_scale_s is required for FP8 fused attention." assert (d_scale_o is not None), "d_scale_o is required for FP8 fused attention." assert (d_scale_do is not None), "d_scale_do is required for FP8 fused attention." + assert (d_scale_dp is not None), "d_scale_dp is required for FP8 fused attention." assert (q_scale_s is not None), "q_scale_s is required for FP8 fused attention." assert (q_scale_dp is not None), "q_scale_dp is required for FP8 fused attention." assert (q_scale_dqkv is not None), "q_scale_dqkv is required for FP8 fused attention." @@ -643,8 +678,8 @@ def fused_attn_bwd_kvpacked( output_tensors = tex.fused_attn_bwd_kvpacked( max_seqlen_q, max_seqlen_kv, attn_scale, dropout, fast_zero_fill, QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], - cu_seqlens_q, cu_seqlens_kv, q, kv, o, d_o, qkv_dtype, aux_ctx_tensors, - d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, + cu_seqlens_q, cu_seqlens_kv, q, kv, o, d_o, qkv_dtype, dqkv_dtype, aux_ctx_tensors, + d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, d_scale_dp, q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv, ) @@ -664,6 +699,7 @@ def fused_attn_fwd( fused_attention_backend: tex.NVTE_Fused_Attn_Backend, attn_bias: torch.Tensor = None, d_scale_qkv: torch.Tensor = None, + d_scale_s: torch.Tensor = None, q_scale_s: torch.Tensor = None, q_scale_o: torch.Tensor = None, amax_s: torch.Tensor = None, @@ -710,6 +746,8 @@ def fused_attn_fwd( shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q, k and v d_scale_qkv: torch.Tensor, default = None input tensor for the dequantization of Q, K and V in FP8 computations + d_scale_s: torch.Tensor, default = None + input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T) q_scale_s: torch.Tensor, default = None input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T) q_scale_o: torch.Tensor, default = None @@ -798,12 +836,25 @@ def fused_attn_fwd( rng_elts_per_thread = (max_seqlen_q * max_seqlen_q + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1)//BACKEND_F16m512_FP8_THREADS_PER_CTA + assert (d_scale_qkv is not None + ), "d_scale_qkv is required as an input for FP8 fused attention." + assert (d_scale_s is not None + ), "q_scale_s is required as an input for FP8 fused attention." + assert (q_scale_s is not None + ), "q_scale_s is required as an input for FP8 fused attention." + assert (q_scale_o is not None + ), "q_scale_o is required as an input for FP8 fused attention." + assert (amax_s is not None + ), "amax_s is required as an input for FP8 fused attention." + assert (amax_o is not None + ), "amax_o is required as an input for FP8 fused attention." + # execute kernel output_tensors = tex.fused_attn_fwd( max_seqlen_q, max_seqlen_kv, is_training, attn_scale, dropout, fast_zero_fill, QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], cu_seqlens_q, cu_seqlens_kv, q, k, v, qkv_dtype, - d_scale_qkv, q_scale_s, q_scale_o, amax_s, amax_o, + d_scale_qkv, d_scale_s, q_scale_s, q_scale_o, amax_s, amax_o, attn_bias, rng_gen, rng_elts_per_thread, ) @@ -822,12 +873,14 @@ def fused_attn_bwd( o: torch.Tensor, d_o: torch.Tensor, qkv_dtype: tex.DType, + dqkv_dtype: tex.DType, aux_ctx_tensors: List[torch.Tensor], fused_attention_backend: tex.NVTE_Fused_Attn_Backend, d_scale_qkv: torch.Tensor = None, d_scale_s: torch.Tensor = None, d_scale_o: torch.Tensor = None, d_scale_do: torch.Tensor = None, + d_scale_dp: torch.Tensor = None, q_scale_s: torch.Tensor = None, q_scale_dp: torch.Tensor = None, q_scale_dqkv: torch.Tensor = None, @@ -869,6 +922,8 @@ def fused_attn_bwd( same shape as Q qkv_dtype: tex.DType data type of Q, K and V; in tex.DType, not torch.dtype + dqkv_dtype: tex.DType + data type of dQ, dK and dV; in tex.DType, not torch.dtype aux_ctx_tensors: List[torch.Tensor] auxiliary output tensors of the forward pass when its is_training is True, e.g. aux_ctx_tensors = [M, ZInv, rng_state] @@ -882,6 +937,8 @@ def fused_attn_bwd( input tensor for the dequantization of O in FP8 computations d_scale_do: torch.Tensor, default = None input tensor for the dequantization of dO in FP8 computations + d_scale_dp: torch.Tensor, default = None + input tensor for the dequantization of dP in FP8 computations q_scale_s: torch.Tensor, default = None input tensor for the quantization of S in FP8 computations q_scale_dp: torch.Tensor, default = None @@ -941,6 +998,7 @@ def fused_attn_bwd( assert (d_scale_s is not None), "d_scale_s is required for FP8 fused attention." assert (d_scale_o is not None), "d_scale_o is required for FP8 fused attention." assert (d_scale_do is not None), "d_scale_do is required for FP8 fused attention." + assert (d_scale_dp is not None), "d_scale_dp is required for FP8 fused attention." assert (q_scale_s is not None), "q_scale_s is required for FP8 fused attention." assert (q_scale_dp is not None), "q_scale_dp is required for FP8 fused attention." assert (q_scale_dqkv is not None), "q_scale_dqkv is required for FP8 fused attention." @@ -953,8 +1011,8 @@ def fused_attn_bwd( output_tensors = tex.fused_attn_bwd( max_seqlen_q, max_seqlen_kv, attn_scale, dropout, fast_zero_fill, QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], - cu_seqlens_q, cu_seqlens_kv, q, k, v, o, d_o, qkv_dtype, aux_ctx_tensors, - d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, + cu_seqlens_q, cu_seqlens_kv, q, k, v, o, d_o, qkv_dtype, dqkv_dtype, aux_ctx_tensors, + d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, d_scale_dp, q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv, ) diff --git a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h index 3c039b9a88..dfbcfe3e8a 100644 --- a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h +++ b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h @@ -786,9 +786,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { // Get communication and GEMM output chunk sizes const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); const bool do_gelu = pre_gelu_out.numel() > 0; - const int output_chunk_bytes = (do_gelu - ? (n_chunk * m) * D.element_size() - : (n_chunk * m) * HALF_BYTES); + const int output_chunk_bytes = (n_chunk * m) * D.element_size(); const int aux_chunk_bytes = do_gelu ? (n_chunk * m) * pre_gelu_out.element_size() : 0; // Get output and workspace data pointers diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index bf0bb576ec..abbecb1609 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -32,6 +32,7 @@ std::vector fused_attn_fwd_qkvpacked( const at::Tensor QKV, const transformer_engine::DType qkv_type, const c10::optional descale_QKV, + const c10::optional descale_S, const c10::optional scale_S, const c10::optional scale_O, c10::optional amax_S, @@ -51,11 +52,13 @@ std::vector fused_attn_bwd_qkvpacked( const at::Tensor O, const at::Tensor dO, const transformer_engine::DType qkv_type, + const transformer_engine::DType dqkv_type, const std::vector Aux_CTX_Tensors, const c10::optional descale_QKV, const c10::optional descale_S, const c10::optional descale_O, const c10::optional descale_dO, + const c10::optional descale_dP, const c10::optional scale_S, const c10::optional scale_dP, const c10::optional scale_dQKV, @@ -74,6 +77,7 @@ std::vector fused_attn_fwd_kvpacked( const at::Tensor KV, const transformer_engine::DType qkv_type, const c10::optional descale_QKV, + const c10::optional descale_S, const c10::optional scale_S, const c10::optional scale_O, c10::optional amax_S, @@ -95,11 +99,13 @@ std::vector fused_attn_bwd_kvpacked( const at::Tensor O, const at::Tensor dO, const transformer_engine::DType qkv_type, + const transformer_engine::DType dqkv_type, const std::vector Aux_CTX_Tensors, const c10::optional descale_QKV, const c10::optional descale_S, const c10::optional descale_O, const c10::optional descale_dO, + const c10::optional descale_dP, const c10::optional scale_S, const c10::optional scale_dP, const c10::optional scale_dQKV, @@ -119,6 +125,7 @@ std::vector fused_attn_fwd( const at::Tensor V, const transformer_engine::DType qkv_type, const c10::optional descale_QKV, + const c10::optional descale_S, const c10::optional scale_S, const c10::optional scale_O, c10::optional amax_S, @@ -141,11 +148,13 @@ std::vector fused_attn_bwd( const at::Tensor O, const at::Tensor dO, const transformer_engine::DType qkv_type, + const transformer_engine::DType dqkv_type, const std::vector Aux_CTX_Tensors, const c10::optional descale_QKV, const c10::optional descale_S, const c10::optional descale_O, const c10::optional descale_dO, + const c10::optional descale_dP, const c10::optional scale_S, const c10::optional scale_dP, const c10::optional scale_dQKV, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index 0a84ea3089..cc747655c4 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -97,6 +97,7 @@ std::vector fused_attn_fwd_qkvpacked( const at::Tensor QKV, const transformer_engine::DType qkv_type, const c10::optional descale_QKV, + const c10::optional descale_S, const c10::optional scale_S, const c10::optional scale_O, c10::optional amax_S, @@ -126,22 +127,24 @@ std::vector fused_attn_fwd_qkvpacked( // FP8 auto h = q_shape[q_shape.size() - 2]; auto d = q_shape[q_shape.size() - 1]; - if (set_zero && ((h * d) % block_size == 0)) { + if (set_zero + && ((h * d) % block_size == 0) + && (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { mha_fill(O, cu_seqlens.index({torch::indexing::Slice(-1, torch::indexing::None)})); } else { O.fill_(0); } - if ((!descale_QKV.has_value()) || (!scale_S.has_value()) || (!scale_O.has_value()) - || (!amax_S.has_value()) || (!amax_O.has_value())) { - std::string err_tensors = "descale_QKV, scale_S, scale_O, amax_S and amax_O"; + if ((!descale_QKV.has_value()) || (!descale_S.has_value()) + || (!scale_S.has_value()) || (!scale_O.has_value()) + || (!amax_S.has_value()) || (!amax_O.has_value())) { + std::string err_tensors = "descale_QKV, descale_S, scale_S, scale_O, amax_S and amax_O "; NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); } te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), qkv_shape, qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); - at::Tensor descale_S = torch::empty_like(scale_S.value()); te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, amax_S.value().data_ptr(), - scale_S.value().data_ptr(), descale_S.data_ptr()); + scale_S.value().data_ptr(), descale_S.value().data_ptr()); te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, amax_O.value().data_ptr(), scale_O.value().data_ptr(), nullptr); } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { @@ -261,11 +264,13 @@ std::vector fused_attn_bwd_qkvpacked( const at::Tensor O, const at::Tensor dO, const transformer_engine::DType qkv_type, + const transformer_engine::DType dqkv_type, const std::vector Aux_CTX_Tensors, const c10::optional descale_QKV, const c10::optional descale_S, const c10::optional descale_O, const c10::optional descale_dO, + const c10::optional descale_dP, const c10::optional scale_S, const c10::optional scale_dP, const c10::optional scale_dQKV, @@ -284,26 +289,29 @@ std::vector fused_attn_bwd_qkvpacked( auto h = q_shape[q_shape.size() - 2]; // create output tensor dQKV - at::Tensor dQKV = torch::empty_like(QKV); - auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA); + auto options = torch::TensorOptions().dtype(GetATenDType(dqkv_type)).device(torch::kCUDA); + at::Tensor dQKV = torch::empty_like(QKV, options); // construct NVTE tensors TensorWrapper te_QKV, te_O, te_dO, te_S, te_dP, te_dQKV; if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { // FP8 auto d = q_shape[q_shape.size() - 1]; - if (set_zero && ((h * d) % block_size == 0)) { + if (set_zero + && ((h * d) % block_size == 0) + && (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { mha_fill(dQKV, cu_seqlens.index({torch::indexing::Slice(-1, torch::indexing::None)})); } else { dQKV.fill_(0); } if ((!descale_QKV.has_value()) || (!descale_S.has_value()) - || (!descale_O.has_value()) || (!descale_dO.has_value()) - || (!scale_S.has_value()) || (!scale_dP.has_value()) - || (!scale_dQKV.has_value()) - || (!amax_dP.has_value()) || (!amax_dQKV.has_value())) { - std::string err_tensors = "descale_QKV, descale_S, descale_O, scale_S, scale_dP, "; - err_tensors = err_tensors + std::string("scale_dQKV, amax_dP and amax_dQKV"); + || (!descale_O.has_value()) || (!descale_dO.has_value()) + || (!descale_dP.has_value()) || (!scale_S.has_value()) + || (!scale_dP.has_value()) || (!scale_dQKV.has_value()) + || (!amax_dP.has_value()) || (!amax_dQKV.has_value())) { + std::string err_tensors = "descale_QKV, descale_S, descale_O, descale_dO, descale_dP, "; + err_tensors = err_tensors + std::string("scale_S, scale_dP, scale_dQKV, "); + err_tensors = err_tensors + std::string("amax_dP and amax_dQKV "); NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); } te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), qkv_shape, @@ -311,14 +319,13 @@ std::vector fused_attn_bwd_qkvpacked( te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, nullptr, nullptr, descale_O.value().data_ptr()); te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape, - qkv_type, nullptr, nullptr, descale_dO.value().data_ptr()); + dqkv_type, nullptr, nullptr, descale_dO.value().data_ptr()); te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, scale_S.value().data_ptr(), descale_S.value().data_ptr()); - at::Tensor descale_dP = torch::empty_like(scale_dP.value()); te_dP = makeTransformerEngineTensor(nullptr, {0}, - DType::kFloat32, amax_dP.value().data_ptr(), scale_dP.value().data_ptr(), - descale_dP.data_ptr()); - te_dQKV = makeTransformerEngineTensor(dQKV.data_ptr(), qkv_shape, qkv_type, + DType::kFloat32, amax_dP.value().data_ptr(), + scale_dP.value().data_ptr(), descale_dP.value().data_ptr()); + te_dQKV = makeTransformerEngineTensor(dQKV.data_ptr(), qkv_shape, dqkv_type, amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr); } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { // BF16 or FP16 @@ -327,13 +334,13 @@ std::vector fused_attn_bwd_qkvpacked( te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape, - qkv_type, nullptr, nullptr, nullptr); + dqkv_type, nullptr, nullptr, nullptr); te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr); te_dP = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr); te_dQKV = makeTransformerEngineTensor(dQKV.data_ptr(), qkv_shape, - qkv_type, nullptr, nullptr, nullptr); + dqkv_type, nullptr, nullptr, nullptr); } else { NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); } @@ -433,6 +440,7 @@ std::vector fused_attn_fwd_kvpacked( const at::Tensor KV, const transformer_engine::DType qkv_type, const c10::optional descale_QKV, + const c10::optional descale_S, const c10::optional scale_S, const c10::optional scale_O, c10::optional amax_S, @@ -458,24 +466,26 @@ std::vector fused_attn_fwd_kvpacked( // FP8 auto h = q_shape[q_shape.size() - 2]; auto d = q_shape[q_shape.size() - 1]; - if (set_zero && ((h * d) % block_size == 0)) { + if (set_zero + && ((h * d) % block_size == 0) + && (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { mha_fill(O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); } else { O.fill_(0); } - if ((!descale_QKV.has_value()) || (!scale_S.has_value()) || (!scale_O.has_value()) - || (!amax_S.has_value()) || (!amax_O.has_value())) { - std::string err_tensors = "descale_QKV, scale_S, scale_O, amax_S and amax_O"; + if ((!descale_QKV.has_value()) || (!descale_S.has_value()) + || (!scale_S.has_value()) || (!scale_O.has_value()) + || (!amax_S.has_value()) || (!amax_O.has_value())) { + std::string err_tensors = "descale_QKV, descale_S, scale_S, scale_O, amax_S and amax_O "; NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); } te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); te_KV = makeTransformerEngineTensor(KV.data_ptr(), kv_shape, qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); - at::Tensor descale_S = torch::empty_like(scale_S.value()); te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, amax_S.value().data_ptr(), - scale_S.value().data_ptr(), descale_S.data_ptr()); + scale_S.value().data_ptr(), descale_S.value().data_ptr()); te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, amax_O.value().data_ptr(), scale_O.value().data_ptr(), nullptr); } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { @@ -608,11 +618,13 @@ std::vector fused_attn_bwd_kvpacked( const at::Tensor O, const at::Tensor dO, const transformer_engine::DType qkv_type, + const transformer_engine::DType dqkv_type, const std::vector Aux_CTX_Tensors, const c10::optional descale_QKV, const c10::optional descale_S, const c10::optional descale_O, const c10::optional descale_dO, + const c10::optional descale_dP, const c10::optional scale_S, const c10::optional scale_dP, const c10::optional scale_dQKV, @@ -635,15 +647,18 @@ std::vector fused_attn_bwd_kvpacked( auto d = q_shape[q_shape.size() - 1]; // create output tensors dQ and dKV - at::Tensor dQ = torch::empty_like(Q); - at::Tensor dKV = torch::empty_like(KV); - auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA); + auto options = torch::TensorOptions().dtype(GetATenDType(dqkv_type)).device(torch::kCUDA); + at::Tensor dQ = torch::empty_like(Q, options); + at::Tensor dKV = torch::empty_like(KV, options); // construct NVTE tensors TensorWrapper te_Q, te_KV, te_O, te_dO, te_S, te_dP, te_dQ, te_dKV; if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { // FP8 - if (set_zero && ((h_q * d)% block_size == 0) && ((h_kv * d)% block_size == 0)) { + if (set_zero + && ((h_q * d)% block_size == 0) + && ((h_kv * d)% block_size == 0) + && (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { mha_fill(dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); mha_fill(dKV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); } else { @@ -651,12 +666,13 @@ std::vector fused_attn_bwd_kvpacked( dKV.fill_(0); } if ((!descale_QKV.has_value()) || (!descale_S.has_value()) - || (!descale_O.has_value()) || (!descale_dO.has_value()) - || (!scale_S.has_value()) || (!scale_dP.has_value()) - || (!scale_dQKV.has_value()) - || (!amax_dP.has_value()) || (!amax_dQKV.has_value())) { - std::string err_tensors = "descale_QKV, descale_S, descale_O, scale_S, scale_dP, "; - err_tensors = err_tensors + std::string("scale_dQKV, amax_dP and amax_dQKV"); + || (!descale_O.has_value()) || (!descale_dO.has_value()) + || (!descale_dP.has_value()) || (!scale_S.has_value()) + || (!scale_dP.has_value()) || (!scale_dQKV.has_value()) + || (!amax_dP.has_value()) || (!amax_dQKV.has_value())) { + std::string err_tensors = "descale_QKV, descale_S, descale_O, descale_dO, descale_dP, "; + err_tensors = err_tensors + std::string("scale_S, scale_dP, scale_dQKV, "); + err_tensors = err_tensors + std::string("amax_dP and amax_dQKV "); NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); } te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, @@ -666,16 +682,15 @@ std::vector fused_attn_bwd_kvpacked( te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, nullptr, nullptr, descale_O.value().data_ptr()); te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape, - qkv_type, nullptr, nullptr, descale_dO.value().data_ptr()); + dqkv_type, nullptr, nullptr, descale_dO.value().data_ptr()); te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, scale_S.value().data_ptr(), descale_S.value().data_ptr()); - at::Tensor descale_dP = torch::empty_like(scale_dP.value()); te_dP = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, amax_dP.value().data_ptr(), scale_dP.value().data_ptr(), - descale_dP.data_ptr()); - te_dQ = makeTransformerEngineTensor(dQ.data_ptr(), q_shape, qkv_type, + descale_dP.value().data_ptr()); + te_dQ = makeTransformerEngineTensor(dQ.data_ptr(), q_shape, dqkv_type, amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr); - te_dKV = makeTransformerEngineTensor(dKV.data_ptr(), kv_shape, qkv_type, + te_dKV = makeTransformerEngineTensor(dKV.data_ptr(), kv_shape, dqkv_type, amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr); } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { // BF16 or FP16 @@ -686,15 +701,15 @@ std::vector fused_attn_bwd_kvpacked( te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape, - qkv_type, nullptr, nullptr, nullptr); + dqkv_type, nullptr, nullptr, nullptr); te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr); te_dP = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr); te_dQ = makeTransformerEngineTensor(dQ.data_ptr(), q_shape, - qkv_type, nullptr, nullptr, nullptr); + dqkv_type, nullptr, nullptr, nullptr); te_dKV = makeTransformerEngineTensor(dKV.data_ptr(), kv_shape, - qkv_type, nullptr, nullptr, nullptr); + dqkv_type, nullptr, nullptr, nullptr); } else { NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); } @@ -806,6 +821,7 @@ std::vector fused_attn_fwd( const at::Tensor V, const transformer_engine::DType qkv_type, const c10::optional descale_QKV, + const c10::optional descale_S, const c10::optional scale_S, const c10::optional scale_O, c10::optional amax_S, @@ -832,14 +848,17 @@ std::vector fused_attn_fwd( // FP8 auto h = q_shape[q_shape.size() - 2]; auto d = q_shape[q_shape.size() - 1]; - if (set_zero && ((h * d) % block_size == 0)) { + if (set_zero + && ((h * d) % block_size == 0) + && (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { mha_fill(O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); } else { O.fill_(0); } - if ((!descale_QKV.has_value()) || (!scale_S.has_value()) || (!scale_O.has_value()) - || (!amax_S.has_value()) || (!amax_O.has_value())) { - std::string err_tensors = "descale_QKV, scale_S, scale_O, amax_S and amax_O"; + if ((!descale_QKV.has_value()) || (!descale_S.has_value()) + || (!scale_S.has_value()) || (!scale_O.has_value()) + || (!amax_S.has_value()) || (!amax_O.has_value())) { + std::string err_tensors = "descale_QKV, descale_S, scale_S, scale_O, amax_S and amax_O "; NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); } te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, @@ -848,10 +867,9 @@ std::vector fused_attn_fwd( qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape, qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); - at::Tensor descale_S = torch::empty_like(scale_S.value()); te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, amax_S.value().data_ptr(), - scale_S.value().data_ptr(), descale_S.data_ptr()); + scale_S.value().data_ptr(), descale_S.value().data_ptr()); te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, amax_O.value().data_ptr(), scale_O.value().data_ptr(), nullptr); } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { @@ -990,11 +1008,13 @@ std::vector fused_attn_bwd( const at::Tensor O, const at::Tensor dO, const transformer_engine::DType qkv_type, + const transformer_engine::DType dqkv_type, const std::vector Aux_CTX_Tensors, const c10::optional descale_QKV, const c10::optional descale_S, const c10::optional descale_O, const c10::optional descale_dO, + const c10::optional descale_dP, const c10::optional scale_S, const c10::optional scale_dP, const c10::optional scale_dQKV, @@ -1011,7 +1031,7 @@ std::vector fused_attn_bwd( auto h_q = q_shape[q_shape.size() - 2]; auto h_kv = k_shape[k_shape.size() - 2]; auto d = q_shape[q_shape.size() - 1]; - auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA); + auto options = torch::TensorOptions().dtype(GetATenDType(dqkv_type)).device(torch::kCUDA); at::Tensor dQ; at::Tensor dK; @@ -1046,7 +1066,7 @@ std::vector fused_attn_bwd( torch::indexing::Slice(0, torch::indexing::None, 1)}).squeeze(tmp_shape.size() - 2); break; case NVTE_QKV_Layout_Group::NVTE_HD_2HD: - dQ = torch::empty_like(Q); + dQ = torch::empty_like(Q, options); tmp_shape = std::vector{k_sizes.begin(), k_sizes.end()}; tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 2, int64_t(2)); dKV = torch::empty(c10::IntArrayRef(tmp_shape), options); @@ -1058,7 +1078,7 @@ std::vector fused_attn_bwd( torch::indexing::Slice(0, torch::indexing::None, 1)}).squeeze(tmp_shape.size() - 3); break; case NVTE_QKV_Layout_Group::NVTE_HD_H2D: - dQ = torch::empty_like(Q); + dQ = torch::empty_like(Q, options); tmp_shape = std::vector{k_sizes.begin(), k_sizes.end()}; tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 1, int64_t(2)); dKV = torch::empty(c10::IntArrayRef(tmp_shape), options); @@ -1068,9 +1088,9 @@ std::vector fused_attn_bwd( torch::indexing::Slice(0, torch::indexing::None, 1)}).squeeze(tmp_shape.size() - 2); break; case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD: - dQ = torch::empty_like(Q); - dK = torch::empty_like(K); - dV = torch::empty_like(V); + dQ = torch::empty_like(Q, options); + dK = torch::empty_like(K, options); + dV = torch::empty_like(V, options); break; default: NVTE_ERROR("QKV layout not supported!"); @@ -1085,7 +1105,8 @@ std::vector fused_attn_bwd( && ((h_kv * d) % block_size == 0) && dQ.is_contiguous() && dK.is_contiguous() - && dV.is_contiguous()) { + && dV.is_contiguous() + && (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { mha_fill(dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); mha_fill(dK, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); mha_fill(dV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); @@ -1095,12 +1116,13 @@ std::vector fused_attn_bwd( dV.fill_(0); } if ((!descale_QKV.has_value()) || (!descale_S.has_value()) - || (!descale_O.has_value()) || (!descale_dO.has_value()) - || (!scale_S.has_value()) || (!scale_dP.has_value()) - || (!scale_dQKV.has_value()) - || (!amax_dP.has_value()) || (!amax_dQKV.has_value())) { - std::string err_tensors = "descale_QKV, descale_S, descale_O, scale_S, scale_dP, "; - err_tensors = err_tensors + std::string("scale_dQKV, amax_dP and amax_dQKV"); + || (!descale_O.has_value()) || (!descale_dO.has_value()) + || (!descale_dP.has_value()) || (!scale_S.has_value()) + || (!scale_dP.has_value()) || (!scale_dQKV.has_value()) + || (!amax_dP.has_value()) || (!amax_dQKV.has_value())) { + std::string err_tensors = "descale_QKV, descale_S, descale_O, descale_dO, descale_dP, "; + err_tensors = err_tensors + std::string("scale_S, scale_dP, scale_dQKV, "); + err_tensors = err_tensors + std::string("amax_dP and amax_dQKV "); NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); } te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, @@ -1112,18 +1134,17 @@ std::vector fused_attn_bwd( te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, nullptr, nullptr, descale_O.value().data_ptr()); te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape, - qkv_type, nullptr, nullptr, descale_dO.value().data_ptr()); + dqkv_type, nullptr, nullptr, descale_dO.value().data_ptr()); te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, scale_S.value().data_ptr(), descale_S.value().data_ptr()); - at::Tensor descale_dP = torch::empty_like(scale_dP.value()); te_dP = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, amax_dP.value().data_ptr(), scale_dP.value().data_ptr(), - descale_dP.data_ptr()); - te_dQ = makeTransformerEngineTensor(dQ.data_ptr(), q_shape, qkv_type, + descale_dP.value().data_ptr()); + te_dQ = makeTransformerEngineTensor(dQ.data_ptr(), q_shape, dqkv_type, amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr); - te_dK = makeTransformerEngineTensor(dK.data_ptr(), k_shape, qkv_type, + te_dK = makeTransformerEngineTensor(dK.data_ptr(), k_shape, dqkv_type, amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr); - te_dV = makeTransformerEngineTensor(dV.data_ptr(), v_shape, qkv_type, + te_dV = makeTransformerEngineTensor(dV.data_ptr(), v_shape, dqkv_type, amax_dQKV.value().data_ptr(), scale_dQKV.value().data_ptr(), nullptr); } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { // BF16 or FP16 @@ -1136,17 +1157,17 @@ std::vector fused_attn_bwd( te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape, - qkv_type, nullptr, nullptr, nullptr); + dqkv_type, nullptr, nullptr, nullptr); te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr); te_dP = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr); te_dQ = makeTransformerEngineTensor(dQ.data_ptr(), q_shape, - qkv_type, nullptr, nullptr, nullptr); + dqkv_type, nullptr, nullptr, nullptr); te_dK = makeTransformerEngineTensor(dK.data_ptr(), k_shape, - qkv_type, nullptr, nullptr, nullptr); + dqkv_type, nullptr, nullptr, nullptr); te_dV = makeTransformerEngineTensor(dV.data_ptr(), v_shape, - qkv_type, nullptr, nullptr, nullptr); + dqkv_type, nullptr, nullptr, nullptr); } else { NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); } diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index 9923d24a42..d9a5138e27 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -4,7 +4,7 @@ """Tensor class with FP8 data""" from __future__ import annotations -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Tuple, Union import torch from torch.utils._pytree import tree_map @@ -233,6 +233,87 @@ def forward( def backward(ctx, grad): return grad.to(ctx.input_dtype), None +class _ViewFunc(torch.autograd.Function): + """View function + + View the Float8Tensor using the provided shape. + + """ + + @staticmethod + def forward( + ctx, + tensor: torch.Tensor, + shape: Tuple[int] = None, + ) -> torch.Tensor: + + # Return input tensor if shape is not provided + ctx.shape = tensor.shape + if shape is None: + return tensor + + # Construct new tensor if shape is provided + if isinstance(tensor, Float8Tensor): + return Float8Tensor.make_like( + tensor, + data=tensor._data.view(*shape), + ) + return tensor.view(*shape) + + @staticmethod + def backward(ctx, + grad: torch.Tensor, + ) -> Tuple[[torch.Tensor, None], ...]: + + if isinstance(grad, Float8Tensor): + dgrad = Float8Tensor.make_like( + grad, + data=grad._data.view(ctx.shape), + ) + return dgrad, None + return grad.view(ctx.shape), None + + +class _ReshapeFunc(torch.autograd.Function): + """Reshape function + + Reshape the Float8Tensor using the provided shape. + + """ + + @staticmethod + def forward( + ctx, + tensor: torch.Tensor, + shape: Tuple[int] = None, + ) -> torch.Tensor: + + # Return input tensor if shape is not provided + ctx.shape = tensor.shape + if shape is None: + return tensor + + # Construct new tensor if shape is provided + if isinstance(tensor, Float8Tensor): + return Float8Tensor.make_like( + tensor, + data=tensor._data.reshape(*shape), + ) + return tensor.reshape(*shape) + + @staticmethod + def backward(ctx, + grad: torch.Tensor, + ) -> Tuple[Union[torch.Tensor, None], ...]: + + if isinstance(grad, Float8Tensor): + dgrad = Float8Tensor.make_like( + grad, + data=grad._data.reshape(ctx.shape), + ) + return dgrad, None + return grad.reshape(ctx.shape), None + class Float8Tensor(torch.Tensor): """Experimental tensor class with FP8 data @@ -453,6 +534,12 @@ def cpu(self) -> torch.Tensor: def clone(self) -> Float8Tensor: return _IdentityFunc.apply(self, {"data": self._data.detach().clone()}) + def view(self, *shape: Tuple[int]) -> Float8Tensor: + return _ViewFunc.apply(self, shape) + + def reshape(self, *shape: Tuple[int]) -> Float8Tensor: + return _ReshapeFunc.apply(self, shape) + def expand_as(self, other: torch.Tensor): if other is self: # Note: expand_as is hackily used to create dummy autograd nodes diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index d06443efb6..b871169a11 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -202,6 +202,11 @@ def add_fp8_tensors_to_global_buffer( # `fp8_param_to_autocast`. This is used for keeping track of FP8 weights # in an autocasted region and cross reference them in `float8_tensor.py` # to perform the forward amax reduction. + fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward) + if fp8_meta_tensor_key not in fp8_meta: + # Handles non-parameter FP8 modules, e.g. DPA. + continue + if forward and fp8_weights is not None: autocast_key = cls.get_unique_autocast_key( fp8_meta["recipe"], fp8_meta["fp8_group"]) @@ -217,7 +222,6 @@ def add_fp8_tensors_to_global_buffer( key = cls.get_key_in_buffer( forward, fp8_weights is not None, fp8_meta["recipe"], fp8_meta["fp8_group"]) - fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward) if key not in cls.global_amax_buffer: cls.global_amax_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]] diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 7e0cf5c106..00f5c2216d 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -268,6 +268,9 @@ def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> fp8_meta_tensor_keys = ("scaling_fwd" if fwd else "scaling_bwd",) for meta_key in fp8_meta_tensor_keys: + if meta_key not in self.fp8_meta: + # Handles non-parameter FP8 modules, e.g. DPA. + continue curr_len = self.fp8_meta[meta_key].amax_history.shape[0] if length == curr_len: continue @@ -568,6 +571,7 @@ def prepare_forward( inp: torch.Tensor, is_first_microbatch: Union[bool, None], num_gemms: int = 1, + allow_non_contiguous: bool = False, ) -> Generator[torch.Tensor, None, None]: """Checks and prep for FWD. The context manager is needed because there isn't a way for a module to know @@ -610,7 +614,10 @@ def prepare_forward( FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta) with torch.cuda.nvtx.range(self.__class__.__name__ + " forward"): - yield inp.contiguous() + if not allow_non_contiguous: + yield inp.contiguous() + else: + yield inp if self.fp8 and in_fp8_activation_recompute_phase(): FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta) @@ -645,8 +652,11 @@ def grad_output_preprocess( R4: bias gradient on R1. """ - grad_output = grad_output.contiguous() - grad_output_mat = grad_output.view((-1, grad_output.shape[-1])) + if isinstance(grad_output, Float8Tensor): + grad_output._data = grad_output._data.contiguous() + else: + grad_output = grad_output.contiguous() + grad_output_mat = grad_output.view(-1, grad_output.shape[-1]) gather_grad_output = row_parallel_mode and ctx.sequence_parallel # No-FP8 case: bgrad is fused with wgrad for this case. @@ -684,16 +694,22 @@ def grad_output_preprocess( grad_output_c = ctx.ub_obj_gradout.get_ubuf_output(0) else: grad_output_c = torch.empty_like(grad_output_mat, dtype=torch.uint8) - cast_to_fp8( - grad_output_mat, - ctx.fp8_meta["scaling_bwd"], - tex.FP8BwdTensors.GRAD_OUTPUT1, - fp8_dtype_backward, - out=grad_output_c, - ) + if not isinstance(grad_output_mat, Float8Tensor): + cast_to_fp8( + grad_output_mat, + ctx.fp8_meta["scaling_bwd"], + tex.FP8BwdTensors.GRAD_OUTPUT1, + fp8_dtype_backward, + out=grad_output_c, + ) + else: + grad_output_c = grad_ouput_mat # pylint: disable=undefined-variable if not ctx.ub_overlap_ag: grad_output_c, _ = gather_along_first_dim(grad_output_c, ctx.tp_group) - grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward) + if not isinstance(grad_output_c, Float8Tensor): + grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward) + else: + grad_output_t = grad_output_c.transpose_2d() else: grad_output_c = ctx.ub_obj_gradout.get_ubuf_output(1) grad_output_t = None @@ -702,28 +718,38 @@ def grad_output_preprocess( # FP8 case without gather: cast, transpose, bgrad fused if ctx.use_bias: + grad_output_mat_no_fp8 = grad_output_mat + if isinstance(grad_output_mat, Float8Tensor): + grad_output_mat_no_fp8 = grad_output_mat.from_float8(grad_output_mat.dtype) grad_bias, grad_output_c, grad_output_t = fp8_cast_transpose_bgrad_fused( - grad_output_mat, + grad_output_mat_no_fp8, ctx.fp8_meta["scaling_bwd"], tex.FP8BwdTensors.GRAD_OUTPUT1, fp8_dtype_backward, ) else: if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: - grad_output_c, grad_output_t = fp8_cast_transpose_fused( - grad_output_mat, - ctx.fp8_meta["scaling_bwd"], - tex.FP8BwdTensors.GRAD_OUTPUT1, - fp8_dtype_backward, - ) + if isinstance(grad_output_mat, Float8Tensor): + grad_output_c = grad_output_mat + grad_output_t = grad_output_c.transpose_2d() + else: + grad_output_c, grad_output_t = fp8_cast_transpose_fused( + grad_output_mat, + ctx.fp8_meta["scaling_bwd"], + tex.FP8BwdTensors.GRAD_OUTPUT1, + fp8_dtype_backward, + ) else: grad_output_t = None - grad_output_c = cast_to_fp8( - grad_output_mat, - ctx.fp8_meta["scaling_bwd"], - tex.FP8BwdTensors.GRAD_OUTPUT1, - fp8_dtype_backward, - ) + if not isinstance(grad_output_mat, Float8Tensor): + grad_output_c = cast_to_fp8( + grad_output_mat, + ctx.fp8_meta["scaling_bwd"], + tex.FP8BwdTensors.GRAD_OUTPUT1, + fp8_dtype_backward, + ) + else: + grad_output_c = grad_output_mat grad_bias = None return grad_output_mat, grad_output_c, grad_output_t, grad_bias diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 5df4950276..bc4c29d308 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -43,6 +43,7 @@ from ..graph import is_graph_capturing from ._common import _apply_normalization, _noop_cat from ..float8_tensor import Float8Tensor +_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) __all__ = ["LayerNormLinear"] @@ -190,6 +191,9 @@ def forward( ln_out = ln_out_total if fp8: + if _NVTE_DEBUG: + print('[LayerNormLinear]: using FP8 forward') + bias_dtype = ( torch.bfloat16 if activation_dtype == torch.float32 @@ -230,6 +234,15 @@ def forward( ) weight_t_fp8 = None + if fp8_meta["recipe"].fp8_mha: + out_index, meta_tensor, output_te_dtype, output_dtype = ( + tex.FP8FwdTensors.GEMM1_OUTPUT, + fp8_meta["scaling_fwd"], + fp8_dtype_forward, + torch.uint8) + else: + out_index, meta_tensor, output_te_dtype, output_dtype = ( + None, None, None, activation_dtype) out, _ = tex.fp8_gemm( weight_fp8._data, fp8_meta["scaling_fwd"].scale_inv, @@ -239,7 +252,7 @@ def forward( fp8_meta["scaling_fwd"].scale_inv, tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype_forward, - activation_dtype, + output_dtype, get_workspace(), bias=bias, use_bias=use_bias, @@ -247,8 +260,22 @@ def forward( ub_algo=ub_algo if ub_overlap_ag else None, ub=ub_obj_lnout if ub_overlap_ag else None, extra_output_tensor=ln_out if ub_overlap_ag else None, + out_index=out_index, + fp8_meta_tensor=meta_tensor, + D_dtype=output_te_dtype, ) + if output_dtype == torch.uint8: + out = Float8Tensor(data=out, + fp8_meta=fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=tex.FP8FwdTensors.GEMM1_OUTPUT, + fp8_dtype=fp8_dtype_forward, + dtype=activation_dtype, + ) else: + if _NVTE_DEBUG: + print('[LayerNormLinear]: using non-FP8 forward') + # Cast for native AMP weight = cast_if_needed(weight, activation_dtype) bias = cast_if_needed(bias, activation_dtype) if use_bias else bias @@ -338,7 +365,6 @@ def forward( # [*, in_features] -> [*, out_features] except first dimension changes for SP out = out.view(-1, *inp.shape[1:-1], out.shape[-1]) - if return_layernorm_output: if return_layernorm_output_gathered: shape = list(inp.shape) @@ -352,6 +378,10 @@ def forward( def backward( ctx, *grad_outputs: Tuple[torch.Tensor, ...] ) -> Tuple[Union[torch.Tensor, None], ...]: + if isinstance(grad_outputs[0], Float8Tensor): + ctx.fp8_meta["scaling_bwd"].scale_inv[ + tex.FP8BwdTensors.GRAD_OUTPUT1] = grad_outputs[0]._scale_inv + with torch.cuda.nvtx.range("_LayerNormLinear_backward"): ( inputmat, @@ -465,6 +495,9 @@ def backward( ub_obj = None if ctx.fp8: + if _NVTE_DEBUG: + print('[LayerNormLinear]: using FP8 backward') + fp8_dtype_forward = get_fp8_te_dtype( ctx.fp8_meta["recipe"], fprop_tensor=True ) @@ -486,7 +519,8 @@ def backward( fwd_scale_inverses, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, - grad_output_c, + grad_output_c._data + if isinstance(grad_output_c, Float8Tensor) else grad_output_c, ctx.fp8_meta["scaling_bwd"].scale_inv, tex.FP8BwdTensors.GRAD_OUTPUT1, fp8_dtype_backward, @@ -503,6 +537,9 @@ def backward( ) clear_tensor_data(grad_output_c) else: + if _NVTE_DEBUG: + print('[LayerNormLinear]: using non-FP8 backward') + # DGRAD: Evaluated unconditionally to feed into Linear backward _, _, _ = tex.gemm( weight, @@ -551,7 +588,8 @@ def backward( fwd_scale_inverses, tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype_forward, - grad_output_t, + grad_output_t._data + if isinstance(grad_output_t, Float8Tensor) else grad_output_t, ctx.fp8_meta["scaling_bwd"].scale_inv, tex.FP8BwdTensors.GRAD_OUTPUT1, fp8_dtype_backward, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 3c055270b0..8adaab557f 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """Linear API""" +import os from typing import Union, Optional, Callable, Tuple, List, Dict, Any import torch @@ -46,6 +47,8 @@ from ..graph import is_graph_capturing from ..float8_tensor import Float8Tensor +_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) + __all__ = ["Linear"] @@ -81,11 +84,16 @@ def forward( ub_overlap_rs: bool, ub_overlap_ag: bool, ub_name: str, + is_first_module_in_mha: bool, ) -> torch.Tensor: + is_input_fp8 = isinstance(inp, Float8Tensor) + if is_input_fp8: + fp8_meta["scaling_fwd"].scale_inv[tex.FP8FwdTensors.GEMM1_INPUT] = inp._scale_inv[0] + # Make sure input dimensions are compatible in_features = weight.shape[-1] assert inp.shape[-1] == in_features, "GEMM not possible" - inputmat = inp.view((-1, in_features)) + inputmat = inp.view(-1, in_features) if fp8: assert_dim_for_fp8_exec(inputmat) assert_dim_for_fp8_exec(weight) @@ -103,29 +111,40 @@ def forward( inputmat = cast_if_needed(inputmat, activation_dtype) inputmat_t = None inputmat_no_fp8 = inputmat + if fp8: fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - if ( - not fp8_meta["recipe"].override_linear_precision.wgrad - and is_grad_enabled - and weight.requires_grad - and not sequence_parallel - ): - # FP8 input for forward, FP8 input transpose for backward wgrad - inputmat, inputmat_t = fp8_cast_transpose_fused( - inputmat, - fp8_meta["scaling_fwd"], - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - ) + if isinstance(inputmat, Float8Tensor): + if ( + not fp8_meta["recipe"].override_linear_precision.wgrad + and is_grad_enabled + and weight.requires_grad + and not sequence_parallel + ): + # FP8 input for forward, FP8 input transpose for backward wgrad + inputmat_t = inputmat.transpose_2d() else: - # FP8 input for forward - inputmat = cast_to_fp8( - inputmat, - fp8_meta["scaling_fwd"], - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - ) + if ( + not fp8_meta["recipe"].override_linear_precision.wgrad + and is_grad_enabled + and weight.requires_grad + and not sequence_parallel + ): + # FP8 input for forward, FP8 input transpose for backward wgrad + inputmat, inputmat_t = fp8_cast_transpose_fused( + inputmat, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM1_INPUT, + fp8_dtype_forward, + ) + else: + # FP8 input for forward + inputmat = cast_to_fp8( + inputmat, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM1_INPUT, + fp8_dtype_forward, + ) # Column Parallel Linear if parallel_mode == "column" and sequence_parallel: @@ -134,6 +153,9 @@ def forward( inputmat_total = inputmat if fp8: + if _NVTE_DEBUG: + print('[Linear]: using FP8 forward') + bias_dtype = ( torch.bfloat16 if activation_dtype == torch.float32 @@ -174,8 +196,16 @@ def forward( ) weight_t_fp8 = None - proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = ( - None, None, None, activation_dtype) + if is_first_module_in_mha: + proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = ( + tex.FP8FwdTensors.GEMM1_OUTPUT, + fp8_meta["scaling_fwd"], + fp8_dtype_forward, + torch.uint8) + else: + proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = ( + None, None, None, activation_dtype) + if ub_overlap_rs: ub_obj_projout = get_ub(ub_name+"_fprop") out = ub_obj_projout.get_ubuf_output(1) @@ -202,14 +232,15 @@ def forward( else: dim_size = list(inputmat_total.size()) dim_size[1] = weight.size(0) - out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) + out = torch.empty(dim_size, dtype=proj_out_pttype, device=inputmat_total.device) _ = fp8_gemm( weight_fp8._data, fp8_meta["scaling_fwd"].scale_inv, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, - inputmat_total, + inputmat_total._data + if isinstance(inputmat_total, Float8Tensor) else inputmat_total, fp8_meta["scaling_fwd"].scale_inv, tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype_forward, @@ -226,7 +257,18 @@ def forward( fp8_meta_tensor = meta_tensor, D_dtype = proj_out_tetype, ) + if is_first_module_in_mha: + out = Float8Tensor(data=out, + fp8_meta=fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=tex.FP8FwdTensors.GEMM1_OUTPUT, + fp8_dtype=fp8_dtype_forward, + dtype=activation_dtype, + ) else: + if _NVTE_DEBUG: + print('[Linear]: using non-FP8 forward') + # Cast for native AMP weight = cast_if_needed(weight, activation_dtype) bias = cast_if_needed(bias, activation_dtype) if use_bias else bias @@ -319,6 +361,7 @@ def forward( ctx.ub_name = ub_name ctx.tp_size = tp_size ctx.requires_dgrad = inp.requires_grad + ctx.is_input_fp8 = is_input_fp8 ctx.primary_weights_in_fp8 = primary_weights_in_fp8 ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() @@ -338,6 +381,10 @@ def forward( def backward( ctx, grad_output: torch.Tensor ) -> Tuple[Union[torch.Tensor, None], ...]: + if isinstance(grad_output[0], Float8Tensor): + ctx.fp8_meta["scaling_bwd"].scale_inv[ + tex.FP8BwdTensors.GRAD_OUTPUT1] = grad_output._scale_inv + with torch.cuda.nvtx.range("_Linear_backward"): ( inputmat, @@ -412,6 +459,18 @@ def backward( if ctx.requires_dgrad: if ctx.fp8: + if _NVTE_DEBUG: + print('[Linear]: using FP8 backward') + + if ctx.is_input_fp8: + out_index, meta_tensor, output_te_dtype, output_dtype = ( + tex.FP8BwdTensors.GRAD_INPUT1, + ctx.fp8_meta["scaling_bwd"], + fp8_dtype_backward, + torch.uint8) + else: + out_index, meta_tensor, output_te_dtype, output_dtype = ( + None, None, None, ctx.activation_dtype) dgrad, _ = fp8_gemm( weight_t_fp8, fwd_scale_inverses, @@ -421,13 +480,27 @@ def backward( ctx.fp8_meta["scaling_bwd"].scale_inv, tex.FP8BwdTensors.GRAD_OUTPUT1, fp8_dtype_backward, - ctx.activation_dtype, + output_dtype, get_workspace(), use_split_accumulator=_2X_ACC_DGRAD, ub_algo=ub_algo if ctx.ub_overlap_ag else None, ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None, + out_index=out_index, + fp8_meta_tensor=meta_tensor, + D_dtype=output_te_dtype, ) + if output_dtype == torch.uint8: + dgrad = Float8Tensor(data=dgrad, + fp8_meta=ctx.fp8_meta, + fp8_meta_forward=False, + fp8_meta_index=tex.FP8BwdTensors.GRAD_INPUT1, + fp8_dtype=fp8_dtype_backward, + dtype=ctx.activation_dtype, + ) else: + if _NVTE_DEBUG: + print('[Linear]: using non-FP8 backward') + dgrad, _, _ = gemm( weight, grad_output, @@ -455,11 +528,19 @@ def backward( # WGRAD if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: if ctx.ub_overlap_ag: - grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward) + if isinstance(grad_output_c, Float8Tensor): + grad_output_t = grad_output_c.transpose_2d() + else: + grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward) if inputmat_t_total is None: - inputmat_t_total = tex.fp8_transpose(inputmat_total, fp8_dtype_backward) + if isinstance(inputmat_total, Float8Tensor): + inputmat_t_total = inputmat_total.transpose_2d() + else: + inputmat_t_total = tex.fp8_transpose( + inputmat_total, fp8_dtype_backward) wgrad, _ = fp8_gemm( - inputmat_t_total, + inputmat_t_total._data + if isinstance(inputmat_t_total, Float8Tensor) else inputmat_t_total, fwd_scale_inverses, tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype_forward, @@ -558,6 +639,7 @@ def backward( None, None, None, + None, ) @@ -850,6 +932,7 @@ def forward( self, inp: torch.Tensor, is_first_microbatch: Optional[bool] = None, + is_first_module_in_mha: Optional[bool] = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: """ Apply the linear transformation to the input. @@ -871,16 +954,22 @@ def forward( * it also allows skipping gradient accumulation during the first microbatch (since it is the first gradient being produced) + is_first_module_in_mha: Optional[bool], default = False + Whether to output in FP8. By default, Linear outputs in inp.dtype. """ skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() if skip_fp8_weight_update is not None: is_first_microbatch = False - with self.prepare_forward(inp, is_first_microbatch) as inp: + with self.prepare_forward(inp, + is_first_microbatch, + allow_non_contiguous=isinstance(inp,Float8Tensor)) as inp: assert self.fp8 or not self.primary_weights_in_fp8, \ "Need to run inside fp8_autocast region when weights are stored in FP8." + is_first_module_in_mha = is_first_module_in_mha and self.fp8_meta["recipe"].fp8_mha + # Get concatenated weight and bias tensors if len(self.parameter_split_sizes) == 1: weight_tensor = getattr(self, self.weight_names[0]) @@ -939,6 +1028,7 @@ def forward( self.ub_overlap_rs, self.ub_overlap_ag, self.ub_name, + is_first_module_in_mha, ) out = linear_fn(*args) diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 09eb433957..df750ab1ae 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -15,10 +15,15 @@ def clear_tensor_data(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None: Must be used carefully. """ + from .float8_tensor import Float8Tensor for t in tensors: if t is not None: - t.data = torch.Tensor() - del t + if isinstance(t, Float8Tensor): + t._data.data = torch.Tensor() + del t + else: + t.data = torch.Tensor() + del t def get_device_compute_capability() -> Tuple[int, int]: From 2921464c06d37b058f36f1db8151e080727f6163 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Mon, 15 Apr 2024 22:14:25 -0700 Subject: [PATCH 019/244] Changed VERSION to 1.7.0dev Signed-off-by: Kirthi Shankar Sivamani Signed-off-by: Pawel Gadzinski --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index 65babdef47..2ac2d70206 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.6.0.dev0 +1.7.0.dev0 From ea9f6be9c018bf98e68d9a0469b815f0ef022e53 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 16 Apr 2024 08:58:16 -0700 Subject: [PATCH 020/244] [PyTorch] Use __torch_function__ as a class method (#783) Use torch function as a class method Signed-off-by: Kirthi Shankar Sivamani Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/float8_tensor.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index d9a5138e27..f93d6ae5cb 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -263,7 +263,7 @@ def forward( @staticmethod def backward(ctx, grad: torch.Tensor, - ) -> Tuple[[torch.Tensor, None], ...]: + ) -> Tuple[Union[torch.Tensor, None], ...]: if isinstance(grad, Float8Tensor): dgrad = Float8Tensor.make_like( @@ -853,5 +853,8 @@ def _set_data(self, tensor: torch.Tensor) -> None: _transpose_invalid = property(**_make_fp8_attr_property_funcs("transpose_invalid")) _scale_inv = property(**_make_fp8_attr_property_funcs("scale_inv")) - # Do not force the Float8Tensor type on the returned tensor - __torch_function__ = torch._C._disabled_torch_function_impl + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + return torch._C._disabled_torch_function_impl(func, types, args, kwargs) From 324bafb5150ae9986516d0f08a0f9f2990c58f03 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Tue, 16 Apr 2024 13:12:50 -0500 Subject: [PATCH 021/244] [PyTorch] TE checkpoint pass-through logic fix (#782) * changed TE checkpoint passthrough logic to also recursively look for TE submodules Signed-off-by: Alp Dener * simplified search for TE modules in the checkpointed network Signed-off-by: Alp Dener --------- Signed-off-by: Alp Dener Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/distributed.py | 27 +++++++++++------------ 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 8d499d88d6..08da93587d 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -498,15 +498,15 @@ def get_activation_recompute_contexts(): return forward_ctx, recompute_ctx -def _is_te_module(module): +def has_te_modules(network): """ - Check if given module is a Transformer Engine module that requires the TE checkpoint - implementation for activation recompute. + Check if there are any Transformer Engine modules in the network. """ from .module import LayerNorm, RMSNorm from .module.base import TransformerEngineBaseModule from .attention import UnfusedDotProductAttention, DotProductAttention, MultiheadAttention from .transformer import TransformerLayer + te_classes_list = [ LayerNorm, RMSNorm, @@ -516,12 +516,13 @@ def _is_te_module(module): MultiheadAttention, TransformerLayer, ] - is_te_module = False - for te_class in te_classes_list: - if isinstance(module, te_class): - is_te_module = True - break - return is_te_module + + if isinstance(network, torch.nn.Module): + for module in network.modules(): + if any(isinstance(module, te_class) for te_class in te_classes_list): + return True + + return False def checkpoint( @@ -584,14 +585,12 @@ def checkpoint( distribute_saved_activations, get_rng_state_tracker, tp_group = args[:3] # pylint: disable=unbalanced-tuple-unpacking args = args[3:] - # Trigger the native PyTorch checkpoint if: - # 1. `function` is a `torch.nn.Module` - # AND - # 2. `function` is NOT a TE module + # Trigger the native PyTorch checkpoint if the function is not or does not contain a + # Transformer Engine module. context_fn = kwargs.pop("context_fn", noop_context_fn) determinism_check = kwargs.pop("determinism_check", "default") debug = kwargs.pop("debug", False) - if isinstance(function, torch.nn.Module) and not _is_te_module(function): + if not has_te_modules(function): return torch.utils.checkpoint.checkpoint( function, *args, From f998fee1f304cc7eb7a1abebea1a83c20529e137 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Tue, 16 Apr 2024 13:49:15 -0700 Subject: [PATCH 022/244] Add new users to TE CI Signed-off-by: Przemek Tredak Signed-off-by: Pawel Gadzinski --- .github/workflows/trigger-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/trigger-ci.yml b/.github/workflows/trigger-ci.yml index 59b07429cc..6ab838f461 100644 --- a/.github/workflows/trigger-ci.yml +++ b/.github/workflows/trigger-ci.yml @@ -16,7 +16,7 @@ jobs: # This job only runs for pull request comments if: | - contains( 'ptrendx,ksivaman,schetlur-nv,timmoon10,zlsh80826,mingxu1067,cyanguwa,nzmora-nvidia,galagam,nouiz,denera,sudhakarsingh27,Oleg-Goncharov,', format('{0},', github.actor)) && + contains( 'ptrendx,ksivaman,schetlur-nv,timmoon10,zlsh80826,mingxu1067,cyanguwa,nzmora-nvidia,galagam,nouiz,denera,sudhakarsingh27,Oleg-Goncharov,phu0ngng,', format('{0},', github.actor)) && startsWith(github.event.comment.body, '/te-ci') steps: - name: Check if comment is issued by authorized person From a27264bc0bc6e235a7a43e097010798e1f6ad6f2 Mon Sep 17 00:00:00 2001 From: Ming-Xu Huang Date: Tue, 16 Apr 2024 19:26:54 -0400 Subject: [PATCH 023/244] Support Low Rank Adaptation (LoRA). (#745) Signed-off-by: Pawel Gadzinski --- tests/jax/test_functions.py | 68 ++++++++ tests/jax/test_praxis_layers.py | 44 +++++ transformer_engine/jax/flax/module.py | 169 ++++++++++++++++++- transformer_engine/jax/flax/transformer.py | 103 +++++++++++ transformer_engine/jax/praxis/module.py | 18 ++ transformer_engine/jax/praxis/transformer.py | 12 ++ 6 files changed, 412 insertions(+), 2 deletions(-) create mode 100644 tests/jax/test_functions.py diff --git a/tests/jax/test_functions.py b/tests/jax/test_functions.py new file mode 100644 index 0000000000..aaa6be77ac --- /dev/null +++ b/tests/jax/test_functions.py @@ -0,0 +1,68 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import pytest + +import jax +import jax.numpy as jnp + +from utils import assert_allclose +from transformer_engine.jax.flax.module import _apply_low_rank_adaptation +from transformer_engine.jax.flax.module import _normalize_axes +from transformer_engine.jax.flax.transformer import LoRAScope +from transformer_engine.jax.flax.transformer import _canonicalize_lora_scope + + +class TestLoRA: + + def reference(x, la, lb, pattern, scale): + out = jnp.einsum(pattern, x, la, lb) + return out * scale + + @pytest.mark.parametrize('shape', [(32, 1024), (32, 128, 1024)]) + @pytest.mark.parametrize('dtype', [jnp.float32, jnp.bfloat16]) + @pytest.mark.parametrize('axis_features_pattern', [((-1,), (1024,), '...h,hr,rk->...k'), + ((-1,), (3, 1024), '...h,hkr,krz->...kz')]) + @pytest.mark.parametrize('rank', [32, 16]) + @pytest.mark.parametrize('alpha', [None, 4, 8]) + def test_lora(self, shape, dtype, axis_features_pattern, rank, alpha): + axis, features, pattern = axis_features_pattern + axis = _normalize_axes(axis, len(shape)) + shape_in_axis = tuple(shape[ax] for ax in axis) + + key = jax.random.key(1124) + key, x_key = jax.random.split(key) + x = jax.random.normal(x_key, shape, dtype) + + key, la_key = jax.random.split(key) + la_shape = (*shape_in_axis, *features[:-1], rank) + la = jax.random.normal(la_key, la_shape, dtype) + + key, lb_key = jax.random.split(key) + lb_shape = (*features[:-1], rank, features[-1]) + lb = jax.random.normal(lb_key, lb_shape, dtype) + + out_target = _apply_low_rank_adaptation(x, axis, features, la, lb, alpha) + scale_ref = alpha / rank if alpha is not None else 1.0 + out_ref = TestLoRA.reference(x, la, lb, pattern, scale_ref) + + assert_allclose(out_target, out_ref, dtype=dtype) + + @pytest.mark.parametrize('scope_ref_assert', + [('none', LoRAScope(False, False, False), False), + ('all', LoRAScope(True, True, True), False), + ('qkv_proj', LoRAScope(True, False, False), False), + ('output_proj', LoRAScope(False, True, False), False), + ('mlp', LoRAScope(False, False, True), False), + ('exclude_qkv_proj', LoRAScope(False, True, True), False), + ('exclude_output_proj', LoRAScope(True, False, True), False), + ('exclude_mlp', LoRAScope(True, True, False), False), + ('messing_up', LoRAScope(), True)]) + def test_lora_scope_generator(self, scope_ref_assert): + scope, reference, need_assert = scope_ref_assert + try: + lora_scope = _canonicalize_lora_scope(scope) + assert lora_scope == reference + except AssertionError as ae: + assert need_assert, f"{ae.args}" diff --git a/tests/jax/test_praxis_layers.py b/tests/jax/test_praxis_layers.py index 43581f1015..dce0263ac7 100644 --- a/tests/jax/test_praxis_layers.py +++ b/tests/jax/test_praxis_layers.py @@ -784,6 +784,7 @@ class MultiHeadAttnAttr: NUM_GQA_GROUPS = 'num_gqa_groups' ENABLE_ROPE = 'enable_rotary_pos_emb' ROPE_GROUP_METHOD = 'rotary_pos_emb_group_method' + LORA_SCOPE = 'low_rank_adaptation_scope' ATTRS = [{ USE_BIAS: True, LN_TYPE: 'layernorm', @@ -853,6 +854,22 @@ class MultiHeadAttnAttr: NUM_ATTN_HEADS: 8, NUM_GQA_GROUPS: 4, ATTN_MASK_TYPE: 'causal' + }, { + USE_BIAS: True, + LN_TYPE: 'layernorm', + ZERO_CEN: False, + ENABLE_ROPE: False, + ROPE_GROUP_METHOD: 'consecutive', + ATTN_MASK_TYPE: 'padding', + LORA_SCOPE: 'all' + }, { + USE_BIAS: True, + LN_TYPE: 'layernorm', + ZERO_CEN: False, + ENABLE_ROPE: False, + ROPE_GROUP_METHOD: 'consecutive', + ATTN_MASK_TYPE: 'causal', + LORA_SCOPE: 'all' }] @@ -883,6 +900,7 @@ def generate_praxis_p_and_flax_cls(self, dtype, attrs): attn_mask_type = attrs[MultiHeadAttnAttr.ATTN_MASK_TYPE] enable_rotary_pos_emb = attrs[MultiHeadAttnAttr.ENABLE_ROPE] rotary_pos_emb_group_method = attrs[MultiHeadAttnAttr.ROPE_GROUP_METHOD] + low_rank_adaptation_scope = attrs.get(MultiHeadAttnAttr.LORA_SCOPE, 'none') fuse_qkv_params = True transpose_batch_sequence = True scale_attn_logits = False @@ -905,6 +923,7 @@ def generate_praxis_p_and_flax_cls(self, dtype, attrs): attn_mask_type=attn_mask_type, enable_rotary_pos_emb=enable_rotary_pos_emb, rotary_pos_emb_group_method=rotary_pos_emb_group_method, + low_rank_adaptation_scope=low_rank_adaptation_scope, fuse_qkv_params=fuse_qkv_params, transpose_batch_sequence=transpose_batch_sequence, scale_attn_logits=scale_attn_logits, @@ -926,6 +945,7 @@ def generate_praxis_p_and_flax_cls(self, dtype, attrs): attn_mask_type=attn_mask_type, enable_rotary_pos_emb=enable_rotary_pos_emb, rotary_pos_emb_group_method=rotary_pos_emb_group_method, + low_rank_adaptation_scope=low_rank_adaptation_scope, fuse_qkv_params=fuse_qkv_params, transpose_batch_sequence=transpose_batch_sequence, scale_attn_logits=scale_attn_logits, @@ -969,6 +989,7 @@ class TransformerLayerAttr: TRANSPOSE_BS = 'transpose_batch_sequence' ENABLE_ROPE = 'enable_rotary_pos_emb' ROPE_GROUP_METHOD = 'rotary_pos_emb_group_method' + LORA_SCOPE = 'low_rank_adaptation_scope' ATTRS = [{ USE_BIAS: True, LN_TYPE: 'layernorm', @@ -1113,6 +1134,16 @@ class TransformerLayerAttr: ENABLE_ROPE: False, ROPE_GROUP_METHOD: 'consecutive', TRANSPOSE_BS: False + }, { + USE_BIAS: True, + LN_TYPE: 'layernorm', + ZERO_CEN: False, + ACTIVATION: ('gelu',), + LYR_TYPE: TransformerLayerType.ENCODER, + ENABLE_ROPE: False, + ROPE_GROUP_METHOD: 'consecutive', + TRANSPOSE_BS: False, + LORA_SCOPE: 'all' }, { USE_BIAS: True, LN_TYPE: 'layernorm', @@ -1185,6 +1216,16 @@ class TransformerLayerAttr: ENABLE_ROPE: True, ROPE_GROUP_METHOD: 'consecutive', TRANSPOSE_BS: False + }, { + USE_BIAS: True, + LN_TYPE: 'layernorm', + ZERO_CEN: False, + ACTIVATION: ('gelu',), + LYR_TYPE: TransformerLayerType.DECODER, + ENABLE_ROPE: False, + ROPE_GROUP_METHOD: 'consecutive', + TRANSPOSE_BS: False, + LORA_SCOPE: 'all' }] @@ -1219,6 +1260,7 @@ def generate_praxis_p_and_flax_cls(self, dtype, attrs): layer_type = attrs[TransformerLayerAttr.LYR_TYPE] enable_rotary_pos_emb = attrs[TransformerLayerAttr.ENABLE_ROPE] rotary_pos_emb_group_method = attrs[TransformerLayerAttr.ROPE_GROUP_METHOD] + low_rank_adaptation_scope = attrs.get(TransformerLayerAttr.LORA_SCOPE, 'none') enable_relative_embedding = True relative_embedding = pax_fiddle.Config(RelativePositionBiases, dtype=dtype, @@ -1257,6 +1299,7 @@ def generate_praxis_p_and_flax_cls(self, dtype, attrs): enable_relative_embedding=enable_relative_embedding, enable_rotary_pos_emb=enable_rotary_pos_emb, rotary_pos_emb_group_method=rotary_pos_emb_group_method, + low_rank_adaptation_scope=low_rank_adaptation_scope, relative_embedding=relative_embedding, drop_path=drop_path, transpose_batch_sequence=transpose_batch_sequence) @@ -1282,6 +1325,7 @@ def generate_praxis_p_and_flax_cls(self, dtype, attrs): rotary_pos_emb_group_method=rotary_pos_emb_group_method, enable_relative_embedding=enable_relative_embedding, relative_embedding=relative_embedding_flax_module, + low_rank_adaptation_scope=low_rank_adaptation_scope, drop_path=drop_path, transpose_batch_sequence=transpose_batch_sequence) diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 8ddc74ac2e..8ca8edcb0b 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -104,6 +104,31 @@ def _combine_biases(*masks: List[Array]): return mask +def _apply_low_rank_adaptation(x, axis, features, lora_a_kernel, lora_b_kernel, alpha): + """Low Rank Adaptation Implementation""" + + assert len(axis) <= 5 + hidden_in_names = 'ijklm'[:len(axis)] + assert len(features) <= 5 + hidden_out_names = 'nopqr'[:len(features)] + rank_name = 's' + + assert lora_a_kernel.shape[-1] == lora_b_kernel.shape[-2] + rank = lora_a_kernel.shape[-1] + scaling = alpha / rank if alpha is not None else 1.0 + + x_einsum_express = f"...{hidden_in_names}" + lora_a_einsum_express = f"{hidden_in_names}{hidden_out_names[:-1]}{rank_name}" + lora_b_einsum_express = f"{hidden_out_names[:-1]}{rank_name}{hidden_out_names[-1]}" + output_einsum_express = f"...{hidden_out_names}" + final_einsum_express = f"{x_einsum_express},{lora_a_einsum_express},{lora_b_einsum_express}" \ + f"->{output_einsum_express}" + + output = jnp.einsum(final_einsum_express, x, lora_a_kernel, lora_b_kernel) + output = output * scaling + return output + + class Softmax(nn.Module): # pylint: disable=too-few-public-methods r""" Applies softmax over a mini-batch of inputs. @@ -355,6 +380,14 @@ class DenseGeneral(TransformerEngineBase): bias_axes: Tuple[str, ...], default = () The name of axes used to shard bias with a corresponding mesh, only used when :attr:`use_bias=True`. + enable_low_rank_adaptation: bool, default = False + Indicate whether to enable low rank adaptation for each linear layer. + low_rank_adaptation_dim: int, default = 32 + The dimension for low rank adaptation, only used when + :attr:`enable_low_rank_adaptation=True` + low_rank_adaptation_alpha: float, default = None + The alpha for computing the scaling factor of LoRA output. + :math:`\frac{alpha}{rank} * lora_output`. None means no scaling. axis: Union[Iterable[int], int], default = -1 An integer tuple with axes to apply the transformation on. @@ -374,6 +407,9 @@ class DenseGeneral(TransformerEngineBase): use_bias: bool = True bias_init: Initializer = nn.initializers.zeros bias_axes: Tuple[str, ...] = () + enable_low_rank_adaptation: bool = False + low_rank_adaptation_dim: int = 32 + low_rank_adaptation_alpha: float = None axis: Union[Iterable[int], int] = -1 dtype: DType = jnp.float32 transpose_batch_sequence: bool = False @@ -439,6 +475,32 @@ def __call__(self, inputs: Array) -> Array: fp8_meta_pkg=fp8_gemm_pkg, contracting_dims=(axis, contract_ind)) + if self.enable_low_rank_adaptation: + lora_a_kernel_shape = (*kernel_shape[:len(axis)], *features[:-1], + self.low_rank_adaptation_dim) + lora_a_kernel_init_shape = (kernel_param_shape[0], *features[:-1], + self.low_rank_adaptation_dim) + lora_a_kernel_axes = (None,) * len(lora_a_kernel_init_shape) + lora_a_kernel = nn_partitioning.param_with_axes('lora_a_kernel', + self.kernel_init, + lora_a_kernel_init_shape, + jnp.float32, + axes=lora_a_kernel_axes) + lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape) + lora_a_kernel = lora_a_kernel.astype(self.dtype) + + lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1]) + lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape) + lora_b_kernel = nn_partitioning.param_with_axes('lora_b_kernel', + nn.initializers.zeros, + lora_b_kernel_shape, + jnp.float32, + axes=lora_b_kernel_axes) + lora_b_kernel = lora_b_kernel.astype(self.dtype) + + y += _apply_low_rank_adaptation(inputs, axis, features, lora_a_kernel, lora_b_kernel, + self.low_rank_adaptation_alpha) + if bias is not None: bias_shape = (1,) * (y.ndim - bias.ndim) + bias.shape y += jnp.reshape(bias, bias_shape) @@ -502,6 +564,14 @@ class LayerNormDenseGeneral(TransformerEngineBase): return_layernorm_output: bool, default = True Indicate whether to return the output of layer normalization. If set False, return None as the second tensor in outputs. + enable_low_rank_adaptation: bool, default = False + Indicate whether to enable low rank adaptation for each linear layer. + low_rank_adaptation_dim: int, default = 32 + The dimension for low rank adaptation, only used when + :attr:`enable_low_rank_adaptation=True` + low_rank_adaptation_alpha: float, default = None + The alpha for computing the scaling factor of LoRA output. + :math:`\frac{alpha}{rank} * lora_output`. None means no scaling. axis: Union[Iterable[int], int], default = -1 An integer tuple with axes to apply the transformation on. layernorm_input_axes: Tuple[str, ...], default = None @@ -541,6 +611,9 @@ class LayerNormDenseGeneral(TransformerEngineBase): bias_init: Initializer = nn.initializers.zeros bias_axes: Tuple[str, ...] = () return_layernorm_output: bool = True + enable_low_rank_adaptation: bool = False + low_rank_adaptation_dim: int = 32 + low_rank_adaptation_alpha: float = None axis: Union[Iterable[int], int] = -1 dtype: DType = jnp.float32 transpose_batch_sequence: bool = True @@ -650,6 +723,32 @@ def __call__(self, inputs: Array) -> Array: fp8_meta_pkg=fp8_meta_package, contracting_dims=(axis, contract_ind)) + if self.enable_low_rank_adaptation: + lora_a_kernel_shape = (*kernel_shape[:len(axis)], *features[:-1], + self.low_rank_adaptation_dim) + lora_a_kernel_init_shape = (kernel_param_shape[0], *features[:-1], + self.low_rank_adaptation_dim) + lora_a_kernel_axes = (None,) * len(lora_a_kernel_init_shape) + lora_a_kernel = nn_partitioning.param_with_axes('lora_a_kernel', + self.kernel_init, + lora_a_kernel_init_shape, + jnp.float32, + axes=lora_a_kernel_axes) + lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape) + lora_a_kernel = lora_a_kernel.astype(self.dtype) + + lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1]) + lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape) + lora_b_kernel = nn_partitioning.param_with_axes('lora_b_kernel', + nn.initializers.zeros, + lora_b_kernel_shape, + jnp.float32, + axes=lora_b_kernel_axes) + lora_b_kernel = lora_b_kernel.astype(self.dtype) + + z += _apply_low_rank_adaptation(y, axis, features, lora_a_kernel, lora_b_kernel, + self.low_rank_adaptation_alpha) + bias = None if self.use_bias: bias = nn_partitioning.param_with_axes('bias', @@ -745,6 +844,14 @@ class LayerNormMLP(TransformerEngineBase): Dropout probability for the dropout op after the :attr:`activations`. intermediate_hidden_dropout_dims: Sequence[int], default = () Dimensions that will share the same dropout mask for hidden + enable_low_rank_adaptation: bool, default = False + Indicate whether to enable low rank adaptation for each linear layer. + low_rank_adaptation_dim: int, default = 32 + The dimension for low rank adaptation, only used when + :attr:`enable_low_rank_adaptation=True`. + low_rank_adaptation_alpha: float, default = None + The alpha for computing the scaling factor of LoRA output. + :math:`\frac{alpha}{rank} * lora_output`. None means no scaling. axis: Union[Iterable[int], int], default = -1 An integer tuple with axes to apply the transformation on. layernorm_input_axes: Tuple[str, ...], default = None @@ -791,6 +898,9 @@ class LayerNormMLP(TransformerEngineBase): intermediate_dropout_rng_name: str = 'dropout' intermediate_dropout_rate: float = 0.1 intermediate_hidden_dropout_dims: Sequence[int] = () + enable_low_rank_adaptation: bool = False + low_rank_adaptation_dim: int = 32 + low_rank_adaptation_alpha: float = None axis: Union[Iterable[int], int] = -1 dtype: DType = jnp.float32 transpose_batch_sequence: bool = True @@ -856,11 +966,13 @@ def is_gelu(acts): use_fused_ln_geglu_mlp = fuse_layernorm \ and (not self.use_bias) and is_geglu(self.activations) \ - and (self.intermediate_dropout_rate < 1e-3) + and (self.intermediate_dropout_rate < 1e-3) \ + and not self.enable_low_rank_adaptation use_fused_ln_gelu_mlp = fuse_layernorm \ and self.use_bias and is_gelu(self.activations) \ - and (self.intermediate_dropout_rate < 1e-3) + and (self.intermediate_dropout_rate < 1e-3) \ + and not self.enable_low_rank_adaptation # LayerNorm if self.enable_layernorm: @@ -999,6 +1111,37 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): fp8_meta_pkg=gemm1_fp8_meta_package, contracting_dims=(axis, contract_ind)) + if self.enable_low_rank_adaptation: + wi_lora_a_kernel_shape = (*kernel_1_shape[:len(axis)], num_activations, + self.low_rank_adaptation_dim) + wi_lora_a_kernel_init_shape = (kernel_1_each_shape[0], num_activations, + self.low_rank_adaptation_dim) + wi_lora_a_kernel_init_each_shape = (kernel_1_each_shape[0], + self.low_rank_adaptation_dim) + wi_lora_a_kernel_axes = (None,) * len(wi_lora_a_kernel_init_shape) + wi_lora_a_kernel = nn_partitioning.param_with_axes('wi_lora_a_kernel', + kernel_1_init, + num_activations, + -2, + wi_lora_a_kernel_init_each_shape, + jnp.float32, + axes=wi_lora_a_kernel_axes) + wi_lora_a_kernel = jnp.reshape(wi_lora_a_kernel, wi_lora_a_kernel_shape) + wi_lora_a_kernel = wi_lora_a_kernel.astype(self.dtype) + + wi_lora_b_kernel_shape = (num_activations, self.low_rank_adaptation_dim, + self.intermediate_dim) + wi_lora_b_kernel_axes = (None,) * len(wi_lora_b_kernel_shape) + wi_lora_b_kernel = nn_partitioning.param_with_axes('wi_lora_b_kernel', + nn.initializers.zeros, + wi_lora_b_kernel_shape, + jnp.float32, + axes=wi_lora_b_kernel_axes) + wi_lora_b_kernel = wi_lora_b_kernel.astype(self.dtype) + + x += _apply_low_rank_adaptation(y, axis, intermediate_dim, wi_lora_a_kernel, + wi_lora_b_kernel, self.low_rank_adaptation_alpha) + bias = None if self.use_bias: bias = nn_partitioning.param_with_axes('wi_bias', @@ -1042,6 +1185,28 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): fp8_meta_pkg=gemm2_fp8_meta_package, contracting_dims=(axis, contract_ind)) + if self.enable_low_rank_adaptation: + wo_lora_a_kernel_shape = (self.intermediate_dim, self.low_rank_adaptation_dim) + wo_lora_a_kernel_axes = (None,) * len(wo_lora_a_kernel_shape) + wo_lora_a_kernel = nn_partitioning.param_with_axes('wo_lora_a_kernel', + self.kernel_init, + wo_lora_a_kernel_shape, + jnp.float32, + axes=wo_lora_a_kernel_axes) + wo_lora_a_kernel = wo_lora_a_kernel.astype(self.dtype) + + wo_lora_b_kernel_shape = (self.low_rank_adaptation_dim, hidden_size) + wo_lora_b_kernel_axes = (None,) * len(wo_lora_b_kernel_shape) + wo_lora_b_kernel = nn_partitioning.param_with_axes('wo_lora_b_kernel', + nn.initializers.zeros, + wo_lora_b_kernel_shape, + jnp.float32, + axes=wo_lora_b_kernel_axes) + wo_lora_b_kernel = wo_lora_b_kernel.astype(self.dtype) + + out += _apply_low_rank_adaptation(z, axis, hidden_size_tuple, wo_lora_a_kernel, + wo_lora_b_kernel, self.low_rank_adaptation_alpha) + bias = None if self.use_bias: bias = nn_partitioning.param_with_axes('wo_bias', diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index fcf06aa128..cacb360a27 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -637,6 +637,53 @@ def canonicalize_group_method(gm): return consecutive_impl() +class LoRAScope: # pylint: disable=too-few-public-methods + """LoRA Scope""" + + def __init__(self, qkv_proj=False, output_proj=False, mlp=False): + self.qkv_proj = qkv_proj + self.output_proj = output_proj + self.mlp = mlp + + def __eq__(self, other): + return (self.qkv_proj, self.output_proj, self.mlp) == \ + (other.qkv_proj, other.output_proj, other.mlp) + + +def _canonicalize_lora_scope(scope): + + SCOPE_NONE = 'none' + SCOPE_ALL = 'all' + SCOPE_QKV_PROJ = 'qkv_proj' + SCOPE_OUTPUT_PROJ = 'output_proj' + SCOPE_MLP = 'mlp' + SCOPE_EX_QKV_PROJ = 'exclude_qkv_proj' + SCOPE_EX_OUTPUT_PROJ = 'exclude_output_proj' + SCOPE_EX_MLP = 'exclude_mlp' + + scope = SCOPE_NONE if scope is None else scope + + scope = scope.lower() + + assert scope in [ + SCOPE_NONE, SCOPE_ALL, SCOPE_QKV_PROJ, SCOPE_OUTPUT_PROJ, SCOPE_MLP, SCOPE_EX_QKV_PROJ, + SCOPE_EX_OUTPUT_PROJ, SCOPE_EX_MLP + ] + + lora_scope = LoRAScope() + + if scope in [SCOPE_ALL, SCOPE_QKV_PROJ, SCOPE_EX_OUTPUT_PROJ, SCOPE_EX_MLP]: + lora_scope.qkv_proj = True + + if scope in [SCOPE_ALL, SCOPE_OUTPUT_PROJ, SCOPE_EX_QKV_PROJ, SCOPE_EX_MLP]: + lora_scope.output_proj = True + + if scope in [SCOPE_ALL, SCOPE_MLP, SCOPE_EX_QKV_PROJ, SCOPE_EX_OUTPUT_PROJ]: + lora_scope.mlp = True + + return lora_scope + + class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods r""" Multi-head Attention (MHA), including Query, @@ -723,6 +770,15 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods Indicate the method to coupled the coordinates. It should be one of ['consecutive', 'alternate']. 'alternate' is to pair index :math:`i` with :math:`i + d/2` , d is the hidden dimension. 'consecutive' pairs index :math:`i` with :math:`i + 1`. + low_rank_adaptation_scope: str, default = 'none' + Indicate the scope to apply low rank adaptation. It should be one of + ['none', 'all', 'qkv_proj', 'output_proj', 'exclude_qkv_proj', 'exclude_output_proj'] + low_rank_adaptation_dim: int, default = 32 + The dimension for low rank adaptation, only used when + :attr:`enable_low_rank_adaptation=True` + low_rank_adaptation_alpha: float, default = None + The alpha for computing the scaling factor of LoRA output. + :math:`\frac{alpha}{rank} * lora_output`. None means no scaling. enable_sequence_parallel: bool, default = False Whether to enable sequence parallelism to operations except dot. num_heads: int, default = None @@ -777,6 +833,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods enable_rotary_pos_emb: bool = False rotary_pos_emb_windows: Tuple[int, int] = (1, 10000) rotary_pos_emb_group_method: str = 'consecutive' + low_rank_adaptation_scope: str = 'none' + low_rank_adaptation_dim: int = 32 + low_rank_adaptation_alpha: float = None dtype: DType = jnp.float32 fuse_qkv_params: bool = True transpose_batch_sequence: bool = True @@ -914,6 +973,8 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): inputs_q = with_sharding_constraint_by_logical_axes(inputs_q, inputs_logical_axes_maybe_sp) + lora_scope = _canonicalize_lora_scope(self.low_rank_adaptation_scope) + if self.fuse_qkv_params: if is_qkvpack: qkv_proj, ln_out = LayerNormDenseGeneral( @@ -932,6 +993,9 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): use_bias=self.use_bias, bias_init=self.bias_init, bias_axes=(W_JOINED_AXES, W_TP_AXES), + enable_low_rank_adaptation=lora_scope.qkv_proj, + low_rank_adaptation_dim=self.low_rank_adaptation_dim, + low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, layernorm_input_axes=inputs_logical_axes_maybe_sp, dot_input_axes=inputs_logical_axes_no_sp, name='qkv', @@ -954,6 +1018,9 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): use_bias=self.use_bias, bias_init=self.bias_init, bias_axes=(W_TP_AXES,), + enable_low_rank_adaptation=lora_scope.qkv_proj, + low_rank_adaptation_dim=self.low_rank_adaptation_dim, + low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, dtype=self.dtype, kernel_init=query_init, layernorm_input_axes=inputs_logical_axes_maybe_sp, @@ -972,6 +1039,9 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): use_bias=self.use_bias, bias_init=self.bias_init, bias_axes=(W_JOINED_AXES, W_TP_AXES), + enable_low_rank_adaptation=lora_scope.qkv_proj, + low_rank_adaptation_dim=self.low_rank_adaptation_dim, + low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, name='kv', dtype=self.dtype)(inputs_kv) kv_proj = checkpoint_name(kv_proj, 'combined_kv_proj') @@ -986,6 +1056,9 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): use_bias=self.use_bias, bias_init=self.bias_init, bias_axes=(W_TP_AXES,), + enable_low_rank_adaptation=lora_scope.qkv_proj, + low_rank_adaptation_dim=self.low_rank_adaptation_dim, + low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, dtype=self.dtype) query, ln_out = LayerNormDenseGeneral( enable_layernorm=self.input_layernorm, @@ -1002,6 +1075,9 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): use_bias=self.use_bias, bias_init=self.bias_init, bias_axes=(W_TP_AXES,), + enable_low_rank_adaptation=lora_scope.qkv_proj, + low_rank_adaptation_dim=self.low_rank_adaptation_dim, + low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, dtype=self.dtype, kernel_init=query_init, layernorm_input_axes=inputs_logical_axes_maybe_sp, @@ -1142,6 +1218,9 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): use_bias=self.use_bias, bias_init=self.bias_init, bias_axes=(W_NO_SHARD_AXES,), + enable_low_rank_adaptation=lora_scope.output_proj, + low_rank_adaptation_dim=self.low_rank_adaptation_dim, + low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, dtype=self.dtype, name='out')(x) out = checkpoint_name(out, 'out_proj') @@ -1379,6 +1458,16 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods Indicate the method to coupled the coordinates. It should be one of ['consecutive', 'alternate']. 'alternate' is to pair index :math:`i` with :math:`i + d/2` , d is the hidden dimension. 'consecutive' pairs index :math:`i` with :math:`i + 1`. + low_rank_adaptation_scope: str, default = 'none' + Indicate the scope to apply low rank adaptation. It should be one of + ['none', 'all', 'qkv_proj', 'output_proj', 'mlp', 'exclude_qkv_proj', + 'exclude_output_proj', 'exclude_mlp'] + low_rank_adaptation_dim: int, default = 32 + The dimension for low rank adaptation, only used when + :attr:`enable_low_rank_adaptation=True` + low_rank_adaptation_alpha: float, default = None + The alpha for computing the scaling factor of LoRA output. + :math:`\frac{alpha}{rank} * lora_output`. None means no scaling. enable_sequence_parallel: bool, default = False Whether to enable sequence parallelism to operations except dot. @@ -1434,6 +1523,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods enable_rotary_pos_emb: bool = False rotary_pos_emb_windows: Tuple[int, int] = (1, 10000) rotary_pos_emb_group_method: str = 'consecutive' + low_rank_adaptation_scope: str = 'none' + low_rank_adaptation_dim: int = 32 + low_rank_adaptation_alpha: float = None dtype: DType = jnp.float32 drop_path: float = 0.0 fuse_qkv_params: bool = True @@ -1579,6 +1671,9 @@ def generate_batch_seqlen_logical_axes(is_shared_seq=None): enable_rotary_pos_emb=self.enable_rotary_pos_emb, rotary_pos_emb_windows=self.rotary_pos_emb_windows, rotary_pos_emb_group_method=self.rotary_pos_emb_group_method, + low_rank_adaptation_scope=self.low_rank_adaptation_scope, + low_rank_adaptation_dim=self.low_rank_adaptation_dim, + low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, fuse_qkv_params=self.fuse_qkv_params, kernel_init=self.mha_kernel_init, use_bias=self.use_bias, @@ -1646,6 +1741,9 @@ def hidden_dropout(x, deterministic): enable_rotary_pos_emb=self.enable_rotary_pos_emb, rotary_pos_emb_windows=self.rotary_pos_emb_windows, rotary_pos_emb_group_method=self.rotary_pos_emb_group_method, + low_rank_adaptation_scope=self.low_rank_adaptation_scope, + low_rank_adaptation_dim=self.low_rank_adaptation_dim, + low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, float32_logits=self.float32_attention_logits, scale_attn_logits=self.scale_attn_logits, scaled_query_init=self.scaled_query_init, @@ -1674,6 +1772,8 @@ def hidden_dropout(x, deterministic): mlp_input = with_sharding_constraint_by_logical_axes( mlp_input, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)) + lora_scope = _canonicalize_lora_scope(self.low_rank_adaptation_scope) + # MlpBlock residual = mlp_input z, ln_out = LayerNormMLP( @@ -1697,6 +1797,9 @@ def hidden_dropout(x, deterministic): bias_init=self.bias_init, bias_axes_1=(W_JOINED_AXES, W_TP_AXES), bias_axes_2=(W_NO_SHARD_AXES,), + enable_low_rank_adaptation=lora_scope.mlp, + low_rank_adaptation_dim=self.low_rank_adaptation_dim, + low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, layernorm_input_axes=(*generate_batch_seqlen_logical_axes(), HIDDEN_AXES), dot_1_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES), dot_2_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_TP_AXES), diff --git a/transformer_engine/jax/praxis/module.py b/transformer_engine/jax/praxis/module.py index 3688b62370..e6372b91dc 100644 --- a/transformer_engine/jax/praxis/module.py +++ b/transformer_engine/jax/praxis/module.py @@ -131,6 +131,9 @@ class Linear(TransformerEngineBaseLayer): use_bias: bool = True bias_init: WeightInit = WeightInit.Constant(0.0) bias_axes: Tuple[str, ...] = () + enable_low_rank_adaptation: bool = False + low_rank_adaptation_dim: int = 32 + low_rank_adaptation_alpha: float = None axis: Union[Iterable[int], int] = -1 transpose_batch_sequence: bool = False sharding_type: ShardingType = ShardingType.SINGLE @@ -147,6 +150,9 @@ def setup(self) -> None: use_bias=self.use_bias, bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init), bias_axes=self.bias_axes, + enable_low_rank_adaptation=self.enable_low_rank_adaptation, + low_rank_adaptation_dim=self.low_rank_adaptation_dim, + low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, axis=self.axis, dtype=self.dtype, transpose_batch_sequence=self.transpose_batch_sequence) @@ -174,6 +180,9 @@ class LayerNormLinear(TransformerEngineBaseLayer): use_bias: bool = False bias_init: WeightInit = WeightInit.Constant(0.0) bias_axes: Tuple[str, ...] = () + enable_low_rank_adaptation: bool = False + low_rank_adaptation_dim: int = 32 + low_rank_adaptation_alpha: float = None return_layernorm_output: bool = True axis: Union[Iterable[int], int] = -1 transpose_batch_sequence: bool = False @@ -201,6 +210,9 @@ def setup(self) -> None: use_bias=self.use_bias, bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init), bias_axes=self.bias_axes, + enable_low_rank_adaptation=self.enable_low_rank_adaptation, + low_rank_adaptation_dim=self.low_rank_adaptation_dim, + low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, return_layernorm_output=self.return_layernorm_output, axis=self.axis, dtype=self.dtype, @@ -232,6 +244,9 @@ class LayerNormMLP(TransformerEngineBaseLayer): bias_init: WeightInit = WeightInit.Constant(0.0) bias_axes_1: Tuple[str, ...] = () bias_axes_2: Tuple[str, ...] = () + enable_low_rank_adaptation: bool = False + low_rank_adaptation_dim: int = 32 + low_rank_adaptation_alpha: float = None return_layernorm_output: bool = True activations: Sequence[Union[str, Callable]] = ('relu',) intermediate_dropout_rate: float = 0.1 @@ -263,6 +278,9 @@ def setup(self) -> None: bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init), bias_axes_1=self.bias_axes_1, bias_axes_2=self.bias_axes_2, + enable_low_rank_adaptation=self.enable_low_rank_adaptation, + low_rank_adaptation_dim=self.low_rank_adaptation_dim, + low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, return_layernorm_output=self.return_layernorm_output, activations=self.activations, intermediate_dropout_rate=self.intermediate_dropout_rate, diff --git a/transformer_engine/jax/praxis/transformer.py b/transformer_engine/jax/praxis/transformer.py index d0a37e89b8..b68909190b 100644 --- a/transformer_engine/jax/praxis/transformer.py +++ b/transformer_engine/jax/praxis/transformer.py @@ -137,6 +137,9 @@ class MultiHeadAttention(TransformerEngineBaseLayer): enable_rotary_pos_emb: bool = False rotary_pos_emb_windows: Tuple[int, int] = (1, 10000) rotary_pos_emb_group_method: str = 'consecutive' + low_rank_adaptation_scope: str = 'none' + low_rank_adaptation_dim: int = 32 + low_rank_adaptation_alpha: float = None fuse_qkv_params: bool = True transpose_batch_sequence: bool = True enable_sequence_parallel: bool = False @@ -208,6 +211,9 @@ def setup(self) -> None: enable_rotary_pos_emb=self.enable_rotary_pos_emb, rotary_pos_emb_windows=self.rotary_pos_emb_windows, rotary_pos_emb_group_method=self.rotary_pos_emb_group_method, + low_rank_adaptation_scope=self.low_rank_adaptation_scope, + low_rank_adaptation_dim=self.low_rank_adaptation_dim, + low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, fuse_qkv_params=self.fuse_qkv_params, transpose_batch_sequence=self.transpose_batch_sequence, enable_sequence_parallel=self.enable_sequence_parallel, @@ -262,6 +268,9 @@ class TransformerLayer(TransformerEngineBaseLayer): enable_rotary_pos_emb: bool = False rotary_pos_emb_windows: Tuple[int, int] = (1, 10000) rotary_pos_emb_group_method: str = 'consecutive' + low_rank_adaptation_scope: str = 'none' + low_rank_adaptation_dim: int = 32 + low_rank_adaptation_alpha: float = None enable_relative_embedding: bool = True relative_embedding: pax_fiddle.Config[RelativePositionBiases] = pax_fiddle.template_field(None) drop_path: float = 0.0 @@ -332,6 +341,9 @@ def setup(self) -> None: enable_rotary_pos_emb=self.enable_rotary_pos_emb, rotary_pos_emb_windows=self.rotary_pos_emb_windows, rotary_pos_emb_group_method=self.rotary_pos_emb_group_method, + low_rank_adaptation_scope=self.low_rank_adaptation_scope, + low_rank_adaptation_dim=self.low_rank_adaptation_dim, + low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, enable_relative_embedding=self.enable_relative_embedding, relative_embedding=relative_embedding_flax_module, drop_path=self.drop_path, From 7e9dbcaabc4d474efb54012314d436a46c3a252d Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Wed, 17 Apr 2024 09:02:41 -0700 Subject: [PATCH 024/244] [PyTorch] Misc fixes for release_v1.6 (#784) * fixes; docs Signed-off-by: Kirthi Shankar Sivamani * Check for FP8 Signed-off-by: Kirthi Shankar Sivamani * Fix LoRa-like use cases Signed-off-by: Kirthi Shankar Sivamani * Reviews Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani Signed-off-by: Pawel Gadzinski --- .../pytorch/module/layernorm_linear.py | 9 +++++++-- transformer_engine/pytorch/module/layernorm_mlp.py | 8 ++++++-- transformer_engine/pytorch/module/linear.py | 13 ++++++++----- transformer_engine/pytorch/utils.py | 8 ++++++++ 4 files changed, 29 insertions(+), 9 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index bc4c29d308..7d7bb0bbd5 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -28,6 +28,7 @@ cast_if_needed, assert_dim_for_fp8_exec, clear_tensor_data, + requires_grad, ) from ..distributed import ( set_tensor_model_parallel_attributes, @@ -355,7 +356,11 @@ def forward( ctx.requires_dgrad = inp.requires_grad ctx.normalization = normalization ctx.primary_weights_in_fp8 = primary_weights_in_fp8 - ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() + ctx.reduce_and_update_bwd_fp8_tensors = False + if ctx.fp8 and requires_grad(inp, ln_weight, ln_bias, weight, bias): + ctx.reduce_and_update_bwd_fp8_tensors = ( + ctx.reduce_and_update_bwd_fp8_tensors or + FP8GlobalStateManager.is_first_fp8_module()) # Row Parallel Linear if parallel_mode == "row" and sequence_parallel: @@ -699,7 +704,7 @@ def backward( else: wgrad = None - if ctx.is_first_module and not is_graph_capturing(): + if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing(): FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) return ( diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 6efb72b8db..9b80ea3a21 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -33,6 +33,7 @@ cast_if_needed, assert_dim_for_fp8_exec, clear_tensor_data, + requires_grad, ) from ..distributed import ( set_tensor_model_parallel_attributes, @@ -544,7 +545,10 @@ def forward( ctx.requires_dgrad = inp.requires_grad ctx.normalization = normalization ctx.primary_weights_in_fp8 = primary_weights_in_fp8 - ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() + ctx.reduce_and_update_bwd_fp8_tensors = False + if ctx.fp8 and requires_grad( + inp, ln_weight, ln_bias, fc1_weight, fc2_weight, fc1_bias, fc2_bias): + ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() # Row Parallel Linear if ub_overlap_rs: @@ -1121,7 +1125,7 @@ def backward( else: fc2_wgrad = None - if ctx.is_first_module and not is_graph_capturing(): + if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing(): FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) return ( diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 8adaab557f..cb2f6871b3 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -26,6 +26,7 @@ assert_dim_for_fp8_exec, clear_tensor_data, init_method_constant, + requires_grad, ) from ..distributed import ( set_tensor_model_parallel_attributes, @@ -363,7 +364,11 @@ def forward( ctx.requires_dgrad = inp.requires_grad ctx.is_input_fp8 = is_input_fp8 ctx.primary_weights_in_fp8 = primary_weights_in_fp8 - ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() + ctx.reduce_and_update_bwd_fp8_tensors = False + if ctx.fp8 and requires_grad(inp, weight, bias): + ctx.reduce_and_update_bwd_fp8_tensors = ( + ctx.reduce_and_update_bwd_fp8_tensors or + FP8GlobalStateManager.is_first_fp8_module()) # Row Parallel Linear if ub_overlap_rs: @@ -381,7 +386,7 @@ def forward( def backward( ctx, grad_output: torch.Tensor ) -> Tuple[Union[torch.Tensor, None], ...]: - if isinstance(grad_output[0], Float8Tensor): + if isinstance(grad_output, Float8Tensor): ctx.fp8_meta["scaling_bwd"].scale_inv[ tex.FP8BwdTensors.GRAD_OUTPUT1] = grad_output._scale_inv @@ -611,7 +616,7 @@ def backward( else: wgrad = None - if ctx.is_first_module and not is_graph_capturing(): + if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing(): FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) return ( @@ -954,8 +959,6 @@ def forward( * it also allows skipping gradient accumulation during the first microbatch (since it is the first gradient being produced) - is_first_module_in_mha: Optional[bool], default = False - Whether to output in FP8. By default, Linear outputs in inp.dtype. """ skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index df750ab1ae..f60f8c29c7 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -8,6 +8,14 @@ import torch +def requires_grad(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None: + """Check if any of the given tensors require gradient.""" + for tensor in tensors: + if tensor is not None and tensor.requires_grad: + return True + return False + + def clear_tensor_data(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None: """ Trick to deallocate tensor memory when delete operation does not From 4a8a80760aae649884b6d1f24ad543e07d8b025a Mon Sep 17 00:00:00 2001 From: "Pavel Shamis (Pasha)" Date: Wed, 17 Apr 2024 12:02:17 -0500 Subject: [PATCH 025/244] [UB] Adding configurable timeout for userbuffer and improving error reporting for potential hangs (#757) * Improving error reporting and hang detection logic * Adding verbose error reporting in case of UB hang * Adding CE hang detector * Replacing hard-coded timeout with configurable one Signed-off-by: Pasha (Pavel) Shamis * Cleaning up warnings in the code Signed-off-by: Pasha (Pavel) Shamis * Removing unused codes Signed-off-by: Pasha (Pavel) Shamis * Fixing styling issues reported on github Signed-off-by: Pasha (Pavel) Shamis * Addressing lint new line and casting warnings Signed-off-by: Pasha (Pavel) Shamis * Addressing lint warning about the usage of `unsigned long long` Signed-off-by: Pasha (Pavel) Shamis * Removing unused case causing build issues on multi-arch setup Signed-off-by: Pasha (Pavel) Shamis * Post GRDCOPY removal cleanup * Remove cmake check * Remove unused includes Signed-off-by: Pasha (Pavel) Shamis --------- Signed-off-by: Pasha (Pavel) Shamis Co-authored-by: Kirthi Shankar Sivamani Signed-off-by: Pawel Gadzinski --- .../pytorch/csrc/userbuffers/CMakeLists.txt | 8 +- .../csrc/userbuffers/userbuffers-host.cpp | 227 +-- .../pytorch/csrc/userbuffers/userbuffers.cu | 1732 +++-------------- .../pytorch/csrc/userbuffers/userbuffers.h | 38 +- 4 files changed, 334 insertions(+), 1671 deletions(-) diff --git a/transformer_engine/pytorch/csrc/userbuffers/CMakeLists.txt b/transformer_engine/pytorch/csrc/userbuffers/CMakeLists.txt index 7e89ac135f..5106c25598 100644 --- a/transformer_engine/pytorch/csrc/userbuffers/CMakeLists.txt +++ b/transformer_engine/pytorch/csrc/userbuffers/CMakeLists.txt @@ -11,17 +11,11 @@ target_include_directories(transformer_engine_userbuffers PUBLIC # Configure dependencies find_package(MPI REQUIRED) -find_library(GDRCOPY_LIBRARY gdrapi - HINTS "${GDRCOPY_LIBRARY_DIR}" "$ENV{GDRCOPY_LIBRARY_DIR}") -if(NOT GDRCOPY_LIBRARY) - message(FATAL_ERROR "Could not find GDRCopy, please set GDRCOPY_LIBRARY_DIR") -endif() -message(STATUS "Found GDRCopy: ${GDRCOPY_LIBRARY}") target_link_libraries(transformer_engine_userbuffers PUBLIC CUDA::cudart CUDA::cuda_driver MPI::MPI_CXX - ${GDRCOPY_LIBRARY}) + ) target_include_directories(transformer_engine_userbuffers PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) diff --git a/transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp b/transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp index c62b6ef7f3..c59f84b35f 100644 --- a/transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp +++ b/transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp @@ -11,7 +11,6 @@ #include #include #include -#include #include #include #include @@ -19,7 +18,6 @@ #include #include #include -#include #define MULTICAST_GB_TOTAL 512 static int oob_bcast(void *comm_context, void *buf, int size, int root) { @@ -123,11 +121,20 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode (*comm)->basecounter[i] = 0; (*comm)->head = 0; (*comm)->tail = 0; - (*comm)->activeproxy = 1; (*comm)->active_nreqs = 0; for (int i = 0; i < userbuffers_op_types; i++) (*comm)->active_req[i].active = -1; + int device_clock = 0; + // 110 sec wait time by default + int sec_timeout = getenv("UB_TIMEOUT") ? atoi(getenv("UB_TIMEOUT")) : 110; + CUDACHECK(cudaDeviceGetAttribute(&device_clock, cudaDevAttrClockRate, cur_dev)); + (*comm)->ub_timeout = 1000ull * device_clock * sec_timeout; + if ((*comm)->myrank == 0) { + printf("UB_TIMEOUT is set to %d sec, %" PRIu64 " cycles, freq: %dkhz\n", + sec_timeout, (*comm)->ub_timeout, device_clock); + } + int ret = 0; // split communicator char host_name[MPI_MAX_PROCESSOR_NAME]; @@ -232,59 +239,12 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode (*comm)->num2_nodes = tensornodes; (*comm)->my2_node = (mynode / datanodes) % tensornodes; (*comm)->first2_node = mynode - (*comm)->my2_node * datanodes; - - char *ib_dev_list; - int ZIONROCE = getenv("NVTE_ZIONROCE") ? atoi(getenv("NVTE_ZIONROCE")) : 0; - int ROCE = getenv("NVTE_ROCE") ? atoi(getenv("NVTE_ROCE")) : 0; - if (ZIONROCE) - ROCE = 1; - int DGX_H100 = device_prop.major == 9; - - switch (mylocal) { - case 0: - ib_dev_list = "mlx5_0:1"; - break; // NOLINT(*) - case 1: - ib_dev_list = (char *)(DGX_H100 ? "mlx5_3:1" : "mlx5_1:1"); // NOLINT(*) - break; // NOLINT(*) - case 2: - ib_dev_list = (char *)(ZIONROCE ? "mlx5_4:1" : DGX_H100 ? "mlx5_4:1" : "mlx5_2:1"); // NOLINT(*) - break; // NOLINT(*) - case 3: - ib_dev_list = (char *)(DGX_H100 ? "mlx5_5:1" : "mlx5_3:1"); // NOLINT(*) - break; // NOLINT(*) - case 4: - ib_dev_list = (char *)(DGX_H100 ? "mlx5_6:1" : "mlx5_6:1"); // NOLINT(*) - break; // NOLINT(*) - case 5: - ib_dev_list = (char *)(DGX_H100 ? "mlx5_9:1" : "mlx5_7:1"); // NOLINT(*) - break; // NOLINT(*) - case 6: - ib_dev_list = (char *)(ZIONROCE ? "mlx5_10:1" : DGX_H100 ? "mlx5_10:1" : "mlx5_8:1"); // NOLINT(*) - break; // NOLINT(*) - case 7: - ib_dev_list = (char *)(DGX_H100 ? "mlx5_11:1" : "mlx5_9:1"); // NOLINT(*) - break; // NOLINT(*) - default: - break; - } - (*comm)->fifo = reinterpret_cast(malloc(sizeof(ub_request) * NVTE_MAX_REQUESTS)); (*comm)->nblocks = 8; (*comm)->alignblock = 1024 * 512; (*comm)->minblock = 1024 * 2 * 1024; (*comm)->asyncblocks = 16; - CUDACHECK(cudaMallocHost((void **)&(*comm)->hostflags, // NOLINT(*) - (NVTE_MAX_SMS + 100) * sizeof(int))); - for (int i = 0; i < 100 + NVTE_MAX_SMS; i++) - (*comm)->hostflags[i] = 0; - _mm_mfence(); - sleep(1); - - // init_p2p_transport(); - (*comm)->ibnvsize = (*comm)->nvsize; - #define NBUF 2 if ((*comm)->sm_arch >= 9 && (*comm)->ar2_nvsize > 1 && !getenv("UB_SKIPMC")) { // multicast init only for TP ops (____2 operations) @@ -374,6 +334,7 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode #define GPU_PAGE_SIZE (1UL << GPU_PAGE_SHIFT) #define GPU_PAGE_OFFSET (GPU_PAGE_SIZE - 1) #define GPU_PAGE_MASK (~GPU_PAGE_OFFSET) + CUDACHECK(cudaMalloc(&(*comm)->flags, 2 * GPU_PAGE_SIZE)); unsigned int flag = 1; CUDACHECK(cudaMemset((*comm)->flags, 0, 2 * GPU_PAGE_SIZE)); @@ -381,23 +342,6 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode reinterpret_cast(((CUdeviceptr)(*comm)->flags + GPU_PAGE_SIZE - 1) & GPU_PAGE_MASK); using namespace std; - (*comm)->g = gdr_open(); - if ((*comm)->g == NULL) { - fprintf(stderr, "gdrcopy open failed\n"); - return -1; - } - gdr_mh_t mh; - ret = gdr_pin_buffer((*comm)->g, (CUdeviceptr)(*comm)->flags, GPU_PAGE_SIZE, 0, 0, &mh); - if (ret) { - fprintf(stderr, "gdr_pin_buffer failed\n"); - return -1; - } - ret = gdr_map((*comm)->g, mh, (void **)&((*comm)->map_flags), GPU_PAGE_SIZE); // NOLINT(*) - - if (ret) { - fprintf(stderr, "gdr_map failed\n"); - return -1; - } sched_param param; pthread_attr_t attr; pthread_attr_init(&attr); @@ -426,10 +370,6 @@ int create_communicator(communicator **comm) { } void destroy_communicator(communicator *comm) { - comm->activeproxy = 0; - if (!comm->myrank && getenv("NVTE_UBDEBUG")) - printf("waiting for userbuffers proxy thread to exit()\n"); - gdr_close(comm->g); } int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *comm, bool alloc) { @@ -533,7 +473,7 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * CUCHECK(cuMulticastBindMem(comm->mc_handle, comm->mc_offset, comm->uchandles[hndl][myrank], 0 /*memOffset*/, aligned_size, 0)); comm->memflags[hndl] |= UB_MEM_MC_CREATED; - comm->mc_ptr[hndl] = comm->mc_baseptr + comm->mc_offset; + comm->mc_ptr[hndl] = reinterpret_cast(comm->mc_baseptr) + comm->mc_offset; comm->mc_offset += aligned_size; } else if (!comm->myrank) { printf("UB: warning region %d size %ld MB registered without MC access\n", hndl, @@ -570,146 +510,3 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * return comm->free_region++; } - -int allreduce_userbuff_inplace_gpu(const int handler, const int offset, const int elements, - const int blocksize, communicator *comm, cudaStream_t stream); - -int allreduce2_userbuff_inplace_gpu(const int maxcredit, const int handler, const int offset, - const int elements, const int blocksize, communicator *comm, - cudaStream_t stream, int op); - -int reducescatter2_userbuff_inplace_gpu(const int maxcredit, const int handler, const int offset, - const int elements, const int blocksize, communicator *comm, - cudaStream_t stream, int op); - -int allgather2_userbuff_inplace_gpu(const int maxcredit, const int handler, const int offset, - const int elements, const int blocksize, communicator *comm, - cudaStream_t stream, int op); - -void allreduce_nonsharp_inplace(const int handler, const int offset, const int elements, - communicator *comm, cudaStream_t stream, int op) { - if (elements < 64) - NVTE_UB_ERROR("Userbuffer comm for given config not implemented."); - // if(comm->myrank==0) fprintf(stderr,"AR2(%d) user call - // launch_mode=%d\n",op,comm->launch_mode); - const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; - int blocksize = elements * 2; - int maxcredit = 0; - const int num_nodes = op == userbuffers_allreduceop_nonsharp ? comm->num_nodes : comm->num2_nodes; - blocksize = (comm->nblocks - 1 + (comm->alignblock - 1 + elements * 2) / comm->alignblock) / - comm->nblocks; // FIXME TUNING - blocksize *= comm->alignblock; - if (blocksize < comm->minblock) - blocksize = comm->minblock; - - maxcredit = (elements * 2 + blocksize - 1) / blocksize; - size_t peerblock = sizeof(int) * NVTE_REG0_COMMBUFFER / maxcredit; // max size we can fit - if (blocksize > peerblock * ar_nvsize) - blocksize = peerblock * ar_nvsize; - int sms = allreduce2_userbuff_inplace_gpu(maxcredit, handler, offset, elements, blocksize, comm, - stream, op); - - if (num_nodes > 1 && comm->launch_mode & NVTE_LAUNCH_CPU) { - if (!sms) - return; - comm->fifo[comm->head].optype = op; - comm->fifo[comm->head].basecounter = comm->basecounter[op]; - comm->fifo[comm->head].blocksize = blocksize; - comm->fifo[comm->head].maxcredit = maxcredit; - comm->fifo[comm->head].handler = handler; - comm->fifo[comm->head].offset = offset; - comm->fifo[comm->head].elements = elements; - - int newhead = (comm->head + 1) & (NVTE_MAX_REQUESTS - 1); - while (newhead == comm->tail) { - } - comm->head = newhead; - - comm->basecounter[op] += (elements * 2 + blocksize - 1) / blocksize; - } -} - -void allreduce2_userbuff_inplace(const int handler, const int offset, const int elements, - communicator *comm, cudaStream_t stream) { - allreduce_nonsharp_inplace(handler, offset, elements, comm, stream, - userbuffers_allreduceop_nonsharp2); -} - -void allreduce_userbuff_inplace(const int handler, const int offset, const int elements, - communicator *comm, cudaStream_t stream) { - if (elements < 64) - NVTE_UB_ERROR("Userbuffer comm for given config not implemented."); - allreduce_nonsharp_inplace(handler, offset, elements, comm, stream, - userbuffers_allreduceop_nonsharp); - return; -} - -void reducescatter_userbuff_inplace(const int handler, const int offset, const int elements, - communicator *comm, cudaStream_t stream) { - if (elements < 64) - NVTE_UB_ERROR("Userbuffer comm for given config not implemented."); - - int op = userbuffers_allreduceop_nonsharp; - const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; - int blocksize = elements * 2; - int maxcredit = 0; - - const int num_nodes = op == userbuffers_allreduceop_nonsharp ? comm->num_nodes : comm->num2_nodes; - blocksize = (comm->nblocks - 1 + (comm->alignblock - 1 + elements * 2) / comm->alignblock) / - comm->nblocks; // FIXME TUNING - blocksize *= comm->alignblock; - if (blocksize < comm->minblock) - blocksize = comm->minblock; - - maxcredit = (elements * 2 + blocksize - 1) / blocksize; - size_t peerblock = sizeof(int) * NVTE_REG0_COMMBUFFER / maxcredit; // max size we can fit - if (blocksize > peerblock * ar_nvsize) - blocksize = peerblock * ar_nvsize; - - int sms = reducescatter2_userbuff_inplace_gpu(maxcredit, handler, offset, elements, blocksize, - comm, stream, op); - - if (num_nodes > 1 && comm->launch_mode & NVTE_LAUNCH_CPU) { - if (!sms) - return; - comm->fifo[comm->head].optype = op; - comm->fifo[comm->head].basecounter = comm->basecounter[op]; - comm->fifo[comm->head].blocksize = blocksize; - comm->fifo[comm->head].maxcredit = maxcredit; - comm->fifo[comm->head].handler = handler; - comm->fifo[comm->head].offset = offset; - comm->fifo[comm->head].elements = elements; - - int newhead = (comm->head + 1) & (NVTE_MAX_REQUESTS - 1); - while (newhead == comm->tail) { - } - comm->head = newhead; - - comm->basecounter[op] += (elements * 2 + blocksize - 1) / blocksize; - } -} - -void allgather_userbuff_inplace(const int handler, const int offset, const int elements, - communicator *comm, cudaStream_t stream) { - if (elements < 64) - NVTE_UB_ERROR("Userbuffer comm for given config not implemented."); - int op = userbuffers_allreduceop_nonsharp; - const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; - int blocksize = elements * 2; - int maxcredit = 0; - - const int num_nodes = op == userbuffers_allreduceop_nonsharp ? comm->num_nodes : comm->num2_nodes; - blocksize = (comm->nblocks - 1 + (comm->alignblock - 1 + elements * 2) / comm->alignblock) / - comm->nblocks; // FIXME TUNING - blocksize *= comm->alignblock; - if (blocksize < comm->minblock) - blocksize = comm->minblock; - - maxcredit = (elements * 2 + blocksize - 1) / blocksize; - size_t peerblock = sizeof(int) * NVTE_REG0_COMMBUFFER / maxcredit; // max size we can fit - if (blocksize > peerblock * ar_nvsize) - blocksize = peerblock * ar_nvsize; - - int sms = allgather2_userbuff_inplace_gpu(maxcredit, handler, offset, elements, blocksize, comm, - stream, op); -} diff --git a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu b/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu index 0cf1a091b9..d14cb8a538 100644 --- a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu +++ b/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu @@ -4,12 +4,8 @@ * See LICENSE for license information. ************************************************************************/ -#include -#include - #include #include -#include #if __CUDA_ARCH__ >= 800 #include @@ -20,8 +16,12 @@ #include "userbuffers.h" +#include +#include +#include +#include + #define MAX_THREADS 1024 -#define TIMEOUT 200000000000ull #define CUDACHECK(cmd) \ do { \ @@ -35,8 +35,7 @@ #define ATOMIC_CONSUMER(chunk) \ if (counters) { \ if (threadIdx.x == 0 && blockIdx.x == 0) { \ - int old_val; \ - while (0 != (old_val = atomicCAS(((unsigned int *)counters) + chunk, 0, 0))) { \ + while (0 != (atomicCAS(((unsigned int *)counters) + chunk, 0, 0))) { \ } \ ((unsigned int *)counters)[chunk] = 1; \ asm volatile("fence.sc.gpu;\n"); \ @@ -54,11 +53,32 @@ // If we expect that producer will be 2B+ messages behind consumer #define CHECK_IDS(producer, consumer) (((unsigned)(producer) - (unsigned)(consumer)) & (~INT_MAX)) +// Strip the path from a full filename +#define FILENAME(file) ({ \ + const char* filename = file; \ + const char* basename = filename; \ + for (const char* ptr = filename; *ptr != '\0'; ptr++) { \ + if (*ptr == '/' || *ptr == '\\') { \ + basename = ptr + 1; \ + } \ + } \ + basename; \ +}) + +// Printf to provide enough information so it is easier to attribute failures +#define UB_PRINT(message, ...) printf("[%s:%s:%d] " message "\n", FILENAME(__FILE__), \ + __FUNCTION__, \ + __LINE__, __VA_ARGS__) + +// Report and error on timeout +#define CHECK_TIMEOUT(t, timeout) ((clock64() - (t)) > timeout) + template __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_rw(const int op, const int flagoffset, const int firstrank, const int myrank, const int gpustep, const int lineoffset, - const int numlines, void **commbuff, const int handleridx) { + const int numlines, void **commbuff, const int handleridx, + const uint64_t ub_timeout) { __shared__ int4 *userptr[RANKS]; int *flagptr, physgpu, targetgpu, *myptr; int *reduceidptr, reduce_id; @@ -78,9 +98,9 @@ __global__ void __launch_bounds__(MAX_THREADS) userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu + handleridx]); clock_t s = clock64(); while (CHECK_IDS(*flag, reduce_id)) { - if (clock64() - s > TIMEOUT) { - printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, - *flag); + if (CHECK_TIMEOUT(s, ub_timeout)) { + UB_PRINT("[%d] Allreduce reduce-scatter: SM %d [%d]: expecting %d got %d", myrank, + blockIdx.x, threadIdx.x, reduce_id, *flag); break; } } @@ -132,9 +152,9 @@ __global__ void __launch_bounds__(MAX_THREADS) volatile int *flag = (volatile int *)&myptr[targetgpu]; clock_t s = clock64(); while (CHECK_IDS(*flag, reduce_id)) { - if (clock64() - s > 2ull * TIMEOUT) { - printf("NVONLY AGBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, - *flag); + if (CHECK_TIMEOUT(s, ub_timeout)) { + UB_PRINT("[%d] Allreduce Gather: SM %d [%d]:expecting %d got %d", myrank, blockIdx.x, + threadIdx.x, reduce_id, *flag); break; } } @@ -147,7 +167,8 @@ template __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_rr(const int op, const int flagoffset, const int firstrank, const int myrank, const int gpustep, const int lineoffset, - const int numlines, void **commbuff, const int handleridx) { + const int numlines, void **commbuff, const int handleridx, + const uint64_t ub_timeout) { __shared__ int4 *userptr[RANKS]; int *flagptr, physgpu, targetgpu, *myptr; int *reduceidptr, reduce_id; @@ -166,9 +187,9 @@ __global__ void __launch_bounds__(MAX_THREADS) userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu + handleridx]); clock_t s = clock64(); while (CHECK_IDS(*flag, reduce_id)) { - if (clock64() - s > TIMEOUT) { - printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, - *flag); + if (CHECK_TIMEOUT(s, ub_timeout)) { + UB_PRINT("[%d ]Allreduce reduce-scatter:SM %d [%d]:expecting %d got %d", myrank, blockIdx.x, + threadIdx.x, reduce_id, *flag); break; } } @@ -215,9 +236,9 @@ __global__ void __launch_bounds__(MAX_THREADS) volatile int *flag = (volatile int *)&myptr[targetgpu]; clock_t s = clock64(); while (CHECK_IDS(*flag, reduce_id)) { - if (clock64() - s > 2ull * TIMEOUT) { - printf("NVONLY AGBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, - *flag); + if (CHECK_TIMEOUT(s, ub_timeout)) { + UB_PRINT("[%d] Allreduce gather: SM %d [%d]:expecting %d got %d", myrank, blockIdx.x, + threadIdx.x, reduce_id, *flag); break; } } @@ -258,7 +279,8 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_rr_rs(const int op, const int flagoffset, const int firstrank, const int myrank, const int gpustep, const int mylineoffset, const int totallines, - void **commbuff, const int handleridx) { + void **commbuff, const int handleridx, + const uint64_t ub_timeout) { __shared__ int4 *userptr[RANKS]; volatile int *flagptr; int physgpu, targetgpu, *myptr; @@ -277,9 +299,9 @@ __global__ void __launch_bounds__(MAX_THREADS) userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu + handleridx]); clock_t s = clock64(); while (CHECK_IDS(*flag, reduce_id)) { - if (clock64() - s > TIMEOUT) { - printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, - *flag); + if (CHECK_TIMEOUT(s, ub_timeout)) { + UB_PRINT("[%d] Reduce-scatter: SM %d [%d]:expecting %d got %d", myrank, blockIdx.x, + threadIdx.x, reduce_id, *flag); break; } } @@ -333,7 +355,8 @@ __global__ void __launch_bounds__(MAX_THREADS) const int gpustep, const int mylineoffset, const int totallines, const int rowlines, const int skiplines, void **commbuff, - const int handleridx, void *outbuf) { + const int handleridx, void *outbuf, + const uint64_t ub_timeout) { __shared__ int4 *userptr[RANKS]; volatile int *flagptr; int physgpu, targetgpu, *myptr; @@ -352,9 +375,9 @@ __global__ void __launch_bounds__(MAX_THREADS) userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu + handleridx]); clock_t s = clock64(); while (CHECK_IDS(*flag, reduce_id)) { - if (clock64() - s > TIMEOUT) { - printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, - *flag); + if (CHECK_TIMEOUT(s, ub_timeout)) { + UB_PRINT("[%d] Reduce-scatter: SM %d [%d]:expecting %d got %d", myrank, blockIdx.x, + threadIdx.x, reduce_id, *flag); break; } } @@ -427,8 +450,8 @@ __global__ void __launch_bounds__(MAX_THREADS) clock_t s = clock64(); while (CHECK_IDS(*flag, reduce_id)) { if (clock64() - s > TIMEOUT) { - printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, - *flag); + UB_PRINT("Reduce-scatter: SM %d [%d]:expecting %d got %d", blockIdx.x, threadIdx.x, + reduce_id, *flag); break; } } @@ -495,7 +518,7 @@ __global__ void __launch_bounds__(MAX_THREADS) clock_t s = clock64(); while (CHECK_IDS(*flag, reduce_id)) { if (clock64() - s > 2ull * TIMEOUT) { - printf("NVONLY AGBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, + UB_PRINT("Allgather: SM %d [%d]:expecting %d got %d", blockIdx.x, threadIdx.x, reduce_id, *flag); break; } @@ -510,7 +533,8 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_mc_rs(const int op, const int flagoffset, const int firstrank, const int myrank, const int gpustep, const int mylineoffset, const int totallines, - void **commbuff, const int handleridx, float4 *mc_ptr) { + void **commbuff, const int handleridx, float4 *mc_ptr, + const uint64_t ub_timeout) { volatile int *flagptr; int physgpu, targetgpu, *myptr; int *reduceidptr, reduce_id; @@ -529,10 +553,10 @@ __global__ void __launch_bounds__(MAX_THREADS) volatile int *flag = (volatile int *)&(myptr[targetgpu]); clock_t s = clock64(); while (CHECK_IDS(*flag, reduce_id)) { - if (clock64() - s > TIMEOUT) { - printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, - *flag); - break; + if (CHECK_TIMEOUT(s, ub_timeout)) { + UB_PRINT("[%d] Reduce-scatter: SM %d [%d]:expecting %d got %d", myrank, blockIdx.x, + threadIdx.x, reduce_id, *flag); + break; } } } @@ -596,7 +620,8 @@ __global__ void __launch_bounds__(MAX_THREADS) const int gpustep, const int mylineoffset, const int totallines, const int rowlines, const int skiplines, void **commbuff, - const int handleridx, void *outbuf, float4 *mc_ptr) { + const int handleridx, void *outbuf, float4 *mc_ptr, + const uint64_t ub_timeout) { volatile int *flagptr; int physgpu, targetgpu, *myptr; int *reduceidptr, reduce_id; @@ -614,9 +639,9 @@ __global__ void __launch_bounds__(MAX_THREADS) volatile int *flag = (volatile int *)&(myptr[targetgpu]); clock_t s = clock64(); while (CHECK_IDS(*flag, reduce_id)) { - if (clock64() - s > TIMEOUT) { - printf("[%d] NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", myrank, blockIdx.x, - threadIdx.x, reduce_id, *flag); + if (CHECK_TIMEOUT(s, ub_timeout)) { + UB_PRINT("[%d] Reduce-scatter: SM %d [%d]:expecting %d got %d", myrank, blockIdx.x, + threadIdx.x, reduce_id, *flag); break; } } @@ -680,7 +705,8 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_mc_ag(const int op, const int flagoffset, const int firstrank, const int myrank, const int gpustep, const int mylineoffset, const int totallines, - void **commbuff, const int handleridx, uint4 *mc_ptr) { + void **commbuff, const int handleridx, uint4 *mc_ptr, + const uint64_t ub_timeout) { volatile int *flagptr; int physgpu, targetgpu, *myptr; int *reduceidptr, reduce_id; @@ -744,10 +770,10 @@ __global__ void __launch_bounds__(MAX_THREADS) volatile int *flag = (volatile int *)&myptr[targetgpu]; clock_t s = clock64(); while (CHECK_IDS(*flag, reduce_id)) { - if (clock64() - s > 2ull * TIMEOUT) { - printf("NVONLY AGBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, - *flag); - break; + if (CHECK_TIMEOUT(s, ub_timeout)) { + UB_PRINT("[%d] Allgather: SM %d [%d]:expecting %d got %d", myrank, blockIdx.x, threadIdx.x, + reduce_id, *flag); + break; } } } @@ -764,26 +790,32 @@ template __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_mc_rs_oop( const int op, const int flagoffset, const int firstrank, const int myrank, const int gpustep, const int mylineoffset, const int totallines, const int rowlines, const int skiplines, - void **commbuff, const int handleridx, void *outbuf, float4 *mc_ptr) {} + void **commbuff, const int handleridx, void *outbuf, float4 *mc_ptr, + const uint64_t ub_timeout) {} + template __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_mc_ag(const int op, const int flagoffset, const int firstrank, const int myrank, const int gpustep, const int mylineoffset, const int totallines, - void **commbuff, const int handleridx, uint4 *mc_ptr) {} + void **commbuff, const int handleridx, uint4 *mc_ptr, + const uint64_t ub_timeout) {} + template __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_mc_rs(const int op, const int flagoffset, const int firstrank, const int myrank, const int gpustep, const int mylineoffset, const int totallines, - void **commbuff, const int handleridx, float4 *mc_ptr) {} + void **commbuff, const int handleridx, float4 *mc_ptr, + const uint64_t ub_timeout) {} #endif template __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_fp8( const int op, const int flagoffset, const int firstrank, const int myrank, const int gpustep, const int mylineoffset, const int totallines, const int rowlines, const int skiplines, - void **commbuff, const int handleridx, void *outbuf, float *scale) { + void **commbuff, const int handleridx, void *outbuf, float *scale, + const uint64_t ub_timeout) { __shared__ int4 *userptr[RANKS]; volatile int *flagptr; int physgpu, targetgpu, *myptr; @@ -804,8 +836,8 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu + handleridx]); clock_t s = clock64(); while (CHECK_IDS(*flag, reduce_id)) { - if (clock64() - s > TIMEOUT) { - printf("[%d] NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", myrank, blockIdx.x, + if (CHECK_TIMEOUT(s, ub_timeout)) { + UB_PRINT("[%d] Reduce-scatter: SM %d [%d]:expecting %d got %d", myrank, blockIdx.x, threadIdx.x, reduce_id, *flag); break; } @@ -862,7 +894,8 @@ __global__ void __launch_bounds__(MAX_THREADS) const int op, const int flagoffset, const int firstrank, const int myrank, const int gpustep, const int mylineoffset, const int totallines, const int rowlines, const int skiplines_out, const int skiplines_in, void **commbuff, const int handleridx, - void *outbuf, float *scale, void *counters, const int numchunks, const int atomicindex) { + void *outbuf, float *scale, void *counters, const int numchunks, const int atomicindex, + const uint64_t ub_timeout) { __shared__ int4 *userptr[RANKS]; volatile int *flagptr; int physgpu, targetgpu, *myptr; @@ -892,8 +925,8 @@ __global__ void __launch_bounds__(MAX_THREADS) userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu + handleridx]); clock_t s = clock64(); while (CHECK_IDS(*flag, reduce_id)) { - if (clock64() - s > TIMEOUT) { - printf("[%d] NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", myrank, blockIdx.x, + if (CHECK_TIMEOUT(s, ub_timeout)) { + UB_PRINT("[%d] Reduce-scatter: SM %d [%d]:expecting %d got %d", myrank, blockIdx.x, threadIdx.x, reduce_id, *flag); break; } @@ -959,7 +992,8 @@ __global__ void __launch_bounds__(MAX_THREADS) const int gpustep, const int mylineoffset, const int totallines, const int rowlines, const int skiplines, void **commbuff, - const int handleridx, void *outbuf) { + const int handleridx, void *outbuf, + const uint64_t ub_timeout) { __shared__ int4 *userptr[RANKS]; volatile int *flagptr; int physgpu, targetgpu, *myptr; @@ -979,8 +1013,8 @@ __global__ void __launch_bounds__(MAX_THREADS) userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu + handleridx]); clock_t s = clock64(); while (CHECK_IDS(*flag, reduce_id)) { - if (clock64() - s > TIMEOUT) { - printf("[%d] NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", myrank, blockIdx.x, + if (CHECK_TIMEOUT(s, ub_timeout)) { + UB_PRINT("[%d] Reduce-scatter: SM %d [%d]:expecting %d got %d", myrank, blockIdx.x, threadIdx.x, reduce_id, *flag); break; } @@ -1030,123 +1064,22 @@ __global__ void __launch_bounds__(MAX_THREADS) *reduceidptr = reduce_id; } // fp16 reduce-scatter kernel (out of place) fp16 -#if 0 -template -__global__ void -__launch_bounds__(MAX_THREADS) -userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride_atomic_fp8( - const int op, const int flagoffset, const int firstrank, const int myrank, const int gpustep, - const int mylineoffset, const int totallines, const int rowlines, const int skiplines, - const int numchunks, void **commbuff, const int handleridx, void* outbuf, void *counters, - float* scale) { - if (counters) { - if ( threadIdx.x == 0 ) { - // spin-lock on counter from producer - int old_val; - while (0 != (old_val = atomicCAS(((unsigned int*)counters), 0, 0) )) {} - - // make sure all threadblocks have read/waited on counters. - int old_val2; - atomicInc(((unsigned int *)counters)+numchunks, gridDim.x-1); - while (0 != (old_val2 = atomicCAS(((unsigned int*)counters)+numchunks, 0, 0) )) {} - - // reset counter for next producer. - ((unsigned int*)counters)[0] = 1; - asm volatile ("fence.sc.gpu;\n"); - } - } - __syncthreads(); - - __shared__ int4* userptr[RANKS]; - volatile int *flagptr; - int physgpu, targetgpu, *myptr; - int *reduceidptr, reduce_id; - int lastSM = 0; - half hscale = (half) *scale; - - if (threadIdx.x < RANKS) { - physgpu = myrank*gpustep+firstrank; - targetgpu = threadIdx.x*gpustep+firstrank; - myptr = (reinterpret_cast(commbuff[physgpu])) + flagoffset; - reduceidptr = myptr-NVTE_MAX_OPS; // +op; - reduce_id =(*reduceidptr)+1; - flagptr = (reinterpret_cast(commbuff[targetgpu])) + flagoffset; - if (blockIdx.x == 0) flagptr[physgpu] = reduce_id; - volatile int* flag = (volatile int*)&(myptr[targetgpu]); - userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu+handleridx]); - clock_t s = clock64(); - while (CHECK_IDS(*flag, reduce_id)) { - if (clock64()-s > TIMEOUT) { - printf("[%d] NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", - myrank, blockIdx.x, threadIdx.x, reduce_id, *flag); - break; - } - } - } - __syncthreads(); - if (threadIdx.x == 0) { - const int adder = blockIdx.x == 0 ? NVTE_MAX_SMS-gridDim.x+1 : 1; - int old_val = atomicAdd(myptr+(NVTE_MAX_NVLINK*2), adder); - if (old_val+adder == NVTE_MAX_SMS*reduce_id) lastSM = 1; - } - - - int warp = blockIdx.x+(threadIdx.x>>5); - int dest[RANKS]; -#pragma unroll - for (int i = 0; i < RANKS; i++) - dest[i] = (i+myrank+warp)&(RANKS-1); - - for (int line = threadIdx.x+blockDim.x*blockIdx.x; - line < totallines; line+=blockDim.x*gridDim.x) { - int4 val[RANKS]; - int index_in = mylineoffset + myrank*(totallines*skiplines/rowlines/2) + - (line/rowlines)*skiplines/2+(line%rowlines); - -#pragma unroll - for (int i = 0; i < RANKS; i++) { - val[i] = userptr[dest[i]][index_in]; - } - - int4 sum[2] = {{0, 0, 0, 0}, {0, 0, 0, 0}}; - half *s = reinterpret_cast(&sum); - -#pragma unroll - for (int i = 0; i < RANKS; i++) { - fp8type *x = reinterpret_cast(&val[i]); -#pragma unroll - for (int j=0; j < sizeof(int4)/sizeof(fp8type); j++) s[j] += hscale * (half)(x[j]); - } - int hline = 2*line; - int index_out1 = (hline/rowlines)*skiplines+(hline%rowlines); - (reinterpret_cast(outbuf))[index_out1] = sum[0]; - hline++; - int index_out2 = (hline/rowlines)*skiplines+(hline%rowlines); - (reinterpret_cast(outbuf))[index_out2] = sum[1]; - } - - if (threadIdx.x == 0 && lastSM) *reduceidptr = reduce_id; -} // fp16 reduce-scatter kernel (out of place) fp16 -#endif - template __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride_atomic( const int op, const int flagoffset, const int firstrank, const int myrank, const int gpustep, const int mylineoffset, const int totallines, const int rowlines, const int skiplines, const int numchunks, void **commbuff, const int handleridx, - void *outbuf, void *counters) { + void *outbuf, void *counters, const uint64_t ub_timeout) { if (counters) { if (threadIdx.x == 0) { // spin-lock on counter from producer - int old_val; - while (0 != (old_val = atomicCAS(((unsigned int *)counters), 0, 0))) { + while (0 != (atomicCAS(((unsigned int *)counters), 0, 0))) { } // make sure all threadblocks have read/waited on counters. - int old_val2; atomicInc(((unsigned int *)counters) + numchunks, gridDim.x - 1); - while (0 != (old_val2 = atomicCAS(((unsigned int *)counters) + numchunks, 0, 0))) { + while (0 != (atomicCAS(((unsigned int *)counters) + numchunks, 0, 0))) { } // reset counter for next producer. @@ -1175,8 +1108,8 @@ __global__ void __launch_bounds__(MAX_THREADS) userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu + handleridx]); clock_t s = clock64(); while (CHECK_IDS(*flag, reduce_id)) { - if (clock64() - s > TIMEOUT) { - printf("[%d] NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", myrank, blockIdx.x, + if (CHECK_TIMEOUT(s, ub_timeout)) { + UB_PRINT("[%d] Reduce-scatter: SM %d [%d]:expecting %d got %d", myrank, blockIdx.x, threadIdx.x, reduce_id, *flag); break; } @@ -1232,20 +1165,17 @@ __global__ void __launch_bounds__(MAX_THREADS) const int op, const int flagoffset, const int firstrank, const int myrank, const int gpustep, const int mylineoffset, const int totallines, const int rowlines, const int skiplines, const int numchunks, void **commbuff, const int handleridx, - void *outbuf, void *counters) { + void *outbuf, void *counters, const uint64_t ub_timeout) { for (int chunk_i = 0; chunk_i < numchunks; chunk_i++) { if (counters) { if (threadIdx.x == 0) { // spin-lock on counter from producer - int old_val; - while (0 != (old_val = atomicCAS(((unsigned int *)counters) + chunk_i, 0, 0))) { + while (0 != (atomicCAS(((unsigned int *)counters) + chunk_i, 0, 0))) { } // make sure all threadblocks have read/waited on counters. - int old_val2; atomicInc(((unsigned int *)counters) + numchunks + chunk_i, gridDim.x - 1); - while (0 != - (old_val2 = atomicCAS(((unsigned int *)counters) + numchunks + chunk_i, 0, 0))) { + while (0 != (atomicCAS(((unsigned int *)counters) + numchunks + chunk_i, 0, 0))) { } // reset counter for next producer. @@ -1274,8 +1204,8 @@ __global__ void __launch_bounds__(MAX_THREADS) userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu + handleridx]); clock_t s = clock64(); while (CHECK_IDS(*flag, reduce_id)) { - if (clock64() - s > TIMEOUT) { - printf("[%d] NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", myrank, blockIdx.x, + if (CHECK_TIMEOUT(s, ub_timeout)) { + UB_PRINT("[%d] Reduce-scatter: SM %d [%d]:expecting %d got %d", myrank, blockIdx.x, threadIdx.x, reduce_id, *flag); break; } @@ -1330,7 +1260,8 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_rr_ag(const int op, const int flagoffset, const int firstrank, const int myrank, const int gpustep, const int mylineoffset, const int totallines, - void **commbuff, const int handleridx) { + void **commbuff, const int handleridx, + const uint64_t ub_timeout) { __shared__ int4 *userptr[RANKS]; volatile int *flagptr; int physgpu, targetgpu, *myptr; @@ -1342,7 +1273,6 @@ __global__ void __launch_bounds__(MAX_THREADS) reduceidptr = myptr - NVTE_MAX_OPS; // +op; reduce_id = (*reduceidptr) + 1; flagptr = (reinterpret_cast(commbuff[targetgpu])) + flagoffset; - volatile int *flag = (volatile int *)&(myptr[targetgpu]); userptr[threadIdx.x] = reinterpret_cast(commbuff[targetgpu + handleridx]); clock_t s = clock64(); } @@ -1393,9 +1323,9 @@ __global__ void __launch_bounds__(MAX_THREADS) volatile int *flag = (volatile int *)&myptr[targetgpu]; clock_t s = clock64(); while (CHECK_IDS(*flag, reduce_id)) { - if (clock64() - s > 2ull * TIMEOUT) { - printf("NVONLY AGBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, - *flag); + if (CHECK_TIMEOUT(s, ub_timeout)) { + UB_PRINT("[%d] Allgather: SM %d [%d]:expecting %d got %d", myrank, blockIdx.x, threadIdx.x, + reduce_id, *flag); break; } } @@ -1407,7 +1337,8 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_rw_ag(const int op, const int flagoffset, const int firstrank, const int myrank, const int gpustep, const int mylineoffset, const int totallines, - void **commbuff, const int handleridx) { + void **commbuff, const int handleridx, + const uint64_t ub_timeout) { __shared__ int4 *userptr[RANKS]; volatile int *flagptr; int physgpu, targetgpu, *myptr; @@ -1490,784 +1421,15 @@ __global__ void __launch_bounds__(MAX_THREADS) volatile int *flag = (volatile int *)&myptr[targetgpu]; clock_t s = clock64(); while (CHECK_IDS(*flag, reduce_id)) { - if (clock64() - s > 2ull * TIMEOUT) { - printf("NVONLY AGBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id, - *flag); + if (CHECK_TIMEOUT(s, ub_timeout)) { + UB_PRINT("[%d] Allgather: SM %d [%d]:expecting %d got %d", myrank, blockIdx.x, threadIdx.x, + reduce_id, *flag); break; } } } } // fp16 inplace allgather kernel (Volta,Hopper) -template -__global__ void __launch_bounds__(MAX_THREADS) - userbuffers_fp16_sum_inplace_gpu_rr_blocked(const int op, const int flagoffset, - const int firstrank, const int myrank, - const int lineoffset, const int numlines, - void **commbuff, const int handleridx, - const int peerblocklines, int *hostflags, - int *gpuflag, const int numblocks) { - const int basecounter = gpuflag[NVTE_GF_STATE + op]; - -#define REDUCETHREADS (blockDim.x - 32) - - if (threadIdx.x < 32) { - int *flagptr; - if (threadIdx.x < RANKS) { - if (!blockIdx.x) { - flagptr = reinterpret_cast(commbuff[threadIdx.x + firstrank]); - flagptr[flagoffset + myrank + firstrank] = basecounter; - } - volatile int *flag = (volatile int *)&((reinterpret_cast( - commbuff[myrank + firstrank]))[flagoffset + threadIdx.x + firstrank]); - while (CHECK_IDS(*flag, basecounter)) { - } - } - __syncthreads(); - - int startblock = 0, endblock = numblocks; - - for (int nblock = 0; nblock < endblock; nblock++) { - asm volatile("bar.sync 13, %0;" ::"r"(REDUCETHREADS + 32)); - - if (threadIdx.x == 0) { - __threadfence(); - if (blockIdx.x) - gpuflag[op * NVTE_MAX_SMS * 2 + blockIdx.x] = nblock + basecounter + 1; - } else if (blockIdx.x == 0) { - int expecting = (basecounter + nblock + 1); - if (threadIdx.x < gridDim.x) - while (((volatile int *)gpuflag)[op * NVTE_MAX_SMS * 2 + threadIdx.x] < expecting) { - } - } - if (!blockIdx.x) { - asm volatile("bar.sync 15, %0;" ::"r"(32)); - if (!threadIdx.x) - hostflags[0] = nblock + basecounter + 1; - } - } - - int cachedflag = basecounter; - -#define ALLGATHERFLAG NVTE_GF_IBSHARPDONE - - if (blockIdx.x == 0 && threadIdx.x < RANKS) { - while (cachedflag < basecounter + numblocks) { - int newflag = ((volatile int *)gpuflag)[ALLGATHERFLAG]; - if (newflag == cachedflag) - continue; - cachedflag = newflag; - flagptr[flagoffset + myrank + 32 + firstrank] = cachedflag; - } - } - - if (blockIdx.x == 0 && threadIdx.x == 0) - gpuflag[NVTE_GF_STATE + op] = basecounter + numblocks; - } else { - const int warp = blockIdx.x + (threadIdx.x >> 5); - int4 *userptr[RANKS]; - int4 *userptrmyrank; -#pragma unroll - for (int i = 0; i < RANKS; i++) - userptr[i] = reinterpret_cast( - commbuff[((i + myrank + warp) & (RANKS - 1)) + handleridx + firstrank]); - userptrmyrank = reinterpret_cast(commbuff[myrank + handleridx + firstrank]); - __syncthreads(); - - int blocklineoffset = 0; - - while (blocklineoffset < numlines) { - const int remainder = min(numlines - blocklineoffset, peerblocklines * RANKS); - const int blocklines = remainder / RANKS; - const int blockstart = lineoffset + blocklineoffset + blocklines * myrank; - - for (int line = threadIdx.x - 32 + REDUCETHREADS * blockIdx.x; line < blocklines; - line += REDUCETHREADS * gridDim.x) { - int4 val[RANKS]; - -#pragma unroll - for (int i = 0; i < RANKS; i++) { - val[i] = userptr[i][blockstart + line]; - } - - int4 sum = val[0]; - half *s = reinterpret_cast(&sum); - -#pragma unroll - for (int i = 1; i < RANKS; i++) { - half *x = reinterpret_cast(&val[i]); -#pragma unroll - for (int j = 0; j < sizeof(int4) / sizeof(half); j++) - s[j] += x[j]; - } - - userptrmyrank[blockstart + line] = sum; - } // single block loop - - asm volatile("bar.sync 13, %0;" ::"r"(REDUCETHREADS + 32)); - - blocklineoffset += peerblocklines * RANKS; - } // block loop NVLINK-REDUCESCATTER - const int nwarps = (REDUCETHREADS >> 5) / (RANKS - 1); - const int myblockDim = nwarps << 5; - const int mywarp = ((threadIdx.x - 32) >> 5) / (RANKS - 1); - const int maxthreadIdx = myblockDim * (RANKS - 1) + 32; - const int mydest = (myrank + 1 + ((threadIdx.x - 32) >> 5) % (RANKS - 1)) & (RANKS - 1); - const int mythreadIdx = (mywarp << 5) + (threadIdx.x & 31); - volatile int *flag = (volatile int *)&((reinterpret_cast( - commbuff[myrank + firstrank]))[flagoffset + mydest + 32 + firstrank]); - - int4 *userptrmydest = userptr[((RANKS << 10) + mydest - myrank - warp) & (RANKS - 1)]; - - blocklineoffset = 0; - int gathercounter = basecounter + 1; - while (blocklineoffset < numlines) { - const int remainder = min(numlines - blocklineoffset, peerblocklines * RANKS); - const int blocklines = remainder / RANKS; - const int blockstart = lineoffset + blocklineoffset; - -#define UNROLL 6 - int4 *myptr = &userptrmyrank[blockstart + blocklines * mydest]; - int4 *peerptr = &userptrmydest[blockstart + blocklines * mydest]; - - if (threadIdx.x < maxthreadIdx) { - const int start_elem = mythreadIdx + myblockDim * blockIdx.x; - const int end_elem = max(start_elem, blocklines); - const int aligned_elem = ((end_elem - start_elem) / (myblockDim * gridDim.x * UNROLL)) * - (myblockDim * gridDim.x * UNROLL); - const int end_aligned = start_elem + aligned_elem; - - if (mythreadIdx == 0) { - while (CHECK_IDS(*flag, gathercounter)) { - } - gathercounter++; - } - - asm volatile("bar.sync %0, %1;" ::"r"(1 + mydest), "r"(myblockDim)); - - for (int line = start_elem; line < end_aligned; line += myblockDim * gridDim.x * UNROLL) { - int4 val[UNROLL]; -#pragma unroll - for (int i = 0; i < UNROLL; i++) - val[i] = peerptr[line + i * myblockDim * gridDim.x]; -#pragma unroll - for (int i = 0; i < UNROLL; i++) - myptr[line + i * myblockDim * gridDim.x] = val[i]; - } - for (int line = end_aligned; line < end_elem; line += myblockDim * gridDim.x) - myptr[line] = peerptr[line]; - } - blocklineoffset += peerblocklines * RANKS; - } // block loop for NVLINK-ALLGATHER - } // worker warps else block -} // fp16 inplace reduce kernel with SHARP / in blocks - -// threadfence and SMs sync to SM0 -#define SMBAR(offset, block) \ - asm volatile("bar.sync 13, %0;" ::"r"(blockDim.x)); \ - if (threadIdx.x == 0) { \ - __threadfence_system(); \ - if (blockIdx.x) \ - gpuflag[offset + blockIdx.x] = block + basecounter + 1; \ - } else if (blockIdx.x == 0) { \ - int expecting = (basecounter + block + 1); \ - if (threadIdx.x < gridDim.x) \ - while (((volatile int *)gpuflag)[offset + threadIdx.x] < expecting) { \ - } \ - } \ - if (blockIdx.x == 0) \ - asm volatile("bar.sync 15, %0;" ::"r"(32)); - -template -__global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_rr_blocked2( - const int op, const int maxcredit, const int headstart, const int myibrank, const int ibranks, - const int commbufoffset, const int flagoffset, const int firstrank, const int myrank, - const int gpustep, const int lineoffset, const int numlines, void **commbuff, - const int handleridx, const int peerblocklines, int *hostflags, int *gpuflag, - const int numblocks) { - const int basecounter = gpuflag[NVTE_GF_STATE + op]; - if (threadIdx.x < 32) { - int *flagptr; - volatile int *localflag = (volatile int *)&( - ((int *)commbuff[gpustep * myrank + firstrank])[flagoffset]); // NOLINT(*) - // initial intranode barrier - once - if (threadIdx.x < RANKS) { - if (!blockIdx.x) { - flagptr = reinterpret_cast(commbuff[gpustep * threadIdx.x + firstrank]); - flagptr[flagoffset + gpustep * myrank + firstrank] = basecounter; - } - volatile int *flag = &localflag[gpustep * threadIdx.x + firstrank]; - while (CHECK_IDS(*flag, basecounter)) { - } - } - __syncthreads(); - - for (int nblock = 0; nblock < numblocks + headstart; nblock++) { - if (nblock < numblocks) { - // RS happens here - SMBAR(op * 2 * NVTE_MAX_SMS, nblock); - if (!blockIdx.x && !threadIdx.x) - hostflags[NVTE_HF_NVRSDONE + (op & 1)] = nblock + basecounter + 1; - } - - if (nblock >= headstart) { - for (int ibflag = threadIdx.x; ibflag < ibranks; ibflag += 32) - if (ibflag != myibrank) - while (localflag[NVTE_REG0_IBRS + ibflag] < basecounter + nblock - headstart + 1) { - } - asm volatile("bar.sync 13, %0;" ::"r"(blockDim.x)); - // REDUCE happens here - SMBAR(op * 2 * NVTE_MAX_SMS + NVTE_MAX_SMS, nblock - headstart); - if (!blockIdx.x && !threadIdx.x) - hostflags[NVTE_HF_NVREDUCEDONE + (op & 1)] = nblock + basecounter + 1 - headstart; - } - } - // final part doing NVAG based on responses from NIC-RMW:IBAG - - if (blockIdx.x == 0) { - for (int nblock = 0; nblock < numblocks; nblock++) { - const int expected = basecounter + nblock + 1; - for (int ibflag = threadIdx.x; ibflag < ibranks; ibflag += 32) - if (ibflag != myibrank) - while (localflag[NVTE_REG0_IBAG + ibflag] < expected) { - } - asm volatile("bar.sync 15, %0;" ::"r"(32)); - if (threadIdx.x < RANKS) - flagptr[flagoffset + gpustep * myrank + NVTE_MAX_NVLINK + firstrank] = expected; - } - } - - if (blockIdx.x == 0 && threadIdx.x == 0) - gpuflag[NVTE_GF_STATE + op] = basecounter + numblocks; - } else { // sync warp - // reducethreads - const int warp = blockIdx.x + (threadIdx.x >> 5); - int4 *userptr[RANKS]; - int4 *userptrmyrank; -#pragma unroll - for (int i = 0; i < RANKS; i++) - userptr[i] = reinterpret_cast( - commbuff[((i + myrank + warp) & (RANKS - 1)) * gpustep + handleridx + firstrank]); - userptrmyrank = reinterpret_cast(commbuff[gpustep * myrank + handleridx + firstrank]); - int4 *internalbuf = reinterpret_cast(commbuff[myrank * gpustep + firstrank] + - commbufoffset * sizeof(int)); - __syncthreads(); - - int blocklineoffset = 0, rblocklineoffset = 0; - - for (int nblock = 0; nblock < numblocks + headstart; nblock++) { - // NVRS part(only first numblocks steps) - if (blocklineoffset < numlines) { - const int remainder = min(numlines - blocklineoffset, peerblocklines * RANKS); - const int blocklines = remainder / RANKS; - const int blockstart = lineoffset + blocklineoffset + blocklines * myrank; - if (RANKS > 1) { - for (int line = threadIdx.x - 32 + REDUCETHREADS * blockIdx.x; line < blocklines; - line += REDUCETHREADS * gridDim.x) { - int4 val[RANKS]; - -#pragma unroll - for (int i = 0; i < RANKS; i++) { - val[i] = userptr[i][blockstart + line]; - } - - int4 sum = val[0]; - half *s = reinterpret_cast(&sum); - -#pragma unroll - for (int i = 1; i < RANKS; i++) { - half *x = reinterpret_cast(&val[i]); -#pragma unroll - for (int j = 0; j < sizeof(int4) / sizeof(half); j++) - s[j] += x[j]; - } - - userptrmyrank[blockstart + line] = sum; - } // single block loop - } - - asm volatile("bar.sync 13, %0;" ::"r"(REDUCETHREADS + 32)); - blocklineoffset += peerblocklines * RANKS; - } - if (nblock >= headstart) { -#define UNROLLRS 2 - const int remainder = min(numlines - rblocklineoffset, peerblocklines * RANKS); - const int blocklines = remainder / RANKS; - rblocklineoffset += peerblocklines * RANKS; - const int ibblocklines = blocklines / ibranks; - int4 *tempbufptr = &internalbuf[((nblock - headstart) % maxcredit) * peerblocklines]; - const int tempstart = lineoffset + (nblock - headstart) * peerblocklines * RANKS + - myrank * blocklines + ibblocklines * myibrank; - - asm volatile("bar.sync 13, %0;" ::"r"(REDUCETHREADS + 32)); - - for (int line = threadIdx.x - 32 + REDUCETHREADS * blockIdx.x; line < ibblocklines; - line += REDUCETHREADS * gridDim.x) { - int4 val[UNROLLRS]; - -#pragma unroll - for (int i = 0; i < UNROLLRS; i++) - val[i] = i == myibrank ? userptrmyrank[tempstart + line] - : tempbufptr[i * ibblocklines + line]; - - int4 sum = val[0]; - half *s = reinterpret_cast(&sum); - - for (int i = 0; i < ibranks - UNROLLRS; i++) { - val[i % UNROLLRS] = i == myibrank ? userptrmyrank[tempstart + line] - : tempbufptr[i * ibblocklines + line]; - half *x = reinterpret_cast(&val[(i + 1) % UNROLLRS]); -#pragma unroll - for (int j = 0; j < 16; j++) - s[j] += x[j]; - } -#pragma unroll - for (int i = 1; i < UNROLLRS; i++) { - half *x = reinterpret_cast(&val[i]); -#pragma unroll - for (int j = 0; j < 16; j++) - s[j] += x[j]; - } - userptrmyrank[tempstart + line] = sum; - } - - asm volatile("bar.sync 13, %0;" ::"r"(REDUCETHREADS + 32)); - } - } // nblock loop NVLINK-REDUCESCATTER + IBREDUCE LOCAL COMPUTE - - if (RANKS != 1) { - const int nwarps = (REDUCETHREADS >> 5) / (RANKS - 1); - const int myblockDim = nwarps << 5; - const int mywarp = ((threadIdx.x - 32) >> 5) / (RANKS - 1); - const int maxthreadIdx = myblockDim * (RANKS - 1) + 32; - const int mydest = (myrank + 1 + ((threadIdx.x - 32) >> 5) % (RANKS - 1)) & (RANKS - 1); - const int mythreadIdx = (mywarp << 5) + (threadIdx.x & 31); - volatile int *flag = (volatile int *)&((reinterpret_cast( - commbuff[gpustep * myrank + firstrank]))[flagoffset + gpustep * mydest + NVTE_MAX_NVLINK + - firstrank]); - - int4 *userptrmydest = userptr[((RANKS << 10) + mydest - myrank - warp) & (RANKS - 1)]; - - blocklineoffset = 0; - int gathercounter = basecounter + 1; - while (blocklineoffset < numlines) { - const int remainder = min(numlines - blocklineoffset, peerblocklines * RANKS); - const int blocklines = remainder / RANKS; - const int blockstart = lineoffset + blocklineoffset; - -#define UNROLL 6 - int4 *myptr = &userptrmyrank[blockstart + blocklines * mydest]; - int4 *peerptr = &userptrmydest[blockstart + blocklines * mydest]; - - if (threadIdx.x < maxthreadIdx) { - const int start_elem = mythreadIdx + myblockDim * blockIdx.x; - const int end_elem = max(start_elem, blocklines); - const int aligned_elem = ((end_elem - start_elem) / (myblockDim * gridDim.x * UNROLL)) * - (myblockDim * gridDim.x * UNROLL); - const int end_aligned = start_elem + aligned_elem; - - if (mythreadIdx == 0) { - while (CHECK_IDS(*flag, gathercounter)) { - } - gathercounter++; - } - - asm volatile("bar.sync %0, %1;" ::"r"(1 + mydest), "r"(myblockDim)); - - for (int line = start_elem; line < end_aligned; line += myblockDim * gridDim.x * UNROLL) { - int4 val[UNROLL]; -#pragma unroll - for (int i = 0; i < UNROLL; i++) - val[i] = peerptr[line + i * myblockDim * gridDim.x]; -#pragma unroll - for (int i = 0; i < UNROLL; i++) - myptr[line + i * myblockDim * gridDim.x] = val[i]; - } - for (int line = end_aligned; line < end_elem; line += myblockDim * gridDim.x) - myptr[line] = peerptr[line]; - } - blocklineoffset += peerblocklines * RANKS; - } // block loop for NVLINK-ALLGATHER - } // RANKS!=1 - } // worker warps else block -} // fp16 inplace reduce kernel with SHARP / in blocks - -template -__global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_rr_blocked2_rs( - const int op, const int maxcredit, const int headstart, const int myibrank, const int ibranks, - const int commbufoffset, const int flagoffset, const int firstrank, const int myrank, - const int gpustep, const int lineoffset, const int numlines, void **commbuff, - const int handleridx, const int peerblocklines, int *hostflags, int *gpuflag, - const int numblocks) { - const int basecounter = gpuflag[NVTE_GF_STATE + op]; - if (threadIdx.x < 32) { - int *flagptr; - volatile int *localflag = (volatile int *)&( - ((int *)commbuff[gpustep * myrank + firstrank])[flagoffset]); // NOLINT(*) - // initial intranode barrier - once - if (threadIdx.x < RANKS) { - if (!blockIdx.x) { - flagptr = reinterpret_cast(commbuff[gpustep * threadIdx.x + firstrank]); - flagptr[flagoffset + gpustep * myrank + firstrank] = basecounter; - } - volatile int *flag = &localflag[gpustep * threadIdx.x + firstrank]; - while (CHECK_IDS(*flag, basecounter)) { - } - } - __syncthreads(); - - for (int nblock = 0; nblock < numblocks + headstart; nblock++) { - if (nblock < numblocks) { - // RS happens here - SMBAR(op * 2 * NVTE_MAX_SMS, nblock); - if (!blockIdx.x && !threadIdx.x) - hostflags[NVTE_HF_NVRSDONE + (op & 1)] = nblock + basecounter + 1; - } - - if (nblock >= headstart) { - for (int ibflag = threadIdx.x; ibflag < ibranks; ibflag += 32) - if (ibflag != myibrank) - while (localflag[NVTE_REG0_IBRS + ibflag] < basecounter + nblock - headstart + 1) { - } - asm volatile("bar.sync 13, %0;" ::"r"(blockDim.x)); - // REDUCE happens here - SMBAR(op * 2 * NVTE_MAX_SMS + NVTE_MAX_SMS, nblock - headstart); - } - } - } else { // sync warp - // reducethreads - const int warp = blockIdx.x + (threadIdx.x >> 5); - int4 *userptr[RANKS]; - int4 *userptrmyrank; -#pragma unroll - for (int i = 0; i < RANKS; i++) - userptr[i] = reinterpret_cast( - commbuff[((i + myrank + warp) & (RANKS - 1)) * gpustep + handleridx + firstrank]); - userptrmyrank = reinterpret_cast(commbuff[gpustep * myrank + handleridx + firstrank]); - int4 *internalbuf = reinterpret_cast(commbuff[myrank * gpustep + firstrank] + - commbufoffset * sizeof(int)); - __syncthreads(); - - int blocklineoffset = 0, rblocklineoffset = 0; - - for (int nblock = 0; nblock < numblocks + headstart; nblock++) { - // NVRS part(only first numblocks steps) - if (blocklineoffset < numlines) { - const int remainder = min(numlines - blocklineoffset, peerblocklines * RANKS); - const int blocklines = remainder / RANKS; - const int blockstart = lineoffset + blocklineoffset + blocklines * myrank; - if (RANKS > 1) { - for (int line = threadIdx.x - 32 + REDUCETHREADS * blockIdx.x; line < blocklines; - line += REDUCETHREADS * gridDim.x) { - int4 val[RANKS]; - -#pragma unroll - for (int i = 0; i < RANKS; i++) { - val[i] = userptr[i][blockstart + line]; - } - - int4 sum = val[0]; - half *s = reinterpret_cast(&sum); - -#pragma unroll - for (int i = 1; i < RANKS; i++) { - half *x = reinterpret_cast(&val[i]); -#pragma unroll - for (int j = 0; j < sizeof(int4) / sizeof(half); j++) - s[j] += x[j]; - } - - userptrmyrank[blockstart + line] = sum; - } // single block loop - } - - asm volatile("bar.sync 13, %0;" ::"r"(REDUCETHREADS + 32)); - blocklineoffset += peerblocklines * RANKS; - } - if (nblock >= headstart) { -#define UNROLLRS 2 - const int remainder = min(numlines - rblocklineoffset, peerblocklines * RANKS); - const int blocklines = remainder / RANKS; - rblocklineoffset += peerblocklines * RANKS; - const int ibblocklines = blocklines / ibranks; - int4 *tempbufptr = &internalbuf[((nblock - headstart) % maxcredit) * peerblocklines]; - const int tempstart = lineoffset + (nblock - headstart) * peerblocklines * RANKS + - myrank * blocklines + ibblocklines * myibrank; - - asm volatile("bar.sync 13, %0;" ::"r"(REDUCETHREADS + 32)); - - for (int line = threadIdx.x - 32 + REDUCETHREADS * blockIdx.x; line < ibblocklines; - line += REDUCETHREADS * gridDim.x) { - int4 val[UNROLLRS]; - -#pragma unroll - for (int i = 0; i < UNROLLRS; i++) - val[i] = i == myibrank ? userptrmyrank[tempstart + line] - : tempbufptr[i * ibblocklines + line]; - - int4 sum = val[0]; - half *s = reinterpret_cast(&sum); - - for (int i = 0; i < ibranks - UNROLLRS; i++) { - val[i % UNROLLRS] = i == myibrank ? userptrmyrank[tempstart + line] - : tempbufptr[i * ibblocklines + line]; - half *x = reinterpret_cast(&val[(i + 1) % UNROLLRS]); -#pragma unroll - for (int j = 0; j < 16; j++) - s[j] += x[j]; - } -#pragma unroll - for (int i = 1; i < UNROLLRS; i++) { - half *x = reinterpret_cast(&val[i]); -#pragma unroll - for (int j = 0; j < 16; j++) - s[j] += x[j]; - } - userptrmyrank[tempstart + line] = sum; - } - - asm volatile("bar.sync 13, %0;" ::"r"(REDUCETHREADS + 32)); - } - } // nblock loop NVLINK-REDUCESCATTER + IBREDUCE LOCAL COMPUTE - } // worker warps else block -} // fp16 inplace reduce kernel with SHARP / in blocks - -template -__global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_rr_blocked2_ag( - const int op, const int maxcredit, const int headstart, const int myibrank, const int ibranks, - const int commbufoffset, const int flagoffset, const int firstrank, const int myrank, - const int gpustep, const int lineoffset, const int numlines, void **commbuff, - const int handleridx, const int peerblocklines, int *hostflags, int *gpuflag, - const int numblocks) { - const int basecounter = gpuflag[NVTE_GF_STATE + op]; - if (threadIdx.x < 32) { - int *flagptr; - volatile int *localflag = (volatile int *)&( - ((int *)commbuff[gpustep * myrank + firstrank])[flagoffset]); // NOLINT(*) - if (threadIdx.x < RANKS) { - if (!blockIdx.x) { - flagptr = reinterpret_cast(commbuff[gpustep * threadIdx.x + firstrank]); - } - } - __syncthreads(); - if (!blockIdx.x && !threadIdx.x) - hostflags[NVTE_HF_NVREDUCEDONE + (op & 1)] = numblocks + basecounter; - // tell CPU proxy all blocks are done and ready for NVAG - - // final part doing NVAG based on responses from NIC-RMW:IBAG - - if (blockIdx.x == 0) { - for (int nblock = 0; nblock < numblocks; nblock++) { - const int expected = basecounter + nblock + 1; - for (int ibflag = threadIdx.x; ibflag < ibranks; ibflag += 32) - if (ibflag != myibrank) - while (localflag[NVTE_REG0_IBAG + ibflag] < expected) { - } - asm volatile("bar.sync 15, %0;" ::"r"(32)); - if (threadIdx.x < RANKS) - flagptr[flagoffset + gpustep * myrank + NVTE_MAX_NVLINK + firstrank] = expected; - } - } - - if (blockIdx.x == 0 && threadIdx.x == 0) - gpuflag[NVTE_GF_STATE + op] = basecounter + numblocks; - } else { // sync warp - // reducethreads - const int warp = blockIdx.x + (threadIdx.x >> 5); - int4 *userptr[RANKS]; - int4 *userptrmyrank; -#pragma unroll - for (int i = 0; i < RANKS; i++) - userptr[i] = reinterpret_cast( - commbuff[((i + myrank + warp) & (RANKS - 1)) * gpustep + handleridx + firstrank]); - userptrmyrank = reinterpret_cast(commbuff[gpustep * myrank + handleridx + firstrank]); - __syncthreads(); - - int blocklineoffset = 0, rblocklineoffset = 0; - - if (RANKS != 1) { - const int nwarps = (REDUCETHREADS >> 5) / (RANKS - 1); - const int myblockDim = nwarps << 5; - const int mywarp = ((threadIdx.x - 32) >> 5) / (RANKS - 1); - const int maxthreadIdx = myblockDim * (RANKS - 1) + 32; - const int mydest = (myrank + 1 + ((threadIdx.x - 32) >> 5) % (RANKS - 1)) & (RANKS - 1); - const int mythreadIdx = (mywarp << 5) + (threadIdx.x & 31); - volatile int *flag = (volatile int *)&((reinterpret_cast( - commbuff[gpustep * myrank + firstrank]))[flagoffset + gpustep * mydest + NVTE_MAX_NVLINK + - firstrank]); - - int4 *userptrmydest = userptr[((RANKS << 10) + mydest - myrank - warp) & (RANKS - 1)]; - - blocklineoffset = 0; - int gathercounter = basecounter + 1; - while (blocklineoffset < numlines) { - const int remainder = min(numlines - blocklineoffset, peerblocklines * RANKS); - const int blocklines = remainder / RANKS; - const int blockstart = lineoffset + blocklineoffset; - -#define UNROLL 6 - int4 *myptr = &userptrmyrank[blockstart + blocklines * mydest]; - int4 *peerptr = &userptrmydest[blockstart + blocklines * mydest]; - - if (threadIdx.x < maxthreadIdx) { - const int start_elem = mythreadIdx + myblockDim * blockIdx.x; - const int end_elem = max(start_elem, blocklines); - const int aligned_elem = ((end_elem - start_elem) / (myblockDim * gridDim.x * UNROLL)) * - (myblockDim * gridDim.x * UNROLL); - const int end_aligned = start_elem + aligned_elem; - - if (mythreadIdx == 0) { - while (CHECK_IDS(*flag, gathercounter)) { - } - gathercounter++; - } - - asm volatile("bar.sync %0, %1;" ::"r"(1 + mydest), "r"(myblockDim)); - - for (int line = start_elem; line < end_aligned; line += myblockDim * gridDim.x * UNROLL) { - int4 val[UNROLL]; -#pragma unroll - for (int i = 0; i < UNROLL; i++) - val[i] = peerptr[line + i * myblockDim * gridDim.x]; -#pragma unroll - for (int i = 0; i < UNROLL; i++) - myptr[line + i * myblockDim * gridDim.x] = val[i]; - } - for (int line = end_aligned; line < end_elem; line += myblockDim * gridDim.x) - myptr[line] = peerptr[line]; - } - blocklineoffset += peerblocklines * RANKS; - } // block loop for NVLINK-ALLGATHER - } // RANKS!=1 - } // worker warps else block -} // fp16 inplace reduce kernel with SHARP / in blocks - -__global__ void userbuffers_fp16_sum_inplace_gpu_null(const int op, int *hostflags, int *gpuflag, - int numblocks) { - const int basecounter = gpuflag[NVTE_GF_STATE + op] + numblocks; - hostflags[0] = basecounter; - gpuflag[NVTE_GF_STATE + op] = basecounter; - while (((volatile int *)gpuflag)[NVTE_GF_IBSHARPDONE] < basecounter) { - } -} - -#define callranks_block(x) \ - if (comm->ar_nvsize == x) \ - userbuffers_fp16_sum_inplace_gpu_rr_blocked<<>>( \ - userbuffers_allreduceop_sharp, NVTE_REG0_OFFSET(comm), comm->ar_firstgpu, comm->ar_nvrank, \ - offset / 8, elements / 8, reinterpret_cast(comm->gpu_ptrs), \ - handler * comm->nvsize, blocksize / sizeof(int4) / comm->ar_nvsize, \ - reinterpret_cast(comm->hostflags), comm->flags, \ - (elements * 2 + blocksize - 1) / blocksize); - -#define callranks2_block(x) \ - if (ar_nvsize == x) { \ - int numblocks = (elements * 2 + blocksize - 1) / blocksize; \ - int headstart = numblocks - 1; /*<3?numblocks-1:3;*/ \ - if (headstart > maxcredit) \ - headstart = maxcredit; \ - if (x == 1) \ - headstart = maxcredit; \ - if (headstart > numblocks) \ - headstart = numblocks; \ - if (headstart == 0) \ - headstart = 1; \ - userbuffers_fp16_sum_inplace_gpu_rr_blocked2<<>>( \ - op, maxcredit, headstart, my_node, num_nodes, \ - NVTE_REG0_OFFSET(comm) + NVTE_REG0_FLAGS + \ - (op == userbuffers_allreduceop_nonsharp ? NVTE_REG0_COMMBUFFER : 0), \ - NVTE_REG0_OFFSET(comm) + NVTE_REG0_OPFLAGS * op, ar_firstgpu, ar_nvrank, ar_step, \ - offset / 8, elements / 8, reinterpret_cast(comm->gpu_ptrs), \ - handler * comm->nvsize, blocksize / sizeof(int4) / ar_nvsize, \ - reinterpret_cast(comm->hostflags), comm->flags, numblocks); \ - } - -#define callranks2_block_rs(x) \ - if (ar_nvsize == x) { \ - int numblocks = (elements * 2 + blocksize - 1) / blocksize; \ - int headstart = numblocks - 1; /*<3?numblocks-1:3;*/ \ - if (headstart > maxcredit) \ - headstart = maxcredit; \ - if (x == 1) \ - headstart = maxcredit; \ - if (headstart > numblocks) \ - headstart = numblocks; \ - if (headstart == 0) \ - headstart = 1; \ - userbuffers_fp16_sum_inplace_gpu_rr_blocked2_rs<<>>( \ - op, maxcredit, headstart, my_node, num_nodes, \ - NVTE_REG0_OFFSET(comm) + NVTE_REG0_FLAGS + \ - (op == userbuffers_allreduceop_nonsharp ? NVTE_REG0_COMMBUFFER : 0), \ - NVTE_REG0_OFFSET(comm) + NVTE_REG0_OPFLAGS * op, ar_firstgpu, ar_nvrank, ar_step, \ - offset / 8, elements / 8, reinterpret_cast(comm->gpu_ptrs), \ - handler * comm->nvsize, blocksize / sizeof(int4) / ar_nvsize, \ - reinterpret_cast(comm->hostflags), comm->flags, numblocks); \ - } - -#define callranks2_block_ag(x) \ - if (ar_nvsize == x) { \ - int numblocks = (elements * 2 + blocksize - 1) / blocksize; \ - int headstart = numblocks - 1; /*<3?numblocks-1:3;*/ \ - if (headstart > maxcredit) \ - headstart = maxcredit; \ - if (x == 1) \ - headstart = maxcredit; \ - if (headstart > numblocks) \ - headstart = numblocks; \ - if (headstart == 0) \ - headstart = 1; \ - userbuffers_fp16_sum_inplace_gpu_rr_blocked2_ag<<>>( \ - op, maxcredit, headstart, my_node, num_nodes, \ - NVTE_REG0_OFFSET(comm) + NVTE_REG0_FLAGS + \ - (op == userbuffers_allreduceop_nonsharp ? NVTE_REG0_COMMBUFFER : 0), \ - NVTE_REG0_OFFSET(comm) + NVTE_REG0_OPFLAGS * op, ar_firstgpu, ar_nvrank, ar_step, \ - offset / 8, elements / 8, reinterpret_cast(comm->gpu_ptrs), \ - handler * comm->nvsize, blocksize / sizeof(int4) / ar_nvsize, \ - reinterpret_cast(comm->hostflags), comm->flags, numblocks); \ - } - -#define callranks(x) \ - if (ar_nvsize == x) { \ - int arg1 = op - NVTE_MAX_OPS, \ - arg2 = NVTE_REG0_OFFSET(comm) - \ - (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ - NVTE_MAX_OPS, \ - arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg6 = offset / 8, \ - arg7 = elements / 8; \ - void **arg8 = reinterpret_cast(comm->gpu_ptrs); \ - int arg9 = handler * comm->nvsize; \ - void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ - reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ - reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ - reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ - reinterpret_cast(&arg9)}; \ - CUDACHECK(cudaLaunchKernelExC( \ - &cfg, \ - reinterpret_cast(comm->use_rr_kernel ? userbuffers_fp16_sum_inplace_gpu_rr \ - : userbuffers_fp16_sum_inplace_gpu_rw), \ - kernelArgs)); \ - } - -#define callranksMC(x) \ - if (ar_nvsize == x) { \ - int arg1 = op - NVTE_MAX_OPS, \ - arg2 = NVTE_REG0_OFFSET(comm) - \ - (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ - NVTE_MAX_OPS, \ - arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg6 = offset / 8, \ - arg7 = elements / 8; \ - void **arg8 = reinterpret_cast(comm->gpu_ptrs); \ - int arg9 = handler * comm->nvsize; \ - void *arg10 = comm->mc_ptr[handler]; \ - void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ - reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ - reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ - reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ - reinterpret_cast(&arg9), reinterpret_cast(&arg10)}; \ - CUDACHECK(cudaLaunchKernelExC( \ - &cfg, reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_mc), kernelArgs)); \ - } - #define SETUP_LAUNCH_CONFIG(sms, threads, stream) \ cudaLaunchConfig_t cfg = {sms, threads, 0, stream, NULL, 0}; \ cudaLaunchAttribute attribute_ub[2]; \ @@ -2279,60 +1441,6 @@ __global__ void userbuffers_fp16_sum_inplace_gpu_null(const int op, int *hostfla cfg.attrs = attribute_ub; \ cfg.numAttrs = comm->sm_arch >= 9 ? 2 : 1; -int allreduce_userbuff_inplace_gpu(const int handler, const int offset, const int elements, - const int blocksize, communicator *comm, cudaStream_t stream) { - // schedule GPU kernel only - // CPU/SHARP part is responsibility of caller - const int ar_step = comm->ar2_nvsize; - const int op = userbuffers_allreduceop_nonsharp; - const int ar_nvsize = comm->nvsize; - const int ar_firstgpu = comm->ar_firstgpu; - const int ar_nvrank = comm->ar_nvrank; - if (elements < 8) - return 0; - int sms = sms = comm->sms; - int warps = comm->threads / 32; - if (warps < comm->ar_nvsize) - warps = comm->ar_nvsize; - - if (comm->launch_mode & NVTE_LAUNCH_GPU) { - if (comm->ar_nvsize == 1) - userbuffers_fp16_sum_inplace_gpu_null<<<1, 1, 0, stream>>>( - userbuffers_allreduceop_sharp, reinterpret_cast(comm->hostflags), comm->flags, - (elements * 2 + blocksize - 1) / blocksize); - callranks_block(2) callranks_block(4) callranks_block(8) - } - return sms; -} - -int allreduce2_userbuff_inplace_gpu(const int maxcredit, const int handler, const int offset, - const int elements, const int blocksize, communicator *comm, - cudaStream_t stream, int op) { - // schedule GPU kernel only - // CPU/SHARP part is responsibility of caller - const int num_nodes = op == userbuffers_allreduceop_nonsharp ? comm->num_nodes : comm->num2_nodes; - const int my_node = op == userbuffers_allreduceop_nonsharp ? comm->my_node : comm->my2_node; - const int ar_firstgpu = - op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu; - const int ar_step = op == userbuffers_allreduceop_nonsharp2 ? 1 : comm->ar2_nvsize; - const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; - const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank; - - if (elements < 8) - return 0; - int sms = ar_nvsize == 1 ? 2 : comm->sms; - int warps = comm->threads / 32; - if (warps < ar_nvsize) - warps = ar_nvsize; - if (num_nodes > 1) { - callranks2_block(1) callranks2_block(2) callranks2_block(4) callranks2_block(8) - } else { - SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); - callranks(2) callranks(4) callranks(8) - } - return sms; -} - #define callranks_ag(x) \ if (ar_nvsize == x) { \ int arg1 = op - NVTE_MAX_OPS, \ @@ -2343,11 +1451,12 @@ int allreduce2_userbuff_inplace_gpu(const int maxcredit, const int handler, cons arg6 = offset / 8 + (comm->use_rr_kernel ? 0 : arg4 * arg7); \ void **arg8 = reinterpret_cast(comm->gpu_ptrs); \ int arg9 = handler * comm->nvsize; \ + uint64_t arg10 = comm->ub_timeout; \ void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ - reinterpret_cast(&arg9)}; \ + reinterpret_cast(&arg9), reinterpret_cast(&arg10)}; \ CUDACHECK(cudaLaunchKernelExC( \ &cfg, \ reinterpret_cast(comm->use_rr_kernel ? userbuffers_fp16_sum_inplace_gpu_rr_ag \ @@ -2366,11 +1475,13 @@ int allreduce2_userbuff_inplace_gpu(const int maxcredit, const int handler, cons void **arg8 = reinterpret_cast(comm->gpu_ptrs); \ int arg9 = handler * comm->nvsize; \ uint4 *arg10 = reinterpret_cast(comm->mc_ptr[handler]); \ + uint64_t arg11 = comm->ub_timeout; \ void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ - reinterpret_cast(&arg9), reinterpret_cast(&arg10)}; \ + reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ + reinterpret_cast(&arg11)}; \ CUDACHECK(cudaLaunchKernelExC( \ &cfg, reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_mc_ag), kernelArgs)); \ } @@ -2385,11 +1496,12 @@ int allreduce2_userbuff_inplace_gpu(const int maxcredit, const int handler, cons arg6 = offset / 8 + arg4 * arg7; \ void **arg8 = reinterpret_cast(comm->gpu_ptrs); \ int arg9 = handler * comm->nvsize; \ + uint64_t arg10 = comm->ub_timeout; \ void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ - reinterpret_cast(&arg9)}; \ + reinterpret_cast(&arg9), reinterpret_cast(&arg10)}; \ CUDACHECK(cudaLaunchKernelExC( \ &cfg, reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_rr_rs), kernelArgs)); \ } @@ -2405,11 +1517,13 @@ int allreduce2_userbuff_inplace_gpu(const int maxcredit, const int handler, cons void **arg8 = reinterpret_cast(comm->gpu_ptrs); \ int arg9 = handler * comm->nvsize; \ void *arg10 = comm->mc_ptr[handler]; \ + uint64_t arg11 = comm->ub_timeout; \ void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ - reinterpret_cast(&arg9), reinterpret_cast(&arg10)}; \ + reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ + reinterpret_cast(&arg11)}; \ CUDACHECK(cudaLaunchKernelExC( \ &cfg, reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_mc_rs), kernelArgs)); \ } @@ -2425,12 +1539,14 @@ int allreduce2_userbuff_inplace_gpu(const int maxcredit, const int handler, cons void **arg10 = reinterpret_cast(comm->gpu_ptrs); \ int arg11 = handler * comm->nvsize; \ void *arg12 = output; \ + uint64_t arg13 = comm->ub_timeout; \ void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ - reinterpret_cast(&arg11), reinterpret_cast(&arg12)}; \ + reinterpret_cast(&arg11), reinterpret_cast(&arg12), \ + reinterpret_cast(&arg13)}; \ CUDACHECK(cudaLaunchKernelExC( \ &cfg, reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_rr_rs_oop), \ kernelArgs)); \ @@ -2448,13 +1564,14 @@ int allreduce2_userbuff_inplace_gpu(const int maxcredit, const int handler, cons int arg11 = handler * comm->nvsize; \ void *arg12 = output; \ float *arg13 = scale; \ + uint64_t arg14 = comm->ub_timeout; \ void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ reinterpret_cast(&arg11), reinterpret_cast(&arg12), \ - reinterpret_cast(&arg13)}; \ + reinterpret_cast(&arg13), reinterpret_cast(&arg14)}; \ CUDACHECK(cudaLaunchKernelExC( \ &cfg, \ reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_fp8), \ @@ -2473,13 +1590,14 @@ int allreduce2_userbuff_inplace_gpu(const int maxcredit, const int handler, cons int arg11 = handler * comm->nvsize; \ void *arg12 = output; \ void *arg13 = comm->mc_ptr[handler]; \ + uint64_t arg14 = comm->ub_timeout; \ void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ reinterpret_cast(&arg11), reinterpret_cast(&arg12), \ - reinterpret_cast(&arg13)}; \ + reinterpret_cast(&arg13), reinterpret_cast(&arg14)}; \ CUDACHECK(cudaLaunchKernelExC( \ &cfg, reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_mc_rs_oop), \ kernelArgs)); \ @@ -2500,6 +1618,7 @@ int allreduce2_userbuff_inplace_gpu(const int maxcredit, const int handler, cons float *arg14 = scale; \ void *arg15 = counters; \ int arg16 = numchunks, arg17 = atomicindex; \ + uint64_t arg18 = comm->ub_timeout; \ void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ @@ -2508,7 +1627,7 @@ int allreduce2_userbuff_inplace_gpu(const int maxcredit, const int handler, cons reinterpret_cast(&arg11), reinterpret_cast(&arg12), \ reinterpret_cast(&arg13), reinterpret_cast(&arg14), \ reinterpret_cast(&arg15), reinterpret_cast(&arg16), \ - reinterpret_cast(&arg17)}; \ + reinterpret_cast(&arg17), reinterpret_cast(&arg18)}; \ CUDACHECK(cudaLaunchKernelExC( \ &cfg, \ reinterpret_cast( \ @@ -2527,46 +1646,18 @@ int allreduce2_userbuff_inplace_gpu(const int maxcredit, const int handler, cons void **arg10 = reinterpret_cast(comm->gpu_ptrs); \ int arg11 = handler * comm->nvsize; \ void *arg12 = output; \ - void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ - reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ - reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ - reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ - reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ - reinterpret_cast(&arg11), reinterpret_cast(&arg12)}; \ - CUDACHECK(cudaLaunchKernelExC( \ - &cfg, reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride), \ - kernelArgs)); \ - } - -#if 0 -#define callranks_rs_oop_stride_atomic_fp8(x) \ - if (ar_nvsize == x) { \ - int arg1 = op - NVTE_MAX_OPS, \ - arg2 = NVTE_REG0_OFFSET(comm) - \ - (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ - NVTE_MAX_OPS, \ - arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 16 / x, \ - arg6 = offset / 16, arg8 = rowelements / 8, arg9 = strideelements / 8, arg10 = numchunks; \ - void **arg11 = reinterpret_cast(comm->gpu_ptrs); \ - int arg12 = handler * comm->nvsize; \ - void *arg13 = output; \ - void *arg14 = counters; \ - float *arg15 = scale; \ + uint64_t arg13 = comm->ub_timeout; \ void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ reinterpret_cast(&arg11), reinterpret_cast(&arg12), \ - reinterpret_cast(&arg13), reinterpret_cast(&arg14), \ - reinterpret_cast(&arg15)}; \ + reinterpret_cast(&arg13)}; \ CUDACHECK(cudaLaunchKernelExC( \ - &cfg, \ - reinterpret_cast( \ - userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride_atomic_fp8), \ + &cfg, reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride), \ kernelArgs)); \ } -#endif #define callranks_rs_oop_stride_atomic(x) \ if (ar_nvsize == x) { \ @@ -2580,13 +1671,15 @@ int allreduce2_userbuff_inplace_gpu(const int maxcredit, const int handler, cons int arg12 = handler * comm->nvsize; \ void *arg13 = output; \ void *arg14 = counters; \ + uint64_t arg15 = comm->ub_timeout; \ void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ reinterpret_cast(&arg11), reinterpret_cast(&arg12), \ - reinterpret_cast(&arg13), reinterpret_cast(&arg14)}; \ + reinterpret_cast(&arg13), reinterpret_cast(&arg14), \ + reinterpret_cast(&arg15)}; \ CUDACHECK(cudaLaunchKernelExC( \ &cfg, \ reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride_atomic), \ @@ -2605,13 +1698,15 @@ int allreduce2_userbuff_inplace_gpu(const int maxcredit, const int handler, cons int arg12 = handler * comm->nvsize; \ void *arg13 = output; \ void *arg14 = counters; \ + uint64_t arg15 = comm->ub_timeout; \ void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ reinterpret_cast(&arg11), reinterpret_cast(&arg12), \ - reinterpret_cast(&arg13), reinterpret_cast(&arg14)}; \ + reinterpret_cast(&arg13), reinterpret_cast(&arg14), \ + reinterpret_cast(&arg15)}; \ CUDACHECK( \ cudaLaunchKernelExC(&cfg, \ reinterpret_cast( \ @@ -2619,47 +1714,12 @@ int allreduce2_userbuff_inplace_gpu(const int maxcredit, const int handler, cons kernelArgs)); \ } -int reducescatter2_userbuff_inplace_gpu(const int maxcredit, const int handler, const int offset, - const int elements, const int blocksize, communicator *comm, - cudaStream_t stream, int op) { - // schedule GPU kernel only - // CPU/SHARP part is responsibility of caller - - const int num_nodes = op == userbuffers_allreduceop_nonsharp ? comm->num_nodes : comm->num2_nodes; - const int my_node = op == userbuffers_allreduceop_nonsharp ? comm->my_node : comm->my2_node; - const int ar_firstgpu = - op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu; - const int ar_step = op == userbuffers_allreduceop_nonsharp2 ? 1 : comm->ar2_nvsize; - const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; - const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank; - - if (elements < 8) - return 0; - int sms = ar_nvsize == 1 ? 2 : comm->sms; - int warps = comm->threads / 32; - if (warps < ar_nvsize) - warps = ar_nvsize; - - if (num_nodes > 1) { - callranks2_block_rs(1) callranks2_block_rs(2) callranks2_block_rs(4) callranks2_block_rs(8) - } else { - SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); - if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { - callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8) - } else { - callranks_rs(2) callranks_rs(4) callranks_rs(8) - } - } - return sms; -} - void reducescatter2_userbuff_strided(void *output, const int handler, const int offset, const int rowelements, const int colelements, const int strideelements, communicator *comm, cudaStream_t stream) { const int elements = rowelements * colelements; const int op = userbuffers_allreduceop_nonsharp2; - const int blocksize = elements * 2; const int ar_firstgpu = op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu; const int ar_step = op == userbuffers_allreduceop_nonsharp2 ? 1 : comm->ar2_nvsize; @@ -2683,7 +1743,6 @@ void reducescatter2_userbuff_strided_atomic(void *output, const int handler, con cudaStream_t stream) { const int elements = rowelements * colelements; const int op = userbuffers_allreduceop_nonsharp2; - const int blocksize = elements * 2; const int ar_firstgpu = op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu; const int ar_step = op == userbuffers_allreduceop_nonsharp2 ? 1 : comm->ar2_nvsize; @@ -2702,36 +1761,6 @@ void reducescatter2_userbuff_strided_atomic(void *output, const int handler, con callranks_rs_oop_stride_atomic(8) } -#if 0 - template - void reducescatter2_userbuff_strided_atomic_fp8( - void* output, float *scale, const int handler, const int offset, const int rowelements, - const int colelements, const int strideelements, const int numchunks, void *counters, - communicator* comm, cudaStream_t stream) { - const int elements = rowelements*colelements; - const int op = userbuffers_allreduceop_nonsharp2; - const int blocksize = elements; - const int ar_firstgpu = op == userbuffers_allreduceop_nonsharp ? - comm->ar_firstgpu : comm->ar2_firstgpu; - const int ar_step = op == userbuffers_allreduceop_nonsharp2 ? - 1 : comm->ar2_nvsize; - const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? - comm->ar_nvsize : comm->ar2_nvsize; - const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? - comm->ar_nvrank : comm->ar2_nvrank; - - assert(comm->sm_arch >= 9); - if (elements < 128) return; - int sms = ar_nvsize == 1 ? 2 : comm->sms; - int warps = comm->threads/32; - if (warps < ar_nvsize) warps = ar_nvsize; - - SETUP_LAUNCH_CONFIG(sms, warps*32, stream); - callranks_rs_oop_stride_atomic_fp8(2) - callranks_rs_oop_stride_atomic_fp8(4) - callranks_rs_oop_stride_atomic_fp8(8) - } -#endif template void reducescatter2_userbuff_strided_universal_fp8(void *output, float *scale, const int handler, const int offset, const int rowelements, @@ -2742,7 +1771,6 @@ void reducescatter2_userbuff_strided_universal_fp8(void *output, float *scale, c communicator *comm, cudaStream_t stream) { const int elements = rowelements * colelements; const int op = userbuffers_allreduceop_nonsharp2; - const int blocksize = elements; const int ar_firstgpu = op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu; const int ar_step = op == userbuffers_allreduceop_nonsharp2 ? 1 : comm->ar2_nvsize; @@ -2771,6 +1799,7 @@ void reducescatter2_userbuff_strided_atomic_fp8(void *output, float *scale, cons output, scale, handler, offset, rowelements, colelements, strideelements_out, strideelements_in, 1, numchunks, counters /*nullptr*/, comm, stream); } + template void reducescatter2_userbuff_strided_multiatomic_fp8( void *output, float *scale, const int handler, const int offset, const int rowelements, @@ -2788,7 +1817,6 @@ void reducescatter2_userbuff_strided_multiatomic(void *output, const int handler cudaStream_t stream) { const int elements = rowelements * colelements; const int op = userbuffers_allreduceop_nonsharp2; - const int blocksize = elements * 2; const int ar_firstgpu = op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu; const int ar_step = op == userbuffers_allreduceop_nonsharp2 ? 1 : comm->ar2_nvsize; @@ -2803,56 +1831,13 @@ void reducescatter2_userbuff_strided_multiatomic(void *output, const int handler warps = ar_nvsize; SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); - // if(comm->use_mc && (comm->memflags[handler] & NVTE_UB_MEM_MC_CREATED)) { - // //callranks_rs_oopMC(2) - // //callranks_rs_oopMC(4) - // //callranks_rs_oopMC(8) - // } else { - // if(comm->memflags[handler] & NVTE_UB_MEM_UC_CONTIG) { - // //callranks_rs_oopUCPTR(2) - // //callranks_rs_oopUCPTR(4) - // //callranks_rs_oopUCPTR(8) - // } else { callranks_rs_oop_stride_multiatomic(2) callranks_rs_oop_stride_multiatomic(4) callranks_rs_oop_stride_multiatomic(8) - // } - //} -} - -int allgather2_userbuff_inplace_gpu(const int maxcredit, const int handler, const int offset, - const int elements, const int blocksize, communicator *comm, - cudaStream_t stream, int op) { - // schedule GPU kernel only - // CPU/SHARP part is responsibility of caller - - const int num_nodes = op == userbuffers_allreduceop_nonsharp ? comm->num_nodes : comm->num2_nodes; - const int my_node = op == userbuffers_allreduceop_nonsharp ? comm->my_node : comm->my2_node; - const int ar_firstgpu = - op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu; - const int ar_step = op == userbuffers_allreduceop_nonsharp2 ? 1 : comm->ar2_nvsize; - const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; - const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank; - - if (elements < 8) - return 0; - int sms = ar_nvsize == 1 ? 2 : comm->sms; - int warps = comm->threads / 32; - if (warps < ar_nvsize) - warps = ar_nvsize; - - if (num_nodes > 1) { - callranks2_block_ag(1) callranks2_block_ag(2) callranks2_block_ag(4) callranks2_block_ag(8) - } else { - SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); - callranks_ag(2) callranks_ag(4) callranks_ag(8) - } - return sms; } void allgather2_userbuff_inplace(const int handler, const int offset, const int elements, communicator *comm, cudaStream_t stream) { const int op = userbuffers_allreduceop_nonsharp2; - const int blocksize = elements * 2; const int ar_firstgpu = op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu; const int ar_step = op == userbuffers_allreduceop_nonsharp2 ? 1 : comm->ar2_nvsize; @@ -2892,7 +1877,6 @@ void allgather2_userbuff_inplace_sliced(const int handler, const int offset, con void reducescatter2_userbuff_inplace(const int handler, const int offset, const int elements, communicator *comm, cudaStream_t stream) { const int op = userbuffers_allreduceop_nonsharp2; - const int blocksize = elements * 2; const int ar_firstgpu = op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu; const int ar_step = op == userbuffers_allreduceop_nonsharp2 ? 1 : comm->ar2_nvsize; @@ -2919,7 +1903,6 @@ void reducescatter2_userbuff_stridedoutput(void *output, const int handler, cons cudaStream_t stream) { const int elements = rowelements * colelements; const int op = userbuffers_allreduceop_nonsharp2; - const int blocksize = elements * 2; const int ar_firstgpu = op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu; const int ar_step = op == userbuffers_allreduceop_nonsharp2 ? 1 : comm->ar2_nvsize; @@ -2952,7 +1935,6 @@ void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const communicator *comm, cudaStream_t stream) { const int elements = rowelements * colelements; const int op = userbuffers_allreduceop_nonsharp2; - const int blocksize = elements; const int ar_firstgpu = op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu; const int ar_step = op == userbuffers_allreduceop_nonsharp2 ? 1 : comm->ar2_nvsize; @@ -2980,92 +1962,36 @@ void reducescatter2_userbuff_fp8(void *output, float *scale, const int handler, template void reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(void *output, float *scale, const int handler, const int offset, const int elements, communicator *comm, - cudaStream_t stream = 0); + cudaStream_t stream); template void reducescatter2_userbuff_fp8<__nv_fp8_e4m3>(void *output, float *scale, const int handler, const int offset, const int elements, communicator *comm, - cudaStream_t stream = 0); -#if 0 -template void reducescatter2_userbuff_strided_atomic_fp8<__nv_fp8_e4m3>( - void* output, float *scale, const int handler, const int offset, - const int rowelements, const int colelements, const int strideelements, - const int numchunks, void *counters, communicator* comm, cudaStream_t stream = 0); -#endif + cudaStream_t stream); + template void reducescatter2_userbuff_strided_atomic_fp8<__nv_fp8_e4m3>( void *output, float *scale, const int handler, const int offset, const int rowelements, const int colelements, const int strideelements_out, const int strideelements_in, - const int numchunks, void *counters, communicator *comm, cudaStream_t stream = 0); + const int numchunks, void *counters, communicator *comm, cudaStream_t stream); + template void reducescatter2_userbuff_strided_multiatomic_fp8<__nv_fp8_e4m3>( void *output, float *scale, const int handler, const int offset, const int rowelements, const int colelements, const int strideelements_out, const int strideelements_in, - const int numchunks, void *counters, communicator *comm, cudaStream_t stream = 0); -__global__ void __launch_bounds__(MAX_THREADS) - kuserbuffers_pullsendrecv(int myrank, int peer, int *recv_id, int *send_flagptr, - int *recv_flagptr, int4 *srcptr, int4 *dstptr, const int lines) { - if (blockIdx.x == 0 && threadIdx.x == 0) { - atomicAdd_system(send_flagptr, 1); - } - -#define UNROLLCOPY 8 - const int start_elem = threadIdx.x + blockDim.x * blockIdx.x; - const int end_elem = lines; - const int aligned_elem = (end_elem - start_elem) & (~(blockDim.x * gridDim.x * UNROLLCOPY - 1)); - const int end_aligned = start_elem + aligned_elem; - - if (threadIdx.x == 0) { - const int signal_id = (*recv_id) + 1; - volatile int *flag = (volatile int *)recv_flagptr; - clock_t s = clock64(); - while (CHECK_IDS(*flag, signal_id)) { - if (clock64() - s > TIMEOUT) { - printf("[%d from %d] pullrecv: expected %d, stuck with %d\n", myrank, peer, signal_id, - *flag); - break; - } - } - if (lines == 0) { - *recv_id = signal_id; - return; - } // otherwise need an extra kernel - } - __syncthreads(); - - if (end_elem <= start_elem) - return; - - for (int line = start_elem; line < end_aligned; line += blockDim.x * gridDim.x * UNROLLCOPY) { - int4 val[UNROLLCOPY]; -#pragma unroll - for (int i = 0; i < UNROLLCOPY; i++) - val[i] = srcptr[line + i * blockDim.x * gridDim.x]; -#pragma unroll - for (int i = 0; i < UNROLLCOPY; i++) - dstptr[line + i * blockDim.x * gridDim.x] = val[i]; - } - for (int line = end_aligned; line < end_elem; line += blockDim.x * gridDim.x) - dstptr[line] = srcptr[line]; -} + const int numchunks, void *counters, communicator *comm, cudaStream_t stream); __global__ void kuserbuffers_pullsend(int myrank, int peer, int *send_id, int *flagptr) { atomicAdd_system(flagptr, 1); } __global__ void kuserbuffers_inc(int *id) { - const int signal_id = (*id) + 1; - *id = signal_id; -} - -__global__ void kuserbuffers_proxysend(int *id, int *hostflag) { - const int signal_id = (*id) + 1; - *hostflag = signal_id; - *id = signal_id; + atomicAdd(id, 1); } __global__ void kuserbuffers_dummy(void) {} __global__ void __launch_bounds__(MAX_THREADS) - kuserbuffers_pullrecv(int myrank, int peer, int *recv_id, int *flagptr, int4 *srcptr, - int4 *dstptr, const int lines) { + kuserbuffers_pullrecv(int myrank, int peer, int nvrank, int nvpeer, int *recv_id, int *flagptr, + int4 *srcptr, int4 *dstptr, const int lines, + uint64_t ub_timeout) { #define UNROLLCOPY 8 const int start_elem = threadIdx.x + blockDim.x * blockIdx.x; const int end_elem = lines; @@ -3077,9 +2003,9 @@ __global__ void __launch_bounds__(MAX_THREADS) volatile int *flag = (volatile int *)flagptr; clock_t s = clock64(); while (CHECK_IDS(*flag, signal_id)) { - if (clock64() - s > TIMEOUT) { - printf("[%d from %d] pullrecv: expected %d, stuck with %d\n", myrank, peer, signal_id, - *flag); + if (CHECK_TIMEOUT(s, ub_timeout)) { + UB_PRINT("pullrecv [grank dst:%d global src:%d][nvrank(GPU) dst: %d src: %d]: expected %d," + " observed %d", myrank, peer, nvrank, nvpeer, signal_id, *flag); break; } } @@ -3138,7 +2064,12 @@ __global__ void __launch_bounds__(MAX_THREADS) } } -__global__ void kuserbuffers_pushrecv(int myrank, int peer, int *recv_id, int *flagptr, int adder) { +#define CHECK_CE(ce_start, ce_end) ((ce_start) != nullptr && (ce_end) != nullptr && \ + *(ce_start) != *(ce_end)) + +__global__ void kuserbuffers_pushrecv(int myrank, int peer, int nvrank, int nvpeer, int *recv_id, + int *flagptr, int adder, uint64_t ub_timeout, + int *ce_start_ptr, int *ce_end_ptr) { const int signal_id = (*recv_id) + adder; *recv_id = signal_id; volatile int *flag = (volatile int *)flagptr; @@ -3146,8 +2077,12 @@ __global__ void kuserbuffers_pushrecv(int myrank, int peer, int *recv_id, int *f return; clock_t s = clock64(); while (CHECK_IDS(*flag, signal_id)) { - if (clock64() - s > TIMEOUT) { - printf("%d from %d] pushrecv: expected %d, stuck with %d\n", myrank, peer, signal_id, *flag); + if (CHECK_TIMEOUT(s, ub_timeout)) { + UB_PRINT("pushrecv [grank dst:%d global src:%d][nvrank(GPU) dst: %d src: %d] : " + "expected %d, observed %d", myrank, peer, nvrank, nvpeer, signal_id, *flag); + if (CHECK_CE(ce_start_ptr, ce_end_ptr)) + UB_PRINT("pushrecv: CE deadlock DETECTED: %d (ce_start) != %d (ce_end)\n", + *ce_start_ptr, *ce_end_ptr); return; } } @@ -3155,8 +2090,9 @@ __global__ void kuserbuffers_pushrecv(int myrank, int peer, int *recv_id, int *f __global__ void __launch_bounds__(MAX_THREADS) kuserbuffers_pushsendrecv(int *send_id, int *send_flagptr, int4 *srcptr, int4 *dstptr, - const int lines, int myrank, int peer, int *recv_id, - int *recv_flagptr, int adder) { + const int lines, int send_peer, int recv_peer, int *recv_id, + int *recv_flagptr, int adder, uint64_t ub_timeout, + int nv_send, int nv_recv, int *ce_start_ptr, int *ce_end_ptr) { if (lines) { const int start_elem = threadIdx.x + blockDim.x * blockIdx.x; const int end_elem = lines; @@ -3197,9 +2133,13 @@ __global__ void __launch_bounds__(MAX_THREADS) return; clock_t s = clock64(); while (CHECK_IDS(*flag, signal_id)) { - if (clock64() - s > TIMEOUT) { - printf("%d from %d] pushrecv: expected %d, stuck with %d\n", myrank, peer, signal_id, - *flag); + if (CHECK_TIMEOUT(s, ub_timeout)) { + UB_PRINT("pushsendrecv [sending peer:%d receiving peer:%d][nvrank(GPU) sending peer: %d" + " receiving peer: %d]: expected %d, observed %d", + send_peer, recv_peer, nv_send, nv_recv, signal_id, *flag); + if (CHECK_CE(ce_start_ptr, ce_end_ptr)) + UB_PRINT("pushrecv: CE deadlock DETECTED: %d (ce_start) != %d (ce_end)\n", + *ce_start_ptr, *ce_end_ptr); return; } } @@ -3208,8 +2148,10 @@ __global__ void __launch_bounds__(MAX_THREADS) __global__ void __launch_bounds__(MAX_THREADS) kuserbuffers_pushsendrecv_atomic(int *send_id, int *send_flagptr, int4 *srcptr, int4 *dstptr, - const int lines, int myrank, int peer, int *recv_id, - int *recv_flagptr, int adder, void *counters) { + const int lines, int send_peer, int recv_peer, int *recv_id, + int *recv_flagptr, int adder, void *counters, + uint64_t ub_timeout, int nv_send, int nv_recv, + int *ce_start_ptr, int *ce_end_ptr) { if (lines) { const int start_elem = threadIdx.x + blockDim.x * blockIdx.x; const int end_elem = lines; @@ -3246,12 +2188,15 @@ __global__ void __launch_bounds__(MAX_THREADS) const int signal_id = (*recv_id) + adder; *recv_id = signal_id; volatile int *flag = (volatile int *)recv_flagptr; - // if(*flag>=signal_id) return; clock_t s = clock64(); while (CHECK_IDS(*flag, signal_id)) { - if (clock64() - s > TIMEOUT) { - printf("%d from %d] pushrecv: expected %d, stuck with %d\n", myrank, peer, signal_id, - *flag); /*return;*/ + if (CHECK_TIMEOUT(s, ub_timeout)) { + UB_PRINT("pushsendrecv atomic [sending peer:%d receiving peer:%d][nvrank(GPU) sending peer:" + " %d receiving peer: %d]: expected %d, observed %d", + send_peer, recv_peer, nv_send, nv_recv, signal_id, *flag); /*return;*/ + if (CHECK_CE(ce_start_ptr, ce_end_ptr)) + UB_PRINT("pushsendrecv atomic: CE deadlock DETECTED: %d (ce_start) != %d (ce_end)\n", + *ce_start_ptr, *ce_end_ptr); } } @@ -3265,13 +2210,14 @@ __global__ void __launch_bounds__(MAX_THREADS) __global__ void __launch_bounds__(MAX_THREADS) kuserbuffers_pushsendrecv_multiatomic(int *send_id, int *send_flagptr, int4 *srcptr, - int4 *dstptr, const int lines, int myrank, int peer, - int *recv_id, int *recv_flagptr, int adder, + int4 *dstptr, const int lines, int send_peer, + int recv_peer, int *recv_id, int *recv_flagptr, int adder, void *counters, int nchunks, int send_stride, - int recv_stride, bool shuffle) { + int recv_stride, bool shuffle, + uint64_t ub_timeout, int nv_send, int nv_recv) { for (int chunk_i = 0; chunk_i < nchunks - 1; chunk_i++) { - int send_chunk_id = shuffle ? chunk_i : (nchunks + myrank - chunk_i) % nchunks; - int recv_chunk_id = shuffle ? chunk_i + 1 : (nchunks + myrank - chunk_i - 1) % nchunks; + int send_chunk_id = shuffle ? chunk_i : (nchunks + send_peer - chunk_i) % nchunks; + int recv_chunk_id = shuffle ? chunk_i + 1 : (nchunks + send_peer - chunk_i - 1) % nchunks; int send_offset = (send_chunk_id * send_stride) / 16; int recv_offset = ((shuffle ? recv_chunk_id : send_chunk_id) * recv_stride) / 16; @@ -3313,12 +2259,14 @@ __global__ void __launch_bounds__(MAX_THREADS) const int signal_id = (*recv_id) + adder; *recv_id = signal_id; volatile int *flag = (volatile int *)recv_flagptr; - // if(*flag>=signal_id) return; clock_t s = clock64(); while (CHECK_IDS(*flag, signal_id)) { - if (clock64() - s > TIMEOUT) { - printf("%d from %d] pushrecv: expected %d, stuck with %d\n", myrank, peer, signal_id, - *flag); /*return;*/ + if (CHECK_TIMEOUT(s, ub_timeout)) { + UB_PRINT("pushsendrecv multiatomic [sending peer:%d receiving peer:%d][nvrank(GPU)" + " sending peer: %d receiving peer: %d]: expected %d, observed %d", + send_peer, recv_peer, nv_send, nv_recv, signal_id, *flag); /*return;*/ + // CE mode is not supported for multi-atomic, so there is no need to check for a deadlock + return; } } } @@ -3334,9 +2282,8 @@ __global__ void __launch_bounds__(MAX_THREADS) // sync all CTAs before moving to next chunk. if (threadIdx.x == 0) { - int old_val2; atomicInc(((unsigned int *)counters) + nchunks + chunk_i, gridDim.x - 1); - while (0 != (old_val2 = atomicCAS(((unsigned int *)counters) + nchunks + chunk_i, 0, 0))) { + while (0 != (atomicCAS(((unsigned int *)counters) + nchunks + chunk_i, 0, 0))) { } } __syncthreads(); @@ -3352,50 +2299,56 @@ __global__ void __launch_bounds__(MAX_THREADS) } \ } while (0) +// Return TRUE if two ranks share the same NV domain #define INTRANODE(peer) ((peer / comm->nvsize) == (comm->myrank / comm->nvsize)) +// Index corresponds to the type of flag: +// 0 - Send index counter +// 1 - CE start index counter +// 2 - CE end index counter +#define GET_SEND_PTR_BY_INDEX(peerlocal, comm, dsth, index) \ + ((reinterpret_cast((comm)->peer_ptr[0][(peerlocal)])) + \ + ((NVTE_REG0_OFFSET(comm) + NVTE_REG0_RECV + \ + (comm)->myrank * NVTE_MAX_REGIONS + (dsth) + \ + (index) * NVTE_MAX_NVLINK * NVTE_MAX_REGIONS) * \ + sizeof(int))) + +// Index corresponds to the type of flag: +// 0 - Receive index counter +// 1 - CE start index counter +// 2 - CE end index counter +#define GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsth, index) \ + ((reinterpret_cast((comm)->mem_ptr[0])) + \ + ((NVTE_REG0_OFFSET(comm) + \ + NVTE_REG0_RECV + (recv_peer) * NVTE_MAX_REGIONS + \ + (dsth) + (index) * NVTE_MAX_NVLINK * NVTE_MAX_REGIONS) * \ + sizeof(int))) + void userbuffers_send(const int srchandler, const size_t srcoffset, const int dsthandler, const size_t dstoffset, const size_t bytes, communicator *comm, const int peer, cudaStream_t stream) { - int peerlocal = peer % comm->nvsize; - void *flagptr = - (comm->peer_ptr[0][peerlocal]) + - ((NVTE_REG0_OFFSET(comm) + NVTE_REG0_RECV + comm->myrank * NVTE_MAX_REGIONS + dsthandler) * - sizeof(int)); - bool signalonly = (bytes / 16 == 0) || (comm->use_ce != 0); - bool intranode = INTRANODE(peer); - if (!intranode && (comm->launch_mode & NVTE_LAUNCH_CPU)) { - comm->fifo[comm->head].optype = userbuffers_sendop; - comm->fifo[comm->head].basecounter = comm->basecounter[userbuffers_sendop]; - comm->fifo[comm->head].handler = srchandler; - comm->fifo[comm->head].offset = srcoffset; - comm->fifo[comm->head].handler2 = dsthandler; - comm->fifo[comm->head].offset2 = dstoffset; - comm->fifo[comm->head].elements = bytes; - comm->fifo[comm->head].peer = peer; - - int newhead = (comm->head + 1) & (NVTE_MAX_REQUESTS - 1); - while (newhead == comm->tail) { - } - comm->head = newhead; - comm->basecounter[userbuffers_sendop] += 1; - } - if (!intranode && (comm->launch_mode & NVTE_LAUNCH_GPU)) { - kuserbuffers_proxysend<<<1, 1, 0, stream>>>(&(comm->flags[NVTE_GF_STATE + userbuffers_sendop]), - comm->hostflags + userbuffers_sendop); - return; - } + int peerlocal = peer % comm->nvsize; + void *flagptr = GET_SEND_PTR_BY_INDEX(peerlocal, comm, dsthandler, 0); + void *ce_send_start_ptr = GET_SEND_PTR_BY_INDEX(peerlocal, comm, dsthandler, 1); + void *ce_send_end_ptr = GET_SEND_PTR_BY_INDEX(peerlocal, comm, dsthandler, 2); + bool signalonly = (bytes / 16 == 0) || (comm->use_ce != 0); + + assert(INTRANODE(peer)); + if (!(comm->launch_mode & NVTE_LAUNCH_GPU)) return; if (comm->push == 0) { kuserbuffers_pullsend<<<1, 1, 0, stream>>>(comm->myrank, peer, &(comm->send_id[peer]), reinterpret_cast(flagptr)); } else { - void *srcptr = (comm->mem_ptr[srchandler]) + srcoffset; - void *dstptr = (comm->peer_ptr[dsthandler][peerlocal]) + dstoffset; + void *srcptr = reinterpret_cast(comm->mem_ptr[srchandler]) + srcoffset; + void *dstptr = reinterpret_cast(comm->peer_ptr[dsthandler][peerlocal]) + dstoffset; - if (comm->use_ce) + if (comm->use_ce) { + kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast(ce_send_start_ptr)); CUDACHECK(cudaMemcpyAsync(dstptr, srcptr, bytes, cudaMemcpyDeviceToDevice, stream)); + kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast(ce_send_end_ptr)); + } SETUP_LAUNCH_CONFIG(signalonly ? 1 : comm->sms, signalonly ? 1 : 1024, stream); int *arg1 = &comm->send_id[peer], *arg2 = reinterpret_cast(flagptr); int4 *arg3 = reinterpret_cast(srcptr), *arg4 = reinterpret_cast(dstptr); @@ -3414,19 +2367,20 @@ void userbuffers_sendrecv(const int srchandler, const int dsthandler, const size bool signalonly = (bytes / 16 == 0) || (comm->use_ce != 0); int send_peerlocal = send_peer % comm->nvsize; int recv_peerlocal = recv_peer % comm->nvsize; - void *flagptr_send = - (comm->peer_ptr[0][send_peerlocal]) + - ((NVTE_REG0_OFFSET(comm) + NVTE_REG0_RECV + comm->myrank * NVTE_MAX_REGIONS + dsthandler) * - sizeof(int)); - void *flagptr_recv = - (comm->mem_ptr[0]) + - ((NVTE_REG0_OFFSET(comm) + NVTE_REG0_RECV + recv_peer * NVTE_MAX_REGIONS + dsthandler) * - sizeof(int)); - - void *send_srcptr = (comm->mem_ptr[srchandler]) + send_offset; - void *send_dstptr = (comm->peer_ptr[dsthandler][send_peerlocal]) + send_offset; - if (comm->use_ce) + void *flagptr_send = GET_SEND_PTR_BY_INDEX(send_peerlocal, comm, dsthandler, 0); + void *ce_send_start_ptr = GET_SEND_PTR_BY_INDEX(send_peerlocal, comm, dsthandler, 1); + void *ce_send_end_ptr = GET_SEND_PTR_BY_INDEX(send_peerlocal, comm, dsthandler, 2); + void *flagptr_recv = GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsthandler, 0); + + void *send_srcptr = reinterpret_cast(comm->mem_ptr[srchandler]) + send_offset; + void *send_dstptr = reinterpret_cast(comm->peer_ptr[dsthandler][send_peerlocal]) + + send_offset; + + if (comm->use_ce) { + kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast(ce_send_start_ptr)); CUDACHECK(cudaMemcpyAsync(send_dstptr, send_srcptr, bytes, cudaMemcpyDeviceToDevice, stream)); + kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast(ce_send_end_ptr)); + } SETUP_LAUNCH_CONFIG(signalonly ? 1 : comm->sms, signalonly ? 1 : 1024, stream); int *arg1 = &comm->send_id[send_peer]; @@ -3434,19 +2388,30 @@ void userbuffers_sendrecv(const int srchandler, const int dsthandler, const size int4 *arg3 = reinterpret_cast(send_srcptr); int4 *arg4 = reinterpret_cast(send_dstptr); int arg5 = signalonly ? 0 : bytes / 16; - int arg6 = comm->myrank; + int arg6 = send_peer; int arg7 = recv_peer; int *arg8 = &comm->recv_id[recv_peer * NVTE_MAX_REGIONS + dsthandler]; int *arg9 = reinterpret_cast(flagptr_recv); int arg10 = signalonly ? 1 : comm->sms; + uint64_t arg11 = comm->ub_timeout; + int arg12 = send_peerlocal; + int arg13 = recv_peerlocal; + int *arg14 = reinterpret_cast(comm->use_ce ? + GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsthandler, 1): + nullptr); + int *arg15 = reinterpret_cast(comm->use_ce ? + GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsthandler, 2): + nullptr); void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), reinterpret_cast(&arg3), reinterpret_cast(&arg4), reinterpret_cast(&arg5), reinterpret_cast(&arg6), reinterpret_cast(&arg7), reinterpret_cast(&arg8), - reinterpret_cast(&arg9), reinterpret_cast(&arg10)}; + reinterpret_cast(&arg9), reinterpret_cast(&arg10), + reinterpret_cast(&arg11), reinterpret_cast(&arg12), + reinterpret_cast(&arg13), reinterpret_cast(&arg14), + reinterpret_cast(&arg15)}; CUDACHECK( cudaLaunchKernelExC(&cfg, reinterpret_cast(kuserbuffers_pushsendrecv), kernelArgs)); - //} } void userbuffers_sendrecv_atomic(const int srchandler, const int dsthandler, @@ -3458,19 +2423,18 @@ void userbuffers_sendrecv_atomic(const int srchandler, const int dsthandler, int send_peerlocal = send_peer % comm->nvsize; int recv_peerlocal = recv_peer % comm->nvsize; - void *flagptr_send = - (comm->peer_ptr[0][send_peerlocal]) + - ((NVTE_REG0_OFFSET(comm) + NVTE_REG0_RECV + comm->myrank * NVTE_MAX_REGIONS + dsthandler) * - sizeof(int)); - void *flagptr_recv = - (comm->mem_ptr[0]) + - ((NVTE_REG0_OFFSET(comm) + NVTE_REG0_RECV + recv_peer * NVTE_MAX_REGIONS + dsthandler) * - sizeof(int)); - - void *send_srcptr = (comm->mem_ptr[srchandler]) + send_offset; - void *send_dstptr = (comm->peer_ptr[dsthandler][send_peerlocal]) + send_offset; + void *flagptr_send = GET_SEND_PTR_BY_INDEX(send_peerlocal, comm, dsthandler, 0); + void *ce_send_start_ptr = GET_SEND_PTR_BY_INDEX(send_peerlocal, comm, dsthandler, 1); + void *ce_send_end_ptr = GET_SEND_PTR_BY_INDEX(send_peerlocal, comm, dsthandler, 2); + void *flagptr_recv = GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsthandler, 0); + + void *send_srcptr = reinterpret_cast(comm->mem_ptr[srchandler]) + send_offset; + void *send_dstptr = reinterpret_cast(comm->peer_ptr[dsthandler][send_peerlocal]) + + send_offset; if (comm->use_ce) { + kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast(ce_send_start_ptr)); CUDACHECK(cudaMemcpyAsync(send_dstptr, send_srcptr, bytes, cudaMemcpyDeviceToDevice, stream)); + kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast(ce_send_end_ptr)); } SETUP_LAUNCH_CONFIG(signalonly ? 1 : comm->sms, signalonly ? 1 : 1024, stream); @@ -3479,18 +2443,29 @@ void userbuffers_sendrecv_atomic(const int srchandler, const int dsthandler, int4 *arg3 = reinterpret_cast(send_srcptr); int4 *arg4 = reinterpret_cast(send_dstptr); int arg5 = signalonly ? 0 : bytes / 16; - int arg6 = comm->myrank; + int arg6 = send_peer; int arg7 = recv_peer; int *arg8 = &comm->recv_id[recv_peer * NVTE_MAX_REGIONS + dsthandler]; int *arg9 = reinterpret_cast(flagptr_recv); int arg10 = signalonly ? 1 : comm->sms; void *arg11 = counters; + int arg12 = comm->ub_timeout; + int arg13 = send_peerlocal; + int arg14 = recv_peerlocal; + int *arg15 = reinterpret_cast(comm->use_ce ? + GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsthandler, 1) : + nullptr); + int *arg16 = reinterpret_cast(comm->use_ce ? + GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsthandler, 2) : + nullptr); void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), reinterpret_cast(&arg3), reinterpret_cast(&arg4), reinterpret_cast(&arg5), reinterpret_cast(&arg6), reinterpret_cast(&arg7), reinterpret_cast(&arg8), reinterpret_cast(&arg9), reinterpret_cast(&arg10), - reinterpret_cast(&arg11)}; + reinterpret_cast(&arg11), reinterpret_cast(&arg12), + reinterpret_cast(&arg13), reinterpret_cast(&arg14), + reinterpret_cast(&arg15), reinterpret_cast(&arg16)}; CUDACHECK(cudaLaunchKernelExC(&cfg, reinterpret_cast(kuserbuffers_pushsendrecv_atomic), kernelArgs)); } @@ -3501,17 +2476,12 @@ void userbuffers_sendrecv_multiatomic(const int srchandler, const int dsthandler const int recv_peer, const int nchunks, void *counters, bool shuffle, cudaStream_t stream) { assert(comm->push && comm->use_ce == 0); + // CE is not supported int send_peerlocal = send_peer % comm->nvsize; int recv_peerlocal = recv_peer % comm->nvsize; - void *flagptr_send = - (comm->peer_ptr[0][send_peerlocal]) + - ((NVTE_REG0_OFFSET(comm) + NVTE_REG0_RECV + comm->myrank * NVTE_MAX_REGIONS + dsthandler) * - sizeof(int)); - void *flagptr_recv = - (comm->mem_ptr[0]) + - ((NVTE_REG0_OFFSET(comm) + NVTE_REG0_RECV + recv_peer * NVTE_MAX_REGIONS + dsthandler) * - sizeof(int)); + void *flagptr_send = GET_SEND_PTR_BY_INDEX(send_peerlocal, comm, dsthandler, 0); + void *flagptr_recv = GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsthandler, 0); SETUP_LAUNCH_CONFIG(comm->sms, 1024, stream); @@ -3530,6 +2500,9 @@ void userbuffers_sendrecv_multiatomic(const int srchandler, const int dsthandler int arg13 = send_stride; int arg14 = recv_stride; bool arg15 = shuffle; + uint64_t arg16 = comm->ub_timeout; + int arg17 = send_peerlocal; + int arg18 = recv_peerlocal; void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), reinterpret_cast(&arg3), reinterpret_cast(&arg4), reinterpret_cast(&arg5), reinterpret_cast(&arg6), @@ -3537,95 +2510,33 @@ void userbuffers_sendrecv_multiatomic(const int srchandler, const int dsthandler reinterpret_cast(&arg9), reinterpret_cast(&arg10), reinterpret_cast(&arg11), reinterpret_cast(&arg12), reinterpret_cast(&arg13), reinterpret_cast(&arg14), - reinterpret_cast(&arg15)}; + reinterpret_cast(&arg15), reinterpret_cast(&arg16), + reinterpret_cast(&arg17), reinterpret_cast(&arg18)}; CUDACHECK(cudaLaunchKernelExC( - &cfg, reinterpret_cast(kuserbuffers_pushsendrecv_multiatomic), kernelArgs)); -} - -__global__ void __launch_bounds__(MAX_THREADS) - kuserbuffers_alltoall(void **baseflagptrs, int flagoffset, int4 *basesrcptr, void **dstptrs, - size_t dstoffset, const int lines, const int myrank) { - if (blockIdx.x == myrank) - return; - int4 *dstptr = reinterpret_cast(dstptrs[blockIdx.x] + dstoffset); - int *flagptr = reinterpret_cast(baseflagptrs[blockIdx.x] + flagoffset); - const size_t myblockoffset = blockIdx.x * lines; - int4 *srcptr = basesrcptr + myblockoffset; - dstptr += myblockoffset; - - if (lines) { - const int start_elem = threadIdx.x; - const int end_elem = lines; - const int aligned_elem = ((end_elem - start_elem) & (~(blockDim.x * UNROLLCOPY - 1))); - const int end_aligned = start_elem + aligned_elem; - if (end_elem > start_elem) { - for (int line = start_elem; line < end_aligned; line += blockDim.x * UNROLLCOPY) { - int4 val[UNROLLCOPY]; -#pragma unroll - for (int i = 0; i < UNROLLCOPY; i++) - val[i] = srcptr[line + i * blockDim.x]; -#pragma unroll - for (int i = 0; i < UNROLLCOPY; i++) - dstptr[line + i * blockDim.x] = val[i]; - } - for (int line = end_aligned; line < end_elem; line += blockDim.x) - dstptr[line] = srcptr[line]; - } - __syncthreads(); - if (threadIdx.x) - return; - __threadfence_system(); - atomicAdd(flagptr, 1); - - } else { - atomicAdd(flagptr, 1); - } -} - -void userbuffers_alltoall_send(const int srchandler, const size_t srcoffset, const int dsthandler, - const size_t dstoffset, const size_t bytes, communicator *comm, - cudaStream_t stream) { - if (comm->launch_mode & NVTE_LAUNCH_CPU) { - comm->fifo[comm->head].optype = userbuffers_alltoall; - comm->fifo[comm->head].basecounter = comm->basecounter[userbuffers_alltoall]; - comm->fifo[comm->head].handler = srchandler; - comm->fifo[comm->head].offset = srcoffset; - comm->fifo[comm->head].handler2 = dsthandler; - comm->fifo[comm->head].offset2 = dstoffset; - comm->fifo[comm->head].elements = bytes; - - int newhead = (comm->head + 1) & (NVTE_MAX_REQUESTS - 1); - while (newhead == comm->tail) { - } - comm->head = newhead; - comm->basecounter[userbuffers_alltoall] += 1; - } - if (comm->launch_mode & NVTE_LAUNCH_GPU) - kuserbuffers_proxysend<<<1, 1, 0, stream>>>( - &(comm->flags[NVTE_GF_STATE + userbuffers_alltoall]), - comm->hostflags + userbuffers_alltoall); + &cfg, reinterpret_cast(kuserbuffers_pushsendrecv_multiatomic), kernelArgs)); } void userbuffers_recv(const int srchandler, const size_t srcoffset, const int dsthandler, const size_t dstoffset, const size_t bytes, communicator *comm, const int peer, cudaStream_t stream) { - int peerlocal = peer % comm->nvsize; - void *flagptr = - (comm->mem_ptr[0]) + - ((NVTE_REG0_OFFSET(comm) + NVTE_REG0_RECV + peer * NVTE_MAX_REGIONS + dsthandler) * - sizeof(int)); + int peerlocal = peer % comm->nvsize; + void *flagptr = GET_RECV_PTR_BY_INDEX(peer, comm, dsthandler, 0); bool signalonly = (bytes / 16 == 0) || (comm->use_ce != 0); - bool intranode = INTRANODE(peer); + + assert(INTRANODE(peer)); + if (!(comm->launch_mode & NVTE_LAUNCH_GPU)) return; - if (comm->push == 0 && intranode) { - void *dstptr = (comm->mem_ptr[dsthandler]) + dstoffset; - void *srcptr = (comm->peer_ptr[srchandler][peerlocal]) + srcoffset; + if (comm->push == 0) { + void *dstptr = reinterpret_cast(comm->mem_ptr[dsthandler]) + dstoffset; + void *srcptr = reinterpret_cast(comm->peer_ptr[srchandler][peerlocal]) + srcoffset; kuserbuffers_pullrecv<<sms, signalonly ? 1 : 1024, 0, stream>>>( - comm->myrank, peer, &(comm->recv_id[peer * NVTE_MAX_REGIONS + dsthandler]), + comm->myrank, peer, comm->nvrank, + peerlocal, &(comm->recv_id[peer * NVTE_MAX_REGIONS + dsthandler]), reinterpret_cast(flagptr), reinterpret_cast(srcptr), - reinterpret_cast(dstptr), signalonly ? 0 : bytes / 16); + reinterpret_cast(dstptr), signalonly ? 0 : bytes / 16, + comm->ub_timeout); if (!signalonly) kuserbuffers_inc<<<1, 1, 0, stream>>>(&(comm->recv_id[peer * NVTE_MAX_REGIONS + dsthandler])); if (comm->use_ce) { @@ -3633,22 +2544,17 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds } } else { kuserbuffers_pushrecv<<<1, 1, 0, stream>>>( - comm->myrank, peer, &comm->recv_id[peer * NVTE_MAX_REGIONS + dsthandler], - reinterpret_cast(flagptr), signalonly || !intranode ? 1 : comm->sms); + comm->myrank, peer, comm->nvrank, peerlocal, + &comm->recv_id[peer * NVTE_MAX_REGIONS + dsthandler], + reinterpret_cast(flagptr), signalonly || comm->sms, + comm->ub_timeout, + reinterpret_cast(comm->use_ce ? + GET_RECV_PTR_BY_INDEX(peer, comm, dsthandler, 1) : nullptr), + reinterpret_cast(comm->use_ce ? + GET_RECV_PTR_BY_INDEX(peer, comm, dsthandler, 2) : nullptr)); } } -void userbuffers_alltoall_recv(communicator *comm, cudaStream_t stream) { - void *flagptr = - (comm->mem_ptr[0]) + - ((NVTE_REG0_OFFSET(comm) + NVTE_REG0_OPFLAGS * userbuffers_alltoall) * sizeof(int)); - - if (!(comm->launch_mode & NVTE_LAUNCH_GPU)) - return; - kuserbuffers_pushrecv<<<1, 1, 0, stream>>>(comm->myrank, -1, reinterpret_cast(flagptr + 4), - reinterpret_cast(flagptr), comm->nranks - 1); -} - // producer static __global__ void producer_kernel(void *atomic_ptr, int chunk_i) { // Decrement atomic val to signal current output tile finish @@ -3666,8 +2572,7 @@ static __global__ void producer_kernel(void *atomic_ptr, int chunk_i) { static __global__ void consumer_kernel(void *atomic_ptr, int chunk_i) { // Wait for producer to change the val to 0, which signal producer ready if (blockIdx.x == 0 && threadIdx.x == 0) { - int old_val; - while (0 != (old_val = atomicCAS((unsigned int *)atomic_ptr + chunk_i, 0, 0))) { + while (0 != (atomicCAS((unsigned int *)atomic_ptr + chunk_i, 0, 0))) { } ((unsigned int *)atomic_ptr)[chunk_i] = 1; asm volatile("fence.sc.gpu;\n"); @@ -3678,9 +2583,8 @@ static __global__ void consumer_kernel(void *atomic_ptr, int chunk_i) { static __global__ void consumer_batch_kernel(void *atomic_ptr, int first_chunk_i, int num_chunks) { // Wait for producer to change the val to 0, which signal producer ready if (blockIdx.x == 0 && threadIdx.x == 0) { - int old_val; for (int i = first_chunk_i; i < num_chunks; i++) { - while (0 != (old_val = atomicCAS((unsigned int *)atomic_ptr + i, 0, 0))) { + while (0 != (atomicCAS((unsigned int *)atomic_ptr + i, 0, 0))) { } ((unsigned int *)atomic_ptr)[i] = 1; asm volatile("fence.sc.gpu;\n"); diff --git a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.h b/transformer_engine/pytorch/csrc/userbuffers/userbuffers.h index 1306636881..8d4a887f52 100644 --- a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.h +++ b/transformer_engine/pytorch/csrc/userbuffers/userbuffers.h @@ -12,7 +12,6 @@ #include "cuda_runtime.h" #include #include -#include "gdrapi.h" #include #define NVTE_MAX_REGIONS 16 @@ -32,10 +31,6 @@ #define NVTE_UB_MEM_MC_CREATED 2 #define NVTE_UB_MEM_ALLOCATED 4 -#ifdef UCP -#include -#endif - // region 0 flag offsets #define NVTE_REG0_OPFLAGS 1024 #define NVTE_REG0_RECV (NVTE_REG0_OPFLAGS * userbuffers_op_types) @@ -43,7 +38,8 @@ #define NVTE_REG0_OFFSET(comm) ((2 * NVTE_MAX_REGIONS) * NVTE_MAX_NVLINK \ + NVTE_REG0_SINGLENODE * 2 + NVTE_MAX_PEERS) #define NVTE_REG0_COMMBUFFER 0 -#define NVTE_REG0_FLAGS (NVTE_REG0_RECV + NVTE_MAX_PEERS * NVTE_MAX_REGIONS) +// x3 for [flagptr, ce_start_ptr, ce_end_ptr] +#define NVTE_REG0_FLAGS (NVTE_REG0_RECV + NVTE_MAX_PEERS * NVTE_MAX_REGIONS * 3) #define NVTE_REG0_IBRS 32 #define NVTE_REG0_IBAG 512 @@ -122,16 +118,11 @@ struct communicator { // max value for running block counters in hostflags int basecounter[userbuffers_op_types]; // NOLINT(*) - int *hostflags; int *flags, *map_flags; - gdr_t g; - struct sharp_coll_context *sharp_coll_context; - struct sharp_coll_comm *sharp_coll_comm; void *mem_mr[NVTE_MAX_REGIONS]; ub_request *fifo; - volatile int activeproxy; int nblocks, alignblock, minblock, asyncblocks, active_nreqs; ub_request active_req[userbuffers_op_types]; // NOLINT(*) int padding[7]; @@ -142,10 +133,9 @@ struct communicator { MPI_Request mpihndl[NVTE_MAX_SHARP]; MPI_Comm comm_inter, // reduction group communicator (subset of the nodes) along GPU rail comm_intra; // full intranode (all ndev GPUS) - int ibnvsize; // can be used to fake smaller or larger nvlink domain to use ib instead of nvlink - // or force MNNVL int *send_id, *recv_id; int mydev; + uint64_t ub_timeout; }; typedef struct communicator communicator; @@ -185,23 +175,9 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * SHARP and NSO/MNNVL) */ -void allreduce_userbuff_inplace(const int handler, const int offset, const int elements, - communicator *comm, cudaStream_t stream = 0); -// for DP distributed optimizer, only nonSHARP multinode is implemented & calls must come in pairs -// ordered -void allgather_userbuff_inplace(const int handler, const int offset, const int elements, - communicator *comm, cudaStream_t stream = 0); -void reducescatter_userbuff_inplace(const int handler, const int offset, const int elements, - communicator *comm, cudaStream_t stream = 0); - -void allreduce2_userbuff_inplace(const int handler, const int offset, const int elements, - communicator *comm, cudaStream_t stream = 0); // for TP-parallelism, only single node is implemented void allgather2_userbuff_inplace(const int handler, const int offset, const int elements, communicator *comm, cudaStream_t stream = 0); -void allgather2_userbuff_inplace_sliced(const int handler, const int offset, const int elements, - communicator *comm, const int slice_id, const int nslices, - cudaStream_t stream = 0); /* each Rank input is allgather2_userbuff_inplace: offset+myrank*elements @@ -231,14 +207,6 @@ void reducescatter2_userbuff_stridedoutput_fp8(void* output, float* scale, const template void reducescatter2_userbuff_fp8(void* output, float* scale, const int handler, const int offset, const int elements, communicator* comm, cudaStream_t stream = 0); -#if 0 -template -void reducescatter2_userbuff_strided_atomic_fp8(void* output, float *scale, const int handler, - const int offset, const int rowelements, - const int colelements, const int strideelements, - const int numchunks, void *counters, - communicator* comm, cudaStream_t stream = 0); -#endif template void reducescatter2_userbuff_strided_atomic_fp8(void* output, float *scale, const int handler, const int offset, const int rowelements, From fc2a8bc14218e95bc8dd979b4243500645f49543 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Wed, 17 Apr 2024 19:40:39 -0500 Subject: [PATCH 026/244] [PyTorch] Fix for type checking failure on custom callables (#790) fix type checking in checkpointing to assume that there must be TE modules in custom callables Signed-off-by: Alp Dener Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/distributed.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 08da93587d..caaef91985 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -521,8 +521,11 @@ def has_te_modules(network): for module in network.modules(): if any(isinstance(module, te_class) for te_class in te_classes_list): return True + return False - return False + # Cannot check for TE modules inside a custom class/callable that's not a torch.nn.Module, + # so just assume that it has TE modules just to be safe. + return True def checkpoint( From df28cea6cf9d226f9323bb495ca502320656cd88 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Thu, 18 Apr 2024 10:05:37 -0500 Subject: [PATCH 027/244] [JAX] Fixing CI failure due to incorrect use of `static_argnums` in jax.jit (#785) * fixed static argnums for jax.jit in single gpu encoder test, changed warning filtering for pytest Signed-off-by: Alp Dener * propagating the fix to the JAX mnist example Signed-off-by: Alp Dener * fixed missing space ibetween flags i QAA scripts Signed-off-by: Alp Dener * added TE warnings into the ignore list Signed-off-by: Alp Dener --------- Signed-off-by: Alp Dener Signed-off-by: Pawel Gadzinski --- .../jax/encoder/test_single_gpu_encoder.py | 2 +- examples/jax/mnist/test_single_gpu_mnist.py | 2 +- qa/L0_jax_unittest/test.sh | 9 +++--- qa/L1_jax_distributed_unittest/test.sh | 2 +- tests/jax/pytest.ini | 28 +++++++++++++++++++ 5 files changed, 36 insertions(+), 7 deletions(-) create mode 100644 tests/jax/pytest.ini diff --git a/examples/jax/encoder/test_single_gpu_encoder.py b/examples/jax/encoder/test_single_gpu_encoder.py index 85e03342b2..ae5304628f 100644 --- a/examples/jax/encoder/test_single_gpu_encoder.py +++ b/examples/jax/encoder/test_single_gpu_encoder.py @@ -55,7 +55,7 @@ def __call__(self, x, mask, disable_dropout=False): return x -@partial(jax.jit, static_argnums=6) +@partial(jax.jit, static_argnums=(0, 1, 2, 3, 4, 5)) def train_step(state, inputs, masks, labels, var_collect, rngs): """Computes gradients, loss and accuracy for a single batch.""" diff --git a/examples/jax/mnist/test_single_gpu_mnist.py b/examples/jax/mnist/test_single_gpu_mnist.py index dc28a9fd46..f9824ae000 100644 --- a/examples/jax/mnist/test_single_gpu_mnist.py +++ b/examples/jax/mnist/test_single_gpu_mnist.py @@ -74,7 +74,7 @@ def loss_fn(var_collect, disable_dropout=False): return grads, loss, accuracy -@partial(jax.jit, static_argnums=2) +@partial(jax.jit, static_argnums=(0, 1)) def update_model(state, grads): """Update model params and FP8 meta.""" state = state.apply_gradients(grads=grads[PARAMS_KEY]) diff --git a/qa/L0_jax_unittest/test.sh b/qa/L0_jax_unittest/test.sh index 9f20769045..b640e3ee4f 100644 --- a/qa/L0_jax_unittest/test.sh +++ b/qa/L0_jax_unittest/test.sh @@ -5,14 +5,15 @@ set -xe : ${TE_PATH:=/opt/transformerengine} -pytest -Wignore -v $TE_PATH/tests/jax -k 'not distributed' + +pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax -k 'not distributed' pip install -r $TE_PATH/examples/jax/mnist/requirements.txt pip install -r $TE_PATH/examples/jax/encoder/requirements.txt -pytest -Wignore -v $TE_PATH/examples/jax/mnist +pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/mnist # Make encoder tests to have run-to-run deterministic to have the stable CI results export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" -pytest -Wignore -v $TE_PATH/examples/jax/encoder --ignore=$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py -pytest -Wignore -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py +pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder --ignore=$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py +pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py diff --git a/qa/L1_jax_distributed_unittest/test.sh b/qa/L1_jax_distributed_unittest/test.sh index 51512d0744..1966f35208 100644 --- a/qa/L1_jax_distributed_unittest/test.sh +++ b/qa/L1_jax_distributed_unittest/test.sh @@ -5,5 +5,5 @@ set -xe : ${TE_PATH:=/opt/transformerengine} -pytest -Wignore -v $TE_PATH/tests/jax/test_distributed_* +pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_* diff --git a/tests/jax/pytest.ini b/tests/jax/pytest.ini new file mode 100644 index 0000000000..4da88e1476 --- /dev/null +++ b/tests/jax/pytest.ini @@ -0,0 +1,28 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +[pytest] +filterwarnings= + ignore:sharding_type of.*:DeprecationWarning + ignore:major_sharding_type of.*:DeprecationWarning + ignore:Fused attention is not enabled.*:UserWarning + ignore:The hookimpl.*:DeprecationWarning + ignore:xmap is an experimental feature and probably has bugs! + ignore:the imp module is deprecated in favour of importlib.*:DeprecationWarning + ignore:can't resolve package from __spec__ or __package__:ImportWarning + ignore:Using or importing the ABCs.*:DeprecationWarning + ignore:numpy.ufunc size changed + ignore:.*experimental feature + ignore:The distutils.* is deprecated.*:DeprecationWarning + ignore:backend and device argument on jit is deprecated.*:DeprecationWarning + ignore:ml_dtypes.float8_e4m3b11 is deprecated. + ignore:np.find_common_type is deprecated.*:DeprecationWarning + ignore:jax.numpy.in1d is deprecated.*:DeprecationWarning + ignore:The numpy.array_api submodule is still experimental.*:UserWarning + ignore:case not machine-readable.*:UserWarning + ignore:not machine-readable.*:UserWarning + ignore:Special cases found for .* but none were parsed.*:UserWarning + ignore:jax.extend.mlir.dialects.mhlo is deprecated.*:DeprecationWarning + ignore:jax.experimental.maps and .* are deprecated.*:DeprecationWarning + ignore:The host_callback APIs are deprecated .*:DeprecationWarning From 346e7da2e70eddbeff10230c233ce9cdf11356c3 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Fri, 19 Apr 2024 10:48:15 -0700 Subject: [PATCH 028/244] NVRTC kernels for cast-transpose (#258) * Add NVRTC kernels for cast-transpose Signed-off-by: Tim Moon * Update copyright year Signed-off-by: Tim Moon * Add noop flag to NVRTC cast-transpose kernel Signed-off-by: Tim Moon * Apply suggestions from code review Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --------- Signed-off-by: Tim Moon Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Pawel Gadzinski --- tests/cpp/operator/test_cast_transpose.cu | 5 +- transformer_engine/common/CMakeLists.txt | 6 +- .../common/transpose/cast_transpose.cu | 686 ++++++++---------- .../common/transpose/rtc/cast_transpose.cu | 129 ++++ .../common/transpose/transpose.cu | 181 +++-- 5 files changed, 550 insertions(+), 457 deletions(-) create mode 100644 transformer_engine/common/transpose/rtc/cast_transpose.cu diff --git a/tests/cpp/operator/test_cast_transpose.cu b/tests/cpp/operator/test_cast_transpose.cu index 4a548ddf6f..8c168c76f4 100644 --- a/tests/cpp/operator/test_cast_transpose.cu +++ b/tests/cpp/operator/test_cast_transpose.cu @@ -81,7 +81,10 @@ std::vector> test_cases = {{2048, 12288}, {65536, 128}, {256, 256}, {120, 2080}, - {8, 8}}; + {8, 8}, + {1, 3221}, // Prime 456 + {2333, 1}, // Prime 345 + {1481, 677}}; // Primes 234, 123 } // namespace class CTTestSuite : public ::testing::TestWithParam #include -#include -#include -#include -#include "../utils.cuh" -#include "../common.h" - -namespace transformer_engine { - -template -inline __device__ void cast_and_transpose_regs(const IVec (&in)[nvec_out], - OVec (&out_trans)[nvec_in], - typename OVec::type *output_cast_tile, - const size_t current_place, - const size_t stride, - CType &max, // NOLINT(*) - const CType scale, - const bool valid_store) { - using T = typename OVec::type; - using OVecC = Vec; -#pragma unroll - for (unsigned int i = 0; i < nvec_out; ++i) { - OVecC out_cast; -#pragma unroll - for (unsigned int j = 0; j < nvec_in; ++j) { - const CType tmp = static_cast(in[i].data.elt[j]); - const T elt_o = T(scale * tmp); - - out_cast.data.elt[j] = elt_o; - out_trans[j].data.elt[i] = elt_o; // thread tile transpose - - __builtin_assume(max >= 0); - max = fmaxf(fabsf(tmp), max); - } - if (full_tile || valid_store) { - out_cast.store_to(output_cast_tile, current_place + stride * i); - } - } -} +#include -// STUFF TO TUNE -constexpr unsigned int n_warps_per_tile = 4; +#include -constexpr unsigned int max_threads_per_block = 256; -static_assert(n_warps_per_tile * THREADS_PER_WARP <= max_threads_per_block); -constexpr unsigned int cast_transpose_num_threads = n_warps_per_tile * THREADS_PER_WARP; +#include "../common.h" +#include "../util/rtc.h" +#include "../util/string.h" +#include "../utils.cuh" -template -__global__ void -__launch_bounds__(cast_transpose_num_threads) -cast_transpose_kernel(const IType * const input, - const CType * const noop, - OType * const output_c, - OType * const output_t, - const CType * const scale_ptr, - CType * const amax, - const size_t row_length, - const size_t num_rows, - const size_t num_tiles) { - if (noop != nullptr && noop[0] == 1.0f) return; +namespace transformer_engine { - using IVec = Vec; - using OVec = Vec; - - extern __shared__ char scratch[]; - - const int warp_id = threadIdx.x / THREADS_PER_WARP; - const int my_id_in_warp = threadIdx.x % THREADS_PER_WARP; - const size_t num_tiles_x = row_length / (nvec_in * THREADS_PER_WARP); - const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) + - warp_id / n_warps_per_tile; - if (tile_id >= num_tiles) return; - const size_t tile_id_x = tile_id % num_tiles_x; - const size_t tile_id_y = tile_id / num_tiles_x; - - const IType * const my_input_tile = input + (tile_id_x * nvec_in + - tile_id_y * row_length * nvec_out) * - THREADS_PER_WARP; - OType * const my_output_c_tile = output_c + (tile_id_x * nvec_in + - tile_id_y * row_length * nvec_out) * - THREADS_PER_WARP; - OType * const my_output_t_tile = output_t + (tile_id_y * nvec_out + - tile_id_x * num_rows * nvec_in) * - THREADS_PER_WARP; - OVec * const my_scratch = reinterpret_cast(scratch) + - (my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) * - (THREADS_PER_WARP + 1); - - IVec in[2][nvec_out]; - const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile; - constexpr unsigned int n_iterations = THREADS_PER_WARP / n_warps_per_tile; - OVec out_space[n_iterations][nvec_in]; - - const size_t stride = row_length / nvec_in; - const size_t output_stride = num_rows / nvec_out; - size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride; - unsigned int my_place = (my_id_in_warp + THREADS_PER_WARP - - warp_id_in_tile * n_iterations) % - THREADS_PER_WARP; - CType max = 0; - const CType scale = scale_ptr != nullptr ? *scale_ptr : 1; -#pragma unroll - for (unsigned int i = 0; i < nvec_out; ++i) { - in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i); - } -#pragma unroll - for (unsigned int i = 0; i < n_iterations; ++i) { - const size_t current_place = current_stride + my_place; - const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; - const unsigned int current_in = (i + 1) % 2; - if (i < n_iterations - 1) { -#pragma unroll - for (unsigned int j = 0; j < nvec_out; ++j) { - in[current_in][j].load_from(my_input_tile, - current_stride + my_place_in + stride * (nvec_out + j)); - } +namespace { + +// String with RTC kernel implementation +#include "string_code_transpose_rtc_cast_transpose_cu.h" + +// Hard-coded kernel parameters +using CType = float; +constexpr size_t warps_per_tile = 4; +constexpr size_t block_size = THREADS_PER_WARP * warps_per_tile; + +/* Performance heuristics for optimized kernel parameters */ +struct KernelConfig { + /** Vector load size */ + size_t load_size = 0; + /** Vector store size to transposed output */ + size_t store_size = 0; + + /* Whether config is valid */ + bool valid = false; + /* Number of CUDA blocks */ + size_t num_blocks = 0; + + /* Number of active SMs */ + size_t active_sm_count = 0; + /* Elements per L1 cache load */ + size_t elements_per_load = 0; + /* Elements per L1 cache store to cast output*/ + size_t elements_per_store_c = 0; + /* Elements per L1 cache store to transposed output */ + size_t elements_per_store_t = 0; + + KernelConfig(size_t row_length, + size_t num_rows, + size_t itype_size, + size_t otype_size, + size_t load_size_, + size_t store_size_) + : load_size{load_size_} + , store_size{store_size_} { + // Check that tiles are correctly aligned + constexpr size_t cache_line_size = 128; + if (load_size % itype_size != 0 + || store_size % otype_size != 0 + || cache_line_size % itype_size != 0 + || cache_line_size % otype_size != 0) { + return; } - OVec out_trans[nvec_in]; // NOLINT(*) - cast_and_transpose_regs(in[current_in ^ 1], out_trans, my_output_c_tile, - current_place, stride, max, scale, true); -#pragma unroll - for (unsigned int j = 0; j < nvec_in; ++j) { - out_space[i][j].data.vec = out_trans[j].data.vec; + const size_t row_tile_elements = load_size * THREADS_PER_WARP / itype_size; + const size_t col_tile_elements = store_size * THREADS_PER_WARP / otype_size; + valid = (row_length % row_tile_elements == 0 + && num_rows % col_tile_elements == 0); + if (!valid) { + return; } - my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; - current_stride += nvec_out * stride; - } - for (unsigned int i = 0; i < nvec_in; ++i) { -#pragma unroll - for (unsigned int j = 0; j < n_iterations; ++j) { - my_scratch[(my_id_in_warp + THREADS_PER_WARP - - j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space[j][i]; - } - __syncthreads(); - my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % - THREADS_PER_WARP; - current_stride = i * output_stride + - warp_id_in_tile * n_iterations * output_stride * nvec_in; - for (unsigned int j = 0; j < n_iterations; ++j) { - my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile, - current_stride + my_place); - my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; - current_stride += output_stride * nvec_in; - } - __syncthreads(); + // Number of CUDA blocks + num_blocks = (row_length / row_tile_elements) * (num_rows / col_tile_elements); + + // Parameters for performance model + constexpr size_t warps_per_sm = 16; // Rough estimate for saturated SMs + active_sm_count = std::min(DIVUP(num_blocks * warps_per_tile, warps_per_sm), + static_cast(cuda::sm_count())); + elements_per_load = (std::min(cache_line_size, row_tile_elements * itype_size) + / itype_size); + elements_per_store_c = (std::min(cache_line_size, row_tile_elements * otype_size) + / otype_size); + elements_per_store_t = (std::min(cache_line_size, col_tile_elements * otype_size) + / otype_size); } - /* warp tile amax reduce*/ - max = reduce_max(max, warp_id); - - if (threadIdx.x == 0) { - static_assert(std::is_same::value); - if (amax != nullptr) atomicMaxFloat(amax, max); + /* Compare by estimated cost */ + bool operator<(const KernelConfig &other) const { + if (this->valid && other.valid) { + // cost ~ (1/elements_per_load + // + 1/elements_per_store_c + // + 1/elements_per_store_t) / active_sms + // Note: Integer arithmetic ensures stable ordering + const auto &l1 = this->elements_per_load; + const auto &sc1 = this->elements_per_store_c; + const auto &st1 = this->elements_per_store_t; + const auto &p1 = this->active_sm_count; + const auto &l2 = other.elements_per_load; + const auto &sc2 = other.elements_per_store_c; + const auto &st2 = other.elements_per_store_t; + const auto &p2 = other.active_sm_count; + const auto scale = l1 * sc1 * st1 * p1 * l2 * sc2 * st2 * p2; + const auto cost1 = (scale/l1 + scale/sc1 + scale/st1) / p1; + const auto cost2 = (scale/l2 + scale/sc2 + scale/st2) / p2; + return cost1 < cost2; + } else { + return this->valid && !other.valid; + } } -} +}; -template +template __global__ void -__launch_bounds__(cast_transpose_num_threads) -cast_transpose_kernel_notaligned(const IType * const input, - const CType * const noop, - OType * const output_c, - OType * const output_t, - const CType * const scale_ptr, - CType * const amax, - const size_t row_length, - const size_t num_rows, - const size_t num_tiles) { +__launch_bounds__(block_size) +cast_transpose_general_kernel(const IType * __restrict__ const input, + const CType * __restrict__ const noop, + OType * __restrict__ const output_c, + OType * __restrict__ const output_t, + const CType * __restrict__ const scale_ptr, + CType * __restrict__ const amax_ptr, + const size_t row_length, + const size_t num_rows) { if (noop != nullptr && noop[0] == 1.0f) return; + // Vectorized load/store sizes + constexpr size_t nvec_in = load_size / sizeof(IType); + constexpr size_t nvec_out = store_size / sizeof(OType); using IVec = Vec; - using OVec = Vec; - - extern __shared__ char scratch[]; - - const int warp_id = threadIdx.x / THREADS_PER_WARP; - const int my_id_in_warp = threadIdx.x % THREADS_PER_WARP; - const size_t num_tiles_x = (row_length + nvec_in * THREADS_PER_WARP - 1) / - (nvec_in * THREADS_PER_WARP); - const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) + - warp_id / n_warps_per_tile; - if (tile_id >= num_tiles) return; - const size_t tile_id_x = tile_id % num_tiles_x; - const size_t tile_id_y = tile_id / num_tiles_x; - - const IType * const my_input_tile = input + (tile_id_x * nvec_in + - tile_id_y * row_length * nvec_out) * - THREADS_PER_WARP; - OType * const my_output_c_tile = output_c + (tile_id_x * nvec_in + - tile_id_y * row_length * nvec_out) * - THREADS_PER_WARP; - OType * const my_output_t_tile = output_t + (tile_id_y * nvec_out + - tile_id_x * num_rows * nvec_in) * - THREADS_PER_WARP; - const size_t stride = row_length / nvec_in; - const size_t output_stride = num_rows / nvec_out; - const size_t row_length_rest = stride - tile_id_x * THREADS_PER_WARP; - const size_t row_height_rest = output_stride - tile_id_y * THREADS_PER_WARP; - const unsigned int tile_length = row_length_rest > THREADS_PER_WARP ? THREADS_PER_WARP - : row_length_rest; - const unsigned int tile_height = row_height_rest > THREADS_PER_WARP ? THREADS_PER_WARP - : row_height_rest; - - OVec * const my_scratch = reinterpret_cast(scratch) + - (my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) * - (THREADS_PER_WARP + 1); - - IVec in[2][nvec_out]; - const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile; - constexpr unsigned int n_iterations = THREADS_PER_WARP / n_warps_per_tile; - OVec out_space[n_iterations][nvec_in]; - - size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride; - unsigned int my_place = (my_id_in_warp + THREADS_PER_WARP - - warp_id_in_tile * n_iterations) % - THREADS_PER_WARP; - CType max = 0; - const CType scale = scale_ptr != nullptr ? *scale_ptr : 1; - { - const bool valid_load = my_place < tile_length && - warp_id_in_tile * n_iterations < tile_height; -#pragma unroll - for (unsigned int i = 0; i < nvec_out; ++i) { - if (valid_load) { - in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i); - } else { - in[0][i].clear(); - } - } - } -#pragma unroll - for (unsigned int i = 0; i < n_iterations; ++i) { - const size_t current_place = current_stride + my_place; - const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; - const unsigned int current_in = (i + 1) % 2; - if (i < n_iterations - 1) { - const bool valid_load = my_place_in < tile_length && - warp_id_in_tile * n_iterations + i + 1 < tile_height; -#pragma unroll - for (unsigned int j = 0; j < nvec_out; ++j) { - if (valid_load) { - in[current_in][j].load_from(my_input_tile, - current_stride + my_place_in + stride * (nvec_out + j)); - } else { - in[current_in][j].clear(); + using OVecT = Vec; + + // Thread indices + // Note: Block is interpreted as a warp_size x num_warps grid + constexpr size_t bdimx = THREADS_PER_WARP; + constexpr size_t bdimy = warps_per_tile; + const size_t tid = threadIdx.x; + const size_t tidx = tid % bdimx; + const size_t tidy = tid / bdimx; + const size_t bid = blockIdx.x; + + // Input tensors are divided into tiles + // Note: Each tile is a warp_size x warp_size grid of nvec_out x nvec_in subtiles + constexpr size_t tile_dim_m = THREADS_PER_WARP * nvec_out; + constexpr size_t tile_dim_n = THREADS_PER_WARP * nvec_in; + + // Position of tile within tensor + const size_t num_tiles_m = (num_rows + tile_dim_m - 1) / tile_dim_m; + const size_t tile_id_m = bid % num_tiles_m; + const size_t tile_id_n = bid / num_tiles_m; + const size_t tile_row = tile_id_m * tile_dim_m; + const size_t tile_col = tile_id_n * tile_dim_n; + + // Number of nvec_out x nvec_in subtiles for each thread to + // load/store + constexpr size_t num_iterations = THREADS_PER_WARP / warps_per_tile; + + // FP8 factors + const CType scale = scale_ptr == nullptr ? 1 : *scale_ptr; + CType amax = 0; + + // Load input and store to registers + // Note: Each thread loads num_iterations subtiles, computes amax, + // casts type, and transposes in registers. + OVecT local_output_t[nvec_in][num_iterations]; + #pragma unroll + for (size_t iter = 0; iter < num_iterations; ++iter) { + const size_t i1 = tidy + iter * bdimy; + const size_t j1 = tidx; + #pragma unroll + for (size_t i2 = 0; i2 < nvec_out; ++i2) { + const size_t row = tile_row + i1 * nvec_out + i2; + const size_t col = tile_col + j1 * nvec_in; + if (row < num_rows) { + #pragma unroll + for (size_t j2 = 0; j2 < nvec_in; ++j2) { + if (col + j2 < row_length) { + const CType in = input[row * row_length + col + j2]; + const OType out = OType(in * scale); + __builtin_assume(amax >= 0); + amax = fmaxf(fabsf(in), amax); + output_c[row * row_length + col + j2] = out; + local_output_t[j2][iter].data.elt[i2] = out; } } + } } - OVec out_trans[nvec_in]; // NOLINT(*) - const bool valid_store = my_place < tile_length && - warp_id_in_tile * n_iterations + i < tile_height; - cast_and_transpose_regs(in[current_in ^ 1], out_trans, my_output_c_tile, - current_place, stride, max, scale, valid_store); -#pragma unroll - for (unsigned int j = 0; j < nvec_in; ++j) { - out_space[i][j].data.vec = out_trans[j].data.vec; - } - my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; - current_stride += nvec_out * stride; } - for (unsigned int i = 0; i < nvec_in; ++i) { -#pragma unroll - for (unsigned int j = 0; j < n_iterations; ++j) { - my_scratch[(my_id_in_warp + THREADS_PER_WARP - - j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space[j][i]; + // Copy transposed output from registers to global memory + __shared__ OVecT shared_output_t[THREADS_PER_WARP][THREADS_PER_WARP+1]; + #pragma unroll + for (size_t j2 = 0; j2 < nvec_in; ++j2) { + #pragma unroll + for (size_t iter = 0; iter < num_iterations; ++iter) { + const size_t i1 = tidy + iter * bdimy; + const size_t j1 = tidx; + shared_output_t[j1][i1] = local_output_t[j2][iter]; } __syncthreads(); - my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % - THREADS_PER_WARP; - current_stride = i * output_stride + - warp_id_in_tile * n_iterations * output_stride * nvec_in; - for (unsigned int j = 0; warp_id_in_tile * n_iterations + j < tile_length; ++j) { - const bool valid_store = my_place < tile_height; - if (valid_store) { - my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile, - current_stride + my_place); + #pragma unroll + for (size_t iter = 0; iter < num_iterations; ++iter) { + const size_t i1 = tidx; + const size_t j1 = tidy + iter * bdimy; + const size_t row = tile_row + i1 * nvec_out; + const size_t col = tile_col + j1 * nvec_in + j2; + if (col < row_length) { + #pragma unroll + for (size_t i2 = 0; i2 < nvec_out; ++i2) { + if (row + i2 < num_rows) { + output_t[col * num_rows + row + i2] = shared_output_t[j1][i1].data.elt[i2]; + } + } } - my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; - current_stride += output_stride * nvec_in; } __syncthreads(); } - /* warp tile amax reduce*/ - max = reduce_max(max, warp_id); - - if (threadIdx.x == 0) { - static_assert(std::is_same::value); - if (amax != nullptr) atomicMaxFloat(amax, max); + // Reduce amax over block + if (amax_ptr != nullptr) { + amax = reduce_max(amax, tidy); + if (threadIdx.x == 0) { + atomicMaxFloat(amax_ptr, amax); + } } } +} // namespace + void cast_transpose(const Tensor &input, const Tensor &noop, - Tensor *cast_output, - Tensor *transposed_output, + Tensor *cast_output_, + Tensor *transposed_output_, cudaStream_t stream) { - CheckInputTensor(input, "cast_transpose_input"); - CheckOutputTensor(*cast_output, "cast_output"); - CheckOutputTensor(*transposed_output, "transposed_output"); - - // Number of elements in tensor - auto numel = [] (const Tensor &tensor) -> size_t { - size_t acc = 1; - for (const auto& dim : tensor.data.shape) { - acc *= dim; - } - return acc; - }; + Tensor &cast_output = *cast_output_; + Tensor &transposed_output = *transposed_output_; + // Check no-op flag if (noop.data.dptr != nullptr) { - NVTE_CHECK(numel(noop) == 1, - "Expected 1 element, ", - "but found ", numel(noop), "."); + size_t numel = 1; + for (const auto& dim : noop.data.shape) { + numel *= dim; + } + NVTE_CHECK(numel == 1, "Expected 1 element, but found ", numel, "."); NVTE_CHECK(noop.data.dtype == DType::kFloat32); NVTE_CHECK(noop.data.dptr != nullptr); } + + // Check tensor dims + CheckInputTensor(input, "cast_transpose_input"); + CheckOutputTensor(cast_output, "cast_output"); + CheckOutputTensor(transposed_output, "transposed_output"); NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); - NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions."); - NVTE_CHECK(transposed_output->data.shape.size() == 2, "T output must have 2 dimensions."); - NVTE_CHECK(input.data.shape == cast_output->data.shape, - "Input and C output must have the same shape."); + NVTE_CHECK(cast_output.data.shape.size() == 2, "Cast output must have 2 dimensions."); + NVTE_CHECK(transposed_output.data.shape.size() == 2, + "Transposed output must have 2 dimensions."); const size_t row_length = input.data.shape[1]; const size_t num_rows = input.data.shape[0]; - - NVTE_CHECK(transposed_output->data.shape[0] == row_length, "Wrong dimension of T output."); - NVTE_CHECK(transposed_output->data.shape[1] == num_rows, "Wrong dimension of T output."); - - NVTE_CHECK(cast_output->data.dtype == transposed_output->data.dtype, - "C and T outputs need to have the same type."); - NVTE_CHECK(cast_output->amax.dptr == transposed_output->amax.dptr, - "C and T outputs need to share amax tensor."); - NVTE_CHECK(cast_output->scale.dptr == transposed_output->scale.dptr, - "C and T outputs need to share scale tensor."); - -// Launch specific cast-transpose kernel -#define LAUNCH_KERNEL(kernel, nvec_in, nvec_out, n_tiles, n_blocks, InputType, OutputType) \ - do { \ - cudaFuncSetAttribute(kernel, \ - cudaFuncAttributePreferredSharedMemoryCarveout, \ - 100); \ - kernel \ - <<), \ - stream>>>( \ - reinterpret_cast(input.data.dptr), \ - reinterpret_cast(noop.data.dptr), \ - reinterpret_cast(cast_output->data.dptr), \ - reinterpret_cast(transposed_output->data.dptr), \ - reinterpret_cast(cast_output->scale.dptr), \ - reinterpret_cast(cast_output->amax.dptr), \ - row_length, num_rows, n_tiles); \ - } while (false) - -// Launch cast-transpose kernel for given vector sizes -#define LAUNCH_KERNEL_VEC_SIZES(load_size, store_size, InputType, OutputType) \ - do { \ - constexpr int nvec_in = load_size / sizeof(InputType); \ - constexpr int nvec_out = store_size / sizeof(OutputType); \ - \ - NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape."); \ - NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape."); \ - \ - const size_t n_tiles = get_n_tiles(load_size, store_size); \ - const size_t n_blocks = get_n_blocks(n_tiles); \ - \ - const bool full_tile = row_length % (nvec_in * THREADS_PER_WARP) == 0 && \ - num_rows % (nvec_out * THREADS_PER_WARP) == 0; \ - \ - if (full_tile) { \ - LAUNCH_KERNEL(cast_transpose_kernel, \ - nvec_in, nvec_out, n_tiles, n_blocks, \ - InputType, OutputType); \ - } else { \ - LAUNCH_KERNEL(cast_transpose_kernel_notaligned, \ - nvec_in, nvec_out, n_tiles, n_blocks, \ - InputType, OutputType); \ - } \ - } while (false) + NVTE_CHECK(cast_output.data.shape[0] == num_rows, "Wrong dimension of cast output."); + NVTE_CHECK(cast_output.data.shape[1] == row_length, "Wrong dimension of cast output."); + NVTE_CHECK(transposed_output.data.shape[0] == row_length, + "Wrong dimension of transposed output."); + NVTE_CHECK(transposed_output.data.shape[1] == num_rows, + "Wrong dimension of transposed output."); + + // Check tensor pointers + NVTE_CHECK(input.data.dptr != nullptr, "Input is not allocated."); + NVTE_CHECK(cast_output.data.dptr != nullptr, "Cast output is not allocated."); + NVTE_CHECK(transposed_output.data.dptr != nullptr, "Transposed output is not allocated."); + NVTE_CHECK(cast_output.data.dtype == transposed_output.data.dtype, + "Cast and transposed output types must match."); + NVTE_CHECK(cast_output.amax.dptr == transposed_output.amax.dptr, + "Cast and transposed outputs need to share amax tensor."); + NVTE_CHECK(cast_output.scale.dptr == transposed_output.scale.dptr, + "Cast and transposed outputs need to share scale tensor."); TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, InputType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(cast_output->data.dtype, OutputType, - - // Estimate number of SMs - // Note: H100 has 132 SMs, A100 has 108 SMs. - // Note: Directly querying number of SMs with cudaGetDeviceProperties is - // slow (>1 ms). Consider querying once and caching. - const int n_sms = 128; - - // Helper functions to get kernel configuration - auto get_n_tiles = [=] (size_t load_size, size_t store_size) -> int { - constexpr size_t threads_per_warp = static_cast(THREADS_PER_WARP); - size_t nvec_in = load_size / sizeof(InputType); - size_t nvec_out = store_size / sizeof(OutputType); - size_t n_tiles = DIVUP(row_length, nvec_in * threads_per_warp) * - DIVUP(num_rows, nvec_out * threads_per_warp); - return n_tiles; - }; - auto get_n_blocks = [=] (size_t n_tiles) -> int { - size_t n_warps_per_block = cast_transpose_num_threads / THREADS_PER_WARP; - size_t n_blocks = DIVUP(n_tiles * n_warps_per_tile, n_warps_per_block); - return n_blocks; - }; - - // Estimate optimal vector sizes and run - // Note: Consider reducing to 2B or 1B loads/stores for - // sufficiently small matrices. Need to consider whether reduced - // cache efficiency is worth increased SM utilization. Also need - // to keep in mind whether datatype can fit. - const size_t estimated_n_tiles = get_n_tiles(8, 8); - const size_t estimated_n_blocks = get_n_blocks(estimated_n_tiles); - if (estimated_n_blocks >= n_sms) { - LAUNCH_KERNEL_VEC_SIZES(8, 8, InputType, OutputType); - } else { - LAUNCH_KERNEL_VEC_SIZES(4, 4, InputType, OutputType); + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(cast_output.data.dtype, OutputType, + constexpr const char *itype_name = TypeInfo::name; + constexpr const char *otype_name = TypeInfo::name; + constexpr size_t itype_size = sizeof(InputType); + constexpr size_t otype_size = sizeof(OutputType); + + // Choose between runtime-compiled or statically-compiled kernel + const bool aligned = (row_length % THREADS_PER_WARP == 0 + && num_rows % THREADS_PER_WARP == 0); + if (aligned && rtc::is_enabled()) { // Runtime-compiled tuned kernel + // Pick kernel config + std::vector kernel_configs; + kernel_configs.reserve(16); + auto add_config = [&](size_t load_size, size_t store_size) { + kernel_configs.emplace_back(row_length, num_rows, + itype_size, otype_size, + load_size, store_size); + }; + add_config(8, 8); + add_config(4, 8); add_config(8, 4); + add_config(4, 4); + add_config(2, 8); add_config(8, 2); + add_config(2, 4); add_config(4, 2); + add_config(2, 2); + add_config(1, 8); add_config(8, 1); + add_config(1, 4); add_config(4, 1); + add_config(1, 2); add_config(2, 1); + add_config(1, 1); + const auto &kernel_config = *std::min_element(kernel_configs.begin(), + kernel_configs.end()); + NVTE_CHECK(kernel_config.valid, "invalid kernel config"); + const size_t load_size = kernel_config.load_size; + const size_t store_size = kernel_config.store_size; + const size_t num_blocks = kernel_config.num_blocks; + + // Compile NVRTC kernel if needed and launch + auto& rtc_manager = rtc::KernelManager::instance(); + const std::string kernel_label = concat_strings("cast_transpose" + ",itype=", itype_name, + ",otype=", otype_name, + ",load_size=", load_size, + ",store_size=", store_size); + if (!rtc_manager.is_compiled(kernel_label)) { + std::string code = string_code_transpose_rtc_cast_transpose_cu; + code = regex_replace(code, "__ITYPE__", itype_name); + code = regex_replace(code, "__OTYPE__", otype_name); + code = regex_replace(code, "__LOAD_SIZE__", load_size); + code = regex_replace(code, "__STORE_SIZE__", store_size); + code = regex_replace(code, "__WARPS_PER_TILE__", warps_per_tile); + code = regex_replace(code, "__BLOCK_SIZE__", block_size); + rtc_manager.compile(kernel_label, + "cast_transpose_optimized_kernel", + code, + "transformer_engine/common/transpose/rtc/cast_transpose.cu"); + } + rtc_manager.launch(kernel_label, + num_blocks, block_size, 0, stream, + static_cast(input.data.dptr), + reinterpret_cast(noop.data.dptr), + static_cast(cast_output.data.dptr), + static_cast(transposed_output.data.dptr), + static_cast(cast_output.scale.dptr), + static_cast(cast_output.amax.dptr), + row_length, num_rows); + } else { // Statically-compiled general kernel + constexpr size_t load_size = 4; + constexpr size_t store_size = 4; + constexpr size_t row_tile_size = load_size / itype_size * THREADS_PER_WARP; + constexpr size_t col_tile_size = store_size / otype_size * THREADS_PER_WARP; + const int num_blocks = (DIVUP(row_length, row_tile_size) + * DIVUP(num_rows, col_tile_size)); + cast_transpose_general_kernel + <<>>( + static_cast(input.data.dptr), + reinterpret_cast(noop.data.dptr), + static_cast(cast_output.data.dptr), + static_cast(transposed_output.data.dptr), + static_cast(cast_output.scale.dptr), + static_cast(cast_output.amax.dptr), + row_length, num_rows); } - ); // NOLINT(*) ); // NOLINT(*) - -#undef LAUNCH_KERNEL -#undef LAUNCH_KERNEL_VEC_SIZES } } // namespace transformer_engine diff --git a/transformer_engine/common/transpose/rtc/cast_transpose.cu b/transformer_engine/common/transpose/rtc/cast_transpose.cu new file mode 100644 index 0000000000..d503581718 --- /dev/null +++ b/transformer_engine/common/transpose/rtc/cast_transpose.cu @@ -0,0 +1,129 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "utils.cuh" + +using namespace transformer_engine; + +namespace { + +// Parameters +using CType = float; +using IType = __ITYPE__; +using OType = __OTYPE__; +constexpr size_t load_size = __LOAD_SIZE__; +constexpr size_t store_size = __STORE_SIZE__; +constexpr size_t warps_per_tile = __WARPS_PER_TILE__; +constexpr size_t block_size = __BLOCK_SIZE__; + +} // namespace + +__global__ void +__launch_bounds__(block_size) +cast_transpose_optimized_kernel(const IType * __restrict__ const input, + const CType * __restrict__ const noop, + OType * __restrict__ const output_c, + OType * __restrict__ const output_t, + const CType * __restrict__ const scale_ptr, + CType * __restrict__ const amax_ptr, + const size_t row_length, + const size_t num_rows) { + if (noop != nullptr && noop[0] == 1.0f) return; + + // Vectorized load/store sizes + constexpr size_t nvec_in = load_size / sizeof(IType); + constexpr size_t nvec_out = store_size / sizeof(OType); + using IVec = Vec; + using OVecC = Vec; + using OVecT = Vec; + + // Thread indices + // Note: Block is interpreted as a warp_size x num_warps grid + constexpr size_t bdimx = THREADS_PER_WARP; + constexpr size_t bdimy = warps_per_tile; + const size_t tid = threadIdx.x; + const size_t tidx = tid % bdimx; + const size_t tidy = tid / bdimx; + const size_t bid = blockIdx.x; + + // Input tensors are divided into tiles + // Note: Each tile is a warp_size x warp_size grid of nvec_out x nvec_in subtiles + constexpr size_t tile_dim_m = THREADS_PER_WARP * nvec_out; + constexpr size_t tile_dim_n = THREADS_PER_WARP * nvec_in; + + // Position of tile within tensor + const size_t num_tiles_m = num_rows / tile_dim_m; + const size_t tile_id_m = bid % num_tiles_m; + const size_t tile_id_n = bid / num_tiles_m; + const size_t tile_row = tile_id_m * tile_dim_m; + const size_t tile_col = tile_id_n * tile_dim_n; + + // Number of nvec_out x nvec_in subtiles for each thread to + // load/store + constexpr size_t num_iterations = THREADS_PER_WARP / warps_per_tile; + + // FP8 factors + const CType scale = scale_ptr == nullptr ? 1 : *scale_ptr; + CType amax = 0; + + // Load input to registers and transpose + // Note: Each thread loads num_iterations subtiles, computes amax, + // casts type, and transposes in registers. + OVecT local_output_t[nvec_in][num_iterations]; + #pragma unroll + for (size_t iter = 0; iter < num_iterations; ++iter) { + const size_t i1 = tidy + iter * bdimy; + const size_t j1 = tidx; + #pragma unroll + for (size_t i2 = 0; i2 < nvec_out; ++i2) { + const size_t row = tile_row + i1 * nvec_out + i2; + const size_t col = tile_col + j1 * nvec_in; + IVec local_input; + OVecC local_output_c; + local_input.load_from(&input[row * row_length + col]); + #pragma unroll + for (size_t j2 = 0; j2 < nvec_in; ++j2) { + const CType in = static_cast(local_input.data.elt[j2]); + const OType out = OType(in * scale); + __builtin_assume(amax >= 0); + amax = fmaxf(fabsf(in), amax); + local_output_c.data.elt[j2] = out; + local_output_t[j2][iter].data.elt[i2] = out; + } + local_output_c.store_to(&output_c[row * row_length + col]); + } + } + + // Copy from registers to shared memory to global memory + __shared__ OVecT shared_output_t[THREADS_PER_WARP][THREADS_PER_WARP+1]; + #pragma unroll + for (size_t j2 = 0; j2 < nvec_in; ++j2) { + #pragma unroll + for (size_t iter = 0; iter < num_iterations; ++iter) { + const size_t i1 = tidy + iter * bdimy; + const size_t j1 = tidx; + shared_output_t[j1][i1] = local_output_t[j2][iter]; + } + __syncthreads(); + #pragma unroll + for (size_t iter = 0; iter < num_iterations; ++iter) { + const size_t i1 = tidx; + const size_t j1 = tidy + iter * bdimy; + const size_t row = tile_row + i1 * nvec_out; + const size_t col = tile_col + j1 * nvec_in + j2; + shared_output_t[j1][i1].store_to(&output_t[col * num_rows + row]); + } + __syncthreads(); + } + + // Reduce amax over block + if (amax_ptr != nullptr) { + amax = reduce_max(amax, tidy); + if (threadIdx.x == 0) { + atomicMaxFloat(amax_ptr, amax); + } + } +} diff --git a/transformer_engine/common/transpose/transpose.cu b/transformer_engine/common/transpose/transpose.cu index 3ab83b944b..c0a1a7fbcf 100644 --- a/transformer_engine/common/transpose/transpose.cu +++ b/transformer_engine/common/transpose/transpose.cu @@ -6,13 +6,15 @@ #include #include + +#include + #include -#include -#include + #include "../common.h" -#include "../utils.cuh" -#include "../util/string.h" #include "../util/rtc.h" +#include "../util/string.h" +#include "../utils.cuh" namespace transformer_engine { @@ -25,7 +27,80 @@ namespace { constexpr size_t warps_per_tile = 4; constexpr size_t block_size = THREADS_PER_WARP * warps_per_tile; -} // namespace +/* Performance heuristics for optimized kernel parameters */ +struct KernelConfig { + /** Vector load size */ + size_t load_size; + /** Vector store size */ + size_t store_size; + + /* Whether config is valid */ + bool valid = false; + /* Number of CUDA blocks */ + size_t num_blocks = 0; + + /* Number of active SMs */ + size_t active_sm_count = 0; + /* Elements per L1 cache load */ + size_t elements_per_load = 0; + /* Elements per L1 cache store */ + size_t elements_per_store = 0; + + KernelConfig(size_t row_length, + size_t num_rows, + size_t type_size, + size_t load_size_, + size_t store_size_) + : load_size{load_size_} + , store_size{store_size_} { + // Check that tiles are correctly aligned + constexpr size_t cache_line_size = 128; + if (load_size % type_size != 0 + || store_size % type_size != 0 + || cache_line_size % type_size != 0) { + return; + } + const size_t row_tile_elements = load_size * THREADS_PER_WARP / type_size; + const size_t col_tile_elements = store_size * THREADS_PER_WARP / type_size; + valid = (row_length % row_tile_elements == 0 + && num_rows % col_tile_elements == 0); + if (!valid) { + return; + } + + // Number of CUDA blocks + num_blocks = (row_length / row_tile_elements) * (num_rows / col_tile_elements); + + // Parameters for performance model + constexpr size_t warps_per_sm = 16; // Rough estimate for saturated SMs + active_sm_count = std::min(DIVUP(num_blocks * warps_per_tile, warps_per_sm), + static_cast(cuda::sm_count())); + elements_per_load = (std::min(cache_line_size, row_tile_elements * type_size) + / type_size); + elements_per_store = (std::min(cache_line_size, col_tile_elements * type_size) + / type_size); + } + + /* Compare by estimated cost */ + bool operator<(const KernelConfig &other) const { + if (this->valid && other.valid) { + // cost ~ (1/elements_per_load + 1/elements_per_store) / active_sms + // Note: Integer arithmetic ensures stable ordering + const auto &l1 = this->elements_per_load; + const auto &s1 = this->elements_per_store; + const auto &p1 = this->active_sm_count; + const auto &l2 = other.elements_per_load; + const auto &s2 = other.elements_per_store; + const auto &p2 = other.active_sm_count; + const auto scale = l1 * s1 * p1 * l2 * s2 * p2; + const auto cost1 = (scale/l1 + scale/s1) / p1; + const auto cost2 = (scale/l2 + scale/s2) / p2; + return cost1 < cost2; + } else { + return this->valid && !other.valid; + } + } +}; template __global__ void @@ -127,6 +202,8 @@ transpose_general_kernel(const Type * __restrict__ const input, } } +} // namespace + void transpose(const Tensor &input, const Tensor &noop, Tensor *output_, @@ -170,82 +247,36 @@ void transpose(const Tensor &input, const bool aligned = (row_length % THREADS_PER_WARP == 0 && num_rows % THREADS_PER_WARP == 0); if (aligned && rtc::is_enabled()) { // Runtime-compiled tuned kernel - // Determine kernel config - size_t load_size = 8; - size_t store_size = 8; - auto is_tile_aligned = [&](size_t load_size_, size_t store_size_) -> bool { - return (row_length % (load_size / type_size * THREADS_PER_WARP) == 0 - && num_rows % (store_size / type_size * THREADS_PER_WARP) == 0); + // Pick kernel config + std::vector kernel_configs; + kernel_configs.reserve(16); + auto add_config = [&](size_t load_size, size_t store_size) { + kernel_configs.emplace_back(row_length, num_rows, type_size, + load_size, store_size); }; - auto num_blocks = [&](size_t load_size_, size_t store_size_) -> int { - const size_t row_tile_size = load_size_ / type_size * THREADS_PER_WARP; - const size_t col_tile_size = store_size_ / type_size * THREADS_PER_WARP; - return (row_length / row_tile_size) * (num_rows / col_tile_size); - }; - do { - const int sm_count = cuda::sm_count(); - - // Try maximizing SM occupancy without sacrificing cache - // efficiency - // Note: 32 threads/warp access 128B L1 cache line, so 4B - // loads/stores achieve full cache efficiency - if constexpr (type_size > 4) break; - if (is_tile_aligned(load_size, store_size) - && num_blocks(load_size, store_size) >= 4*sm_count) { - break; - } - load_size = 4; store_size = 8; - if (is_tile_aligned(load_size, store_size) - && num_blocks(load_size, store_size) >= 4*sm_count) { - break; - } - load_size = 4; store_size = 4; - if (is_tile_aligned(load_size, store_size) - && num_blocks(load_size, store_size) >= sm_count) { - break; - } - - // Simple performance model to balance SM occupancy and cache - // efficiency - auto cost = [&](int load_size_, int store_size_) -> double { - int active_sms = std::min(sm_count, num_blocks(load_size_, store_size_)); - // Amortize memory accesses over 128B L1 cache line - int elements_per_load = std::min(128, load_size_) / type_size; - int elements_per_store = std::min(128, store_size_) / type_size; - return (1.0 / elements_per_load + 1.0 / elements_per_store) / active_sms; - }; - if constexpr (type_size > 2) break; - if (is_tile_aligned(load_size, store_size) - && cost(2, 4) >= cost(load_size, store_size)) { - break; - } - load_size = 2; store_size = 4; - if (is_tile_aligned(load_size, store_size) - && cost(2, 2) >= cost(load_size, store_size)) { - break; - } - load_size = 2; store_size = 2; - if constexpr (type_size > 1) break; - if (is_tile_aligned(load_size, store_size) - && cost(1, 2) >= cost(load_size, store_size)) { - break; - } - load_size = 1; store_size = 2; - if (is_tile_aligned(load_size, store_size) - && cost(1, 1) >= cost(load_size, store_size)) { - break; - } - load_size = 1; store_size = 1; - } while (false); - NVTE_CHECK(is_tile_aligned(load_size, store_size), - "memory accesses are not properly aligned"); + add_config(8, 8); + add_config(4, 8); add_config(8, 4); + add_config(4, 4); + add_config(2, 8); add_config(8, 2); + add_config(2, 4); add_config(4, 2); + add_config(2, 2); + add_config(1, 8); add_config(8, 1); + add_config(1, 4); add_config(4, 1); + add_config(1, 2); add_config(2, 1); + add_config(1, 1); + const auto &kernel_config = *std::min_element(kernel_configs.begin(), + kernel_configs.end()); + NVTE_CHECK(kernel_config.valid, "invalid kernel config"); + const size_t load_size = kernel_config.load_size; + const size_t store_size = kernel_config.store_size; + const size_t num_blocks = kernel_config.num_blocks; // Compile NVRTC kernel if needed and launch auto& rtc_manager = rtc::KernelManager::instance(); const std::string kernel_label = concat_strings("transpose" ",type=", type_name, ",load_size=", load_size, - ",store_size", store_size); + ",store_size=", store_size); if (!rtc_manager.is_compiled(kernel_label)) { std::string code = string_code_transpose_rtc_transpose_cu; code = regex_replace(code, "__TYPE__", type_name); @@ -259,7 +290,7 @@ void transpose(const Tensor &input, "transformer_engine/common/transpose/rtc/transpose.cu"); } rtc_manager.launch(kernel_label, - num_blocks(load_size, store_size), block_size, 0, stream, + num_blocks, block_size, 0, stream, static_cast(input.data.dptr), static_cast(noop.data.dptr), static_cast(output.data.dptr), From 5b9e2e4cf0c057405c1ea2733ab1303b7f20ca86 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Fri, 19 Apr 2024 11:28:30 -0700 Subject: [PATCH 029/244] [PyTorch] Stop storing fused weight tensor in linear modules (#719) * Support noop concat without providing full tensor Stop storing fused buffers in linear modules. Signed-off-by: Tim Moon * Debug noop cat func Signed-off-by: Tim Moon * Construct TE modules in tests with correct dtypes Signed-off-by: Tim Moon * Add tolerances to numerical tests Signed-off-by: Tim Moon * Use plain PyTorch concat when exporting to ONNX Signed-off-by: Tim Moon --------- Signed-off-by: Tim Moon Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: Kirthi Shankar Sivamani Signed-off-by: Pawel Gadzinski --- tests/pytorch/test_numerics.py | 344 +++++++++------ tests/pytorch/test_sanity.py | 396 ++++++++++-------- transformer_engine/pytorch/module/_common.py | 146 ++++--- .../pytorch/module/layernorm_linear.py | 96 ++--- transformer_engine/pytorch/module/linear.py | 97 +++-- 5 files changed, 597 insertions(+), 482 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 0cda82e0c4..90cfce8a6f 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -4,7 +4,7 @@ import math import os -from typing import List, Optional +from typing import Dict, List, Optional import pytest import copy @@ -79,19 +79,26 @@ def get_causal_attn_mask(sq: int) -> torch.Tensor: return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool() -def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None) -> bool: - """Ensures two lists are equal.""" - assert len(l1) == len(l2), "Unequal number of outputs." - failed = False - failed_tensors = "" - for i, (t1, t2) in enumerate(zip(l1, l2)): - if not torch.equal(t1, t2): - failed = True - failed_tensors += f" {names[i]}\n" if names is not None else f" tensor at idx={i}\n" - assert not failed, "Output mismatches in:\n" + failed_tensors +def dtype_tols(dtype: torch.dtype) -> Dict[str, float]: + """Estimated numerical error for a datatype + Based on tolerances for torch.testing.assert_close. -def assert_allclose(l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float) -> bool: + """ + if dtype == torch.float32: + return dict(rtol=1.3e-6, atol=1e-5) + if dtype == torch.float16: + return dict(rtol=1e-3, atol=1e-5) + if dtype == torch.bfloat16: + return dict(rtol=1.6e-2, atol=1e-5) + raise ValueError(f"Unsuppored dtype ({dtype})") + + +def assert_allclose( + l1: List[torch.Tensor], + l2: List[torch.Tensor], + atol: float, +) -> bool: """Ensures two lists are equal.""" assert len(l1) == len(l2), "Unequal number of outputs." for i, (t1, t2) in enumerate(zip(l1, l2)): @@ -424,13 +431,16 @@ def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False output_layernorm=False, params_dtype=dtype, fuse_qkv_params=True, + device="cuda", ) - .cuda() ) te_inp_hidden_states = torch.randn( - config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True - ).cuda() + (config.seq_len, bs, config.hidden_size), + dtype=dtype, + device="cuda", + requires_grad=True, + ) te_inp_hidden_states.retain_grad() te_inp_attn_mask = get_causal_attn_mask(config.seq_len) @@ -464,7 +474,20 @@ def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, fp8_model_par outputs = _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=False) outputs_recompute = _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=True) - assert_all_equal(outputs, outputs_recompute) + + # Check that results match + tols = dtype_tols(dtype) + if dtype in (torch.float16, torch.bfloat16): + tols["atol"] = 1e-4 + if fp8 or fp8_model_params: + tols.update(dict(rtol=0.125, atol=0.0675)) + for i, (ref, test) in enumerate(zip(outputs, outputs_recompute)): + torch.testing.assert_close( + test, + ref, + msg=f"Mismatch in tensor {i}", + **tols, + ) def _test_e2e_full_recompute( @@ -481,8 +504,7 @@ def _test_e2e_full_recompute( output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) with fp8_model_init(enabled=fp8 and fp8_model_params): - block = ( - TransformerLayer( + block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, config.num_attention_heads, @@ -496,13 +518,15 @@ def _test_e2e_full_recompute( output_layernorm=False, params_dtype=dtype, fuse_qkv_params=True, - ) - .cuda() + device="cuda", ) te_inp_hidden_states = torch.randn( - config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=use_reentrant - ).cuda() + (config.seq_len, bs, config.hidden_size), + dtype=dtype, + device="cuda", + requires_grad=use_reentrant, + ) if use_reentrant: te_inp_hidden_states.retain_grad() te_inp_attn_mask = get_causal_attn_mask(config.seq_len) @@ -566,7 +590,19 @@ def test_gpt_full_activation_recompute(dtype, bs, model, fp8, fp8_model_params, # Reset bias+GELU fusion flag to avoid contaminating other tests del os.environ["NVTE_BIAS_GELU_NVFUSION"] - assert_all_equal(outputs, outputs_recompute, names=names) + # Check that results match + tols = dtype_tols(dtype) + if dtype in (torch.float16, torch.bfloat16): + tols["atol"] = 1e-3 + if fp8 or fp8_model_params: + tols.update(dict(rtol=0.125, atol=0.0675)) + for i, (ref, test) in enumerate(zip(outputs, outputs_recompute)): + torch.testing.assert_close( + test, + ref, + msg=f"Mismatch in tensor {i}", + **tols, + ) def _test_e2e_checkpointing_get_model(config, dtype): @@ -574,22 +610,20 @@ def _test_e2e_checkpointing_get_model(config, dtype): init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - return ( - TransformerLayer( - config.hidden_size, - 4 * config.hidden_size, - config.num_attention_heads, - layernorm_epsilon=config.eps, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - hidden_dropout=0.1, - attention_dropout=0.1, - kv_channels=config.embed, - apply_residual_connection_post_layernorm=False, - output_layernorm=False, - params_dtype=dtype, - ) - .cuda() + return TransformerLayer( + config.hidden_size, + 4 * config.hidden_size, + config.num_attention_heads, + layernorm_epsilon=config.eps, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + hidden_dropout=0.1, + attention_dropout=0.1, + kv_channels=config.embed, + apply_residual_connection_post_layernorm=False, + output_layernorm=False, + params_dtype=dtype, + device="cuda", ) @@ -597,8 +631,11 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path= reset_rng_states() te_inp_hidden_states = torch.randn( - config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True - ).cuda() + (config.seq_len, bs, config.hidden_size), + dtype=dtype, + device="cuda", + requires_grad=True, + ) te_inp_hidden_states.retain_grad() block = _test_e2e_checkpointing_get_model(config, dtype) @@ -666,15 +703,29 @@ def test_gpt_checkpointing(dtype, bs, model): config = model_configs[model] outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False) outputs_checkpoint = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True) - assert_all_equal(outputs, outputs_checkpoint) + + # Check that results match + tols = dtype_tols(dtype) + if dtype in (torch.float16, torch.bfloat16): + tols.update(dict(rtol=2e-2, atol=2e-3)) + for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint)): + torch.testing.assert_close( + test, + ref, + msg=f"Mismatch in tensor {i}", + **tols, + ) def _test_e2e_gpt_accuracy(block, bs, dtype, config): reset_rng_states() inp_hidden_states = torch.randn( - config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True - ).cuda() + (config.seq_len, bs, config.hidden_size), + dtype=dtype, + device="cuda", + requires_grad=True, + ) inp_hidden_states.retain_grad() inp_attn_mask = get_causal_attn_mask(config.seq_len) @@ -705,12 +756,12 @@ def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp): layernorm_epsilon=config.eps, attention_dropout=0.1, hidden_dropout=0.1, + params_dtype=dtype, fuse_qkv_params=True, qkv_weight_interleaved=False, parallel_attention_mlp=parallel_attention_mlp, + device="cuda", ) - .to(dtype=dtype) - .cuda() .eval() ) @@ -765,8 +816,11 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True): reset_rng_states() inp_hidden_states = torch.randn( - config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True - ).cuda() + (config.seq_len, bs, config.hidden_size), + dtype=dtype, + device="cuda", + requires_grad=True, + ) inp_hidden_states.retain_grad() inp_attn_mask = get_causal_attn_mask(config.seq_len) if mask_type == "causal" else None @@ -799,11 +853,11 @@ def test_mha_accuracy(dtype, bs, model, mask_type): config.hidden_size, config.num_attention_heads, fuse_qkv_params=True, + params_dtype=dtype, qkv_weight_interleaved=False, input_layernorm=False, + device="cuda", ) - .to(dtype=dtype) - .cuda() .eval() ) @@ -838,8 +892,11 @@ def _test_granular_accuracy(block, bs, dtype, config): reset_rng_states() inp_hidden_states = torch.randn( - config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True - ).cuda() + (config.seq_len, bs, config.hidden_size), + dtype=dtype, + device="cuda", + requires_grad=True, + ) inp_hidden_states.retain_grad() out = block(inp_hidden_states) @@ -857,10 +914,16 @@ def _test_granular_accuracy(block, bs, dtype, config): def _test_dpa_accuracy(block, bs, dtype, config): reset_rng_states() - mask = torch.triu(torch.ones(config.seq_len, config.seq_len, device="cuda"), diagonal=1).bool() + mask = torch.triu(torch.ones(config.seq_len, config.seq_len, dtype=torch.bool, device="cuda"), diagonal=1) query, key, value = [ - torch.randn(config.seq_len, bs, config.num_attention_heads, - config.embed, dtype=dtype, requires_grad=True).cuda() for _ in range(3)] + torch.randn( + (config.seq_len, bs, config.num_attention_heads, config.embed), + dtype=dtype, + device="cuda", + requires_grad=True, + ) + for _ in range(3) + ] query.retain_grad() key.retain_grad() @@ -921,9 +984,9 @@ def test_linear_accuracy(dtype, bs, model): config.hidden_size, 4 * config.hidden_size, bias=True, + params_dtype=dtype, + device="cuda", ) - .to(dtype=dtype) - .cuda() .eval() ) @@ -932,9 +995,9 @@ def test_linear_accuracy(dtype, bs, model): config.hidden_size, 4 * config.hidden_size, bias=True, + device="cuda", + dtype=dtype, ) - .to(dtype=dtype) - .cuda() .eval() ) @@ -965,10 +1028,10 @@ def test_rmsnorm_accuracy(dtype, bs, model, eps, zero_centered_gamma): RMSNorm( config.hidden_size, eps=eps, - zero_centered_gamma=zero_centered_gamma + params_dtype=dtype, + zero_centered_gamma=zero_centered_gamma, + device="cuda", ) - .to(dtype=dtype) - .cuda() .eval() ) @@ -1009,10 +1072,10 @@ def test_layernorm_accuracy(dtype, bs, model, eps, zero_centered_gamma): LayerNorm( config.hidden_size, eps=eps, - zero_centered_gamma=zero_centered_gamma + params_dtype=dtype, + zero_centered_gamma=zero_centered_gamma, + device="cuda", ) - .to(dtype=dtype) - .cuda() .eval() ) @@ -1058,10 +1121,10 @@ def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centere config.eps, bias=True, normalization=normalization, + params_dtype=dtype, zero_centered_gamma=zero_centered_gamma, + device="cuda", ) - .to(dtype=dtype) - .cuda() .eval() ) @@ -1112,9 +1175,9 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization): 4 * config.hidden_size, activation=activation, normalization=normalization, + params_dtype=dtype, + device="cuda", ) - .to(dtype=dtype) - .cuda() .eval() ) @@ -1229,11 +1292,11 @@ def test_gpt_cuda_graph(dtype, bs, model): hidden_dropout=0.1, attention_dropout=0.1, kv_channels=config.embed, + params_dtype=dtype, apply_residual_connection_post_layernorm=False, output_layernorm=False, + device="cuda", ) - .to(dtype=dtype) - .cuda() ) graphed_block = copy.deepcopy(block) @@ -1257,28 +1320,29 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params): output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) with fp8_model_init(enabled=fp8_model_params): - block = ( - TransformerLayer( - config.hidden_size, - 4 * config.hidden_size, - config.num_attention_heads, - layernorm_epsilon=config.eps, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - hidden_dropout=0.1, - attention_dropout=0.1, - kv_channels=config.embed, - apply_residual_connection_post_layernorm=False, - output_layernorm=False, - params_dtype=dtype, - fuse_qkv_params=True, - ) - .cuda() + block = TransformerLayer( + config.hidden_size, + 4 * config.hidden_size, + config.num_attention_heads, + layernorm_epsilon=config.eps, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + hidden_dropout=0.1, + attention_dropout=0.1, + kv_channels=config.embed, + apply_residual_connection_post_layernorm=False, + output_layernorm=False, + params_dtype=dtype, + fuse_qkv_params=True, + device="cuda", ) te_inp_hidden_states = torch.randn( - config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True - ).cuda() + (config.seq_len, bs, config.hidden_size), + dtype=dtype, + device="cuda", + requires_grad=True, + ) te_inp_hidden_states.retain_grad() te_inp_attn_mask = get_causal_attn_mask(config.seq_len) @@ -1306,7 +1370,18 @@ def test_gpt_fp8_parameters(dtype, bs, model): outputs = _test_gpt_fp8_parameters(bs, dtype, config, False) outputs_fp8_params = _test_gpt_fp8_parameters(bs, dtype, config, True) - assert_all_equal(outputs, outputs_fp8_params) + + # Check that results match + tols = dict(rtol=0.125, atol=0.0675) + for i, (ref, test) in enumerate(zip(outputs, outputs_fp8_params)): + torch.testing.assert_close( + test, + ref, + msg=f"Mismatch in tensor {i}", + rtol=0.125, + atol=0.0675, + ) + @pytest.mark.parametrize("dtype", param_types) @@ -1323,54 +1398,53 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): # other layer. Set `*dropout` values to 0 to make sure the forward pass # is identical to the other layer. torch.manual_seed(0) - block_sbhd = ( - TransformerLayer( - config.hidden_size, - 4 * config.hidden_size, - config.num_attention_heads, - layernorm_epsilon=config.eps, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - hidden_dropout=0, - attention_dropout=0, - kv_channels=config.embed, - apply_residual_connection_post_layernorm=False, - output_layernorm=False, - attn_input_format="sbhd" - ) - .to(dtype=dtype) - .cuda() + block_sbhd = TransformerLayer( + config.hidden_size, + 4 * config.hidden_size, + config.num_attention_heads, + layernorm_epsilon=config.eps, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + hidden_dropout=0, + attention_dropout=0, + kv_channels=config.embed, + params_dtype=dtype, + apply_residual_connection_post_layernorm=False, + output_layernorm=False, + device="cuda", + attn_input_format="sbhd", ) # Set `torch.manual_seed` to make sure the weights are identical to the # other layer. Set `*dropout` values to 0 to make sure the forward pass # is identical to the other layer. torch.manual_seed(0) - block_bshd = ( - TransformerLayer( - config.hidden_size, - 4 * config.hidden_size, - config.num_attention_heads, - layernorm_epsilon=config.eps, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - hidden_dropout=0, - attention_dropout=0, - kv_channels=config.embed, - apply_residual_connection_post_layernorm=False, - output_layernorm=False, - attn_input_format="bshd" - ) - .to(dtype=dtype) - .cuda() + block_bshd = TransformerLayer( + config.hidden_size, + 4 * config.hidden_size, + config.num_attention_heads, + layernorm_epsilon=config.eps, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + hidden_dropout=0, + attention_dropout=0, + kv_channels=config.embed, + params_dtype=dtype, + apply_residual_connection_post_layernorm=False, + output_layernorm=False, + device="cuda", + attn_input_format="bshd", ) for (n1, p1), (n2, p2) in zip(block_bshd.named_parameters(), block_sbhd.named_parameters()): assert torch.all(torch.eq(p1, p2)), f"{n1}, {n2} not identical" x_sbhd = torch.randn( - config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True - ).to(dtype).cuda() + (config.seq_len, bs, config.hidden_size), + dtype=dtype, + device="cuda", + requires_grad=True, + ) x_bshd = x_sbhd.transpose(0,1).contiguous() @@ -1384,7 +1458,11 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): torch.manual_seed(0) y_bshd = block_bshd(x_bshd) - assert_all_equal([y_bshd], [y_sbhd.transpose(0,1).contiguous()]) + # Check that results match + torch.testing.assert_close( + y_bshd, + y_sbhd.transpose(0,1).contiguous(), + ) @pytest.mark.parametrize("dtype", param_types) @@ -1424,10 +1502,10 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, num_attention_heads=H, attn_input_format=input_format, layer_number=layer_number, - attention_dropout = 0.0 + attention_dropout = 0.0, + params_dtype=dtype, + device="cuda", ) - .to(dtype=dtype) - .cuda() .eval() ) else: @@ -1437,9 +1515,9 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, num_attention_heads=H, qkv_format=input_format, layer_number=layer_number, - attention_dropout = 0.0 + attention_dropout = 0.0, + params_dtype=dtype, ) - .to(dtype=dtype) .cuda() .eval() ) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index e91e464fa4..9f8c8f73cb 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -172,10 +172,18 @@ def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad): def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad): te_inp_hidden_states = torch.randn( - config.seq_len, config.batch_size, config.hidden_size, dtype=torch.float32, requires_grad=True - ).cuda() + (config.seq_len, config.batch_size, config.hidden_size), + dtype=torch.float32, + device="cuda", + requires_grad=True, + ) te_inp_hidden_states.retain_grad() - te_inp_attn_mask = torch.randint(2, (1, 1, config.seq_len, config.seq_len)).cuda().bool() + te_inp_attn_mask = torch.randint( + 2, + (1, 1, config.seq_len, config.seq_len), + dtype=torch.bool, + device="cuda", + ) if skip_wgrad: _disable_wgrads(block) @@ -198,9 +206,17 @@ def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad): def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad): te_inp_hidden_states = torch.randn( - config.seq_len, config.batch_size, config.hidden_size, dtype=dtype, requires_grad=True - ).cuda() - te_inp_attn_mask = torch.randint(2, (1, 1, config.seq_len, config.seq_len)).cuda().bool() + (config.seq_len, config.batch_size, config.hidden_size), + dtype=dtype, + device="cuda", + requires_grad=True, + ) + te_inp_attn_mask = torch.randint( + 2, + (1, 1, config.seq_len, config.seq_len), + dtype=torch.bool, + device="cuda", + ) if skip_wgrad: _disable_wgrads(block) @@ -227,8 +243,11 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_reci def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload): te_inp_hidden_states = torch.randn( - config.seq_len, config.batch_size, config.hidden_size, dtype=dtype, requires_grad=True - ).cuda() + (config.seq_len, config.batch_size, config.hidden_size), + dtype=dtype, + device="cuda", + requires_grad=True, + ) if skip_wgrad: _disable_wgrads(block) @@ -250,10 +269,18 @@ def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload): def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad): te_inp_hidden_states = torch.randn( - config.seq_len, config.batch_size, config.hidden_size, dtype=dtype, requires_grad=True - ).cuda() + (config.seq_len, config.batch_size, config.hidden_size), + dtype=dtype, + device="cuda", + requires_grad=True, + ) - te_inp_attn_mask = torch.rand(torch.Size([config.batch_size, 1, 1, config.seq_len])).cuda() > 0.5 + te_inp_attn_mask = torch.randint( + 2, + (config.batch_size, 1, 1, config.seq_len), + dtype=torch.bool, + device="cuda", + ) if skip_wgrad: _disable_wgrads(block) @@ -268,10 +295,24 @@ def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad): def _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad): te_inp_hidden_states = torch.randn( - config.seq_len, config.batch_size, config.hidden_size, dtype=dtype, requires_grad=True - ).cuda() - te_inp_attn_mask = torch.randint(2, (1, 1, config.seq_len, config.seq_len)).cuda().bool() - enc_dec_attn_mask = torch.rand(torch.Size([config.batch_size, 1, 1, config.seq_len])).cuda() > 0.5 + (config.seq_len, config.batch_size, config.hidden_size), + dtype=dtype, + device="cuda", + requires_grad=True, + ) + te_inp_attn_mask = torch.randint( + 2, + (1, 1, config.seq_len, config.seq_len), + dtype=torch.bool, + device="cuda", + ) + + enc_dec_attn_mask = torch.randint( + 2, + (config.batch_size, 1, 1, config.seq_len), + dtype=torch.bool, + device="cuda", + ) if skip_wgrad: _disable_wgrads(block) @@ -294,8 +335,11 @@ def _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.") te_inp = torch.randn( - config.seq_len, config.batch_size, config.hidden_size, dtype=dtype, requires_grad=not skip_dgrad - ).cuda() + (config.seq_len, config.batch_size, config.hidden_size), + dtype=dtype, + device="cuda", + requires_grad=not skip_dgrad, + ) if skip_wgrad: _disable_wgrads(block) @@ -315,8 +359,10 @@ def _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad) pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.") te_inp = torch.randn( - config.seq_len, config.batch_size, config.hidden_size, requires_grad=True - ).cuda() + (config.seq_len, config.batch_size, config.hidden_size), + device="cuda", + requires_grad=True, + ) te_inp.retain_grad() with torch.autocast(device_type="cuda", enabled=True, dtype=dtype): @@ -371,16 +417,14 @@ def test_sanity_layernorm_linear(dtype, fp8_recipe, model, skip_wgrad, sigma = 0.023 init_method = init_method_normal(sigma) - block = ( - LayerNormLinear( - config.hidden_size, - config.hidden_size * 3, - init_method=init_method, - zero_centered_gamma=zero_centered_gamma, - normalization=normalization, - ) - .to(dtype=dtype) - .cuda() + block = LayerNormLinear( + config.hidden_size, + config.hidden_size * 3, + init_method=init_method, + zero_centered_gamma=zero_centered_gamma, + normalization=normalization, + params_dtype=dtype, + device="cuda", ) _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad) @@ -402,12 +446,12 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad): sigma = 0.023 output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - block = ( - Linear( - config.hidden_size, config.hidden_size, init_method=output_layer_init_method - ) - .to(dtype=dtype) - .cuda() + block = Linear( + config.hidden_size, + config.hidden_size, + init_method=output_layer_init_method, + params_dtype=dtype, + device="cuda", ) _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad) @@ -435,18 +479,16 @@ def test_sanity_layernorm_mlp(dtype, fp8_recipe, model, skip_wgrad, init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - block = ( - LayerNormMLP( - config.hidden_size, - 4 * config.hidden_size, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - zero_centered_gamma=zero_centered_gamma, - activation=activation, - normalization=normalization, - ) - .to(dtype=dtype) - .cuda() + block = LayerNormMLP( + config.hidden_size, + 4 * config.hidden_size, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + zero_centered_gamma=zero_centered_gamma, + activation=activation, + normalization=normalization, + params_dtype=dtype, + device="cuda", ) _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad) @@ -477,26 +519,24 @@ def test_sanity_gpt(dtype, fp8_recipe, model, skip_wgrad, init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - block = ( - TransformerLayer( - config.hidden_size, - 4 * config.hidden_size, - config.num_attention_heads, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - hidden_dropout=0.1, - attention_dropout=0.1, - kv_channels=config.kv_channels, - apply_residual_connection_post_layernorm=False, - output_layernorm=False, - zero_centered_gamma=zero_centered_gamma, - bias=bias, - activation=activation, - normalization=normalization, - parallel_attention_mlp=parallel_attention_mlp, - ) - .to(dtype=dtype) - .cuda() + block = TransformerLayer( + config.hidden_size, + 4 * config.hidden_size, + config.num_attention_heads, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + hidden_dropout=0.1, + attention_dropout=0.1, + kv_channels=config.kv_channels, + params_dtype=dtype, + apply_residual_connection_post_layernorm=False, + output_layernorm=False, + zero_centered_gamma=zero_centered_gamma, + bias=bias, + activation=activation, + normalization=normalization, + device="cuda", + parallel_attention_mlp=parallel_attention_mlp, ) _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload) @@ -546,24 +586,22 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - block = ( - TransformerLayer( - config.hidden_size, - 4 * config.hidden_size, - config.num_attention_heads, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - hidden_dropout=0.1, - attention_dropout=0.1, - kv_channels=config.kv_channels, - apply_residual_connection_post_layernorm=True, - output_layernorm=True, - zero_centered_gamma=zero_centered_gamma, - self_attn_mask_type="padding", - normalization=normalization, - ) - .to(dtype=dtype) - .cuda() + block = TransformerLayer( + config.hidden_size, + 4 * config.hidden_size, + config.num_attention_heads, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + hidden_dropout=0.1, + attention_dropout=0.1, + kv_channels=config.kv_channels, + params_dtype=dtype, + apply_residual_connection_post_layernorm=True, + output_layernorm=True, + zero_centered_gamma=zero_centered_gamma, + self_attn_mask_type="padding", + normalization=normalization, + device="cuda", ) _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad) @@ -607,24 +645,22 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - block = ( - TransformerLayer( - config.hidden_size, - 4 * config.hidden_size, - config.num_attention_heads, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - hidden_dropout=0.1, - attention_dropout=0.1, - kv_channels=config.kv_channels, - apply_residual_connection_post_layernorm=False, - output_layernorm=False, - layer_type="decoder", - zero_centered_gamma=zero_centered_gamma, - normalization=normalization, - ) - .to(dtype=dtype) - .cuda() + block = TransformerLayer( + config.hidden_size, + 4 * config.hidden_size, + config.num_attention_heads, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + hidden_dropout=0.1, + attention_dropout=0.1, + kv_channels=config.kv_channels, + params_dtype=dtype, + apply_residual_connection_post_layernorm=False, + output_layernorm=False, + layer_type="decoder", + zero_centered_gamma=zero_centered_gamma, + normalization=normalization, + device="cuda", ) _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad) @@ -665,19 +701,17 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad): init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - block = ( - TransformerLayer( - config.hidden_size, - 4 * config.hidden_size, - config.num_attention_heads, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - hidden_dropout=0.1, - attention_dropout=0.1, - kv_channels=config.kv_channels, - ) - .to(dtype=torch.float32) - .cuda() + block = TransformerLayer( + config.hidden_size, + 4 * config.hidden_size, + config.num_attention_heads, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + hidden_dropout=0.1, + attention_dropout=0.1, + kv_channels=config.kv_channels, + params_dtype=torch.float32, + device="cuda", ) _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad) @@ -700,22 +734,20 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad): init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - block = ( - TransformerLayer( - config.hidden_size, - 4 * config.hidden_size, - config.num_attention_heads, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - hidden_dropout=0.1, - attention_dropout=0.1, - kv_channels=config.kv_channels, - apply_residual_connection_post_layernorm=False, - output_layernorm=False, - drop_path_rate=1.0, - ) - .to(dtype=dtype) - .cuda() + block = TransformerLayer( + config.hidden_size, + 4 * config.hidden_size, + config.num_attention_heads, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + hidden_dropout=0.1, + attention_dropout=0.1, + kv_channels=config.kv_channels, + params_dtype=dtype, + apply_residual_connection_post_layernorm=False, + output_layernorm=False, + drop_path_rate=1.0, + device="cuda", ) _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, False) @@ -738,22 +770,20 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad): init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - block = ( - TransformerLayer( - config.hidden_size, - 4 * config.hidden_size, - config.num_attention_heads, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - hidden_dropout=0.1, - attention_dropout=0.1, - kv_channels=config.kv_channels, - apply_residual_connection_post_layernorm=False, - output_layernorm=False, - fuse_qkv_params=True, - ) - .to(dtype=dtype) - .cuda() + block = TransformerLayer( + config.hidden_size, + 4 * config.hidden_size, + config.num_attention_heads, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + hidden_dropout=0.1, + attention_dropout=0.1, + kv_channels=config.kv_channels, + params_dtype=dtype, + apply_residual_connection_post_layernorm=False, + output_layernorm=False, + fuse_qkv_params=True, + device="cuda", ) _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, False) @@ -777,24 +807,22 @@ def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgra init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - block = ( - TransformerLayer( - config.hidden_size, - 4 * config.hidden_size, - config.num_attention_heads, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - hidden_dropout=0.1, - attention_dropout=0.1, - kv_channels=config.kv_channels, - apply_residual_connection_post_layernorm=False, - output_layernorm=False, - zero_centered_gamma=zero_centered_gamma, - fuse_qkv_params=True, - fuse_wgrad_accumulation=True, - ) - .to(dtype=dtype) - .cuda() + block = TransformerLayer( + config.hidden_size, + 4 * config.hidden_size, + config.num_attention_heads, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + hidden_dropout=0.1, + attention_dropout=0.1, + kv_channels=config.kv_channels, + params_dtype=dtype, + apply_residual_connection_post_layernorm=False, + output_layernorm=False, + zero_centered_gamma=zero_centered_gamma, + fuse_qkv_params=True, + fuse_wgrad_accumulation=True, + device="cuda", ) _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad) @@ -820,30 +848,28 @@ def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamm init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - block = ( - TransformerLayer( - config.hidden_size, - 4 * config.hidden_size, - config.num_attention_heads, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - hidden_dropout=0.1, - attention_dropout=0.1, - kv_channels=config.kv_channels, - apply_residual_connection_post_layernorm=False, - output_layernorm=False, - zero_centered_gamma=zero_centered_gamma, - fuse_qkv_params=True, - normalization=normalization, - ) - .to(dtype=dtype) - .cuda() + block = TransformerLayer( + config.hidden_size, + 4 * config.hidden_size, + config.num_attention_heads, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + hidden_dropout=0.1, + attention_dropout=0.1, + kv_channels=config.kv_channels, + params_dtype=dtype, + apply_residual_connection_post_layernorm=False, + output_layernorm=False, + zero_centered_gamma=zero_centered_gamma, + fuse_qkv_params=True, + normalization=normalization, + device="cuda", ) _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad) def test_model_multiple_cast(): - a = torch.zeros((16,16)).cuda() + a = torch.zeros((16,16), device="cuda") m = Linear(16,32) y = m(a) diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index 79798d2ff0..ab6455649c 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -10,6 +10,7 @@ import torch from .. import cpp_extensions as tex +from ..export import is_in_onnx_export_mode from ..fp8 import get_fp8_te_dtype from ..utils import get_default_init_method @@ -99,32 +100,79 @@ def _apply_normalization(inputmat:torch.Tensor, class _NoopCatFunc(torch.autograd.Function): - """No-op concatenate tensors along dim 0 + """Concatenate tensors, doing a no-op if possible - `full_tensor` is assumed to already be the concatenation of - `tensors`, i.e. they occupy the same memory with the correct - offsets. + See _noop_cat. """ @staticmethod def forward( - ctx, - split_ranges: List[Tuple[int, int]], - full_tensor: torch.Tensor, + ctx: Any, + dim: int, *tensors: Tuple[torch.Tensor, ...], ) -> torch.Tensor: - # pylint: disable=unused-argument + + # Check first tensor + if not tensors: + raise ValueError("Attempted to concatenate 0 tensors") + num_dims = tensors[0].dim() + if not -num_dims <= dim < num_dims: + raise ValueError( + "Attempted to concatenate tensor " + f"with shape {list(tensors[0].size())} along dim {dim}" + ) + dim %= num_dims + + # Check remaining tensors + out_shape = list(tensors[0].size()) + split_ranges = [(0, tensors[0].size(dim))] + for tensor in tensors[1:]: + in_shape = list(tensor.size()) + if ( + len(in_shape) != num_dims + or in_shape[:dim] != out_shape[:dim] + or in_shape[dim+1:] != out_shape[dim+1:] + ): + raise ValueError( + "Attempted to concatenate tensors with shapes " + f"{[list(tensor.size()) for tensor in tensors]} " + f"along dim {dim}" + ) + split_start = out_shape[dim] + split_end = split_start + in_shape[dim] + out_shape[dim] = split_end + split_ranges.append((split_start, split_end)) + + # Save state for backward + ctx.dim = dim ctx.split_ranges = split_ranges - assert not full_tensor.requires_grad, "Concatenated tensor should not require gradient" - out = full_tensor.new() + + # Out-of-place concatenation if needed + dtype = tensors[0].dtype + device = tensors[0].device + strides = tensors[0].stride() + data_ptr_stride = strides[dim] * tensors[0].element_size() + data_ptr = tensors[0].data_ptr() + tensors[0].size(dim) * data_ptr_stride + for tensor in tensors[1:]: + if ( + tensor.dtype != dtype + or tensor.device != device + or tensor.stride() != strides + or tensor.data_ptr() != data_ptr + ): + return torch.cat(tensors, dim=dim) + data_ptr += tensor.size(dim) * data_ptr_stride + + # No-op concatenation + out = tensors[0].new() out.set_( - full_tensor.untyped_storage(), - full_tensor.storage_offset(), - full_tensor.size(), - full_tensor.stride(), + tensors[0].untyped_storage(), + tensors[0].storage_offset(), + out_shape, + strides, ) - out.requires_grad = True + out.requires_grad = any(tensor.requires_grad for tensor in tensors) return out @staticmethod @@ -132,64 +180,32 @@ def backward( ctx, grad_output: torch.Tensor, ) -> Tuple[Optional[torch.Tensor], ...]: - grads = [ - grad_output[split_start:split_end] - for split_start, split_end in ctx.split_ranges - ] - return None, None, *grads + grad_inputs = [] + for split_start, split_end in ctx.split_ranges: + slices = [slice(None)] * grad_output.dim() + slices[ctx.dim] = slice(split_start, split_end) + grad_inputs.append(grad_output[tuple(slices)]) + return None, *grad_inputs def _noop_cat( tensors: List[torch.Tensor], - full_tensor: torch.Tensor, + dim: int = 0, ) -> torch.Tensor: - """Concatenate tensors along dim 0, doing a no-op if possible - - If `full_tensor` is already the concatenation of `tensors`, i.e. - they occupy the same memory region with the correct offsets, then - no copies are performed. Otherwise the buffers in all the tensors - are reallocated so that another call would result in a no-op. + """Concatenate tensors, doing a no-op if possible - In the backward pass, gradients to `partial_tensors` will just be - tensor views. + If tensors are already concatenated in memory, a tensor view of + that memory region will be returned. Otherwise the tensors will be + concatenated out-of-place, as usual. """ - - # Determine split points - split_ranges = [] - full_tensor_shape = full_tensor.size() - offset = 0 - for tensor in tensors: - tensor_shape = tensor.size() - if tensor_shape[1:] != full_tensor_shape[1:]: - raise ValueError( - f"Attempting to concatenate tensor with shape={list(tensor_shape)} " - f"into a tensor with shape={list(full_tensor_shape)}" - ) - split_start = offset - offset += tensor_shape[0] - split_end = offset - split_ranges.append((split_start, split_end)) - if offset != full_tensor_shape[0]: - raise ValueError( - f"Attempting to concatenate tensors with total shape[0]={offset} " - f"into a tensor with shape[0]={full_tensor_shape[0]}" - ) - - # Reallocate buffers if no-op concat isn't possible - need_to_reallocate = False - for tensor, (split_start, _) in zip(tensors, split_ranges): - if tensor.data_ptr() != full_tensor[split_start].data_ptr(): - need_to_reallocate = True - break - if need_to_reallocate: - with torch.no_grad(): - full_tensor.data = torch.cat(tensors) - for tensor, (split_start, split_end) in zip(tensors, split_ranges): - tensor.data = full_tensor[split_start:split_end] - - # Perform no-op concat - return _NoopCatFunc.apply(split_ranges, full_tensor, *tensors) + if not tensors: + raise ValueError("Attempted to concatenate 0 tensors") + if len(tensors) == 1: + return tensors[0] + if is_in_onnx_export_mode(): + return torch.cat(tensors, dim=dim) + return _NoopCatFunc.apply(dim, *tensors) @dataclass diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 7d7bb0bbd5..75a8ad857e 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -926,17 +926,20 @@ def __init__( else: self.layer_norm_bias = None - self.weight_tensor = torch.empty( - self.out_features, self.in_features, - device=device, dtype=params_dtype) - + # Contiguous buffers for params + weight_tensor = torch.empty( + self.out_features, + self.in_features, + device=device, + dtype=params_dtype, + ) + bias_tensor = None if self.use_bias: - self.bias_tensor = torch.empty( + bias_tensor = torch.empty( self.out_features, device=device, - dtype=params_dtype) - else: - self.bias_tensor = torch.Tensor().to(dtype=params_dtype, device=device) + dtype=params_dtype, + ) # Configure parameter splits self.weight_names = [] @@ -982,7 +985,11 @@ def __init__( ) self.parameter_split_sizes[i] = size // self.tp_size - # Construct parameters from weight and bias buffers + # Construct weight parameters + # Note: Register weights together so that they are adjacent to + # each other in LayerNormLinear.parameters(). This makes it + # more likely that they will stay contiguous if the weights + # are manipulated externally, e.g. by FSDP. offset = 0 for i, split_size in enumerate(self.parameter_split_sizes): split_start = offset @@ -998,32 +1005,30 @@ def __init__( ) # Construct weight parameter - weight = self.weight_tensor - if is_subview: - weight = weight[split_start:split_end] - weight = torch.nn.Parameter(weight) - self.register_parameter(self.weight_names[i], weight, - init_fn=init_method, - get_rng_state_tracker=get_rng_state_tracker, - fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT) - - # Construct bias parameter if needed - if self.use_bias: - bias = self.bias_tensor - if is_subview: - bias = bias[split_start:split_end] - bias = torch.nn.Parameter(bias) - self.register_parameter(self.bias_names[i], bias, - init_fn=init_method_constant(0.0)) - else: - bias = torch.Tensor().to(dtype=params_dtype, device=device) - setattr(self, self.bias_names[i], bias) + self.register_parameter( + self.weight_names[i], + torch.nn.Parameter(weight_tensor[split_start:split_end]), + init_fn=init_method, + get_rng_state_tracker=get_rng_state_tracker, + fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, + ) - # Concatenated tensors are not needed if not splitting - # into multiple parameters - if not is_subview: - del self.weight_tensor - del self.bias_tensor + # Construct bias parameters if needed + if self.use_bias: + offset = 0 + for i, split_size in enumerate(self.parameter_split_sizes): + split_start = offset + offset += split_size + split_end = offset + self.register_parameter( + self.bias_names[i], + torch.nn.Parameter(bias_tensor[split_start:split_end]), + init_fn=init_method_constant(0.0), + ) + else: + for name in self.bias_names: + bias = torch.Tensor().to(dtype=params_dtype, device=device) + setattr(self, name, bias) if self.primary_weights_in_fp8: self.init_fp8_metadata() @@ -1150,24 +1155,15 @@ def forward( "Need to run inside fp8_autocast region when weights are stored in FP8." # Get concatenated weight and bias tensors - if len(self.parameter_split_sizes) == 1: - weight_tensor = getattr(self, self.weight_names[0]) - bias_tensor = getattr(self, self.bias_names[0]) - elif torch.is_grad_enabled(): - weight_tensor = _noop_cat( - [getattr(self, name) for name in self.weight_names], - self.weight_tensor, + weight_tensor = _noop_cat( + [getattr(self, name) for name in self.weight_names], + ) + if self.use_bias: + bias_tensor = _noop_cat( + [getattr(self, name) for name in self.bias_names], ) - if self.use_bias: - bias_tensor = _noop_cat( - [getattr(self, name) for name in self.bias_names], - self.bias_tensor, - ) - else: - bias_tensor = getattr(self, self.bias_names[0]) # Unused else: - weight_tensor = self.weight_tensor - bias_tensor = self.bias_tensor + bias_tensor = getattr(self, self.bias_names[0]) # Unused # Fetch the fp8 weights placeholders (for linear/gemm) weight1_fp8, weight1_t_fp8 = self.get_fp8_weights_scratchpad( diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index cb2f6871b3..b48987f34c 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -777,14 +777,20 @@ def __init__( self.sequence_parallel = (self.tp_size > 1) and sequence_parallel - self.weight_tensor = torch.empty( - self.out_features, self.in_features, - device=device, dtype=params_dtype) - + # Contiguous buffers for params + weight_tensor = torch.empty( + self.out_features, + self.in_features, + device=device, + dtype=params_dtype, + ) + bias_tensor = None if self.use_bias: - self.bias_tensor = torch.empty(self.out_features, device=device, dtype=params_dtype) - else: - self.bias_tensor = torch.Tensor().to(dtype=params_dtype, device=device) + bias_tensor = torch.empty( + self.out_features, + device=device, + dtype=params_dtype, + ) # Configure parameter splits self.weight_names = [] @@ -830,7 +836,11 @@ def __init__( ) self.parameter_split_sizes[i] = size // self.tp_size - # Construct parameters from weight and bias buffers + # Construct weight parameters + # Note: Register weights together so that they are adjacent to + # each other in Linear.parameters(). This makes it more likely + # that they will stay contiguous if the weights are + # manipulated externally, e.g. by FSDP. offset = 0 for i, split_size in enumerate(self.parameter_split_sizes): split_start = offset @@ -846,32 +856,30 @@ def __init__( ) # Construct weight parameter - weight = self.weight_tensor - if is_subview: - weight = weight[split_start:split_end] - weight = torch.nn.Parameter(weight) - self.register_parameter(self.weight_names[i], weight, - init_fn=init_method, - get_rng_state_tracker=get_rng_state_tracker, - fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT) - - # Construct bias parameter if needed - if self.use_bias: - bias = self.bias_tensor - if is_subview: - bias = bias[split_start:split_end] - bias = torch.nn.Parameter(bias) - self.register_parameter(self.bias_names[i], bias, - init_fn=init_method_constant(0.0)) - else: - bias = torch.Tensor().to(dtype=params_dtype, device=device) - setattr(self, self.bias_names[i], bias) + self.register_parameter( + self.weight_names[i], + torch.nn.Parameter(weight_tensor[split_start:split_end]), + init_fn=init_method, + get_rng_state_tracker=get_rng_state_tracker, + fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, + ) - # Concatenated tensors are not needed if not splitting - # into multiple parameters - if not is_subview: - del self.weight_tensor - del self.bias_tensor + # Construct bias parameters if needed + if self.use_bias: + offset = 0 + for i, split_size in enumerate(self.parameter_split_sizes): + split_start = offset + offset += split_size + split_end = offset + self.register_parameter( + self.bias_names[i], + torch.nn.Parameter(bias_tensor[split_start:split_end]), + init_fn=init_method_constant(0.0), + ) + else: + for name in self.bias_names: + bias = torch.Tensor().to(dtype=params_dtype, device=device) + setattr(self, name, bias) if self.primary_weights_in_fp8: self.init_fp8_metadata() @@ -974,24 +982,15 @@ def forward( is_first_module_in_mha = is_first_module_in_mha and self.fp8_meta["recipe"].fp8_mha # Get concatenated weight and bias tensors - if len(self.parameter_split_sizes) == 1: - weight_tensor = getattr(self, self.weight_names[0]) - bias_tensor = getattr(self, self.bias_names[0]) - elif torch.is_grad_enabled(): - weight_tensor = _noop_cat( - [getattr(self, name) for name in self.weight_names], - self.weight_tensor, + weight_tensor = _noop_cat( + [getattr(self, name) for name in self.weight_names], + ) + if self.use_bias: + bias_tensor = _noop_cat( + [getattr(self, name) for name in self.bias_names], ) - if self.use_bias: - bias_tensor = _noop_cat( - [getattr(self, name) for name in self.bias_names], - self.bias_tensor, - ) - else: - bias_tensor = getattr(self, self.bias_names[0]) # Unused else: - weight_tensor = self.weight_tensor - bias_tensor = self.bias_tensor + bias_tensor = getattr(self, self.bias_names[0]) # Unused # Fetch the fp8 weights placeholders (for linear/gemm) weight1_fp8, weight1_t_fp8 = self.get_fp8_weights_scratchpad( From 3ba02f160f9b7f5e0a3af2843af4bfdcb702a875 Mon Sep 17 00:00:00 2001 From: Ming-Xu Huang Date: Fri, 19 Apr 2024 16:11:27 -0400 Subject: [PATCH 030/244] [JAX] Allow multi-dims for dgamma and dbeta in LN descriptor. (#780) * Allow multi-dims for dgamma and dbeta in LN descriptor. Signed-off-by: Ming Huang * Fix the jit error in examples/jax Signed-off-by: Ming Huang --------- Signed-off-by: Ming Huang Signed-off-by: Pawel Gadzinski --- .../jax/encoder/test_single_gpu_encoder.py | 2 +- examples/jax/mnist/test_single_gpu_mnist.py | 2 +- transformer_engine/jax/cpp_extensions.py | 25 ++++---- transformer_engine/jax/csrc/modules.cpp | 60 +++++++++++-------- transformer_engine/jax/csrc/modules.h | 16 +++-- 5 files changed, 57 insertions(+), 48 deletions(-) diff --git a/examples/jax/encoder/test_single_gpu_encoder.py b/examples/jax/encoder/test_single_gpu_encoder.py index ae5304628f..b892437925 100644 --- a/examples/jax/encoder/test_single_gpu_encoder.py +++ b/examples/jax/encoder/test_single_gpu_encoder.py @@ -55,7 +55,7 @@ def __call__(self, x, mask, disable_dropout=False): return x -@partial(jax.jit, static_argnums=(0, 1, 2, 3, 4, 5)) +@partial(jax.jit) def train_step(state, inputs, masks, labels, var_collect, rngs): """Computes gradients, loss and accuracy for a single batch.""" diff --git a/examples/jax/mnist/test_single_gpu_mnist.py b/examples/jax/mnist/test_single_gpu_mnist.py index f9824ae000..ae74a66337 100644 --- a/examples/jax/mnist/test_single_gpu_mnist.py +++ b/examples/jax/mnist/test_single_gpu_mnist.py @@ -74,7 +74,7 @@ def loss_fn(var_collect, disable_dropout=False): return grads, loss, accuracy -@partial(jax.jit, static_argnums=(0, 1)) +@partial(jax.jit) def update_model(state, grads): """Update model params and FP8 meta.""" state = state.apply_gradients(grads=grads[PARAMS_KEY]) diff --git a/transformer_engine/jax/cpp_extensions.py b/transformer_engine/jax/cpp_extensions.py index 08bcb94239..3356aafef5 100644 --- a/transformer_engine/jax/cpp_extensions.py +++ b/transformer_engine/jax/cpp_extensions.py @@ -385,8 +385,8 @@ def lowering(ctx, x, gamma, beta, *, zero_centered_gamma, epsilon): hidden_size, wkspace_aval.size, barrier_aval.size, - 0, # no dgamma_part in FWD pass - 0, # no dbeta_part in BWD pass + (0,), # no dgamma_part in FWD pass + (0,), # no dbeta_part in BWD pass jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), @@ -464,7 +464,6 @@ def partition(zero_centered_gamma, epsilon, mesh, arg_infos, result_infos): f"Enforcing no sharding of parameters hidden dim! " \ ) - x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None)) g_sharding = NamedSharding(mesh, PartitionSpec(None)) b_sharding = NamedSharding(mesh, PartitionSpec(None)) @@ -589,8 +588,8 @@ def lowering(ctx, dz, x, mu, rsigma, gamma, *, zero_centered_gamma, epsilon): hidden_size, wkspace_aval.size, barrier_aval.size, - dgamma_part_aval.size, - dbeta_part_aval.size, + dgamma_part_aval.shape, + dbeta_part_aval.shape, jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), @@ -791,8 +790,8 @@ def lowering(ctx, x, gamma, *, epsilon): hidden_size, wkspace_aval.size, barrier_aval.size, - 0, # no dgamma_part in FWD pass - 0, # no dbeta_part in BWD pass + (0,), # no dgamma_part in FWD pass + (0,), # no dbeta_part in BWD pass jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), @@ -968,8 +967,8 @@ def lowering(ctx, dz, x, rsigma, gamma, *, epsilon): hidden_size, wkspace_aval.size, barrier_aval.size, - dgamma_part_aval.size, - 0, # no dbeta_part for RMSnorm + dgamma_part_aval.shape, + (0,), # no dbeta_part for RMSnorm jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), @@ -3588,8 +3587,8 @@ def lowering(ctx, x, gamma, beta, amax, scale, scale_inv, *, out_dtype, zero_cen hidden_size, wkspace_aval.size, barrier_aval.size, - 0, # no dgamma_part in FWD pass - 0, # no dbeta_part in BWD pass + (0,), # no dgamma_part in FWD pass + (0,), # no dbeta_part in BWD pass jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), @@ -3840,8 +3839,8 @@ def lowering(ctx, x, gamma, amax, scale, scale_inv, *, out_dtype, epsilon): hidden_size, wkspace_aval.size, barrier_aval.size, - 0, # no dgamma_part in FWD pass - 0, # no dbeta_part in BWD pass + (0,), # no dgamma_part in FWD pass + (0,), # no dbeta_part in BWD pass jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), diff --git a/transformer_engine/jax/csrc/modules.cpp b/transformer_engine/jax/csrc/modules.cpp index 1c4c468d51..4ac6fa58b1 100644 --- a/transformer_engine/jax/csrc/modules.cpp +++ b/transformer_engine/jax/csrc/modules.cpp @@ -71,17 +71,28 @@ pybind11::bytes PackCustomCallCommonWkDescriptor(const std::vector &shap return PackOpaque(desc); } -pybind11::bytes PackCustomCallNormDescriptor(size_t batch_size, size_t hidden_size, - size_t wkspace_size, size_t barrier_size, - size_t *dgamma_part_sizes, size_t *dbeta_part_sizes, - DType x_dtype, DType w_dtype, DType wkspace_dtype, - DType barrier_dtype, DType dgamma_part_dtype, - DType dbeta_part_dtype, bool zero_centered_gamma, - float eps, int sm_margin) { - return PackOpaque(CustomCallNormDescriptor{ - batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_sizes, dbeta_part_sizes, - x_dtype, w_dtype, wkspace_dtype, barrier_dtype, dgamma_part_dtype, dbeta_part_dtype, - zero_centered_gamma, eps, sm_margin}); +pybind11::bytes PackCustomCallNormDescriptor( + size_t batch_size, size_t hidden_size, size_t wkspace_size, size_t barrier_size, + const std::vector &dgamma_part_shape, const std::vector &dbeta_part_shape, + DType x_dtype, DType w_dtype, DType wkspace_dtype, DType barrier_dtype, DType dgamma_part_dtype, + DType dbeta_part_dtype, bool zero_centered_gamma, float eps, int sm_margin) { + CustomCallNormDescriptor desc; + desc.batch_size = batch_size; + desc.hidden_size = hidden_size; + desc.wkspace_size = wkspace_size; + desc.barrier_size = barrier_size; + desc.dgamma_part_shape.from_vector(dgamma_part_shape); + desc.dbeta_part_shape.from_vector(dbeta_part_shape); + desc.x_dtype = x_dtype; + desc.w_dtype = w_dtype; + desc.wkspace_dtype = wkspace_dtype; + desc.barrier_dtype = barrier_dtype; + desc.dgamma_part_dtype = dgamma_part_dtype; + desc.dbeta_part_dtype = dbeta_part_dtype; + desc.zero_centered_gamma = zero_centered_gamma; + desc.eps = eps; + desc.sm_margin = sm_margin; + return PackOpaque(desc); } pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t padding_size, @@ -529,7 +540,7 @@ pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hid } void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, size_t wkspace_size, - size_t barrier_size, size_t *dgamma_part_sizes, size_t *dbeta_part_sizes, + size_t barrier_size, Shape dgamma_part_shape, Shape dbeta_part_shape, bool zero_centered_gamma, float eps, void *input, DType in_dtype, void *weight, DType w_dtype, void *ograd, void *workspace, DType wkspace_dtype, void *barrier, DType barrier_dtype, void *mu, @@ -563,14 +574,14 @@ void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, size_t wkspace auto workspace_tensor = TensorWrapper(workspace, workspace_shape, wkspace_dtype); auto barrier_shape = std::vector{barrier_size}; auto barrier_tensor = TensorWrapper(barrier, barrier_shape, barrier_dtype); - auto dgamma_part_shape = std::vector{dgamma_part_sizes[0], dgamma_part_sizes[1]}; - auto dgamma_part_tensor = TensorWrapper(dgamma_part, dgamma_part_shape, dgamma_dtype); + auto dgamma_part_tensor = + TensorWrapper(dgamma_part, dgamma_part_shape.to_vector(), dgamma_dtype); if (is_layer_norm) { auto mu_tensor = TensorWrapper(mu, intermediates_shape, intermediates_dtype); auto dbeta_tensor = TensorWrapper(dbeta, weight_shape, w_dtype); - auto dbeta_part_shape = std::vector{dbeta_part_sizes[0], dbeta_part_sizes[1]}; - auto dbeta_part_tensor = TensorWrapper(dbeta_part, dbeta_part_shape, dbeta_dtype); + auto dbeta_part_tensor = + TensorWrapper(dbeta_part, dbeta_part_shape.to_vector(), dbeta_dtype); layernorm_bwd_func(dz_tensor.data(), x_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), gamma_tensor.data(), xgrad_tensor.data(), @@ -664,8 +675,8 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, auto hidden_size = desc.hidden_size; auto wkspace_size = desc.wkspace_size; auto barrier_size = desc.barrier_size; - auto *dgamma_part_sizes = desc.dgamma_part_sizes; - auto *dbeta_part_sizes = desc.dbeta_part_sizes; + auto dgamma_part_shape = desc.dgamma_part_shape; + auto dbeta_part_shape = desc.dbeta_part_shape; auto in_dtype = desc.x_dtype; auto w_dtype = desc.w_dtype; auto wkspace_dtype = desc.wkspace_dtype; @@ -689,8 +700,8 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, auto *dgamma_part = buffers[10]; auto *dbeta_part = buffers[11]; - LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_sizes, - dbeta_part_sizes, zero_centered_gamma, eps, input, in_dtype, weight, + LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_shape, + dbeta_part_shape, zero_centered_gamma, eps, input, in_dtype, weight, w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu, rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part, dbeta_part_dtype, stream); @@ -786,8 +797,9 @@ void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, si auto hidden_size = desc.hidden_size; auto wkspace_size = desc.wkspace_size; auto barrier_size = desc.barrier_size; - auto dgamma_part_sizes = desc.dgamma_part_sizes; - size_t dbeta_part_sizes[2] = {0, 0}; + auto dgamma_part_shape = desc.dgamma_part_shape; + Shape dbeta_part_shape; + dbeta_part_shape.from_vector({0, 0}); auto in_dtype = desc.x_dtype; auto w_dtype = desc.w_dtype; auto wkspace_dtype = desc.wkspace_dtype; @@ -797,8 +809,8 @@ void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, si auto eps = desc.eps; auto zero_centered_gamma = desc.zero_centered_gamma; - LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_sizes, - dbeta_part_sizes, zero_centered_gamma, eps, input, in_dtype, weight, + LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_shape, + dbeta_part_shape, zero_centered_gamma, eps, input, in_dtype, weight, w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu, rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part, dbeta_part_dtype, stream); diff --git a/transformer_engine/jax/csrc/modules.h b/transformer_engine/jax/csrc/modules.h index e392931d04..04f0039b02 100644 --- a/transformer_engine/jax/csrc/modules.h +++ b/transformer_engine/jax/csrc/modules.h @@ -69,8 +69,8 @@ struct CustomCallNormDescriptor { size_t hidden_size; size_t wkspace_size; size_t barrier_size; - size_t *dgamma_part_sizes; // 2D tensor - size_t *dbeta_part_sizes; // 2D tensor + Shape dgamma_part_shape; + Shape dbeta_part_shape; DType x_dtype; DType w_dtype; DType wkspace_dtype; @@ -82,13 +82,11 @@ struct CustomCallNormDescriptor { int sm_margin; }; -pybind11::bytes PackCustomCallNormDescriptor(size_t batch_size, size_t hidden_size, - size_t wkspace_size, size_t barrier_size, - size_t *dgamma_part_sizes, size_t *dbeta_part_sizes, - DType x_dtype, DType w_dtype, DType wkspace_dtype, - DType barrier_dtype, DType dgamma_part_dtype, - DType dbeta_part_dtype, bool zero_centered_gamma, - float eps, int sm_margin); +pybind11::bytes PackCustomCallNormDescriptor( + size_t batch_size, size_t hidden_size, size_t wkspace_size, size_t barrier_size, + const std::vector &dgamma_part_shape, const std::vector &dbeta_part_shape, + DType x_dtype, DType w_dtype, DType wkspace_dtype, DType barrier_dtype, DType dgamma_part_dtype, + DType dbeta_part_dtype, bool zero_centered_gamma, float eps, int sm_margin); struct SoftmaxDescriptor { size_t batch_size; From fab53a4c1856fca3c0a1ec2995e74db6c25049e1 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Mon, 22 Apr 2024 09:22:54 -0700 Subject: [PATCH 031/244] [PyTorch] Remove unnecessary Pylint overrides (#794) * Remove unnecessary Pylint overrides Signed-off-by: Tim Moon * Fixes to lint Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Tim Moon Signed-off-by: Kirthi Shankar Sivamani Co-authored-by: Kirthi Shankar Sivamani Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/__init__.py | 8 +++--- transformer_engine/pytorch/cpu_offload.py | 29 ++++++++++++++------- transformer_engine/pytorch/float8_tensor.py | 2 -- transformer_engine/pytorch/fp8.py | 20 +++++++------- transformer_engine/pytorch/module/base.py | 2 +- transformer_engine/pytorch/softmax.py | 21 ++++++++------- 6 files changed, 46 insertions(+), 36 deletions(-) diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 4c513339a0..eccde1d530 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -3,6 +3,8 @@ # See LICENSE for license information. """Transformer Engine bindings for pyTorch""" +import torch + from .module import LayerNormLinear from .module import Linear from .module import LayerNormMLP @@ -32,8 +34,8 @@ onnx_rmsnorm_fwd, onnx_rmsnorm_fwd_fp8 ) + try: - import torch torch._dynamo.config.error_on_nested_jit_trace = False -except: # pylint: disable=bare-except - pass +except AttributeError: + pass # error_on_nested_jit_trace was added in PyTorch 2.2.0 diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index b2635bb9bf..0890ca5875 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -3,8 +3,10 @@ # See LICENSE for license information. """Functionality for CPU offloading of tensors saved for backward pass.""" -from typing import Any +from __future__ import annotations from contextlib import nullcontext +from typing import Any, Dict, Optional + import torch from .float8_tensor import Float8Tensor @@ -99,10 +101,17 @@ class CpuOffloadHookWithOffloadHandler(CpuOffloadSavedTensorHook): and `tensor_pop` interface. How the offload-handler manages the offloading, recovering or prefetching timing is transparent to this hook. """ - def __init__(self, offload_handler, handler_extra_kwargs={}, debug=False) -> None: # pylint: disable=dangerous-default-value - self.debug = debug - self.offload_handler = offload_handler - self.handler_extra_kwargs = handler_extra_kwargs + def __init__( + self, + offload_handler: OffloadHandler, + handler_extra_kwargs: Optional[Dict[str,Any]] = None, + debug: bool = False, + ) -> None: + if handler_extra_kwargs is None: + handler_extra_kwargs = {} + self.debug: bool = debug + self.offload_handler: OffloadHandler = offload_handler + self.handler_extra_kwargs: Dict[str,Any] = handler_extra_kwargs super().__init__() def on_save_for_backward(self, tensor: torch.Tensor) -> Any: @@ -290,10 +299,10 @@ def get_tensor_buf_for_offloaded_tensor(self, tensor, tensor_tag): allocate_new_buf = True else: tensor_buf = id_buf_map[tensor_id] - if not (tensor_buf.size() == tensor.size() and tensor_buf.dtype == tensor.dtype): # pylint: disable=simplifiable-if-statement - allocate_new_buf = True - else: - allocate_new_buf = False # in this case, reuse the old buffer + allocate_new_buf = ( + tensor_buf.size() != tensor.size() + or tensor_buf.dtype != tensor.dtype + ) if allocate_new_buf: # supposed to only execute once @@ -491,7 +500,7 @@ def tensor_need_offloading_checker_activations(tensor): def tensor_need_offloading_checker_weights(tensor): return hasattr(tensor, "weight_offloading") - def tensor_need_offloading_checker_all(tensor): # pylint: disable=unused-argument + def tensor_need_offloading_checker_all(tensor): return (hasattr(tensor,"activation_offloading") or hasattr(tensor, "weight_offloading")) if offload_activations and offload_weights: diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index f93d6ae5cb..bbcbc2839c 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -730,8 +730,6 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): return None # Slice op - # TODO Consider additional bookkeeping so we invalidate caches # pylint: disable=fixme - # if these slices are modified in-place if func == aten.slice.Tensor: tensor = args[0] data = tensor._data diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index b871169a11..1f359d4864 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -502,12 +502,12 @@ def fp8_model_init(enabled: bool = True) -> None: This functionality is *EXPERIMENTAL*. """ + _fp8_parameters = FP8GlobalStateManager.FP8_PARAMETERS + FP8GlobalStateManager.FP8_PARAMETERS = enabled try: - _fp8_parameters = FP8GlobalStateManager.FP8_PARAMETERS - FP8GlobalStateManager.FP8_PARAMETERS = enabled yield finally: - FP8GlobalStateManager.FP8_PARAMETERS = _fp8_parameters # pylint: disable=used-before-assignment + FP8GlobalStateManager.FP8_PARAMETERS = _fp8_parameters @contextmanager @@ -555,16 +555,16 @@ def fp8_autocast( distributed group over which amaxes for the fp8 tensors are reduced at the end of each training step. """ + fp8_state = FP8GlobalStateManager.get_fp8_autocast_state() + FP8GlobalStateManager.fp8_autocast_enter(enabled=enabled, + calibrating=calibrating, + fp8_recipe=fp8_recipe, + fp8_group=fp8_group, + _graph=_graph) try: - fp8_state = FP8GlobalStateManager.get_fp8_autocast_state() - FP8GlobalStateManager.fp8_autocast_enter(enabled=enabled, - calibrating=calibrating, - fp8_recipe=fp8_recipe, - fp8_group=fp8_group, - _graph=_graph) yield finally: - FP8GlobalStateManager.set_fp8_autocast_state(fp8_state) # pylint: disable=used-before-assignment + FP8GlobalStateManager.set_fp8_autocast_state(fp8_state) FP8GlobalStateManager.fp8_autocast_exit(enabled, _graph=_graph) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 00f5c2216d..e0bf5efbbf 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -703,7 +703,7 @@ def grad_output_preprocess( out=grad_output_c, ) else: - grad_output_c = grad_ouput_mat # pylint: disable=undefined-variable + grad_output_c = grad_output_mat if not ctx.ub_overlap_ag: grad_output_c, _ = gather_along_first_dim(grad_output_c, ctx.tp_group) if not isinstance(grad_output_c, Float8Tensor): diff --git a/transformer_engine/pytorch/softmax.py b/transformer_engine/pytorch/softmax.py index 593a05cb71..57fccd80ad 100644 --- a/transformer_engine/pytorch/softmax.py +++ b/transformer_engine/pytorch/softmax.py @@ -336,19 +336,20 @@ def forward( return self.forward_fused_softmax(inp, mask, scale) return self.forward_torch_softmax(inp, mask, scale) - def is_kernel_available(self, mask: torch.Tensor, b: int, np: int, sq: int, sk: int) -> bool: + def is_kernel_available(self, mask: torch.Tensor, b: int, np: int, sq: int, sk: int) -> bool: # pylint: disable=too-many-return-statements """Check FusedScaleMaskSoftmax kernel availability based on size""" attn_batches = b * np - if ( # pylint: disable=too-many-boolean-expressions - not self.scaled_masked_softmax_fusion # user doesn't want to fuse - or not self.input_in_float16 # input must be fp16 - or sk < 16 - or sk > 16384 # sk must be 16 ~ 16384 - or sk % 8 != 0 # sk must be divisor of 8 - or self.attn_mask_type == "arbitrary" # Custom masks not supported - ): - return False + if not self.scaled_masked_softmax_fusion: + return False # user doesn't want to fuse + if not self.input_in_float16: + return False # input must be fp16 + if not 16 < sk < 16384: + return False # sk must be 16 ~ 16384 + if sk % 8 != 0: + return False # sk must be divisor of 8 + if self.attn_mask_type == "arbitrary": + return False # Custom masks not supported if self.attn_mask_type == "causal": # unfused causal softmax kernel return True From 165225afcf83f2814419f7f564764bb76a876f5b Mon Sep 17 00:00:00 2001 From: Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com> Date: Tue, 23 Apr 2024 17:08:24 -0700 Subject: [PATCH 032/244] [JAX] Unifying GeLU and GeGLU in LayerNorm MLP (#765) * combined layernorm_geglu with layernorm_gelu into fused_layernorm Signed-off-by: Phuong Nguyen * fixes to pass all unit tests in test_custom_call_compute.py, test_layer.py, and test_praxis_layer.py Signed-off-by: Phuong Nguyen * cleaning and formatting Signed-off-by: Phuong Nguyen * renaming based on reviewers suggestions Signed-off-by: Phuong Nguyen * implemented partial fused layernorm Signed-off-by: Phuong Nguyen * geglu + bias passed tests Signed-off-by: Phuong Nguyen * added partial fused calculation for dbias_1 Signed-off-by: Phuong Nguyen * clean up Co-authored-by: Alp Dener Signed-off-by: Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com> --------- Signed-off-by: Phuong Nguyen Signed-off-by: Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com> Co-authored-by: Alp Dener Signed-off-by: Pawel Gadzinski --- tests/jax/test_custom_call_compute.py | 189 ++---- .../common/transpose/cast_transpose_fusion.cu | 9 +- transformer_engine/jax/cpp_extensions.py | 225 +++++++ transformer_engine/jax/csrc/extensions.cpp | 2 + transformer_engine/jax/csrc/modules.cpp | 63 ++ transformer_engine/jax/csrc/modules.h | 6 + transformer_engine/jax/flax/module.py | 105 ++-- transformer_engine/jax/mlp.py | 593 ++++++------------ 8 files changed, 575 insertions(+), 617 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 8aa6c399f4..139ef994fa 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -4,6 +4,7 @@ import functools import operator +from typing import Callable, Sequence, Union import jax import jax.numpy as jnp @@ -22,8 +23,7 @@ from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper from transformer_engine.jax.fp8 import is_fp8_available from transformer_engine.jax.layernorm import layernorm -from transformer_engine.jax.mlp import layernorm_geglu_fp8_mlp -from transformer_engine.jax.mlp import layernorm_gelu_fp8_mlp +from transformer_engine.jax.mlp import fused_layernorm_fp8_mlp GEMM_CASES = [ (256, 256, 512), @@ -174,17 +174,32 @@ def ref_func(x, y): assert_allclose(primitive_b_grad, ref_b_grad, dtype=FP8Helper.BWD_DTYPE) @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize('m,n,k', [(256, 256, 512), (16384, 1024, 2816), (16384, 2816, 1024), + @pytest.mark.parametrize('m,n,k', [(256, 512, 128), (16384, 1024, 2816), (16384, 2816, 1024), (16384, 1024, 1024)]) - def test_grad_ln_geglu_fp8_mlp(self, m, n, k): + @pytest.mark.parametrize('activation_type', [('gelu', ), + ('gelu', 'linear')]) + @pytest.mark.parametrize('use_bias', [True, False]) + def test_grad_fused_layernorm_fp8_mlp(self, m, n, k, + activation_type: Sequence[Union[str, Callable]], + use_bias: bool): + """ N/a """ key = jax.random.PRNGKey(0) - subkeys = jax.random.split(key, 4) - activations = ('gelu', 'linear') + subkeys = jax.random.split(key, 6) + + activation_dict = { + ('gelu', ): jax.nn.gelu + } a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16) - k1 = jax.random.normal(subkeys[1], (k, len(activations), n), jnp.bfloat16) + k1 = jax.random.normal(subkeys[1], (k, len(activation_type), n), jnp.bfloat16) k2 = jax.random.normal(subkeys[2], (n, k), jnp.bfloat16) - s = jax.random.normal(subkeys[3], (k,), jnp.bfloat16) + s = jax.random.normal(subkeys[5], (k,), jnp.bfloat16) + if use_bias: + b1 = jax.random.normal(subkeys[3], (len(activation_type), n), jnp.bfloat16) + b2 = jax.random.normal(subkeys[4], (k,), jnp.bfloat16) + else: + b1 = jax.random.normal(subkeys[3], (0,), jnp.bfloat16) + b2 = jax.random.normal(subkeys[4], (0,), jnp.bfloat16) init_fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM * 2) init_fp8_metas_amax = jnp.zeros( @@ -192,14 +207,16 @@ def test_grad_ln_geglu_fp8_mlp(self, m, n, k): init_fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32) init_fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32) - def primitive_func(x, ln_s, y, z, fp8_max, fp8_metas_amax, fp8_metas_scale, + def primitive_func(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv): # x is input tensor, matrix 2d # y, z are weights, matrix 2d - # out = (x * y) * z + # out = ((x * y) + w) * z + v fp8_meta_pkg = FP8MetaPackage(2, fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv) - return jnp.mean(layernorm_geglu_fp8_mlp(x, ln_s, None, [y, z], fp8_meta_pkg, "rmsnorm")) + return jnp.mean( + fused_layernorm_fp8_mlp(x, ln_s, None, [y, z], [w, v], fp8_meta_pkg, "rmsnorm", + activation_type = activation_type, use_bias = use_bias)) def _convert_to_activation_function(fn_or_string): """Convert a string to an activation function.""" @@ -211,115 +228,7 @@ def _convert_to_activation_function(fn_or_string): return fn_or_string raise ValueError(f"don't know how to convert {fn_or_string} to an activation function") - def ln_geglu_fp8_mlp_ref(x: jnp.ndarray, ln_scale: jnp.ndarray, kernel_1: jnp.ndarray, - kernel_2: jnp.ndarray, fp8_maxs: jnp.ndarray, amax: jnp.ndarray, - scale: jnp.ndarray, scale_inv: jnp.ndarray) -> jnp.ndarray: - - x = jnp.asarray(x, jnp.float32) - mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True) - y = jnp.asarray(x * jax.lax.rsqrt(mean2 + 1e-6), jnp.bfloat16) - ln_out = y * ln_scale - ln_out = jnp.asarray(ln_out, jnp.bfloat16) - - fp8_gemm_1_pkg = FP8MetaPackage(1, fp8_maxs[:FP8Helper.NUM_META_PER_GEMM], - amax[:FP8Helper.NUM_META_PER_GEMM], - scale[:FP8Helper.NUM_META_PER_GEMM], - scale_inv[:FP8Helper.NUM_META_PER_GEMM]) - linear_1_out = type_safe_dot_general(ln_out, kernel_1, fp8_gemm_1_pkg, ((1,), (0,))) - - x = jnp.split(linear_1_out, len(activations), axis=-2) - acts = [] - for idx, act_fn in enumerate(activations): - x_i = _convert_to_activation_function(act_fn)(x[idx]) - acts.append(x_i) - x = functools.reduce(operator.mul, acts) - x = jnp.asarray(jnp.squeeze(x, axis=-2), jnp.bfloat16) - - fp8_gemm_2_pkg = FP8MetaPackage(1, fp8_maxs[FP8Helper.NUM_META_PER_GEMM:], - amax[FP8Helper.NUM_META_PER_GEMM:], - scale[FP8Helper.NUM_META_PER_GEMM:], - scale_inv[FP8Helper.NUM_META_PER_GEMM:]) - output = type_safe_dot_general(x, kernel_2, fp8_gemm_2_pkg, ((1,), (0,))) - return output - - def ref_func(x, ln_s, y, z, fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv): - return jnp.mean( - ln_geglu_fp8_mlp_ref(x, ln_s, y, z, fp8_max, fp8_metas_amax, fp8_metas_scale, - fp8_metas_scale_inv)) - - value_n_grad_primitive_func = jit(value_and_grad(primitive_func, (0, 1, 2, 3, 4, 5, 6, 7))) - value_n_grad_ref_func = jit(value_and_grad(ref_func, (0, 1, 2, 3, 4, 5, 6, 7))) - - ref_fp8_max = init_fp8_max - ref_fp8_metas_amax = init_fp8_metas_amax - ref_fp8_metas_scale = init_fp8_metas_scale - ref_fp8_metas_scale_inv = init_fp8_metas_scale_inv - - pri_fp8_max = init_fp8_max - pri_fp8_metas_amax = init_fp8_metas_amax - pri_fp8_metas_scale = init_fp8_metas_scale - pri_fp8_metas_scale_inv = init_fp8_metas_scale_inv - - for _ in range(3): - ref_out, (ref_a_grad, ref_s_grad, ref_k1_grad, ref_k2_grad, ref_fp8_max, - ref_fp8_metas_amax, ref_fp8_metas_scale, - ref_fp8_metas_scale_inv) = value_n_grad_ref_func( - a, s, k1, k2, ref_fp8_max, ref_fp8_metas_amax, ref_fp8_metas_scale, - ref_fp8_metas_scale_inv) - - for _ in range(3): - primitive_out, (primitive_a_grad, primitive_s_grad, primitive_k1_grad, - primitive_k2_grad, pri_fp8_max, pri_fp8_metas_amax, pri_fp8_metas_scale, - pri_fp8_metas_scale_inv) = value_n_grad_primitive_func( - a, s, k1, k2, pri_fp8_max, pri_fp8_metas_amax, pri_fp8_metas_scale, - pri_fp8_metas_scale_inv) - - assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE) - assert_allclose(jnp.asarray(primitive_a_grad, np.float32), - jnp.asarray(ref_a_grad, np.float32), - dtype=FP8Helper.BWD_DTYPE) - assert_allclose(jnp.asarray(primitive_k1_grad, np.float32), - jnp.asarray(ref_k1_grad, np.float32), - dtype=FP8Helper.BWD_DTYPE) - assert_allclose(jnp.asarray(primitive_k2_grad, np.float32), - jnp.asarray(ref_k2_grad, np.float32), - dtype=FP8Helper.BWD_DTYPE) - assert_allclose(jnp.asarray(primitive_s_grad, np.float32), - jnp.asarray(ref_s_grad, np.float32), - dtype=FP8Helper.BWD_DTYPE) - - @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize('m,n,k', [(256, 256, 512), (16384, 1024, 2816), (16384, 2816, 1024), - (16384, 1024, 1024)]) - def test_grad_ln_gelu_fp8_mlp(self, m, n, k): - key = jax.random.PRNGKey(0) - subkeys = jax.random.split(key, 6) - activations = ('gelu',) - - a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16) - k1 = jax.random.normal(subkeys[1], (k, len(activations), n), jnp.bfloat16) - k2 = jax.random.normal(subkeys[2], (n, k), jnp.bfloat16) - b1 = jax.random.normal(subkeys[3], (len(activations), n), jnp.bfloat16) - b2 = jax.random.normal(subkeys[4], (k,), jnp.bfloat16) - s = jax.random.normal(subkeys[5], (k,), jnp.bfloat16) - - init_fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM * 2) - init_fp8_metas_amax = jnp.zeros( - (FP8Helper.NUM_META_PER_GEMM * 2, FP8Helper.AMAX_HISTORY_LEN), jnp.float32) - init_fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32) - init_fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32) - - def primitive_func(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale, - fp8_metas_scale_inv): - # x is input tensor, matrix 2d - # y, z are weights, matrix 2d - # out = ((x * y) + w) * z + v - fp8_meta_pkg = FP8MetaPackage(2, fp8_max, fp8_metas_amax, fp8_metas_scale, - fp8_metas_scale_inv) - return jnp.mean( - layernorm_gelu_fp8_mlp(x, ln_s, None, [y, z], [w, v], fp8_meta_pkg, "rmsnorm")) - - def ln_gelu_fp8_mlp_ref(x: jnp.ndarray, ln_scale: jnp.ndarray, kernel_1: jnp.ndarray, + def layernorm_fp8_mlp_ref(x: jnp.ndarray, ln_scale: jnp.ndarray, kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, bias_1: jnp.ndarray, bias_2: jnp.ndarray, fp8_maxs: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray) -> jnp.ndarray: @@ -336,10 +245,20 @@ def ln_gelu_fp8_mlp_ref(x: jnp.ndarray, ln_scale: jnp.ndarray, kernel_1: jnp.nda scale_inv[:FP8Helper.NUM_META_PER_GEMM]) linear_1_out = type_safe_dot_general(ln_out, kernel_1, fp8_gemm_1_pkg, ((1,), (0,))) - bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape - linear_1_out += jnp.reshape(bias_1, bias_1_shape) + if use_bias: + bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape + linear_1_out += jnp.reshape(bias_1, bias_1_shape) + + if 'linear' in activation_type: + x = jnp.split(linear_1_out, len(activation_type), axis=-2) + acts = [] + for idx, act_fn in enumerate(activation_type): + x_i = _convert_to_activation_function(act_fn)(x[idx]) + acts.append(x_i) + x = functools.reduce(operator.mul, acts) + else: + x = activation_dict[activation_type](linear_1_out) - x = jax.nn.gelu(linear_1_out) x = jnp.asarray(jnp.squeeze(x, axis=-2), jnp.bfloat16) fp8_gemm_2_pkg = FP8MetaPackage(1, fp8_maxs[FP8Helper.NUM_META_PER_GEMM:], @@ -348,15 +267,16 @@ def ln_gelu_fp8_mlp_ref(x: jnp.ndarray, ln_scale: jnp.ndarray, kernel_1: jnp.nda scale_inv[FP8Helper.NUM_META_PER_GEMM:]) output = type_safe_dot_general(x, kernel_2, fp8_gemm_2_pkg, ((1,), (0,))) - bias_2_shape = (1,) * (output.ndim - bias_2.ndim) + bias_2.shape - output += jnp.reshape(bias_2, bias_2_shape) + if use_bias: + bias_2_shape = (1,) * (output.ndim - bias_2.ndim) + bias_2.shape + output += jnp.reshape(bias_2, bias_2_shape) return output def ref_func(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv): return jnp.mean( - ln_gelu_fp8_mlp_ref(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale, + layernorm_fp8_mlp_ref(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv)) value_n_grad_primitive_func = jit( @@ -373,12 +293,13 @@ def ref_func(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale, pri_fp8_metas_scale = init_fp8_metas_scale pri_fp8_metas_scale_inv = init_fp8_metas_scale_inv + # Convert str to index as str is not a valid type for JAX JIT for _ in range(3): ref_out, (ref_a_grad, ref_s_grad, ref_k1_grad, ref_k2_grad, ref_b1_grad, ref_b2_grad, ref_fp8_max, ref_fp8_metas_amax, ref_fp8_metas_scale, ref_fp8_metas_scale_inv) = value_n_grad_ref_func( a, s, k1, k2, b1, b2, ref_fp8_max, ref_fp8_metas_amax, - ref_fp8_metas_scale, ref_fp8_metas_scale_inv) + ref_fp8_metas_scale, ref_fp8_metas_scale_inv) for _ in range(3): primitive_out, (primitive_a_grad, primitive_s_grad, primitive_k1_grad, @@ -401,12 +322,14 @@ def ref_func(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale, assert_allclose(jnp.asarray(primitive_s_grad, np.float32), jnp.asarray(ref_s_grad, np.float32), dtype=FP8Helper.BWD_DTYPE) - assert_allclose(jnp.asarray(primitive_b1_grad, np.float32), - jnp.asarray(ref_b1_grad, np.float32), - dtype=jnp.bfloat16) - assert_allclose(jnp.asarray(primitive_b2_grad, np.float32), - jnp.asarray(ref_b2_grad, np.float32), - dtype=jnp.bfloat16) + if use_bias: + assert_allclose(jnp.asarray(primitive_b1_grad, np.float32), + jnp.asarray(ref_b1_grad, np.float32), + dtype=jnp.bfloat16) + assert_allclose(jnp.asarray(primitive_b2_grad, np.float32), + jnp.asarray(ref_b2_grad, np.float32), + dtype=jnp.bfloat16) + @pytest.fixture(name="random_inputs") diff --git a/transformer_engine/common/transpose/cast_transpose_fusion.cu b/transformer_engine/common/transpose/cast_transpose_fusion.cu index 434f2651d3..8e455dddb5 100644 --- a/transformer_engine/common/transpose/cast_transpose_fusion.cu +++ b/transformer_engine/common/transpose/cast_transpose_fusion.cu @@ -529,10 +529,11 @@ void cast_transpose_dbias(const Tensor &input, Tensor *dbias, Tensor *workspace, cudaStream_t stream) { - CheckInputTensor(input, "cast_transpose_dbias_input"); - CheckOutputTensor(*cast_output, "cast_output"); - CheckOutputTensor(*transposed_output, "transposed_output"); - CheckOutputTensor(*dbias, "dbias"); + // TODO + // CheckInputTensor(input, "cast_transpose_dbias_input"); + // CheckOutputTensor(*cast_output, "cast_output"); + // CheckOutputTensor(*transposed_output, "transposed_output"); + // CheckOutputTensor(*dbias, "dbias"); NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions."); diff --git a/transformer_engine/jax/cpp_extensions.py b/transformer_engine/jax/cpp_extensions.py index 3356aafef5..adcd5770e2 100644 --- a/transformer_engine/jax/cpp_extensions.py +++ b/transformer_engine/jax/cpp_extensions.py @@ -4334,6 +4334,231 @@ def dgelu_dbias_cast_transpose( transpose_axis_boundary=transpose_axis_boundary) +class DBiasCastTransposePrimitive(BasePrimitive): + """ + DBias Cast Transpose Primitive + """ + name = "te_dbias_cast_transpose" + multiple_results = True + # out_dtype, static_axis_boundary, transpose_axis_boundary + impl_static_args = (4, 5, 6) + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract(dz_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype, + static_axis_boundary, transpose_axis_boundary): + """ + te_dbias_cast_transpose_p abstract + """ + dtype = dtypes.canonicalize_dtype(dz_aval.dtype) + assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + assert amax_aval.dtype == jnp.float32 + assert scale_aval.dtype == jnp.float32 + assert scale_inv_aval.dtype == jnp.float32 + gi_hidden_size = dz_aval.shape[-1] + t_shape = _multidim_transpose(dz_aval.shape, static_axis_boundary, transpose_axis_boundary) + out = dz_aval.update(shape=dz_aval.shape, dtype=out_dtype) + t_out = dz_aval.update(shape=t_shape, dtype=out_dtype) + + if dz_aval.shape[-2] == 2: + gi_hidden_size *= 2 + dbias_shape = (*dz_aval.shape[:static_axis_boundary + 1], gi_hidden_size) + dbias = dz_aval.update(shape=dbias_shape, dtype=dtype) + + updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype) + wkspace_info, = transformer_engine_jax.get_dbias_ct_workspace_sizes( + dz_aval.size // gi_hidden_size, + gi_hidden_size, + jax_dtype_to_te_dtype(dz_aval.dtype), + jax_dtype_to_te_dtype(out_dtype) + ) + wkspace_aval = dz_aval.update(shape=wkspace_info[0], + dtype=te_dtype_to_jax_dtype(wkspace_info[1])) + + return out, t_out, dbias, updated_amax_aval, wkspace_aval + + @staticmethod + def outer_abstract(*args, **kwargs): + """ + te_dbias_cast_transpose_p outer abstract + """ + + out, t_out, dbias, updated_amax_aval, _ = \ + DBiasCastTransposePrimitive.abstract(*args, **kwargs) + return out, t_out, dbias, updated_amax_aval + + @staticmethod + def lowering(ctx, dz, amax, scale, scale_inv, *, out_dtype, static_axis_boundary, + transpose_axis_boundary): + """ + te_dbias_cast_transpose_p lowering rules + """ + dz_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in + assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + assert amax_aval.dtype == jnp.float32 + assert scale_aval.dtype == jnp.float32 + assert scale_inv_aval.dtype == jnp.float32 + ir_dz_type = ir.RankedTensorType(dz.type) + ir_dz_shape = ir_dz_type.shape + ir_hidden_szie = ir_dz_shape[-1] + if dz_aval.shape[-2] == 2: + batch_szie = reduce(operator.mul, ir_dz_shape[:-2]) + ir_hidden_szie *= 2 + else: + batch_szie = reduce(operator.mul, ir_dz_shape[:-1]) + contracted_dz_shape = (batch_szie, ir_hidden_szie) + ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype) + ir_amax_type = ir.RankedTensorType(amax.type) + ir_amax_dtype = ir_amax_type.element_type + ir_amax_shape = ir_amax_type.shape + ir_scale_shape = ir_amax_shape + ir_scale_inv_shape = ir_amax_shape + transposed_dz_shape = _multidim_transpose(ir_dz_shape, static_axis_boundary, + transpose_axis_boundary) + dbias_shape = (*ir_dz_shape[:static_axis_boundary + 1], ir_hidden_szie) + + wkspace_aval = ctx.avals_out[-1] + + out_types = [ + ir.RankedTensorType.get(ir_dz_shape, ir_out_dtype), + ir.RankedTensorType.get(transposed_dz_shape, ir_out_dtype), + ir.RankedTensorType.get(dbias_shape, ir_dz_type.element_type), + ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), + ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)), + ] + operands = [dz, amax, scale, scale_inv] + operand_shapes = [ir_dz_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + opaque = transformer_engine_jax.pack_common_wk_descriptor( + contracted_dz_shape, wkspace_aval.shape, jax_dtype_to_te_dtype(dz_aval.dtype), + jax_dtype_to_te_dtype(out_dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype)) + + out = custom_caller(DBiasCastTransposePrimitive.name, + args, + opaque, + False, + operand_output_aliases={1: 3}) + + return out + + @staticmethod + def impl(dz, amax, scale, scale_inv, out_dtype, static_axis_boundary, + transpose_axis_boundary): + """ + to describe implementation + """ + assert DBiasCastTransposePrimitive.inner_primitive is not None + out, t_out, dbias, updated_amax, _ = DBiasCastTransposePrimitive.inner_primitive.bind( + dz, + amax, + scale, + scale_inv, + out_dtype=out_dtype, + static_axis_boundary=static_axis_boundary, + transpose_axis_boundary=transpose_axis_boundary) + return out, t_out, dbias, updated_amax + + @staticmethod + def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary, + transpose_axis_boundary): + """ + to describe batch rules for vmap + """ + del static_axis_boundary + _check_valid_batch_dims(batch_dims) + assert DBiasCastTransposePrimitive.outer_primitive is not None + dz, amax, scale, scale_inv = batched_args + dz_bdim, _, amax_bdim, _, _ = batch_dims + + # Minus batch dim. + transpose_axis_boundary = _normalize_axis_boundary(transpose_axis_boundary, dz.ndim - 1) + transpose_axis_boundary += 1 # Plus batch dim + + out_bdims = dz_bdim, dz_bdim, dz_bdim, amax_bdim + return DBiasCastTransposePrimitive.outer_primitive.bind( + dz, + amax, + scale, + scale_inv, + out_dtype=out_dtype, + static_axis_boundary=dz_bdim, + transpose_axis_boundary=transpose_axis_boundary), out_bdims + + @staticmethod + def infer_sharding_from_operands(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, + arg_infos, result_infos): + del out_dtype, result_infos + x_spec = get_padded_spec(arg_infos[1]) + out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) + xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary) + tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec)) + dbias_shaprding = NamedSharding( + mesh, PartitionSpec(*x_spec[:static_axis_boundary + 1], x_spec[-1])) + amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2]))) + return (out_sharding, tranposed_out_sharding, dbias_shaprding, amax_sharding) + + @staticmethod + def partition(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, + result_infos): + del result_infos + x_spec = get_padded_spec(arg_infos[1]) + casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) + xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary) + casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec)) + + dbias_shaprding = NamedSharding( + mesh, PartitionSpec(*x_spec[:static_axis_boundary + 1], x_spec[-1])) + + amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2]))) + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) + out_shardings = (casted_x_sharding, casted_transposed_x_sharding, dbias_shaprding, + amax_sharding) + + def sharded_impl(dz, amax, scale, scale_inv): + local_out, local_t_out, local_dbias, local_amax = DBiasCastTransposePrimitive.impl( + dz, + amax, + scale, + scale_inv, + out_dtype=out_dtype, + static_axis_boundary=static_axis_boundary, + transpose_axis_boundary=transpose_axis_boundary) + global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias) + global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax) + return local_out, local_t_out, global_dbias, global_updated_amax + + return mesh, sharded_impl, out_shardings, arg_shardings + + +register_primitive(DBiasCastTransposePrimitive) + + +def dbias_cast_transpose( + dz: jnp.ndarray, + amax: jnp.ndarray, + scale: jnp.ndarray, + scale_inv: jnp.ndarray, + out_dtype: TEDType, + static_axis_boundary: int, + transpose_axis_boundary: int = -1) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """ + cast transpose dbias partial fusion wrapper + Return FP8(inputs), dbias + """ + if static_axis_boundary < 0: + static_axis_boundary = -1 # means no static axes + + return DBiasCastTransposePrimitive.outer_primitive.bind( + dz, + amax, + scale, + scale_inv, + out_dtype=out_dtype, + static_axis_boundary=static_axis_boundary, + transpose_axis_boundary=transpose_axis_boundary) + + class GatedGeluFp8Primitive(BasePrimitive): """ Gated Gelu FP8 Primitive diff --git a/transformer_engine/jax/csrc/extensions.cpp b/transformer_engine/jax/csrc/extensions.cpp index 5e4ab4f205..8aa6b492c8 100644 --- a/transformer_engine/jax/csrc/extensions.cpp +++ b/transformer_engine/jax/csrc/extensions.cpp @@ -29,6 +29,7 @@ pybind11::dict Registrations() { dict["te_gelu_fp8"] = EncapsulateFunction(GeluFP8); dict["te_dgelu"] = EncapsulateFunction(DGelu); dict["te_dgelu_dbias_cast_transpose"] = EncapsulateFunction(DGeluDBiasCastTranspose); + dict["te_dbias_cast_transpose"] = EncapsulateFunction(DBiasCastTranspose); dict["te_gated_gelu"] = EncapsulateFunction(GatedGelu); dict["te_gated_gelu_fp8"] = EncapsulateFunction(GatedGeluFP8); dict["te_dgated_gelu"] = EncapsulateFunction(DGatedGelu); @@ -66,6 +67,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("get_device_compute_capability", &GetDeviceComputeCapability); m.def("get_cublasLt_version", &cublasLtGetVersion); m.def("get_dgelu_dbias_ct_workspace_sizes", &GetDGeluDBiasCastTransposeWorkspaceSizes); + m.def("get_dbias_ct_workspace_sizes", &GetDBiasCastTransposeWorkspaceSizes); m.def("get_layernorm_fwd_workspace_sizes", &GetLayerNormForwardWorkspaceSizes); m.def("get_layernorm_bwd_workspace_sizes", &GetLayerNormBackwardWorkspaceSizes); m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes); diff --git a/transformer_engine/jax/csrc/modules.cpp b/transformer_engine/jax/csrc/modules.cpp index 4ac6fa58b1..48b02bcaeb 100644 --- a/transformer_engine/jax/csrc/modules.cpp +++ b/transformer_engine/jax/csrc/modules.cpp @@ -301,6 +301,69 @@ void DGeluDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *op dbias_tensor.data(), workspace.data(), stream); } +// HERE +pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size, + DType in_dtype, DType out_dtype) { + auto input_shape = std::vector{batch_size, hidden_size}; + auto output_shape = std::vector{batch_size, hidden_size}; + auto output_trans_shape = std::vector{hidden_size, batch_size}; + auto dbias_shape = std::vector{hidden_size}; + + auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype); + auto output_tensor = TensorWrapper(nullptr, output_shape, out_dtype); + auto output_trans_tensor = TensorWrapper(nullptr, output_trans_shape, out_dtype); + auto dbias_tensor = TensorWrapper(nullptr, dbias_shape, in_dtype); + + TensorWrapper dummy_workspace; + + nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(), + output_trans_tensor.data(), dbias_tensor.data(), + dummy_workspace.data(), nullptr); + + auto work_shape = MakeShapeVector(dummy_workspace.shape()); + return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype())); +} + +void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, + size_t opaque_len) { + auto *input = buffers[0]; + float *amax = reinterpret_cast(buffers[1]); + float *scale = reinterpret_cast(buffers[2]); + float *scale_inv = reinterpret_cast(buffers[3]); + auto *output = buffers[4]; + auto *output_trans = buffers[5]; + auto *dbias = buffers[6]; + float *amax_out = reinterpret_cast(buffers[7]); + void *workspace_ptr = buffers[8]; + + const auto &desc = *UnpackOpaque(opaque, opaque_len); + assert(amax == amax_out); + if (!use_fp8(desc.out_dtype)) { + scale = nullptr; + scale_inv = nullptr; + amax_out = nullptr; + } + auto m = desc.shape.dims[0]; + auto n = desc.shape.dims[1]; + auto input_shape = std::vector{m, n}; + auto output_shape = std::vector{m, n}; + auto output_trans_shape = std::vector{n, m}; + auto dbias_shape = std::vector{n}; + + auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype); + auto output_tensor = + TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv); + auto output_trans_tensor = + TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv); + auto dbias_tensor = TensorWrapper(dbias, dbias_shape, desc.in_dtype); + + auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype); + + nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(), + output_trans_tensor.data(), dbias_tensor.data(), + workspace.data(), stream); +} + void GatedGeluImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype, float *scale, cudaStream_t stream, float *scale_inverse, float *amax, void *output) { auto input_shape = std::vector{m, n * 2}; diff --git a/transformer_engine/jax/csrc/modules.h b/transformer_engine/jax/csrc/modules.h index 04f0039b02..4285c8228e 100644 --- a/transformer_engine/jax/csrc/modules.h +++ b/transformer_engine/jax/csrc/modules.h @@ -152,6 +152,12 @@ pybind11::tuple GetDGeluDBiasCastTransposeWorkspaceSizes(size_t batch_size, size void DGeluDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); +pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size, + DType in_dtype, DType out_dtype); + +void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, + size_t opaque_len); + void GatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void GatedGeluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 8ca8edcb0b..36008cf854 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -22,8 +22,7 @@ from ..fp8 import FP8Helper, FP8MetaPackage from ..layernorm import canonicalize_layernorm_type from ..layernorm import layernorm, layernorm_fp8_dot -from ..mlp import layernorm_geglu_fp8_mlp, geglu -from ..mlp import layernorm_gelu_fp8_mlp, gelu +from ..mlp import fused_layernorm_fp8_mlp, activation_lu from ..softmax import is_softmax_kernel_available from ..softmax import softmax, SoftmaxType from ..sharding import with_sharding_constraint_by_logical_axes @@ -944,35 +943,22 @@ def __call__(self, inputs: Array, deterministic: bool = False) -> Array: fuse_layernorm = FP8Helper.is_fp8_enabled( ) and not self.return_layernorm_output and self.enable_layernorm - def is_geglu(acts): - geglu_act_pool = [('gelu', 'linear'), ('linear', 'gelu')] - - normalize_acts = [] - for act in acts: - if not isinstance(act, str): - return False - normalize_acts.append(act.lower()) - return tuple(normalize_acts) in geglu_act_pool - - def is_gelu(acts): - geglu_act_pool = [('gelu',)] - - normalize_acts = [] - for act in acts: - if not isinstance(act, str): - return False - normalize_acts.append(act.lower()) - return tuple(normalize_acts) in geglu_act_pool - - use_fused_ln_geglu_mlp = fuse_layernorm \ - and (not self.use_bias) and is_geglu(self.activations) \ - and (self.intermediate_dropout_rate < 1e-3) \ - and not self.enable_low_rank_adaptation - - use_fused_ln_gelu_mlp = fuse_layernorm \ - and self.use_bias and is_gelu(self.activations) \ - and (self.intermediate_dropout_rate < 1e-3) \ - and not self.enable_low_rank_adaptation + # Make sure each tuple is sorted in alphabet order + gated_act_pool = [('gelu', 'linear')] + #('linear', 'silu')] coming + act_pool = [('gelu',)] + #('silu',)] coming + normalize_acts = [] + for act in self.activations: + if not isinstance(act, str): + return False + normalize_acts.append(act.lower()) + normalize_acts = tuple(sorted(normalize_acts)) + is_gated = normalize_acts in gated_act_pool + is_act_implemented = normalize_acts in (gated_act_pool + act_pool) + + use_fused_layernorm_mlp = fuse_layernorm and is_act_implemented and\ + self.intermediate_dropout_rate < 1e-3 # LayerNorm if self.enable_layernorm: @@ -1045,38 +1031,26 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): ffn1_ckpt_name = 'ffn1' ffn2_ckpt_name = 'ffn2' - if use_fused_ln_geglu_mlp: - assert self.axis == -1 # Only support axis = =-1 at this moment - - out = layernorm_geglu_fp8_mlp(y, - scale, - ln_bias, [kernel_1, kernel_2], - fp8_meta_package, - self.layernorm_type, - zero_centered_gamma=self.zero_centered_gamma, - epsilon=self.epsilon, - layernorm_input_axes=self.layernorm_input_axes, - dot_1_input_axes=self.dot_1_input_axes, - dot_2_input_axes=self.dot_2_input_axes, - ffn1_ckpt_name=ffn1_ckpt_name, - ffn2_ckpt_name=ffn2_ckpt_name) - elif use_fused_ln_gelu_mlp: + if use_fused_layernorm_mlp: assert self.axis == -1 # Only support axis = =-1 at this moment + bias_1_shape = intermediate_dim if self.use_bias else 0 bias_1 = nn_partitioning.param_with_axes('wi_bias', self.bias_init, - intermediate_dim, + bias_1_shape, jnp.float32, axes=self.bias_axes_1) bias_1 = bias_1.astype(self.dtype) + bias_2_shape = (hidden_size,) if self.use_bias else (0,) bias_2 = nn_partitioning.param_with_axes('wo_bias', - self.bias_init, (hidden_size,), + self.bias_init, + bias_2_shape, jnp.float32, axes=self.bias_axes_2) bias_2 = bias_2.astype(self.dtype) - out = layernorm_gelu_fp8_mlp(y, + out = fused_layernorm_fp8_mlp(y, scale, ln_bias, [kernel_1, kernel_2], [bias_1, bias_2], fp8_meta_package, @@ -1087,9 +1061,10 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): dot_1_input_axes=self.dot_1_input_axes, dot_2_input_axes=self.dot_2_input_axes, ffn1_ckpt_name=ffn1_ckpt_name, - ffn2_ckpt_name=ffn2_ckpt_name) + ffn2_ckpt_name=ffn2_ckpt_name, + activation_type = normalize_acts, + use_bias = self.use_bias) else: # not use_fused_ln_geglu_mlp - # DenseGeneral 1 gemm1_fp8_meta_package = None if fp8_meta_package is None \ else fp8_meta_package.get_package_by_gemm_idx(0) @@ -1142,31 +1117,29 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): x += _apply_low_rank_adaptation(y, axis, intermediate_dim, wi_lora_a_kernel, wi_lora_b_kernel, self.low_rank_adaptation_alpha) - bias = None + bias_1 = None if self.use_bias: - bias = nn_partitioning.param_with_axes('wi_bias', + bias_1 = nn_partitioning.param_with_axes('wi_bias', self.bias_init, intermediate_dim, jnp.float32, axes=self.bias_axes_1) - bias = bias.astype(self.dtype) - bias_shape = (1,) * (x.ndim - bias.ndim) + bias.shape - x += jnp.reshape(bias, bias_shape) + bias_1 = bias_1.astype(self.dtype) + bias_1_shape = (1,) * (x.ndim - bias_1.ndim) + bias_1.shape + x += jnp.reshape(bias_1, bias_1_shape) x = checkpoint_name(x, ffn1_ckpt_name) activations = [] - if is_geglu(self.activations): - z = geglu(x) - elif is_gelu(self.activations): - z = gelu(x) - z = jnp.reshape(z, (*z.shape[:-2], -1)) + if is_act_implemented: + z = activation_lu(x, normalize_acts) else: x = jnp.split(x, num_activations, axis=-2) for idx, act_fn in enumerate(self.activations): x_i = _convert_to_activation_function(act_fn)(x[idx]) activations.append(x_i) z = functools.reduce(operator.mul, activations) + if not is_gated: z = jnp.reshape(z, (*z.shape[:-2], -1)) z = nn.Dropout(rate=self.intermediate_dropout_rate, @@ -1207,14 +1180,14 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): out += _apply_low_rank_adaptation(z, axis, hidden_size_tuple, wo_lora_a_kernel, wo_lora_b_kernel, self.low_rank_adaptation_alpha) - bias = None + bias_2 = None if self.use_bias: - bias = nn_partitioning.param_with_axes('wo_bias', + bias_2 = nn_partitioning.param_with_axes('wo_bias', self.bias_init, (hidden_size,), jnp.float32, axes=self.bias_axes_2) - bias = bias.astype(self.dtype) - out += jnp.reshape(bias, (1,) * (out.ndim - 1) + (-1,)) + bias_2 = bias_2.astype(self.dtype) + out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,)) out = checkpoint_name(out, ffn2_ckpt_name) diff --git a/transformer_engine/jax/mlp.py b/transformer_engine/jax/mlp.py index 3b531a6150..30f6d8456b 100644 --- a/transformer_engine/jax/mlp.py +++ b/transformer_engine/jax/mlp.py @@ -3,15 +3,15 @@ # See LICENSE for license information. """JAX MLP modules""" -from typing import List, Tuple +from typing import List, Tuple, Sequence, Union, Callable from functools import partial import jax import jax.numpy as jnp from jax.ad_checkpoint import checkpoint_name -from .cpp_extensions import cast_fp8, transpose, cast_transpose -from .cpp_extensions import gelu as te_gelu +from .cpp_extensions import cast_fp8, transpose, cast_transpose, dbias_cast_transpose +from .cpp_extensions import gelu from .cpp_extensions import gelu_fp8, dgelu, dgelu_dbias_cast_transpose from .cpp_extensions import gated_gelu, gated_gelu_fp8 from .cpp_extensions import dgated_gelu, dgated_gelu_cast_transpose @@ -23,369 +23,56 @@ from .sharding import with_sharding_constraint_by_logical_axes -def gelu(x: jnp.ndarray): - """ - Gelu - """ - output = _gelu(x) - return output - - -@partial(jax.custom_vjp) -def _gelu(x: jnp.ndarray): - - geglu_output, _ = _gelu_fwd_rule(x) - - return geglu_output - - -def _gelu_fwd_rule(x): - geglu_output = te_gelu(x) - return geglu_output, (x,) - - -def _gelu_bwd_rule(ctx, g): - x, = ctx - assert x.dtype == g.dtype - - dx = dgelu(g, x) - dx = jnp.reshape(dx, x.shape) - return (dx,) +activation_dict = { + ('gelu',): {'fwd': gelu, + "bwd": dgelu}, + ('gelu', 'linear'): {'fwd': gated_gelu, + 'bwd': dgated_gelu} +} +activation_fp8_dict = { + ('gelu',): {'fwd': gelu_fp8, + 'bwd': dgelu_dbias_cast_transpose}, + ('gelu', 'linear'): {'fwd': gated_gelu_fp8, + 'bwd': dgated_gelu_cast_transpose} +} -_gelu.defvjp(_gelu_fwd_rule, _gelu_bwd_rule) - -def geglu(x: jnp.ndarray): +def activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]): """ - Gated gelu + Activation Unit """ - assert x.shape[-2] == 2 # Linear + GeLU - - output = _geglu(x) - + if len(activation_type) > 1: + assert x.shape[-2] == 2 # Linear + GeLU + output = _activation_lu(x, activation_type) return output -@partial(jax.custom_vjp) -def _geglu(x: jnp.ndarray): +@partial(jax.custom_vjp, nondiff_argnums=(1,)) +def _activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]): - geglu_output, _ = _geglu_fwd_rule(x) + _output, _ = _activation_lu_fwd_rule(x, activation_type) - return geglu_output + return _output -def _geglu_fwd_rule(x): - geglu_output = gated_gelu(x) - return geglu_output, (x,) +def _activation_lu_fwd_rule(x, activation_type): + fwd_output = activation_dict[activation_type]["fwd"](x) + return fwd_output, (x,) -def _geglu_bwd_rule(ctx, g): +def _activation_lu_bwd_rule(activation_type, ctx, g): x, = ctx assert x.dtype == g.dtype - dx = dgated_gelu(g, x) + dx = activation_dict[activation_type]["bwd"](g, x) dx = jnp.reshape(dx, x.shape) return (dx,) +_activation_lu.defvjp(_activation_lu_fwd_rule, _activation_lu_bwd_rule) -_geglu.defvjp(_geglu_fwd_rule, _geglu_bwd_rule) - - -def layernorm_geglu_fp8_mlp(x: jnp.ndarray, - gamma: jnp.ndarray, - beta: jnp.ndarray, - kernels: List[jnp.ndarray], - fp8_gemm_pkg: FP8MetaPackage, - layernorm_type: str, - zero_centered_gamma: bool = False, - epsilon: float = 1e-6, - layernorm_input_axes: Tuple[str, ...] = None, - dot_1_input_axes: Tuple[str, ...] = None, - dot_2_input_axes: Tuple[str, ...] = None, - ffn1_ckpt_name: str = 'ffn1', - ffn2_ckpt_name: str = 'ffn2') -> jnp.ndarray: - """ - Layernorm + GEMM1 + GeGLU + GEMM2 - """ - - assert len(kernels) == 2 - assert fp8_gemm_pkg.num_of_gemm == len(kernels) - - kernel_1 = kernels[0] - kernel_2 = kernels[1] - fp8_max = fp8_gemm_pkg.fp8_max - amax = fp8_gemm_pkg.amax - scale = fp8_gemm_pkg.scale - scale_inv = fp8_gemm_pkg.scale_inv - - fwd_dtype = FP8Helper.FWD_DTYPE - bwd_dtype = FP8Helper.BWD_DTYPE - - layernorm_type = canonicalize_layernorm_type(layernorm_type) - if layernorm_type == 'rmsnorm': - assert beta is None, "beta should be None if layernorm_type is 'rmsnorm'" - assert not zero_centered_gamma, "zero_centered_gamma is not supported " \ - "if layernorm_type is 'rmsnorm'" - - output = _layernorm_geglu_fp8_mlp(x, gamma, beta, kernel_1, kernel_2, fp8_max, amax, scale, - scale_inv, fwd_dtype, bwd_dtype, layernorm_type, - zero_centered_gamma, epsilon, layernorm_input_axes, - dot_1_input_axes, dot_2_input_axes, ffn1_ckpt_name, - ffn2_ckpt_name) - return output - - -@partial(jax.custom_vjp, nondiff_argnums=(9, 10, 11, 12, 13, 14, 15, 16, 17, 18)) -def _layernorm_geglu_fp8_mlp(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, - kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, fp8_max: jnp.ndarray, - amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray, - fwd_dtype: jnp.dtype, bwd_dtype: jnp.dtype, layernorm_type: str, - zero_centered_gamma: bool, epsilon: float, - layernorm_input_axes: Tuple[str, ...], - dot_1_input_axes: Tuple[str, ...], dot_2_input_axes: Tuple[str, ...], - ffn1_ckpt_name: str, ffn2_ckpt_name: str): - output, _ = _layernorm_geglu_fp8_mlp_fwd_rule(x, gamma, beta, kernel_1, kernel_2, fp8_max, amax, - scale, scale_inv, fwd_dtype, bwd_dtype, - layernorm_type, zero_centered_gamma, epsilon, - layernorm_input_axes, dot_1_input_axes, - dot_2_input_axes, ffn1_ckpt_name, ffn2_ckpt_name) - return output - - -def _layernorm_geglu_fp8_mlp_fwd_rule( - x, - gamma, - beta, - kernel_1, - kernel_2, - fp8_max, - amax, - scale, - scale_inv, - fwd_dtype, - bwd_dtype, # pylint: disable=unused-argument - layernorm_type, - zero_centered_gamma, - epsilon, - layernorm_input_axes, - dot_1_input_axes, - dot_2_input_axes, - ffn1_ckpt_name, - ffn2_ckpt_name): - - # x should be in shape of (batch..., hidden) - # Kernel_1 should be in shape of (Hidden_in, 2, Hidden_out) - # Kernel_2 should be in shape of (Hidden_in, Hidden_out) - assert len(kernel_1.shape) == 3 - assert kernel_1.shape[-2] == 2 - assert len(kernel_2.shape) == 2 - - x_contracting_dims = (len(x.shape) - 1,) - xt_batch_dims = tuple(range(1, x.ndim)) - assert x.shape[x_contracting_dims[0]] == kernel_1.shape[0] - assert kernel_1.shape[-1] == kernel_2.shape[0] - - amax = FP8Helper.update_amax_history(amax) - - gemm1_x_idx, gemm1_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0) - - x_amax = amax[gemm1_x_idx, 0:1] - x_scale = scale[gemm1_x_idx] - x_scale_inv = scale_inv[gemm1_x_idx] - - x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes) - - if layernorm_type == 'layernorm': - ln_out, mu, rsigma, updated_x_amax = layernorm_fwd_fp8( - x, - gamma, - beta, - x_amax, - x_scale, - x_scale_inv, - out_dtype=fwd_dtype, - zero_centered_gamma=zero_centered_gamma, - epsilon=epsilon) - else: - assert not zero_centered_gamma, "zero_centered_gamma is not supported " \ - "if layernorm_type is 'rmsnorm'" - ln_out, rsigma, updated_x_amax = rmsnorm_fwd_fp8(x, - gamma, - x_amax, - x_scale, - x_scale_inv, - out_dtype=fwd_dtype, - epsilon=epsilon) - mu = None - - assert x.shape == ln_out.shape - - kernel_1_amax = amax[gemm1_kernel_idx, 0:1] - kernel_1_scale = scale[gemm1_kernel_idx] - kernel_1_scale_inv = scale_inv[gemm1_kernel_idx] - - # Note (Ming Huang): Use cast only to allow XLA handle tranpose for avoiding - # unnecessary copy to break FP8 GEMM pattern matching. - casted_kernel_1, updated_kernel_1_amax = \ - cast_fp8(kernel_1, kernel_1_amax, kernel_1_scale, kernel_1_scale_inv, fwd_dtype) - - ln_out = with_sharding_constraint_by_logical_axes(ln_out, dot_1_input_axes) - - # (batch..., hidden_in) x (hidden_in, 2, hidden_out) - dot_1_output = fp8_dot_impl(ln_out, casted_kernel_1, x_scale_inv, kernel_1_scale_inv, x.dtype, - (x_contracting_dims, (0,)), - get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP)) - dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name) - - gemm2_x_idx, gemm2_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(1) - - geglu_out_amax = amax[gemm2_x_idx, 0:1] - geglu_out_scale = scale[gemm2_x_idx] - geglu_out_scale_inv = scale_inv[gemm2_x_idx] - - # (batch..., hidden_in) -> (batch..., hidden) - casted_geglu_out, updated_geglu_amax = gated_gelu_fp8(dot_1_output, geglu_out_amax, - geglu_out_scale, geglu_out_scale_inv, - fwd_dtype) - - casted_geglu_out = with_sharding_constraint_by_logical_axes(casted_geglu_out, dot_2_input_axes) - - kernel_2_scale = scale[gemm2_kernel_idx] - kernel_2_scale_inv = scale_inv[gemm2_kernel_idx] - # Note (Ming Huang): Use native cast to allow XLA handle tranpose for avoiding - # unnecessary copy to break FP8 GEMM pattern matching. - casted_kernel_2, updated_kernel_2_amax = quantize(kernel_2, fwd_dtype, kernel_2_scale) - - # (batch..., hidden_in) x (hidden_out, hidden_in) - dot_2_output = fp8_dot_impl(casted_geglu_out, casted_kernel_2, geglu_out_scale_inv, - kernel_2_scale_inv, x.dtype, (x_contracting_dims, (0,)), - get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP)) - dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name) - - ctx = (x, ln_out, mu, rsigma, gamma, dot_1_output, casted_geglu_out, casted_kernel_1, - casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, updated_geglu_amax, - updated_kernel_1_amax, updated_kernel_2_amax, x_contracting_dims, xt_batch_dims) - - return dot_2_output, ctx - - -def _layernorm_geglu_fp8_mlp_bwd_rule( - fwd_dtype, # pylint: disable=unused-argument - bwd_dtype, - layernorm_type, - zero_centered_gamma, - epsilon, - layernorm_input_axes, - dot_1_input_axes, - dot_2_input_axes, - ffn1_ckpt_name, # pylint: disable=unused-argument - ffn2_ckpt_name, # pylint: disable=unused-argument - ctx, - grad): - x, ln_out, mu, rsigma, gamma, dot_1_output, casted_geglu_out, \ - casted_kernel_1, casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, \ - updated_geglu_amax, updated_kernel_1_amax, updated_kernel_2_amax, \ - x_contracting_dims, xt_batch_dims = ctx - - gemm2_x_idx, gemm2_kernel_idx, gemm2_grad_idx = FP8Helper.get_fp8_meta_indices(1) - - grad_amax = amax[gemm2_grad_idx, 0:1] - grad_scale = scale[gemm2_grad_idx] - grad_scale_inv = scale_inv[gemm2_grad_idx] - - # Since the sharding of outputs should be the same as dot_1's input - grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes) - - casted_grad, casted_grad_t, updated_grad_amax = \ - cast_transpose(grad, grad_amax, grad_scale, grad_scale_inv, bwd_dtype, - static_axis_boundary=-1, transpose_axis_boundary=-1) - - casted_geglu_out_t = transpose(casted_geglu_out, - static_axis_boundary=-1, - transpose_axis_boundary=-1) - - # (hidden, batch...,) x (hidden, batch...) - gemm2_x_scale_inv = scale_inv[gemm2_x_idx] - wgrad_2 = fp8_dot_impl(casted_geglu_out_t, casted_grad_t, gemm2_x_scale_inv, grad_scale_inv, - grad.dtype, (xt_batch_dims, xt_batch_dims), - get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD)) - - # (batch..., hidden_out) x (hidden_in, hidden_out) - kernel_2_scale_inv = scale_inv[gemm2_kernel_idx] - dgrad_2 = fp8_dot_impl(casted_grad, casted_kernel_2, grad_scale_inv, kernel_2_scale_inv, - grad.dtype, (x_contracting_dims, (1,)), - get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD)) - - dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes) - - gemm1_x_idx, gemm1_kernel_idx, gemm1_grad_idx = FP8Helper.get_fp8_meta_indices(0) - - dgeglu_amax = amax[gemm1_grad_idx, 0:1] - dgeglu_scale = scale[gemm1_grad_idx] - dgeglu_scale_inv = scale_inv[gemm1_grad_idx] - - casted_dgeglu, casted_dgeglu_t, updated_dgeglu_amax = dgated_gelu_cast_transpose( - dgrad_2, - dot_1_output, - dgeglu_amax, - dgeglu_scale, - dgeglu_scale_inv, - bwd_dtype, - static_axis_boundary=-1) - - ln_out_t = transpose(ln_out, static_axis_boundary=-1, transpose_axis_boundary=-1) - - # (hidden, batch...) x (2, hidden, batch...) - xt_batch_dims_plus_act_dim = tuple(i + 1 for i in xt_batch_dims) - gemm1_x_scale_inv = scale_inv[gemm1_x_idx] - wgrad_1 = fp8_dot_impl(ln_out_t, casted_dgeglu_t, gemm1_x_scale_inv, dgeglu_scale_inv, - grad.dtype, (xt_batch_dims, xt_batch_dims_plus_act_dim), - get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD)) - - # (batch..., 2, hidden_out) x (hidden_in, 2, hidden_out) - x_contracting_dims_plus_act_dim = (min(x_contracting_dims),) + tuple( - i + 1 for i in x_contracting_dims) - kernel_1_scale_inv = scale_inv[gemm1_kernel_idx] - dgrad_1 = fp8_dot_impl(casted_dgeglu, casted_kernel_1, dgeglu_scale_inv, kernel_1_scale_inv, - grad.dtype, (x_contracting_dims_plus_act_dim, (1, 2)), - get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD)) - - dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, layernorm_input_axes) - - if layernorm_type == 'layernorm': - dx, dgamma, dbeta = layernorm_bwd(dgrad_1, - x, - mu, - rsigma, - gamma, - zero_centered_gamma=zero_centered_gamma, - epsilon=epsilon) - else: - assert not zero_centered_gamma, "zero_centered_gamma is not supported " \ - "if layernorm_type is 'rmsnorm'" - dx, dgamma = rmsnorm_bwd(dgrad_1, x, rsigma, gamma, epsilon=epsilon) - dbeta = None - - amax = amax.at[gemm1_x_idx, 0].set(updated_x_amax[0]) - amax = amax.at[gemm1_kernel_idx, 0].set(updated_kernel_1_amax[0]) - amax = amax.at[gemm1_grad_idx, 0].set(updated_dgeglu_amax[0]) - amax = amax.at[gemm2_x_idx, 0].set(updated_geglu_amax[0]) - amax = amax.at[gemm2_kernel_idx, 0].set(updated_kernel_2_amax) - amax = amax.at[gemm2_grad_idx, 0].set(updated_grad_amax[0]) - - scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale) - - return dx, dgamma, dbeta, wgrad_1, wgrad_2, \ - fp8_max, amax, scale, scale_inv - - -_layernorm_geglu_fp8_mlp.defvjp(_layernorm_geglu_fp8_mlp_fwd_rule, - _layernorm_geglu_fp8_mlp_bwd_rule) - - -def layernorm_gelu_fp8_mlp(x: jnp.ndarray, +def fused_layernorm_fp8_mlp(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, kernels: List[jnp.ndarray], @@ -398,9 +85,11 @@ def layernorm_gelu_fp8_mlp(x: jnp.ndarray, dot_1_input_axes: Tuple[str, ...] = None, dot_2_input_axes: Tuple[str, ...] = None, ffn1_ckpt_name: str = 'ffn1', - ffn2_ckpt_name: str = 'ffn2') -> jnp.ndarray: + ffn2_ckpt_name: str = 'ffn2', + activation_type: Sequence[Union[str, Callable]] = ('gelu',), + use_bias: bool = True) -> jnp.ndarray: """ - Layernorm + GEMM1 + bias + GeLU + GEMM2 + bias + Layernorm + GEMM1 + bias + activation + GEMM2 + bias """ assert len(kernels) == 2 @@ -424,32 +113,36 @@ def layernorm_gelu_fp8_mlp(x: jnp.ndarray, assert not zero_centered_gamma, "zero_centered_gamma is not supported " \ "if layernorm_type is 'rmsnorm'" - output = _layernorm_gelu_fp8_mlp(x, gamma, beta, kernel_1, kernel_2, bias_1, bias_2, fp8_max, + output = _fused_layernorm_fp8_mlp(x, gamma, beta, kernel_1, kernel_2, bias_1, bias_2, fp8_max, amax, scale, scale_inv, fwd_dtype, bwd_dtype, layernorm_type, zero_centered_gamma, epsilon, layernorm_input_axes, dot_1_input_axes, dot_2_input_axes, ffn1_ckpt_name, - ffn2_ckpt_name) + ffn2_ckpt_name, activation_type, use_bias) return output -@partial(jax.custom_vjp, nondiff_argnums=(11, 12, 13, 14, 15, 16, 17, 18, 19, 20)) -def _layernorm_gelu_fp8_mlp(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, +@partial(jax.custom_vjp, nondiff_argnums=(11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22)) +def _fused_layernorm_fp8_mlp(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, bias_1: jnp.ndarray, bias_2: jnp.ndarray, fp8_max: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray, fwd_dtype: jnp.dtype, bwd_dtype: jnp.dtype, layernorm_type: str, zero_centered_gamma: bool, epsilon: float, layernorm_input_axes: Tuple[str, ...], dot_1_input_axes: Tuple[str, ...], dot_2_input_axes: Tuple[str, ...], - ffn1_ckpt_name: str, ffn2_ckpt_name: str): - output, _ = _layernorm_gelu_fp8_mlp_fwd_rule(x, gamma, beta, kernel_1, kernel_2, bias_1, bias_2, - fp8_max, amax, scale, scale_inv, fwd_dtype, - bwd_dtype, layernorm_type, zero_centered_gamma, - epsilon, layernorm_input_axes, dot_1_input_axes, - dot_2_input_axes, ffn1_ckpt_name, ffn2_ckpt_name) + ffn1_ckpt_name: str, ffn2_ckpt_name: str, + activation_type: Sequence[Union[str, Callable]], + use_bias: bool): + output, _ = _fused_layernorm_fp8_mlp_fwd_rule(x, gamma, beta, kernel_1, kernel_2, bias_1, + bias_2, fp8_max, amax, scale, scale_inv, + fwd_dtype, bwd_dtype, layernorm_type, + zero_centered_gamma, epsilon, + layernorm_input_axes, dot_1_input_axes, + dot_2_input_axes, ffn1_ckpt_name, ffn2_ckpt_name, + activation_type, use_bias) return output -def _layernorm_gelu_fp8_mlp_fwd_rule( +def _fused_layernorm_fp8_mlp_fwd_rule( x, gamma, beta, @@ -470,13 +163,16 @@ def _layernorm_gelu_fp8_mlp_fwd_rule( dot_1_input_axes, dot_2_input_axes, ffn1_ckpt_name, - ffn2_ckpt_name): + ffn2_ckpt_name, + activation_type, + use_bias): + is_gated = len(activation_type) > 1 # x should be in shape of (batch..., hidden) # Kernel_1 should be in shape of (Hidden_in, 1, Hidden_out) # Kernel_2 should be in shape of (Hidden_in, Hidden_out) assert len(kernel_1.shape) == 3 - assert kernel_1.shape[-2] == 1 + assert kernel_1.shape[-2] == len(activation_type) assert len(kernel_2.shape) == 2 x_contracting_dims = (len(x.shape) - 1,) @@ -487,7 +183,8 @@ def _layernorm_gelu_fp8_mlp_fwd_rule( # Squeeze act axis # (hidden_in, 1, hidden_out) -> (hidden_in, hidden_out) - kernel_1 = jnp.squeeze(kernel_1, axis=-2) + if not is_gated: + kernel_1 = jnp.squeeze(kernel_1, axis=-2) amax = FP8Helper.update_amax_history(amax) @@ -539,22 +236,26 @@ def _layernorm_gelu_fp8_mlp_fwd_rule( dot_1_output = fp8_dot_impl(ln_out, casted_kernel_1, x_scale_inv, kernel_1_scale_inv, x.dtype, (x_contracting_dims, (0,)), get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP)) - - bias_1_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1.shape - dot_1_output += jnp.reshape(bias_1, bias_1_shape) + if use_bias: + bias_1_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1.shape + dot_1_output += jnp.reshape(bias_1, bias_1_shape) dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name) gemm2_x_idx, gemm2_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(1) - gelu_out_amax = amax[gemm2_x_idx, 0:1] - gelu_out_scale = scale[gemm2_x_idx] - gelu_out_scale_inv = scale_inv[gemm2_x_idx] + activation_lu_out_amax = amax[gemm2_x_idx, 0:1] + activation_lu_out_scale = scale[gemm2_x_idx] + activation_lu_out_scale_inv = scale_inv[gemm2_x_idx] + + activation_lu_fp8 = activation_fp8_dict[activation_type]["fwd"] # (batch..., hidden_in) -> (batch..., hidden) - casted_gelu_out, updated_gelu_amax = gelu_fp8(dot_1_output, gelu_out_amax, gelu_out_scale, - gelu_out_scale_inv, fwd_dtype) + casted_activation_lu_out, updated_activation_lu_amax = activation_lu_fp8(dot_1_output, + activation_lu_out_amax, activation_lu_out_scale, + activation_lu_out_scale_inv, fwd_dtype) - casted_gelu_out = with_sharding_constraint_by_logical_axes(casted_gelu_out, dot_2_input_axes) + casted_activation_lu_out = with_sharding_constraint_by_logical_axes(casted_activation_lu_out, + dot_2_input_axes) kernel_2_scale = scale[gemm2_kernel_idx] kernel_2_scale_inv = scale_inv[gemm2_kernel_idx] @@ -563,23 +264,26 @@ def _layernorm_gelu_fp8_mlp_fwd_rule( casted_kernel_2, updated_kernel_2_amax = quantize(kernel_2, fwd_dtype, kernel_2_scale) # (batch..., hidden_in) x (hidden_out, hidden_in) - dot_2_output = fp8_dot_impl(casted_gelu_out, casted_kernel_2, gelu_out_scale_inv, + dot_2_output = fp8_dot_impl(casted_activation_lu_out, casted_kernel_2, + activation_lu_out_scale_inv, kernel_2_scale_inv, x.dtype, (x_contracting_dims, (0,)), get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP)) - bias_2_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2.shape - dot_2_output += jnp.reshape(bias_2, bias_2_shape) + if use_bias: + bias_2_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2.shape + dot_2_output += jnp.reshape(bias_2, bias_2_shape) + dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name) - ctx = (x, ln_out, mu, rsigma, gamma, dot_1_output, casted_gelu_out, casted_kernel_1, - casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, updated_gelu_amax, - updated_kernel_1_amax, updated_kernel_2_amax, x_contracting_dims, xt_batch_dims, - bias_1.shape, bias_2.shape) + ctx = (x, ln_out, mu, rsigma, gamma, dot_1_output, casted_activation_lu_out, casted_kernel_1, + casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, + updated_activation_lu_amax, updated_kernel_1_amax, updated_kernel_2_amax, + x_contracting_dims, xt_batch_dims, bias_1.shape, bias_2.shape) return dot_2_output, ctx -def _layernorm_gelu_fp8_mlp_bwd_rule( +def _fused_layernorm_fp8_mlp_bwd_rule( fwd_dtype, # pylint: disable=unused-argument bwd_dtype, layernorm_type, @@ -590,13 +294,17 @@ def _layernorm_gelu_fp8_mlp_bwd_rule( dot_2_input_axes, ffn1_ckpt_name, # pylint: disable=unused-argument ffn2_ckpt_name, # pylint: disable=unused-argument + activation_type, + use_bias, ctx, grad): - x, ln_out, mu, rsigma, gamma, dot_1_output, casted_gelu_out, \ + x, ln_out, mu, rsigma, gamma, dot_1_output, casted_activation_lu_out, \ casted_kernel_1, casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, \ - updated_gelu_amax, updated_kernel_1_amax, updated_kernel_2_amax, \ + updated_activation_lu_amax, updated_kernel_1_amax, updated_kernel_2_amax, \ x_contracting_dims, xt_batch_dims, bias_1_shape, bias_2_shape= ctx + is_gated = len(activation_type) > 1 + gemm2_x_idx, gemm2_kernel_idx, gemm2_grad_idx = FP8Helper.get_fp8_meta_indices(1) grad_amax = amax[gemm2_grad_idx, 0:1] @@ -606,21 +314,29 @@ def _layernorm_gelu_fp8_mlp_bwd_rule( # Since the sharding of outputs should be the same as dot_1's input grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes) - casted_grad, casted_grad_t, updated_grad_amax = \ - cast_transpose(grad, grad_amax, grad_scale, grad_scale_inv, bwd_dtype, - static_axis_boundary=-1, transpose_axis_boundary=-1) - - casted_gelu_out_t = transpose(casted_gelu_out, - static_axis_boundary=-1, - transpose_axis_boundary=-1) + if use_bias: + casted_grad, casted_grad_t, dbias_2, updated_grad_amax = \ + dbias_cast_transpose(grad, grad_amax, grad_scale, + grad_scale_inv, bwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-1) + dbias_2 = jnp.reshape(dbias_2, bias_2_shape) + else: + casted_grad, casted_grad_t, updated_grad_amax = \ + cast_transpose(grad, grad_amax, grad_scale, + grad_scale_inv, bwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-1) + dbias_2 = jnp.empty(bias_2_shape, grad.dtype) - dbias_2 = jnp.sum(grad, axis=(i for i in range(grad.ndim - 1))) - dbias_2 = jnp.reshape(dbias_2, bias_2_shape) + casted_activation_lu_out_t = transpose(casted_activation_lu_out, + static_axis_boundary=-1, + transpose_axis_boundary=-1) # (hidden, batch...,) x (hidden, batch...) gemm2_x_scale_inv = scale_inv[gemm2_x_idx] - wgrad_2 = fp8_dot_impl(casted_gelu_out_t, casted_grad_t, gemm2_x_scale_inv, grad_scale_inv, - grad.dtype, (xt_batch_dims, xt_batch_dims), + wgrad_2 = fp8_dot_impl(casted_activation_lu_out_t, casted_grad_t, gemm2_x_scale_inv, + grad_scale_inv, grad.dtype, (xt_batch_dims, xt_batch_dims), get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD)) # (batch..., hidden_out) x (hidden_in, hidden_out) @@ -633,36 +349,85 @@ def _layernorm_gelu_fp8_mlp_bwd_rule( gemm1_x_idx, gemm1_kernel_idx, gemm1_grad_idx = FP8Helper.get_fp8_meta_indices(0) - dgelu_amax = amax[gemm1_grad_idx, 0:1] - dgelu_scale = scale[gemm1_grad_idx] - dgelu_scale_inv = scale_inv[gemm1_grad_idx] - - casted_dgelu, casted_dgelu_t, dbias_1, updated_dgelu_amax = dgelu_dbias_cast_transpose( - dgrad_2, - dot_1_output, - dgelu_amax, - dgelu_scale, - dgelu_scale_inv, - bwd_dtype, - static_axis_boundary=-1, - transpose_axis_boundary=-1) - - dbias_1 = jnp.reshape(dbias_1, bias_1_shape) + dactivation_lu_amax = amax[gemm1_grad_idx, 0:1] + dactivation_lu_scale = scale[gemm1_grad_idx] + dactivation_lu_scale_inv = scale_inv[gemm1_grad_idx] + + dactivation_lu_cast_transpose = activation_fp8_dict[activation_type]["bwd"] + dactivation_lu = activation_dict[activation_type]["bwd"](dgrad_2, dot_1_output) + + if is_gated: + if use_bias: + casted_dactivation_lu, casted_dactivation_lu_t, dbias_1, updated_dactivation_lu_amax = \ + dbias_cast_transpose( + dactivation_lu, + dactivation_lu_amax, + dactivation_lu_scale, + dactivation_lu_scale_inv, + bwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-2) + dbias_1 = jnp.reshape(dbias_1, bias_1_shape) + else: + casted_dactivation_lu, casted_dactivation_lu_t, updated_dactivation_lu_amax = \ + dactivation_lu_cast_transpose( + dgrad_2, + dot_1_output, + dactivation_lu_amax, + dactivation_lu_scale, + dactivation_lu_scale_inv, + bwd_dtype, + static_axis_boundary=-1) + dbias_1 = jnp.empty(bias_1_shape, bwd_dtype) + else: + if use_bias: + casted_dactivation_lu, casted_dactivation_lu_t, dbias_1, updated_dactivation_lu_amax = \ + dactivation_lu_cast_transpose( + dgrad_2, + dot_1_output, + dactivation_lu_amax, + dactivation_lu_scale, + dactivation_lu_scale_inv, + bwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-1) + dbias_1 = jnp.reshape(dbias_1, bias_1_shape) + else: + casted_dactivation_lu, casted_dactivation_lu_t, updated_dactivation_lu_amax = \ + cast_transpose( + dactivation_lu, + dactivation_lu_amax, + dactivation_lu_scale, + dactivation_lu_scale_inv, + bwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-1) + dbias_1 = jnp.empty(bias_1_shape, bwd_dtype) ln_out_t = transpose(ln_out, static_axis_boundary=-1, transpose_axis_boundary=-1) # (hidden, batch...) x (hidden, batch...) gemm1_x_scale_inv = scale_inv[gemm1_x_idx] - wgrad_1 = fp8_dot_impl(ln_out_t, casted_dgelu_t, gemm1_x_scale_inv, dgelu_scale_inv, grad.dtype, - (xt_batch_dims, xt_batch_dims), + xt_batch_dims_2 = xt_batch_dims if not is_gated \ + else tuple(i + 1 for i in xt_batch_dims) + wgrad_1 = fp8_dot_impl(ln_out_t, casted_dactivation_lu_t, gemm1_x_scale_inv, + dactivation_lu_scale_inv, grad.dtype, + (xt_batch_dims, xt_batch_dims_2), get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD)) # Expand act axis to match the shape with the given kernel_1 - wgrad_1 = jnp.expand_dims(wgrad_1, axis=-2) + if not is_gated: + wgrad_1 = jnp.expand_dims(wgrad_1, axis=-2) # (batch..., hidden_out) x (hidden_in, hidden_out) + if is_gated: + x_contracting_dims = ((min(x_contracting_dims),) + tuple( + i + 1 for i in x_contracting_dims), (1,2)) + else: + x_contracting_dims = (x_contracting_dims, (1,)) kernel_1_scale_inv = scale_inv[gemm1_kernel_idx] - dgrad_1 = fp8_dot_impl(casted_dgelu, casted_kernel_1, dgelu_scale_inv, kernel_1_scale_inv, - grad.dtype, (x_contracting_dims, (1,)), + dgrad_1 = fp8_dot_impl(casted_dactivation_lu, casted_kernel_1, + dactivation_lu_scale_inv, kernel_1_scale_inv, + grad.dtype, x_contracting_dims, get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD)) dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, layernorm_input_axes) @@ -683,15 +448,15 @@ def _layernorm_gelu_fp8_mlp_bwd_rule( amax = amax.at[gemm1_x_idx, 0].set(updated_x_amax[0]) amax = amax.at[gemm1_kernel_idx, 0].set(updated_kernel_1_amax[0]) - amax = amax.at[gemm1_grad_idx, 0].set(updated_dgelu_amax[0]) - amax = amax.at[gemm2_x_idx, 0].set(updated_gelu_amax[0]) + amax = amax.at[gemm1_grad_idx, 0].set(updated_dactivation_lu_amax[0]) + amax = amax.at[gemm2_x_idx, 0].set(updated_activation_lu_amax[0]) amax = amax.at[gemm2_kernel_idx, 0].set(updated_kernel_2_amax) amax = amax.at[gemm2_grad_idx, 0].set(updated_grad_amax[0]) scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale) - return dx, dgamma, dbeta, wgrad_1, wgrad_2, dbias_1, dbias_2, \ fp8_max, amax, scale, scale_inv -_layernorm_gelu_fp8_mlp.defvjp(_layernorm_gelu_fp8_mlp_fwd_rule, _layernorm_gelu_fp8_mlp_bwd_rule) +_fused_layernorm_fp8_mlp.defvjp(_fused_layernorm_fp8_mlp_fwd_rule, + _fused_layernorm_fp8_mlp_bwd_rule) From cb6016644cbed8cfa70ebd8bede1d611f844135c Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Wed, 24 Apr 2024 09:20:31 -0700 Subject: [PATCH 033/244] [PyTorch] Avoid using LRU cache for cu_seqlens (#798) * Try using global buffer for cu_seqlens Signed-off-by: Kirthi Shankar Sivamani * Avoid using functools.lru_cache Signed-off-by: Kirthi Shankar Sivamani * fixes Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani Co-authored-by: Vasudevan Rengasamy Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/attention.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 90da9e06b6..4bb39b913f 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -5,7 +5,6 @@ """Attention.""" import collections from contextlib import nullcontext -import functools from importlib.metadata import version import math import os @@ -278,8 +277,7 @@ def get_indices(max_seqlen: int, cu_seqlens: torch.Tensor) -> torch.Tensor: return indices - -@functools.lru_cache +_cu_seqlens_cache = {} def _get_full_cu_seqlens( batch_size: int, max_seqlen: int, @@ -290,13 +288,16 @@ def _get_full_cu_seqlens( All sequences in batch have the maximum sequence length. """ - return torch.arange( - 0, - (batch_size + 1) * max_seqlen, - step=max_seqlen, - dtype=torch.int32, - device=device, - ) + global _cu_seqlens_cache + if (batch_size, max_seqlen) not in _cu_seqlens_cache: + _cu_seqlens_cache[(batch_size, max_seqlen)] = torch.arange( + 0, + (batch_size + 1) * max_seqlen, + step=max_seqlen, + dtype=torch.int32, + device=device, + ) + return _cu_seqlens_cache[(batch_size, max_seqlen)] @jit_fuser From b1a4efc4dbdf340c4df61c984bd16c7819d9355b Mon Sep 17 00:00:00 2001 From: Santosh Bhavani Date: Wed, 24 Apr 2024 16:22:51 -0500 Subject: [PATCH 034/244] Update README.rst (#806) Added HF Nanotron to integrations and updated GTC 24 video to ondemand link Signed-off-by: Santosh Bhavani Signed-off-by: Pawel Gadzinski --- README.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.rst b/README.rst index 190f8fc57c..936dfab077 100644 --- a/README.rst +++ b/README.rst @@ -231,6 +231,7 @@ Transformer Engine has been integrated with popular LLM frameworks such as: * `NVIDIA NeMo Framework `_ * `Amazon SageMaker Model Parallel Library `_ * `Levanter `_ +* `Hugging Face Nanotron `_ - Coming soon! * `Colossal-AI `_ - Coming soon! * `PeriFlow `_ - Coming soon! * `GPT-NeoX `_ - Coming soon! @@ -253,7 +254,7 @@ Papers Videos ====== -* `What's New in Transformer Engine and FP8 Training | GTC 2024 `_ +* `What's New in Transformer Engine and FP8 Training | GTC 2024 `_ * `FP8 Training with Transformer Engine | GTC 2023 `_ * `FP8 for Deep Learning | GTC 2023 `_ * `Inside the Hopper Architecture `_ From a06ab9aa3ea4189ae61a633a916a673304f1fe29 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com> Date: Wed, 24 Apr 2024 15:37:58 -0700 Subject: [PATCH 035/244] [JAX] SwiGLU Implementation (#773) * Implemented swiglu and silu Signed-off-by: Phuong Nguyen * Renamed nvte-*silu to nvte-*swish + generalized GetDBiasDact functions Signed-off-by: Phuong Nguyen Signed-off-by: Pawel Gadzinski --- tests/jax/test_custom_call_compute.py | 265 ++-- tests/jax/test_layer.py | 27 + tests/jax/test_praxis_layers.py | 16 +- .../common/activation/activation_template.h | 136 ++ transformer_engine/common/activation/gelu.cu | 218 +--- transformer_engine/common/activation/relu.cu | 145 +-- .../common/activation/swiglu.cu | 98 +- .../include/transformer_engine/activation.h | 35 +- .../include/transformer_engine/transpose.h | 47 + .../common/transpose/cast_transpose_fusion.cu | 135 +- transformer_engine/jax/cpp_extensions.py | 1147 ++++++++++++++++- transformer_engine/jax/csrc/extensions.cpp | 12 +- transformer_engine/jax/csrc/modules.cpp | 260 +++- transformer_engine/jax/csrc/modules.h | 21 +- transformer_engine/jax/flax/module.py | 13 +- transformer_engine/jax/mlp.py | 88 +- 16 files changed, 1996 insertions(+), 667 deletions(-) create mode 100644 transformer_engine/common/activation/activation_template.h diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 139ef994fa..2d4c9b7e32 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -15,15 +15,12 @@ from flax import linen as nn from utils import assert_allclose -from transformer_engine.jax.cpp_extensions import dgelu, dgelu_dbias_cast_transpose -from transformer_engine.jax.cpp_extensions import gelu, gelu_fp8 -from transformer_engine.jax.cpp_extensions import dgated_gelu, gated_gelu -from transformer_engine.jax.cpp_extensions import dgated_gelu_cast_transpose, gated_gelu_fp8 from transformer_engine.jax.dot import type_safe_dot_general, dequantize, quantize from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper from transformer_engine.jax.fp8 import is_fp8_available from transformer_engine.jax.layernorm import layernorm -from transformer_engine.jax.mlp import fused_layernorm_fp8_mlp +from transformer_engine.jax.mlp import activation_lu, activation_lu_fp8, fused_layernorm_fp8_mlp + GEMM_CASES = [ (256, 256, 512), @@ -37,6 +34,16 @@ DTYPES = [jnp.bfloat16, jnp.float32] is_fp8_supported, reason = is_fp8_available() +def _convert_to_activation_function(fn_or_string): + """Convert a string to an activation function.""" + if fn_or_string == 'linear': + return lambda x: x + if isinstance(fn_or_string, str): + return getattr(nn, fn_or_string) + if callable(fn_or_string): + return fn_or_string + raise ValueError(f"don't know how to convert {fn_or_string} to an activation function") + @pytest.fixture(autouse=True, scope='function') def clear_live_arrays(): @@ -174,22 +181,21 @@ def ref_func(x, y): assert_allclose(primitive_b_grad, ref_b_grad, dtype=FP8Helper.BWD_DTYPE) @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize('m,n,k', [(256, 512, 128), (16384, 1024, 2816), (16384, 2816, 1024), + @pytest.mark.parametrize('m,n,k', [(128, 256, 512), + (16384, 1024, 2816), + (16384, 2816, 1024), (16384, 1024, 1024)]) @pytest.mark.parametrize('activation_type', [('gelu', ), - ('gelu', 'linear')]) + ('gelu', 'linear'), + ('silu', ), + ('silu', 'linear')]) @pytest.mark.parametrize('use_bias', [True, False]) def test_grad_fused_layernorm_fp8_mlp(self, m, n, k, - activation_type: Sequence[Union[str, Callable]], - use_bias: bool): + activation_type: Sequence[Union[str, Callable]], use_bias: bool): """ N/a """ key = jax.random.PRNGKey(0) subkeys = jax.random.split(key, 6) - activation_dict = { - ('gelu', ): jax.nn.gelu - } - a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16) k1 = jax.random.normal(subkeys[1], (k, len(activation_type), n), jnp.bfloat16) k2 = jax.random.normal(subkeys[2], (n, k), jnp.bfloat16) @@ -218,15 +224,6 @@ def primitive_func(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale fused_layernorm_fp8_mlp(x, ln_s, None, [y, z], [w, v], fp8_meta_pkg, "rmsnorm", activation_type = activation_type, use_bias = use_bias)) - def _convert_to_activation_function(fn_or_string): - """Convert a string to an activation function.""" - if fn_or_string == 'linear': - return lambda x: x - if isinstance(fn_or_string, str): - return getattr(nn, fn_or_string) - if callable(fn_or_string): - return fn_or_string - raise ValueError(f"don't know how to convert {fn_or_string} to an activation function") def layernorm_fp8_mlp_ref(x: jnp.ndarray, ln_scale: jnp.ndarray, kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, bias_1: jnp.ndarray, bias_2: jnp.ndarray, @@ -249,15 +246,12 @@ def layernorm_fp8_mlp_ref(x: jnp.ndarray, ln_scale: jnp.ndarray, kernel_1: jnp.n bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape linear_1_out += jnp.reshape(bias_1, bias_1_shape) - if 'linear' in activation_type: - x = jnp.split(linear_1_out, len(activation_type), axis=-2) - acts = [] - for idx, act_fn in enumerate(activation_type): - x_i = _convert_to_activation_function(act_fn)(x[idx]) - acts.append(x_i) - x = functools.reduce(operator.mul, acts) - else: - x = activation_dict[activation_type](linear_1_out) + x = jnp.split(linear_1_out, len(activation_type), axis=-2) + acts = [] + for idx, act_fn in enumerate(activation_type): + x_i = _convert_to_activation_function(act_fn)(x[idx]) + acts.append(x_i) + x = functools.reduce(operator.mul, acts) x = jnp.asarray(jnp.squeeze(x, axis=-2), jnp.bfloat16) @@ -331,7 +325,6 @@ def ref_func(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale, dtype=jnp.bfloat16) - @pytest.fixture(name="random_inputs") def random_inputs_fixture(shape): key = jax.random.PRNGKey(0) @@ -340,190 +333,86 @@ def random_inputs_fixture(shape): return out -class TestGeLu: - - def ref_func(self, inputs): - - func = jit(value_and_grad(lambda x: jnp.mean(jax.nn.gelu(x)))) - return func(inputs) +class TestActivationLu: - def prim_func(self, inputs): - - @jax.custom_vjp - def primitive(x): - out, _ = primitive_fwd(x) - return out - - def primitive_fwd(x): - out = gelu(x) - ctx = x - return out, ctx + def ref_func(self, x, activation_type): + def ref_act_lu(inputs): + x = jnp.split(inputs, len(activation_type), axis=-2) + acts = [] + for idx, act_fn in enumerate(activation_type): + x_i = _convert_to_activation_function(act_fn)(x[idx]) + acts.append(x_i) + x = functools.reduce(operator.mul, acts) + return jnp.mean(x) - def primitive_bwd(ctx, g): - x = ctx - out = dgelu(g, x) - return (out,) + ref_act_func = jit(value_and_grad(ref_act_lu, (0,))) + return ref_act_func(x) - primitive.defvjp(primitive_fwd, primitive_bwd) - func = value_and_grad(lambda x: jnp.mean(primitive(x))) - return func(inputs) + def primitive_func(self, inputs): + return jnp.mean(activation_lu(inputs, activation_type = self.activation_type)) @pytest.mark.parametrize('shape', [(32, 2, 64), (64, 2, 256)]) - def test_gelu(self, random_inputs): + @pytest.mark.parametrize('activation_type', [('gelu',), + ('gelu', 'linear'), + ('silu',), + ('silu', 'linear')]) + def test_activation_lu(self, random_inputs, activation_type): x = random_inputs - prim_out, prim_grad = self.prim_func(x) - ref_out, ref_grad = self.ref_func(x) - - assert_allclose(prim_out, ref_out, dtype=x.dtype) - assert_allclose(prim_grad, ref_grad, dtype=x.dtype) - + self.activation_type = activation_type -class TestGeLuFP8(TestGeLu): - - def prim_func(self, inputs): - amax = self.amax - scale = self.scale - scale_inv = self.scale_inv - no_use = jnp.zeros(1, jnp.float32) + value_n_grad_primitive_func = jit( + value_and_grad(self.primitive_func, (0,))) - @jax.custom_vjp - def primitive(x, y, z, w): - out = primitive_fwd(x) - return out + prim_out, (prim_grad,) = value_n_grad_primitive_func(x) + ref_out, (ref_grad,) = self.ref_func(x, activation_type) + """ prim_grad, = prim_grad """ + """ ref_grad, = ref_grad """ - def primitive_fwd(x, y, z, w): - out, _ = gelu_fp8(x, amax, scale, scale_inv, jnp.float8_e4m3fn) - out = dequantize(out, x.dtype, scale_inv) - ctx = x - return out, ctx + assert_allclose(prim_out, ref_out, dtype=x.dtype) + assert_allclose(prim_grad, ref_grad, dtype=x.dtype) - def primitive_bwd(ctx, g): - x = ctx - dgelu, dgelu_trans, dbias, amax_out = dgelu_dbias_cast_transpose( - g, x, amax, scale, scale_inv, jnp.float8_e5m2, -1) - dgelu = dequantize(dgelu, x.dtype, scale_inv) - dgelu_trans = dequantize(dgelu_trans, x.dtype, scale_inv) - return dgelu, dgelu_trans, dbias, amax_out - primitive.defvjp(primitive_fwd, primitive_bwd) - func = value_and_grad(lambda x, y, z, w: jnp.mean(primitive(x, y, z, w)), (0, 1, 2, 3)) +class TestActivationLuFP8(TestActivationLu): - return func(inputs, jnp.transpose(inputs, (2, 0, 1)), - jnp.zeros(inputs.shape[-1], dtype=inputs.dtype), no_use) + def primitive_func(self, inputs, dx_trans_no_use, dbias_no_use, amax, scale, scale_inv): + return jnp.mean( + activation_lu_fp8(inputs, + amax, scale, scale_inv, + jnp.float8_e4m3fn, jnp.float8_e5m2, + activation_type = self.activation_type)) @pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.parametrize('shape', [(32, 2, 64), (64, 2, 256)]) - def test_gelu(self, random_inputs): + @pytest.mark.parametrize('activation_type', [('gelu',), + ('gelu', 'linear'), + ('silu',), + ('silu', 'linear')]) + def test_activation_lu(self, random_inputs, activation_type): self.amax = jnp.zeros(1, jnp.float32) self.scale = jnp.ones(1, jnp.float32) self.scale_inv = jnp.ones(1, jnp.float32) + self.activation_type = activation_type x = random_inputs - prim_out, (prim_grad, prim_grad_trans, dbias, amax) = self.prim_func(x) - ref_out, ref_grad = self.ref_func(x) - - assert_allclose(prim_out, ref_out, dtype=FP8Helper.FWD_DTYPE) - assert_allclose(amax, jnp.amax(jnp.abs(ref_grad)), rtol=1e-2) - assert_allclose(dbias, jnp.sum(ref_grad, axis=(i for i in range(x.ndim - 1)))) - assert_allclose(prim_grad, ref_grad, dtype=FP8Helper.BWD_DTYPE) - assert_allclose(prim_grad_trans, - jnp.transpose(ref_grad, (2, 0, 1)), - dtype=FP8Helper.BWD_DTYPE) - - -class TestGatedGeLu: - - def ref_func(self, inputs): - - def jax_gated_gelu(x): - x = jnp.split(x, 2, axis=-2) - acts = [jax.nn.gelu(x[0]), x[1]] - x = functools.reduce(operator.mul, acts) - x = jnp.asarray(jnp.squeeze(x, -2), jnp.bfloat16) - return x - - func = jit(value_and_grad(lambda x: jnp.mean(jax_gated_gelu(x)))) - return func(inputs) - - def prim_func(self, inputs): - - @jax.custom_vjp - def primitive(x): - out, _ = primitive_fwd(x) - return out - - def primitive_fwd(x): - out = gated_gelu(x) - ctx = x - return out, ctx - - def primitive_bwd(ctx, g): - x = ctx - out = dgated_gelu(g, x) - return (out,) - - primitive.defvjp(primitive_fwd, primitive_bwd) - func = value_and_grad(lambda x: jnp.mean(primitive(x))) - return func(inputs) - - @pytest.mark.parametrize('shape', [(32, 2, 64), (64, 2, 256)]) - def test_gated_gelu(self, random_inputs): - x = random_inputs - prim_out, prim_grad = self.prim_func(x) - ref_out, ref_grad = self.ref_func(x) - - assert_allclose(prim_out, ref_out, dtype=x.dtype) - assert_allclose(prim_grad, ref_grad, dtype=x.dtype) + value_n_grad_primitive_func = jit( value_and_grad(self.primitive_func, (0, 1, 2, 3, 4, 5,))) -class TestGatedGeLuFP8(TestGatedGeLu): + transpose_indices = (1, 2, 0) if len(activation_type) > 1 else (2, 0, 1) + dx_trans_no_use = jnp.zeros([x.shape[i] for i in transpose_indices], dtype=x.dtype) + dbias_no_use = jnp.zeros(x.shape[-1], dtype=x.dtype) - def prim_func(self, inputs): - amax = self.amax - scale = self.scale - scale_inv = self.scale_inv - no_use = jnp.zeros(1, jnp.float32) - - @jax.custom_vjp - def primitive(x, y, z): - out = primitive_fwd(x) - return out - - def primitive_fwd(x, y, z): - out, _ = gated_gelu_fp8(x, amax, scale, scale_inv, jnp.float8_e4m3fn) - out = dequantize(out, x.dtype, scale_inv) - ctx = x - return out, ctx - - def primitive_bwd(ctx, g): - x = ctx - dgelu, dgelu_trans, amax_out = dgated_gelu_cast_transpose(g, x, amax, scale, scale_inv, - jnp.float8_e5m2, -1) - dgelu = dequantize(dgelu, x.dtype, scale_inv) - dgelu_trans = dequantize(dgelu_trans, x.dtype, scale_inv) - return dgelu, dgelu_trans, amax_out - - primitive.defvjp(primitive_fwd, primitive_bwd) - func = value_and_grad(lambda x, y, z: jnp.mean(primitive(x, y, z)), (0, 1, 2)) - - return func(inputs, jnp.transpose(inputs, (1, 2, 0)), no_use) - - @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize('shape', [(32, 2, 64), (64, 2, 256)]) - def test_gated_gelu(self, random_inputs): - self.amax = jnp.zeros(1, jnp.float32) - self.scale = jnp.ones(1, jnp.float32) - self.scale_inv = jnp.ones(1, jnp.float32) - - x = random_inputs - prim_out, (prim_grad, prim_grad_trans, amax) = self.prim_func(x) - ref_out, ref_grad = self.ref_func(x) + prim_out, (prim_grad, prim_grad_trans, dbias, amax, _, _) = \ + value_n_grad_primitive_func(x, dx_trans_no_use, dbias_no_use, + self.amax, self.scale, self.scale_inv) + ref_out, (ref_grad,) = self.ref_func(x, activation_type) assert_allclose(prim_out, ref_out, dtype=FP8Helper.FWD_DTYPE) assert_allclose(amax, jnp.amax(jnp.abs(ref_grad)), rtol=1e-2) + if 'linear' not in activation_type: + assert_allclose(dbias, jnp.sum(ref_grad, axis=(i for i in range(x.ndim - 1)))) assert_allclose(prim_grad, ref_grad, dtype=FP8Helper.BWD_DTYPE) assert_allclose(prim_grad_trans, - jnp.transpose(ref_grad, (1, 2, 0)), + jnp.transpose(ref_grad, transpose_indices), dtype=FP8Helper.BWD_DTYPE) diff --git a/tests/jax/test_layer.py b/tests/jax/test_layer.py index 1b7b4087d0..70602ccbb8 100644 --- a/tests/jax/test_layer.py +++ b/tests/jax/test_layer.py @@ -158,6 +158,33 @@ def compare_dict(ref_fd, test_fd, rtol=1e-05, atol=1e-08): _KEY_OF_DROPOUT_RATE: 0.0, _KEY_OF_MLP_ACTIVATIONS: (('gelu',)), _KEY_OF_FUSE_MLP_WI: True +}, { + _KEY_OF_LAYERNORM_TYPE: 'rmsnorm', + _KEY_OF_DROPOUT_RATE: 0.0, + _KEY_OF_MLP_ACTIVATIONS: (('silu', 'linear')), + _KEY_OF_FUSE_MLP_WI: True +}, { + _KEY_OF_SCALE_ATTN_LOGITS: True, + _KEY_OF_LAYERNORM_TYPE: 'rmsnorm', + _KEY_OF_DROPOUT_RATE: 0.8, + _KEY_OF_MLP_ACTIVATIONS: (('silu', 'linear')), + _KEY_OF_FUSE_MLP_WI: True +}, { + _KEY_OF_TRANSPOSE_BS: False, + _KEY_OF_SCALE_ATTN_LOGITS: True, + _KEY_OF_LAYERNORM_TYPE: 'rmsnorm', + _KEY_OF_DROPOUT_RATE: 0.0, + _KEY_OF_MLP_ACTIVATIONS: (('silu', 'linear')), + _KEY_OF_FUSE_MLP_WI: True +}, { + _KEY_OF_NUM_HEADS: 8, + _KEY_OF_NUM_GQA_GROUPS: 4, + _KEY_OF_TRANSPOSE_BS: False, + _KEY_OF_SCALE_ATTN_LOGITS: True, + _KEY_OF_LAYERNORM_TYPE: 'layernorm', + _KEY_OF_DROPOUT_RATE: 0.0, + _KEY_OF_MLP_ACTIVATIONS: (('silu',)), + _KEY_OF_FUSE_MLP_WI: True }, { _KEY_OF_TRANSPOSE_BS: False, _KEY_OF_LAYERNORM_TYPE: 'layernorm', diff --git a/tests/jax/test_praxis_layers.py b/tests/jax/test_praxis_layers.py index dce0263ac7..1bc32d1251 100644 --- a/tests/jax/test_praxis_layers.py +++ b/tests/jax/test_praxis_layers.py @@ -543,11 +543,25 @@ class LayerNormMLPAttr: ACTIVATION: ('gelu', 'linear') }, { INTERMEDIATE_DIM: 2048, - USE_BIAS: True, + USE_BIAS: False, ENABLE_LN: True, LN_TYPE: 'rmsnorm', ZERO_CEN: False, ACTIVATION: ('gelu', 'linear') + }, { + INTERMEDIATE_DIM: 2048, + USE_BIAS: True, + ENABLE_LN: True, + LN_TYPE: 'rmsnorm', + ZERO_CEN: False, + ACTIVATION: ('silu', 'linear') + }, { + INTERMEDIATE_DIM: 2048, + USE_BIAS: False, + ENABLE_LN: True, + LN_TYPE: 'rmsnorm', + ZERO_CEN: False, + ACTIVATION: ('silu', 'linear') }] diff --git a/transformer_engine/common/activation/activation_template.h b/transformer_engine/common/activation/activation_template.h new file mode 100644 index 0000000000..12e1b37e8f --- /dev/null +++ b/transformer_engine/common/activation/activation_template.h @@ -0,0 +1,136 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include "../util/vectorized_pointwise.h" +#include "../common.h" + + +namespace transformer_engine { + +template +void act_fn(const Tensor &input, + Tensor *output, + cudaStream_t stream) { + CheckInputTensor(input, "act_lu_input"); + CheckOutputTensor(*output, "act_lu_output"); + NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match."); + const size_t tot_elts = product(input.data.shape); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType, + constexpr int nvec = 32 / sizeof(IType); + VectorizedUnaryKernelLauncher( + reinterpret_cast(input.data.dptr), + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->scale.dptr), + reinterpret_cast(output->amax.dptr), + tot_elts, + {}, + stream); + ); // NOLINT(*) + ); // NOLINT(*) +} + +template +void dact_fn(const Tensor &grad, + const Tensor &input, + Tensor *output, + cudaStream_t stream) { + CheckInputTensor(input, "dact_lu_input"); + CheckInputTensor(grad, "dact_lu_input_grad"); + CheckOutputTensor(*output, "dact_lu_output"); + NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match."); + NVTE_CHECK(input.data.dtype == grad.data.dtype, + "Input and incoming gradient types must match."); + const size_t tot_elts = product(input.data.shape); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType, + constexpr int nvec = 32 / sizeof(IType); + VectorizedUnaryGradKernelLauncher( + reinterpret_cast(grad.data.dptr), + reinterpret_cast(input.data.dptr), + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->scale.dptr), + reinterpret_cast(output->amax.dptr), + tot_elts, + {}, + stream); + ); // NOLINT(*) + ); // NOLINT(*) +} + +template +void gated_act_fn(const Tensor &input, + Tensor *output, + cudaStream_t stream) { + CheckInputTensor(input, "gated_act_input"); + CheckOutputTensor(*output, "gated_act_output"); + NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); + NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions."); + NVTE_CHECK(input.data.shape[0] == output->data.shape[0], + "Input shape[0] must be equal to output shape[0]."); + NVTE_CHECK(input.data.shape[1] == output->data.shape[1] * 2, + "Input shape[1] must be 2x larger than output shape[1]."); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType, + constexpr int nvec = 32 / sizeof(IType); + GatedActivationKernelLauncher( + reinterpret_cast(input.data.dptr), + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->scale.dptr), + reinterpret_cast(output->amax.dptr), + output->data.shape[0], + output->data.shape[1], + {}, + stream); + ); // NOLINT(*) + ); // NOLINT(*) +} + +template +void dgated_act_fn(const Tensor &grad, + const Tensor &input, + Tensor *output, + cudaStream_t stream) { + CheckInputTensor(grad, "dgated_act_grad"); + CheckInputTensor(input, "dgated_act_input"); + CheckOutputTensor(*output, "dgated_act_output"); + NVTE_CHECK(grad.data.shape.size() == 2, "Grad must have 2 dimensions."); + NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); + NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions."); + NVTE_CHECK(output->data.shape[0] == grad.data.shape[0], + "Output shape[0] must be equal to grad shape[0]."); + NVTE_CHECK(output->data.shape[1] == grad.data.shape[1] * 2, + "Output shape[1] must be 2x larger than grad shape[1]."); + NVTE_CHECK(input.data.shape == output->data.shape, + "Input and output shapes must match."); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType, + constexpr int nvec = 32 / sizeof(IType); + DGatedActivationKernelLauncher( + reinterpret_cast(grad.data.dptr), + reinterpret_cast(input.data.dptr), + reinterpret_cast(output->data.dptr), + grad.data.shape[0], + grad.data.shape[1], + {}, + stream); + ); // NOLINT(*) + ); // NOLINT(*) +} + +} // namespace transformer_engine + diff --git a/transformer_engine/common/activation/gelu.cu b/transformer_engine/common/activation/gelu.cu index 96271968e6..5b872b2523 100644 --- a/transformer_engine/common/activation/gelu.cu +++ b/transformer_engine/common/activation/gelu.cu @@ -3,191 +3,18 @@ * * See LICENSE for license information. ************************************************************************/ - -#include -#include -#include -#include -#include "../utils.cuh" -#include "../common.h" -#include -#include <../util/vectorized_pointwise.h> +#include "./activation_template.h" #include "../util/math.h" -namespace transformer_engine { - -void gelu(const Tensor &input, - Tensor *output, - cudaStream_t stream) { - CheckInputTensor(input, "gelu_input"); - CheckOutputTensor(*output, "gelu_output"); - NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match."); - const size_t tot_elts = product(input.data.shape); - - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType, - constexpr int nvec = 32 / sizeof(IType); - VectorizedUnaryKernelLauncher >( - reinterpret_cast(input.data.dptr), - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), - tot_elts, - Empty(), - stream); - ); // NOLINT(*) - ); // NOLINT(*) -} - -void dgelu(const Tensor &grad, - const Tensor &input, - Tensor *output, - cudaStream_t stream) { - CheckInputTensor(input, "dgelu_input"); - CheckInputTensor(grad, "dgelu_input_grad"); - CheckOutputTensor(*output, "dgelu_output"); - NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match."); - NVTE_CHECK(input.data.dtype == grad.data.dtype, - "Input and incoming gradient types must match."); - const size_t tot_elts = product(input.data.shape); - - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType, - constexpr int nvec = 32 / sizeof(IType); - VectorizedUnaryGradKernelLauncher>( - reinterpret_cast(grad.data.dptr), - reinterpret_cast(input.data.dptr), - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), - tot_elts, - {}, - stream); - ); // NOLINT(*) - ); // NOLINT(*) -} - -void geglu(const Tensor &input, - Tensor *output, - cudaStream_t stream) { - CheckInputTensor(input, "geglu_input"); - CheckOutputTensor(*output, "geglu_output"); - NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); - NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions."); - NVTE_CHECK(input.data.shape[0] == output->data.shape[0], - "Input shape[0] must be equal to output shape[0]."); - NVTE_CHECK(input.data.shape[1] == output->data.shape[1] * 2, - "Input shape[1] must be 2x larger than output shape[1]."); - - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType, - constexpr int nvec = 32 / sizeof(IType); - GatedActivationKernelLauncher>( - reinterpret_cast(input.data.dptr), - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), - output->data.shape[0], - output->data.shape[1], - {}, - stream); - ); // NOLINT(*) - ); // NOLINT(*) -} - -void dgeglu(const Tensor &grad, - const Tensor &input, - Tensor *output, - cudaStream_t stream) { - CheckInputTensor(grad, "dgeglu_grad"); - CheckInputTensor(input, "dgeglu_input"); - CheckOutputTensor(*output, "dgeglu_output"); - NVTE_CHECK(grad.data.shape.size() == 2, "Grad must have 2 dimensions."); - NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); - NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions."); - NVTE_CHECK(output->data.shape[0] == grad.data.shape[0], - "Output shape[0] must be equal to grad shape[0]."); - NVTE_CHECK(output->data.shape[1] == grad.data.shape[1] * 2, - "Output shape[1] must be 2x larger than grad shape[1]."); - NVTE_CHECK(input.data.shape == output->data.shape, - "Input and output shapes must match."); - - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType, - constexpr int nvec = 32 / sizeof(IType); - DGatedActivationKernelLauncher, dgelu>( - reinterpret_cast(grad.data.dptr), - reinterpret_cast(input.data.dptr), - reinterpret_cast(output->data.dptr), - grad.data.shape[0], - grad.data.shape[1], - {}, - stream); - ); // NOLINT(*) - ); // NOLINT(*) -} - -void qgelu(const Tensor &input, - Tensor *output, - cudaStream_t stream) { - CheckInputTensor(input, "qgelu_input"); - CheckOutputTensor(*output, "qgelu_output"); - NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match."); - const size_t tot_elts = product(input.data.shape); - - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType, - constexpr int nvec = 32 / sizeof(IType); - VectorizedUnaryKernelLauncher >( - reinterpret_cast(input.data.dptr), - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), - tot_elts, - Empty(), - stream); - ); // NOLINT(*) - ); // NOLINT(*) -} - -void dqgelu(const Tensor &grad, - const Tensor &input, - Tensor *output, - cudaStream_t stream) { - CheckInputTensor(input, "dqgelu_input"); - CheckInputTensor(grad, "dqgelu_input_grad"); - CheckOutputTensor(*output, "dqgelu_output"); - NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match."); - NVTE_CHECK(input.data.dtype == grad.data.dtype, - "Input and incoming gradient types must match."); - const size_t tot_elts = product(input.data.shape); - - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType, - constexpr int nvec = 32 / sizeof(IType); - VectorizedUnaryGradKernelLauncher>( - reinterpret_cast(grad.data.dptr), - reinterpret_cast(input.data.dptr), - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), - tot_elts, - {}, - stream); - ); // NOLINT(*) - ); // NOLINT(*) -} - -} // namespace transformer_engine void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_gelu); using namespace transformer_engine; - gelu(*reinterpret_cast(input), - reinterpret_cast(output), - stream); + act_fn>(*reinterpret_cast(input), + reinterpret_cast(output), + stream); } void nvte_dgelu(const NVTETensor grad, @@ -196,10 +23,10 @@ void nvte_dgelu(const NVTETensor grad, cudaStream_t stream) { NVTE_API_CALL(nvte_dgelu); using namespace transformer_engine; - dgelu(*reinterpret_cast(grad), - *reinterpret_cast(input), - reinterpret_cast(output), - stream); + dact_fn>(*reinterpret_cast(grad), + *reinterpret_cast(input), + reinterpret_cast(output), + stream); } void nvte_geglu(const NVTETensor input, @@ -207,9 +34,9 @@ void nvte_geglu(const NVTETensor input, cudaStream_t stream) { NVTE_API_CALL(nvte_geglu); using namespace transformer_engine; - geglu(*reinterpret_cast(input), - reinterpret_cast(output), - stream); + gated_act_fn>(*reinterpret_cast(input), + reinterpret_cast(output), + stream); } void nvte_dgeglu(const NVTETensor grad, @@ -218,10 +45,11 @@ void nvte_dgeglu(const NVTETensor grad, cudaStream_t stream) { NVTE_API_CALL(nvte_dgeglu); using namespace transformer_engine; - dgeglu(*reinterpret_cast(grad), - *reinterpret_cast(input), - reinterpret_cast(output), - stream); + dgated_act_fn, dgelu>( + *reinterpret_cast(grad), + *reinterpret_cast(input), + reinterpret_cast(output), + stream); } void nvte_qgelu(const NVTETensor input, @@ -229,9 +57,9 @@ void nvte_qgelu(const NVTETensor input, cudaStream_t stream) { NVTE_API_CALL(nvte_qgelu); using namespace transformer_engine; - qgelu(*reinterpret_cast(input), - reinterpret_cast(output), - stream); + act_fn>(*reinterpret_cast(input), + reinterpret_cast(output), + stream); } void nvte_dqgelu(const NVTETensor grad, @@ -240,8 +68,8 @@ void nvte_dqgelu(const NVTETensor grad, cudaStream_t stream) { NVTE_API_CALL(nvte_dqgelu); using namespace transformer_engine; - dqgelu(*reinterpret_cast(grad), - *reinterpret_cast(input), - reinterpret_cast(output), - stream); + dact_fn>(*reinterpret_cast(grad), + *reinterpret_cast(input), + reinterpret_cast(output), + stream); } diff --git a/transformer_engine/common/activation/relu.cu b/transformer_engine/common/activation/relu.cu index b5bf04ac6c..08459bf061 100644 --- a/transformer_engine/common/activation/relu.cu +++ b/transformer_engine/common/activation/relu.cu @@ -4,136 +4,18 @@ * See LICENSE for license information. ************************************************************************/ -#include -#include -#include "../util/vectorized_pointwise.h" +#include "./activation_template.h" #include "../util/math.h" -#include "../common.h" -namespace transformer_engine { - -void relu(const Tensor &input, - Tensor *output, - cudaStream_t stream) { - CheckInputTensor(input, "relu_input"); - CheckOutputTensor(*output, "relu_output"); - NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match."); - const size_t tot_elts = product(input.data.shape); - - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType, - constexpr int nvec = 32 / sizeof(IType); - VectorizedUnaryKernelLauncher>( - reinterpret_cast(input.data.dptr), - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), - tot_elts, - {}, - stream); - ); // NOLINT(*) - ); // NOLINT(*) -} - -void drelu(const Tensor &grad, - const Tensor &input, - Tensor *output, - cudaStream_t stream) { - CheckInputTensor(input, "drelu_input"); - CheckInputTensor(grad, "drelu_input_grad"); - CheckOutputTensor(*output, "drelu_output"); - NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match."); - NVTE_CHECK(input.data.dtype == grad.data.dtype, - "Input and incoming gradient types must match."); - const size_t tot_elts = product(input.data.shape); - - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType, - constexpr int nvec = 32 / sizeof(IType); - VectorizedUnaryGradKernelLauncher>( - reinterpret_cast(grad.data.dptr), - reinterpret_cast(input.data.dptr), - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), - tot_elts, - {}, - stream); - ); // NOLINT(*) - ); // NOLINT(*) -} - -void reglu(const Tensor &input, - Tensor *output, - cudaStream_t stream) { - CheckInputTensor(input, "reglu_input"); - CheckOutputTensor(*output, "reglu_output"); - NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); - NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions."); - NVTE_CHECK(input.data.shape[0] == output->data.shape[0], - "Input shape[0] must be equal to output shape[0]."); - NVTE_CHECK(input.data.shape[1] == output->data.shape[1] * 2, - "Input shape[1] must be 2x larger than output shape[1]."); - - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType, - constexpr int nvec = 32 / sizeof(IType); - GatedActivationKernelLauncher>( - reinterpret_cast(input.data.dptr), - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), - output->data.shape[0], - output->data.shape[1], - {}, - stream); - ); // NOLINT(*) - ); // NOLINT(*) -} - -void dreglu(const Tensor &grad, - const Tensor &input, - Tensor *output, - cudaStream_t stream) { - CheckInputTensor(grad, "dreglu_grad"); - CheckInputTensor(input, "dreglu_input"); - CheckOutputTensor(*output, "dreglu_output"); - NVTE_CHECK(grad.data.shape.size() == 2, "Grad must have 2 dimensions."); - NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); - NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions."); - NVTE_CHECK(output->data.shape[0] == grad.data.shape[0], - "Output shape[0] must be equal to grad shape[0]."); - NVTE_CHECK(output->data.shape[1] == grad.data.shape[1] * 2, - "Output shape[1] must be 2x larger than grad shape[1]."); - NVTE_CHECK(input.data.shape == output->data.shape, - "Input and output shapes must match."); - - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType, - constexpr int nvec = 32 / sizeof(IType); - DGatedActivationKernelLauncher, drelu>( - reinterpret_cast(grad.data.dptr), - reinterpret_cast(input.data.dptr), - reinterpret_cast(output->data.dptr), - grad.data.shape[0], - grad.data.shape[1], - {}, - stream); - ); // NOLINT(*) - ); // NOLINT(*) -} - -} // namespace transformer_engine - void nvte_relu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_relu); using namespace transformer_engine; - relu(*reinterpret_cast(input), - reinterpret_cast(output), - stream); + act_fn>(*reinterpret_cast(input), + reinterpret_cast(output), + stream); } void nvte_drelu(const NVTETensor grad, @@ -142,10 +24,10 @@ void nvte_drelu(const NVTETensor grad, cudaStream_t stream) { NVTE_API_CALL(nvte_drelu); using namespace transformer_engine; - drelu(*reinterpret_cast(grad), - *reinterpret_cast(input), - reinterpret_cast(output), - stream); + dact_fn>(*reinterpret_cast(grad), + *reinterpret_cast(input), + reinterpret_cast(output), + stream); } void nvte_reglu(const NVTETensor input, @@ -153,7 +35,7 @@ void nvte_reglu(const NVTETensor input, cudaStream_t stream) { NVTE_API_CALL(nvte_reglu); using namespace transformer_engine; - reglu(*reinterpret_cast(input), + gated_act_fn>(*reinterpret_cast(input), reinterpret_cast(output), stream); } @@ -164,8 +46,9 @@ void nvte_dreglu(const NVTETensor grad, cudaStream_t stream) { NVTE_API_CALL(nvte_dreglu); using namespace transformer_engine; - dreglu(*reinterpret_cast(grad), - *reinterpret_cast(input), - reinterpret_cast(output), - stream); + dgated_act_fn, drelu>( + *reinterpret_cast(grad), + *reinterpret_cast(input), + reinterpret_cast(output), + stream); } diff --git a/transformer_engine/common/activation/swiglu.cu b/transformer_engine/common/activation/swiglu.cu index 031a11fdcf..088b06bea2 100644 --- a/transformer_engine/common/activation/swiglu.cu +++ b/transformer_engine/common/activation/swiglu.cu @@ -4,85 +4,40 @@ * See LICENSE for license information. ************************************************************************/ -#include -#include -#include "../util/vectorized_pointwise.h" +#include "./activation_template.h" #include "../util/math.h" -#include "../common.h" -namespace transformer_engine { - -void swiglu(const Tensor &input, - Tensor *output, - cudaStream_t stream) { - CheckInputTensor(input, "geglu_input"); - CheckOutputTensor(*output, "geglu_output"); - NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); - NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions."); - NVTE_CHECK(input.data.shape[0] == output->data.shape[0], - "Input shape[0] must be equal to output shape[0]."); - NVTE_CHECK(input.data.shape[1] == output->data.shape[1] * 2, - "Input shape[1] must be 2x larger than output shape[1]."); - - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType, - constexpr int nvec = 32 / sizeof(IType); - GatedActivationKernelLauncher>( - reinterpret_cast(input.data.dptr), - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), - output->data.shape[0], - output->data.shape[1], - {}, - stream); - ); // NOLINT(*) - ); // NOLINT(*) +void nvte_swish(const NVTETensor input, + NVTETensor output, + cudaStream_t stream) { + NVTE_API_CALL(nvte_swish); + using namespace transformer_engine; + act_fn>(*reinterpret_cast(input), + reinterpret_cast(output), + stream); } -void dswiglu(const Tensor &grad, - const Tensor &input, - Tensor *output, - cudaStream_t stream) { - CheckInputTensor(grad, "dswiglu_grad"); - CheckInputTensor(input, "dswiglu_input"); - CheckOutputTensor(*output, "dswiglu_output"); - NVTE_CHECK(grad.data.shape.size() == 2, "Grad must have 2 dimensions."); - NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); - NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions."); - NVTE_CHECK(output->data.shape[0] == grad.data.shape[0], - "Output shape[0] must be equal to grad shape[0]."); - NVTE_CHECK(output->data.shape[1] == grad.data.shape[1] * 2, - "Output shape[1] must be 2x larger than grad shape[1]."); - NVTE_CHECK(input.data.shape == output->data.shape, - "Input and output shapes must match."); - - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType, - constexpr int nvec = 32 / sizeof(IType); - DGatedActivationKernelLauncher, dswish>( - reinterpret_cast(grad.data.dptr), - reinterpret_cast(input.data.dptr), - reinterpret_cast(output->data.dptr), - grad.data.shape[0], - grad.data.shape[1], - {}, - stream); - ); // NOLINT(*) - ); // NOLINT(*) +void nvte_dswish(const NVTETensor grad, + const NVTETensor input, + NVTETensor output, + cudaStream_t stream) { + NVTE_API_CALL(nvte_dswish); + using namespace transformer_engine; + dact_fn>(*reinterpret_cast(grad), + *reinterpret_cast(input), + reinterpret_cast(output), + stream); } -} // namespace transformer_engine - void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_swiglu); using namespace transformer_engine; - swiglu(*reinterpret_cast(input), - reinterpret_cast(output), - stream); + gated_act_fn>(*reinterpret_cast(input), + reinterpret_cast(output), + stream); } void nvte_dswiglu(const NVTETensor grad, @@ -91,8 +46,9 @@ void nvte_dswiglu(const NVTETensor grad, cudaStream_t stream) { NVTE_API_CALL(nvte_dswiglu); using namespace transformer_engine; - dswiglu(*reinterpret_cast(grad), - *reinterpret_cast(input), - reinterpret_cast(output), - stream); + dgated_act_fn, dswish>( + *reinterpret_cast(grad), + *reinterpret_cast(input), + reinterpret_cast(output), + stream); } diff --git a/transformer_engine/common/include/transformer_engine/activation.h b/transformer_engine/common/include/transformer_engine/activation.h index fd3e458ff7..6bf795cd38 100644 --- a/transformer_engine/common/include/transformer_engine/activation.h +++ b/transformer_engine/common/include/transformer_engine/activation.h @@ -61,24 +61,24 @@ void nvte_dgeglu(const NVTETensor grad, NVTETensor output, cudaStream_t stream); -/*! \brief Compute RELU activation of the input. +/*! \brief Compute SiLU activation of the input. * - * \param[in] input Input tensor for RELU activation. + * \param[in] input Input tensor for GELU activation. * \param[in,out] output Output tensor. * \param[in] stream CUDA stream used for the operation. */ -void nvte_relu(const NVTETensor input, +void nvte_swish(const NVTETensor input, NVTETensor output, cudaStream_t stream); -/*! \brief Compute RELU activation gradient. +/*! \brief Compute Swish activation gradient. * * \param[in] grad Incoming gradient. - * \param[in] input Input tensor for RELU activation. + * \param[in] input Input tensor for Swish activation. * \param[in,out] output Output tensor. * \param[in] stream CUDA stream used for the operation. */ -void nvte_drelu(const NVTETensor grad, +void nvte_dswish(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); @@ -105,6 +105,29 @@ void nvte_dswiglu(const NVTETensor grad, NVTETensor output, cudaStream_t stream); + +/*! \brief Compute RELU activation of the input. + * + * \param[in] input Input tensor for RELU activation. + * \param[in,out] output Output tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_relu(const NVTETensor input, + NVTETensor output, + cudaStream_t stream); + +/*! \brief Compute RELU activation gradient. + * + * \param[in] grad Incoming gradient. + * \param[in] input Input tensor for RELU activation. + * \param[in,out] output Output tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_drelu(const NVTETensor grad, + const NVTETensor input, + NVTETensor output, + cudaStream_t stream); + /*! \brief Compute ReGLU activation of the input. * * \param[in] input Input tensor of shape [N, H * 2]. diff --git a/transformer_engine/common/include/transformer_engine/transpose.h b/transformer_engine/common/include/transformer_engine/transpose.h index 4d2061d078..c556c001de 100644 --- a/transformer_engine/common/include/transformer_engine/transpose.h +++ b/transformer_engine/common/include/transformer_engine/transpose.h @@ -159,6 +159,53 @@ void nvte_dgeglu_cast_transpose(const NVTETensor input, NVTETensor transposed_output, cudaStream_t stream); +/*! \brief Compute backward of SiLU operation on the input, then cast and transpose. Additionally, + * reduce the result of the SiLU backward along the first dimension. + * + * This function produces 3 results: + * - `cast_output` is equal to `cast(dSiLU(input))` + * - `transposed_output` is equal to `transpose(cast(dSiLU(input)))` + * - `dbias` is equal to `reduce(dSiLU(input), axis=0)` + * + * Calling this function with workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input tensor of shape [N, H]. + * \param[in] swish_input Tensor used as input to the forward of SiLU operation. + * Shape [N, H]. + * \param[in,out] cast_output Result of the cast. Shape: [N, H]. + * \param[in,out] transposed_output Result of the cast and transpose. Shape: [H, N]. + * \param[out] dbias Result of the reduction of the dSiLU(input) along the + * first dimension. Shape: [H]. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_cast_transpose_dbias_dswish(const NVTETensor input, + const NVTETensor swish_input, + NVTETensor cast_output, + NVTETensor transposed_output, + NVTETensor dbias, + NVTETensor workspace, + cudaStream_t stream); + +/*! \brief Compute dswiglu of the input, additionally does cast and transpose the dswiglu output. + * + * This function produces 2 results: + * - `cast_output` is the result of the cast + * - `transposed_output` is the transposed result of the cast. + * + * \param[in] input Input tensor of shape [N, H]. + * \param[in] swiglu_input Tensor used as input to the forward of SwiGLU operation. + * Shape [N, H * 2]. + * \param[in,out] cast_output Result of the cast. Shape: [N, H * 2]. + * \param[in,out] transposed_output Result of the cast and transpose. Shape: [H * 2, N]. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_dswiglu_cast_transpose(const NVTETensor input, + const NVTETensor swiglu_input, + NVTETensor cast_output, + NVTETensor transposed_output, + cudaStream_t stream); #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/common/transpose/cast_transpose_fusion.cu b/transformer_engine/common/transpose/cast_transpose_fusion.cu index 8e455dddb5..0a0560d470 100644 --- a/transformer_engine/common/transpose/cast_transpose_fusion.cu +++ b/transformer_engine/common/transpose/cast_transpose_fusion.cu @@ -619,7 +619,11 @@ void cast_transpose_dbias(const Tensor &input, ); // NOLINT(*) } -template +// TODO Phuong: Change all the names in these generalized functions. +// For now, I keep the old names so that it is easier to do code review +template __global__ void __launch_bounds__(cast_transpose_num_threads) cast_transpose_dbias_dgelu_kernel(const Param param, @@ -713,7 +717,7 @@ cast_transpose_dbias_dgelu_kernel(const Param param, for (unsigned int j = 0; j < nvec_out; ++j) { #pragma unroll for (unsigned int k = 0; k < nvec_in; ++k) { - after_dgelu[j].data.elt[k] = dgelu(gelu_in[current_in ^ 1][j].data.elt[k], {}) * + after_dgelu[j].data.elt[k] = OP(gelu_in[current_in ^ 1][j].data.elt[k], {}) * CType(in[current_in ^ 1][j].data.elt[k]); } } @@ -779,7 +783,9 @@ cast_transpose_dbias_dgelu_kernel(const Param param, } } -template +template __global__ void __launch_bounds__(cast_transpose_num_threads) cast_transpose_dbias_dgelu_kernel_notaligned(const Param param, @@ -896,7 +902,7 @@ cast_transpose_dbias_dgelu_kernel_notaligned(const Param param, for (unsigned int j = 0; j < nvec_out; ++j) { #pragma unroll for (unsigned int k = 0; k < nvec_in; ++k) { - after_dgelu[j].data.elt[k] = dgelu(gelu_in[current_in ^ 1][j].data.elt[k], {}) * + after_dgelu[j].data.elt[k] = OP(gelu_in[current_in ^ 1][j].data.elt[k], {}) * CType(in[current_in ^ 1][j].data.elt[k]); } } @@ -969,7 +975,11 @@ cast_transpose_dbias_dgelu_kernel_notaligned(const Param param, } } -template +template __global__ void __launch_bounds__(cast_transpose_num_threads) dgeglu_cast_transpose_kernel(const IType * const input, @@ -1068,11 +1078,11 @@ dgeglu_cast_transpose_kernel(const IType * const input, for (unsigned int j = 0; j < nvec_out; ++j) { #pragma unroll for (unsigned int k = 0; k < nvec_in; ++k) { - after_dgelu[j].data.elt[k] = dgelu(gelu_in[current_in ^ 1][j].data.elt[k], {}) * + after_dgelu[j].data.elt[k] = OP1(gelu_in[current_in ^ 1][j].data.elt[k], {}) * CType(in[current_in ^ 1][j].data.elt[k]) * CType(gate_in[current_in ^ 1][j].data.elt[k]); after_dgate[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]) * - gelu(gelu_in[current_in ^ 1][j].data.elt[k], {}); + OP2(gelu_in[current_in ^ 1][j].data.elt[k], {}); } } OVec out_trans_0[nvec_in]; // NOLINT(*) @@ -1138,7 +1148,11 @@ dgeglu_cast_transpose_kernel(const IType * const input, } } -template +template __global__ void __launch_bounds__(cast_transpose_num_threads) dgeglu_cast_transpose_kernel_notaligned(const IType * const input, @@ -1265,11 +1279,11 @@ dgeglu_cast_transpose_kernel_notaligned(const IType * const input, for (unsigned int j = 0; j < nvec_out; ++j) { #pragma unroll for (unsigned int k = 0; k < nvec_in; ++k) { - after_dgelu[j].data.elt[k] = dgelu(gelu_in[current_in ^ 1][j].data.elt[k], {}) * + after_dgelu[j].data.elt[k] = OP1(gelu_in[current_in ^ 1][j].data.elt[k], {}) * CType(in[current_in ^ 1][j].data.elt[k]) * CType(gate_in[current_in ^ 1][j].data.elt[k]); after_dgate[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]) * - gelu(gelu_in[current_in ^ 1][j].data.elt[k], {}); + OP2(gelu_in[current_in ^ 1][j].data.elt[k], {}); } } OVec out_trans_0[nvec_in]; // NOLINT(*) @@ -1343,6 +1357,8 @@ dgeglu_cast_transpose_kernel_notaligned(const IType * const input, } } +template void cast_transpose_dbias_dgelu(const Tensor &input, const Tensor &gelu_input, Tensor *cast_output, @@ -1407,7 +1423,7 @@ void cast_transpose_dbias_dgelu(const Tensor &input, const bool full_tile = row_length % (nvec_in * THREADS_PER_WARP) == 0 && num_rows % (nvec_out * THREADS_PER_WARP) == 0; - using ComputeType = fp32; + // using ComputeType = fp32; constexpr size_t shared_size_transpose = cast_transpose_num_threads / n_warps_per_tile * (THREADS_PER_WARP + 1) * sizeof(Vec); @@ -1423,24 +1439,32 @@ void cast_transpose_dbias_dgelu(const Tensor &input, param.scale_ptr = reinterpret_cast(cast_output->scale.dptr); param.amax = reinterpret_cast(cast_output->amax.dptr); param.workspace = reinterpret_cast(workspace->data.dptr); + if (full_tile) { - cudaFuncSetAttribute(cast_transpose_dbias_dgelu_kernel, - cudaFuncAttributePreferredSharedMemoryCarveout, - 100); - cast_transpose_dbias_dgelu_kernel - <<>>(param, row_length, num_rows, n_tiles); + cudaFuncSetAttribute( + cast_transpose_dbias_dgelu_kernel, + cudaFuncAttributePreferredSharedMemoryCarveout, + 100); + cast_transpose_dbias_dgelu_kernel + <<>>(param, row_length, num_rows, n_tiles); } else { - cudaFuncSetAttribute(cast_transpose_dbias_dgelu_kernel_notaligned, + cudaFuncSetAttribute(cast_transpose_dbias_dgelu_kernel_notaligned< + ComputeType, Empty, + nvec_in, nvec_out, Param, OP>, cudaFuncAttributePreferredSharedMemoryCarveout, 100); - cast_transpose_dbias_dgelu_kernel_notaligned - <<>>(param, row_length, num_rows, n_tiles); + cast_transpose_dbias_dgelu_kernel_notaligned< + ComputeType, Empty, + nvec_in, nvec_out, Param, OP> + <<>>(param, row_length, num_rows, n_tiles); } reduce_dbias(*workspace, dbias, row_length, num_rows, nvec_out, stream); @@ -1448,6 +1472,9 @@ void cast_transpose_dbias_dgelu(const Tensor &input, ); // NOLINT(*) } +template void dgeglu_cast_transpose(const Tensor &input, const Tensor &geglu_input, Tensor *cast_output, @@ -1505,11 +1532,14 @@ void dgeglu_cast_transpose(const Tensor &input, const bool full_tile = row_length % (nvec_in * THREADS_PER_WARP) == 0 && num_rows % (nvec_out * THREADS_PER_WARP) == 0; if (full_tile) { - cudaFuncSetAttribute(dgeglu_cast_transpose_kernel, + cudaFuncSetAttribute(dgeglu_cast_transpose_kernel< + nvec_in, nvec_out, + ComputeType, InputType, OutputType, + Empty, OP1, OP2>, cudaFuncAttributePreferredSharedMemoryCarveout, 100); - dgeglu_cast_transpose_kernel + dgeglu_cast_transpose_kernel< nvec_in, nvec_out, + ComputeType, InputType, OutputType, Empty, OP1, OP2> <<(cast_output->scale_inv.dptr), row_length, num_rows, n_tiles); } else { - cudaFuncSetAttribute(dgeglu_cast_transpose_kernel_notaligned, + cudaFuncSetAttribute(dgeglu_cast_transpose_kernel_notaligned< + nvec_in, nvec_out, + ComputeType, InputType, OutputType, + Empty, OP1, OP2>, cudaFuncAttributePreferredSharedMemoryCarveout, 100); - dgeglu_cast_transpose_kernel_notaligned + dgeglu_cast_transpose_kernel_notaligned <<(input), + cast_transpose_dbias_dgelu>( + *reinterpret_cast(input), *reinterpret_cast(gelu_input), reinterpret_cast(cast_output), reinterpret_cast(transposed_output), @@ -1590,9 +1624,44 @@ void nvte_dgeglu_cast_transpose(const NVTETensor input, cudaStream_t stream) { NVTE_API_CALL(nvte_dgeglu_cast_transpose); using namespace transformer_engine; - dgeglu_cast_transpose(*reinterpret_cast(input), + dgeglu_cast_transpose, gelu>( + *reinterpret_cast(input), *reinterpret_cast(geglu_input), reinterpret_cast(cast_output), reinterpret_cast(transposed_output), stream); } + +void nvte_cast_transpose_dbias_dswish(const NVTETensor input, + const NVTETensor swish_input, + NVTETensor cast_output, + NVTETensor transposed_output, + NVTETensor dbias, + NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_cast_transpose_dbias_dswish); + using namespace transformer_engine; + cast_transpose_dbias_dgelu>( + *reinterpret_cast(input), + *reinterpret_cast(swish_input), + reinterpret_cast(cast_output), + reinterpret_cast(transposed_output), + reinterpret_cast(dbias), + reinterpret_cast(workspace), + stream); +} + +void nvte_dswiglu_cast_transpose(const NVTETensor input, + const NVTETensor swiglu_input, + NVTETensor cast_output, + NVTETensor transposed_output, + cudaStream_t stream) { + NVTE_API_CALL(nvte_dswiglu_cast_transpose); + using namespace transformer_engine; + dgeglu_cast_transpose, swish>( + *reinterpret_cast(input), + *reinterpret_cast(swiglu_input), + reinterpret_cast(cast_output), + reinterpret_cast(transposed_output), + stream); +} diff --git a/transformer_engine/jax/cpp_extensions.py b/transformer_engine/jax/cpp_extensions.py index adcd5770e2..87c5e5fe29 100644 --- a/transformer_engine/jax/cpp_extensions.py +++ b/transformer_engine/jax/cpp_extensions.py @@ -4135,7 +4135,7 @@ def abstract(dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtyp updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype) - wkspace_info, = transformer_engine_jax.get_dgelu_dbias_ct_workspace_sizes( + wkspace_info, = transformer_engine_jax.get_dact_dbias_ct_workspace_sizes( x_aval.size // gi_hidden_size, gi_hidden_size, jax_dtype_to_te_dtype(x_aval.dtype), @@ -4881,3 +4881,1148 @@ def dgated_gelu_cast_transpose( scale_inv, out_dtype=out_dtype, static_axis_boundary=static_axis_boundary) + +# Primitives for SwiGLU and SiLU +class SiluPrimitive(BasePrimitive): + """ + Silu Froward Primitive + """ + name = "te_silu" + multiple_results = False + inner_primitive = None + outer_primitive = None + impl_static_args = () + + @staticmethod + def abstract(x_aval): + """ + gated_silu abstract + """ + dtype = dtypes.canonicalize_dtype(x_aval.dtype) + assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + + out_aval = core.raise_to_shaped(x_aval) + return out_aval + + @staticmethod + def lowering(ctx, x): + """ + gated_silu lowering rules + """ + (x_aval,) = ctx.avals_in + assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + ir_x_type = ir.RankedTensorType(x.type) + ir_x_shape = ir_x_type.shape + out_shape = ir_x_shape + + out_types = [ + ir.RankedTensorType.get(out_shape, ir_x_type.element_type), + ] + operands = [x] + operand_shapes = [ir_x_shape] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + + hidden_size = ir_x_shape[-1] + batch_size = reduce(operator.mul, ir_x_shape[:-1]) + in_dtype = jax_dtype_to_te_dtype(x_aval.dtype) + opaque = transformer_engine_jax.pack_common_descriptor((batch_size, hidden_size), in_dtype, + in_dtype) + + out = custom_caller(SiluPrimitive.name, args, opaque, False) + + return [out] + + @staticmethod + def impl(x): + assert SiluPrimitive.inner_primitive is not None + out = SiluPrimitive.inner_primitive.bind(x) + return out + + @staticmethod + def batcher(batched_args, batch_dims): + """ + gated_silu batcher + """ + _check_valid_batch_dims(batch_dims) + assert SiluPrimitive.outer_primitive is not None + inputs, = batched_args + inputs_bdim, = batch_dims + + out_bdims = inputs_bdim + return SiluPrimitive.outer_primitive.bind(inputs), out_bdims + + @staticmethod + def infer_sharding_from_operands(mesh, arg_infos, result_infos): + """ + gated_silu infer_sharding_from_operands + """ + del result_infos # Unused. + x_spec = get_padded_spec(arg_infos[0]) + out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) + return out_sharding + + @staticmethod + def partition(mesh, arg_infos, result_infos): + """ + gated_silu partitioning + """ + del result_infos + x_spec = get_padded_spec(arg_infos[0]) + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) + out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) + impl = SiluPrimitive.impl + return mesh, impl, out_sharding, arg_shardings + + +register_primitive(SiluPrimitive) + + +def silu(inputs: jnp.ndarray) -> jnp.ndarray: + """ + silu wrapper + Return geglu(inputs) + Assume inputs has two dimensions shape and the memory layout is (N..., H) + """ + return SiluPrimitive.outer_primitive.bind(inputs) + + +class DSiluPrimitive(BasePrimitive): + """ + Dgated Silu Primitive + """ + name = "te_dsilu" + multiple_results = False + inner_primitive = None + outer_primitive = None + impl_static_args = () + + @staticmethod + def abstract(dz_aval, x_aval): + """ + dsilu abstract + """ + dtype = dtypes.canonicalize_dtype(dz_aval.dtype) + assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + assert x_aval.dtype == dtype + assert dz_aval.shape == x_aval.shape + + out_aval = core.raise_to_shaped(x_aval) + return out_aval + + @staticmethod + def lowering(ctx, dz, x): + """ + dsilu lowering rules + """ + in_aval, gi_aval = ctx.avals_in + assert in_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + assert gi_aval.dtype == in_aval.dtype + ir_in_type = ir.RankedTensorType(dz.type) + ir_in_shape = ir_in_type.shape + gi_type = ir.RankedTensorType(x.type) + gi_shape = gi_type.shape + assert ir_in_shape == gi_shape + + ir_batch_size = reduce(operator.mul, ir_in_shape[:-1]) + i_hidden_size = ir_in_shape[-1] + out_dtype = ir_in_type.element_type + out_shape = gi_shape + + out_types = [ + ir.RankedTensorType.get(out_shape, out_dtype), + ] + operands = [dz, x] + operand_shapes = [ir_in_shape, gi_shape] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + + in_dtype = jax_dtype_to_te_dtype(in_aval.dtype) + opaque = transformer_engine_jax.pack_common_descriptor((ir_batch_size, i_hidden_size), + in_dtype, in_dtype) + + out = custom_caller(DSiluPrimitive.name, args, opaque, False) + + return [out] + + @staticmethod + def impl(dz, x): + """ + dsilu implementation + """ + assert DSiluPrimitive.inner_primitive is not None + dx = DSiluPrimitive.inner_primitive.bind(dz, x) + return dx + + @staticmethod + def batcher(batched_args, batch_dims): + """ + dsilu batcher + """ + _check_valid_batch_dims(batch_dims) + assert DSiluPrimitive.outer_primitive is not None + dz, x = batched_args + _, x_bdim = batch_dims + + out_bdims = x_bdim + return DSiluPrimitive.outer_primitive.bind(dz, x), out_bdims + + @staticmethod + def infer_sharding_from_operands(mesh, arg_infos, result_infos): + """ + dsilu infer_sharding_from_operands + """ + del result_infos # Unused. + silu_out_spec = get_padded_spec(arg_infos[1]) + dx_sharding = NamedSharding(mesh, PartitionSpec(*silu_out_spec)) + return dx_sharding + + @staticmethod + def partition(mesh, arg_infos, result_infos): + """ + dsilu partition + """ + del result_infos + dx_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) + out_shardings = dx_sharding + impl = DSiluPrimitive.impl + return mesh, impl, out_shardings, arg_shardings + + +register_primitive(DSiluPrimitive) + + +def dsilu(inputs: jnp.ndarray, silu_inputs: jnp.ndarray) -> jnp.ndarray: + """ + dsilu fusion wrapper + Return dgeglu(inputs) + """ + return DSiluPrimitive.outer_primitive.bind(inputs, silu_inputs) + + +class GatedSiluPrimitive(BasePrimitive): + """ + Gated Silu Froward Primitive + """ + name = "te_gated_silu" + multiple_results = False + inner_primitive = None + outer_primitive = None + impl_static_args = () + + @staticmethod + def abstract(x_aval): + """ + gated_silu abstract + """ + dtype = dtypes.canonicalize_dtype(x_aval.dtype) + assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + x_shape = x_aval.shape + assert x_shape[-2] == 2 # Assume x in (....., 2, hidden) + hidden_size = x_shape[-1] + batch_shapes = x_shape[:-2] + x_shape = x_aval.shape + out_aval = core.raise_to_shaped(x_aval) + out_shape = (batch_shapes) + (hidden_size,) + out_aval = out_aval.update(shape=out_shape, dtype=dtype) + + return out_aval + + @staticmethod + def lowering(ctx, x): + """ + gated_silu lowering rules + """ + (x_aval,) = ctx.avals_in + assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + ir_x_type = ir.RankedTensorType(x.type) + ir_x_shape = ir_x_type.shape + out_shape = ir_x_shape[:-2] + [ir_x_shape[-1]] + + out_types = [ + ir.RankedTensorType.get(out_shape, ir_x_type.element_type), + ] + operands = [x] + operand_shapes = [ir_x_shape] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + + hidden_size = ir_x_shape[-1] + batch_size = reduce(operator.mul, ir_x_shape[:-2]) + in_dtype = jax_dtype_to_te_dtype(x_aval.dtype) + opaque = transformer_engine_jax.pack_common_descriptor((batch_size, hidden_size), in_dtype, + in_dtype) + + out = custom_caller(GatedSiluPrimitive.name, args, opaque, False) + + return [out] + + @staticmethod + def impl(x): + assert GatedSiluPrimitive.inner_primitive is not None + out = GatedSiluPrimitive.inner_primitive.bind(x) + return out + + @staticmethod + def batcher(batched_args, batch_dims): + """ + gated_silu batcher + """ + _check_valid_batch_dims(batch_dims) + assert GatedSiluPrimitive.outer_primitive is not None + inputs, = batched_args + inputs_bdim, = batch_dims + + out_bdims = inputs_bdim + return GatedSiluPrimitive.outer_primitive.bind(inputs), out_bdims + + @staticmethod + def infer_sharding_from_operands(mesh, arg_infos, result_infos): + """ + gated_silu infer_sharding_from_operands + """ + del result_infos # Unused. + x_spec = get_padded_spec(arg_infos[0]) + out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1])) + return out_sharding + + @staticmethod + def partition(mesh, arg_infos, result_infos): + """ + gated_silu partitioning + """ + del result_infos + x_spec = get_padded_spec(arg_infos[0]) + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) + out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1])) + impl = GatedSiluPrimitive.impl + return mesh, impl, out_sharding, arg_shardings + + +register_primitive(GatedSiluPrimitive) + + +def gated_silu(inputs: jnp.ndarray) -> jnp.ndarray: + """ + gated silu wrapper + Return FP8(geglu(inputs)) + Assume inputs has two dimensions shape and the memory layout is (N, 2, H) + """ + return GatedSiluPrimitive.outer_primitive.bind(inputs) + + +class DgatedSiluPrimitive(BasePrimitive): + """ + Dgated Silu Primitive + """ + name = "te_dgated_silu" + multiple_results = False + inner_primitive = None + outer_primitive = None + impl_static_args = () + + @staticmethod + def abstract(dz_aval, x_aval): + """ + dgated_silu abstract + """ + dtype = dtypes.canonicalize_dtype(dz_aval.dtype) + assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + assert x_aval.dtype == dtype + for axis in range(len(dz_aval.shape) - 1): + assert dz_aval.shape[axis] == x_aval.shape[axis] + + assert x_aval.shape[-2] == 2 # Assume x in (....., 2, hidden) + + i_hidden_size = dz_aval.shape[-1] + g_hidden_size = x_aval.shape[-1] + assert i_hidden_size == g_hidden_size + out_aval = core.raise_to_shaped(x_aval) + return out_aval + + @staticmethod + def lowering(ctx, dz, x): + """ + dgated_silu lowering rules + """ + in_aval, gi_aval = ctx.avals_in + assert in_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + assert gi_aval.dtype == in_aval.dtype + ir_in_type = ir.RankedTensorType(dz.type) + ir_in_shape = ir_in_type.shape + gi_type = ir.RankedTensorType(x.type) + gi_shape = gi_type.shape + for axis in range(len(ir_in_shape) - 1): + assert ir_in_shape[axis] == gi_shape[axis] + + ir_batch_size = reduce(operator.mul, ir_in_shape[:-1]) + i_hidden_size = ir_in_shape[-1] + g_hidden_size = gi_shape[-1] + assert i_hidden_size == g_hidden_size + out_dtype = ir_in_type.element_type + out_shape = gi_shape + + out_types = [ + ir.RankedTensorType.get(out_shape, out_dtype), + ] + operands = [dz, x] + operand_shapes = [ir_in_shape, gi_shape] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + + in_dtype = jax_dtype_to_te_dtype(in_aval.dtype) + opaque = transformer_engine_jax.pack_common_descriptor((ir_batch_size, i_hidden_size), + in_dtype, in_dtype) + + out = custom_caller(DgatedSiluPrimitive.name, args, opaque, False) + + return [out] + + @staticmethod + def impl(dz, x): + """ + dgated_silu implementation + """ + assert DgatedSiluPrimitive.inner_primitive is not None + dx = DgatedSiluPrimitive.inner_primitive.bind(dz, x) + return dx + + @staticmethod + def batcher(batched_args, batch_dims): + """ + dgated_silu batcher + """ + _check_valid_batch_dims(batch_dims) + assert DgatedSiluPrimitive.outer_primitive is not None + dz, x = batched_args + _, x_bdim = batch_dims + + out_bdims = x_bdim + return DgatedSiluPrimitive.outer_primitive.bind(dz, x), out_bdims + + @staticmethod + def infer_sharding_from_operands(mesh, arg_infos, result_infos): + """ + dgated_silu infer_sharding_from_operands + """ + del result_infos # Unused. + silu_out_spec = get_padded_spec(arg_infos[1]) + dx_sharding = NamedSharding(mesh, PartitionSpec(*silu_out_spec)) + return dx_sharding + + @staticmethod + def partition(mesh, arg_infos, result_infos): + """ + dgated_silu partition + """ + del result_infos + dx_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) + out_shardings = dx_sharding + impl = DgatedSiluPrimitive.impl + return mesh, impl, out_shardings, arg_shardings + + +register_primitive(DgatedSiluPrimitive) + + +def dgated_silu(inputs: jnp.ndarray, silu_inputs: jnp.ndarray) -> jnp.ndarray: + """ + dgated_silu fusion wrapper + Return dgeglu(inputs) + """ + return DgatedSiluPrimitive.outer_primitive.bind(inputs, silu_inputs) + + +class SiluFp8Primitive(BasePrimitive): + """ + Silu FP8 Primitive + """ + name = "te_silu_fp8" + multiple_results = True + impl_static_args = (4,) #out_dtype + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract(x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype): + """ + te_silu_p abstract + """ + dtype = dtypes.canonicalize_dtype(x_aval.dtype) + # Currently only support casting to E4M3 only in C side. + assert out_dtype == jnp.float8_e4m3fn + assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + assert amax_aval.dtype == jnp.float32 + assert scale_aval.dtype == jnp.float32 + assert scale_inv_aval.dtype == jnp.float32 + + out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype) + updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype) + + return out_aval, updated_amax_aval + + @staticmethod + def lowering(ctx, x, amax, scale, scale_inv, *, out_dtype): + """ + te_gated_silu_p lowering rules + """ + x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in + assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + assert amax_aval.dtype == jnp.float32 + assert scale_aval.dtype == jnp.float32 + assert scale_inv_aval.dtype == jnp.float32 + ir_x_type = ir.RankedTensorType(x.type) + ir_x_shape = ir_x_type.shape + ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype) + ir_amax_type = ir.RankedTensorType(amax.type) + ir_amax_dtype = ir_amax_type.element_type + ir_amax_shape = ir_amax_type.shape + ir_scale_shape = ir_amax_shape + ir_scale_inv_shape = ir_amax_shape + + hidden_size = ir_x_shape[-1] + batch_size = reduce(operator.mul, ir_x_shape[:-1]) + out_shape = ir_x_shape + out_types = [ + ir.RankedTensorType.get(out_shape, ir_out_dtype), + ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), + ] + operands = [x, amax, scale, scale_inv] + operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + + opaque = transformer_engine_jax.pack_common_descriptor((batch_size, hidden_size), + jax_dtype_to_te_dtype(x_aval.dtype), + jax_dtype_to_te_dtype(out_dtype)) + + out = custom_caller(SiluFp8Primitive.name, + args, + opaque, + False, + operand_output_aliases={1: 1}) + + return out + + @staticmethod + def impl(x, amax, scale, scale_inv, out_dtype): + """ + to describe implementation + """ + assert SiluFp8Primitive.inner_primitive is not None + out, updated_amax = SiluFp8Primitive.inner_primitive.bind(x, + amax, + scale, + scale_inv, + out_dtype=out_dtype) + return out, updated_amax + + @staticmethod + def batcher(batched_args, batch_dims, *, out_dtype): + """ + to describe batch rules for vmap + """ + _check_valid_batch_dims(batch_dims) + assert SiluFp8Primitive.outer_primitive is not None + x, amax, scale, scale_inv = batched_args + x_bdim, amax_bdim, _, _ = batch_dims + + out_bdims = x_bdim, amax_bdim + return SiluFp8Primitive.outer_primitive.bind(x, amax, scale, scale_inv, + out_dtype=out_dtype), out_bdims + + @staticmethod + def infer_sharding_from_operands(out_dtype, mesh, arg_infos, result_infos): + del out_dtype, result_infos + x_spec = get_padded_spec(arg_infos[0]) + out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) + amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) + return (out_sharding, amax_sharding) + + @staticmethod + def partition(out_dtype, mesh, arg_infos, result_infos): + del result_infos + x_spec = get_padded_spec(arg_infos[0]) + out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) + amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) + out_shardings = (out_sharding, amax_sharding) + + def sharded_impl(x, amax, scale, scale_inv): + local_x, local_amax = SiluFp8Primitive.impl(x, + amax, + scale, + scale_inv, + out_dtype=out_dtype) + global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax) + + return local_x, global_updated_amax + + return mesh, sharded_impl, out_shardings, arg_shardings + + +register_primitive(SiluFp8Primitive) + + +def silu_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray, + out_dtype: jnp.dtype) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """ + gated silu wrapper + Return FP8(geglu(x)) + """ + return SiluFp8Primitive.outer_primitive.bind(x, amax, scale, scale_inv, out_dtype=out_dtype) + + +class DSiluDBiasCastTransposePrimitive(BasePrimitive): + """ + DSilu DBias Cast Transpose Primitive + """ + name = "te_dsilu_dbias_cast_transpose" + multiple_results = True + # out_dtype, static_axis_boundary, transpose_axis_boundary + impl_static_args = (5, 6, 7) + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract(dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype, + static_axis_boundary, transpose_axis_boundary): + """ + te_dsilu_dbais_cast_transpose_p abstract + """ + dtype = dtypes.canonicalize_dtype(dz_aval.dtype) + assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + assert x_aval.dtype == dtype + assert amax_aval.dtype == jnp.float32 + assert scale_aval.dtype == jnp.float32 + assert scale_inv_aval.dtype == jnp.float32 + ir_hidden_szie = dz_aval.shape[-1] + gi_hidden_size = x_aval.shape[-1] + assert ir_hidden_szie == gi_hidden_size + t_shape = _multidim_transpose(x_aval.shape, static_axis_boundary, transpose_axis_boundary) + out = dz_aval.update(shape=x_aval.shape, dtype=out_dtype) + t_out = dz_aval.update(shape=t_shape, dtype=out_dtype) + + dbias_shape = (*x_aval.shape[:static_axis_boundary + 1], gi_hidden_size) + dbias = dz_aval.update(shape=dbias_shape, dtype=dtype) + + updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype) + + wkspace_info, = transformer_engine_jax.get_dact_dbias_ct_workspace_sizes( + x_aval.size // gi_hidden_size, + gi_hidden_size, + jax_dtype_to_te_dtype(x_aval.dtype), + jax_dtype_to_te_dtype(out_dtype), + ) + wkspace_aval = x_aval.update(shape=wkspace_info[0], + dtype=te_dtype_to_jax_dtype(wkspace_info[1])) + + return out, t_out, dbias, updated_amax_aval, wkspace_aval + + @staticmethod + def outer_abstract(*args, **kwargs): + """ + te_dsilu_dbais_cast_transpose_p outer abstract + """ + + out, t_out, dbias, updated_amax_aval, _ = \ + DSiluDBiasCastTransposePrimitive.abstract(*args, **kwargs) + return out, t_out, dbias, updated_amax_aval + + @staticmethod + def lowering(ctx, dz, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary, + transpose_axis_boundary): + """ + te_dgated_silu_cast_transpose_p lowering rules + """ + dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in + assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + assert x_aval.dtype == dz_aval.dtype + assert amax_aval.dtype == jnp.float32 + assert scale_aval.dtype == jnp.float32 + assert scale_inv_aval.dtype == jnp.float32 + ir_dz_type = ir.RankedTensorType(dz.type) + ir_dz_shape = ir_dz_type.shape + x_type = ir.RankedTensorType(x.type) + x_shape = x_type.shape + assert ir_dz_shape == x_shape + + batch_szie = reduce(operator.mul, ir_dz_shape[:-1]) + ir_hidden_szie = ir_dz_shape[-1] + contracted_x_shape = (batch_szie, ir_hidden_szie) + + ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype) + ir_amax_type = ir.RankedTensorType(amax.type) + ir_amax_dtype = ir_amax_type.element_type + ir_amax_shape = ir_amax_type.shape + ir_scale_shape = ir_amax_shape + ir_scale_inv_shape = ir_amax_shape + transposed_x_shape = _multidim_transpose(x_shape, static_axis_boundary, + transpose_axis_boundary) + dbias_shape = (*x_shape[:static_axis_boundary + 1], ir_hidden_szie) + + wkspace_aval = ctx.avals_out[-1] + + out_types = [ + ir.RankedTensorType.get(x_shape, ir_out_dtype), + ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype), + ir.RankedTensorType.get(dbias_shape, ir_dz_type.element_type), + ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), + ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)), + ] + operands = [dz, x, amax, scale, scale_inv] + operand_shapes = [ir_dz_shape, x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + opaque = transformer_engine_jax.pack_common_wk_descriptor( + contracted_x_shape, wkspace_aval.shape, jax_dtype_to_te_dtype(dz_aval.dtype), + jax_dtype_to_te_dtype(out_dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype)) + + out = custom_caller(DSiluDBiasCastTransposePrimitive.name, + args, + opaque, + False, + operand_output_aliases={2: 3}) + + return out + + @staticmethod + def impl(dz, x, amax, scale, scale_inv, out_dtype, static_axis_boundary, + transpose_axis_boundary): + """ + to describe implementation + """ + assert DSiluDBiasCastTransposePrimitive.inner_primitive is not None + out, t_out, dbias, updated_amax, _ = DSiluDBiasCastTransposePrimitive.inner_primitive.bind( + dz, + x, + amax, + scale, + scale_inv, + out_dtype=out_dtype, + static_axis_boundary=static_axis_boundary, + transpose_axis_boundary=transpose_axis_boundary) + return out, t_out, dbias, updated_amax + + @staticmethod + def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary, + transpose_axis_boundary): + """ + to describe batch rules for vmap + """ + del static_axis_boundary + _check_valid_batch_dims(batch_dims) + assert DSiluDBiasCastTransposePrimitive.outer_primitive is not None + dz, x, amax, scale, scale_inv = batched_args + x_bdim, _, amax_bdim, _, _ = batch_dims + + # Minus batch dim. + transpose_axis_boundary = _normalize_axis_boundary(transpose_axis_boundary, x.ndim - 1) + transpose_axis_boundary += 1 # Plus batch dim + + out_bdims = x_bdim, x_bdim, x_bdim, amax_bdim + return DSiluDBiasCastTransposePrimitive.outer_primitive.bind( + dz, + x, + amax, + scale, + scale_inv, + out_dtype=out_dtype, + static_axis_boundary=x_bdim, + transpose_axis_boundary=transpose_axis_boundary), out_bdims + + @staticmethod + def infer_sharding_from_operands(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, + arg_infos, result_infos): + del out_dtype, result_infos + x_spec = get_padded_spec(arg_infos[1]) + out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) + xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary) + tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec)) + dbias_shaprding = NamedSharding( + mesh, PartitionSpec(*x_spec[:static_axis_boundary + 1], x_spec[-1])) + amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2]))) + return (out_sharding, tranposed_out_sharding, dbias_shaprding, amax_sharding) + + @staticmethod + def partition(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, + result_infos): + del result_infos + x_spec = get_padded_spec(arg_infos[1]) + casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) + xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary) + casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec)) + + dbias_shaprding = NamedSharding( + mesh, PartitionSpec(*x_spec[:static_axis_boundary + 1], x_spec[-1])) + + amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2]))) + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) + out_shardings = (casted_x_sharding, casted_transposed_x_sharding, dbias_shaprding, + amax_sharding) + + def sharded_impl(dz, x, amax, scale, scale_inv): + local_out, local_t_out, local_dbias, local_amax = DSiluDBiasCastTransposePrimitive.impl( + dz, + x, + amax, + scale, + scale_inv, + out_dtype=out_dtype, + static_axis_boundary=static_axis_boundary, + transpose_axis_boundary=transpose_axis_boundary) + global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias) + global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax) + return local_out, local_t_out, global_dbias, global_updated_amax + + return mesh, sharded_impl, out_shardings, arg_shardings + + +register_primitive(DSiluDBiasCastTransposePrimitive) + + +def dsilu_dbias_cast_transpose( + dz: jnp.ndarray, + x: jnp.ndarray, + amax: jnp.ndarray, + scale: jnp.ndarray, + scale_inv: jnp.ndarray, + out_dtype: TEDType, + static_axis_boundary: int, + transpose_axis_boundary: int = -1) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """ + cast transpose dsilu and dbias fusion wrapper + Return FP8(dgeglu(inputs)), dbias + """ + if static_axis_boundary < 0: + static_axis_boundary = -1 # means no static axes + + return DSiluDBiasCastTransposePrimitive.outer_primitive.bind( + dz, + x, + amax, + scale, + scale_inv, + out_dtype=out_dtype, + static_axis_boundary=static_axis_boundary, + transpose_axis_boundary=transpose_axis_boundary) + + +class GatedSiluFp8Primitive(BasePrimitive): + """ + Gated Silu FP8 Primitive + """ + name = "te_gated_silu_fp8" + multiple_results = True + impl_static_args = (4,) #out_dtype + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract(x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype): + """ + te_gated_silu_p abstract + """ + dtype = dtypes.canonicalize_dtype(x_aval.dtype) + # Currently only support casting to E4M3 only in C side. + assert out_dtype == jnp.float8_e4m3fn + assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + assert amax_aval.dtype == jnp.float32 + assert scale_aval.dtype == jnp.float32 + assert scale_inv_aval.dtype == jnp.float32 + + assert x_aval.shape[-2] == 2 # Assume x in (....., 2, hidden) + hidden_size = x_aval.shape[-1] + batch_shape = x_aval.shape[:-2] + out_shape = (batch_shape) + (hidden_size,) + out_aval = x_aval.update(shape=out_shape, dtype=out_dtype) + updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype) + + return out_aval, updated_amax_aval + + @staticmethod + def lowering(ctx, x, amax, scale, scale_inv, *, out_dtype): + """ + te_gated_silu_p lowering rules + """ + x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in + assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + assert amax_aval.dtype == jnp.float32 + assert scale_aval.dtype == jnp.float32 + assert scale_inv_aval.dtype == jnp.float32 + ir_x_type = ir.RankedTensorType(x.type) + ir_x_shape = ir_x_type.shape + ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype) + ir_amax_type = ir.RankedTensorType(amax.type) + ir_amax_dtype = ir_amax_type.element_type + ir_amax_shape = ir_amax_type.shape + ir_scale_shape = ir_amax_shape + ir_scale_inv_shape = ir_amax_shape + + hidden_size = ir_x_shape[-1] + batch_shape = ir_x_shape[:-2] + batch_size = reduce(operator.mul, batch_shape) + out_shape = batch_shape + [hidden_size] + out_types = [ + ir.RankedTensorType.get(out_shape, ir_out_dtype), + ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), + ] + operands = [x, amax, scale, scale_inv] + operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + + opaque = transformer_engine_jax.pack_common_descriptor((batch_size, out_shape[-1]), + jax_dtype_to_te_dtype(x_aval.dtype), + jax_dtype_to_te_dtype(out_dtype)) + + out = custom_caller(GatedSiluFp8Primitive.name, + args, + opaque, + False, + operand_output_aliases={1: 1}) + + return out + + @staticmethod + def impl(x, amax, scale, scale_inv, out_dtype): + """ + to describe implementation + """ + assert GatedSiluFp8Primitive.inner_primitive is not None + out, updated_amax = GatedSiluFp8Primitive.inner_primitive.bind(x, + amax, + scale, + scale_inv, + out_dtype=out_dtype) + return out, updated_amax + + @staticmethod + def batcher(batched_args, batch_dims, *, out_dtype): + """ + to describe batch rules for vmap + """ + _check_valid_batch_dims(batch_dims) + assert GatedSiluFp8Primitive.outer_primitive is not None + x, amax, scale, scale_inv = batched_args + x_bdim, amax_bdim, _, _ = batch_dims + + out_bdims = x_bdim, amax_bdim + return GatedSiluFp8Primitive.outer_primitive.bind(x, + amax, + scale, + scale_inv, + out_dtype=out_dtype), out_bdims + + @staticmethod + def infer_sharding_from_operands(out_dtype, mesh, arg_infos, result_infos): + del out_dtype, result_infos + x_spec = get_padded_spec(arg_infos[0]) + out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1])) + amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) + return (out_sharding, amax_sharding) + + @staticmethod + def partition(out_dtype, mesh, arg_infos, result_infos): + del result_infos + x_spec = get_padded_spec(arg_infos[0]) + out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1])) + amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) + out_shardings = (out_sharding, amax_sharding) + + def sharded_impl(x, amax, scale, scale_inv): + local_x, local_amax = GatedSiluFp8Primitive.impl(x, + amax, + scale, + scale_inv, + out_dtype=out_dtype) + global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax) + + return local_x, global_updated_amax + + return mesh, sharded_impl, out_shardings, arg_shardings + + +register_primitive(GatedSiluFp8Primitive) + + +def gated_silu_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray, + out_dtype: jnp.dtype) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """ + gated silu wrapper + Return FP8(geglu(x)) + """ + return GatedSiluFp8Primitive.outer_primitive.bind(x, + amax, + scale, + scale_inv, + out_dtype=out_dtype) + + +class DgatedSiluCastTransposePrimitive(BasePrimitive): + """ + Dgated Silu Cast Transpose Primitive + """ + name = "te_dgated_silu_cast_transpose" + multiple_results = True + impl_static_args = (5, 6) # out_dtype, static_axis_boundary + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract(dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype, + static_axis_boundary): + """ + te_dgated_silu_cast_transpose_p abstract + """ + dtype = dtypes.canonicalize_dtype(dz_aval.dtype) + assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + assert x_aval.dtype == dtype + assert x_aval.shape[-2] == 2 # Linear + GeLU + assert amax_aval.dtype == jnp.float32 + assert scale_aval.dtype == jnp.float32 + assert scale_inv_aval.dtype == jnp.float32 + ir_hidden_szie = dz_aval.shape[-1] + gi_hidden_size = x_aval.shape[-1] + assert ir_hidden_szie == gi_hidden_size + t_shape = _multidim_transpose(x_aval.shape, static_axis_boundary, -2) + out = dz_aval.update(shape=x_aval.shape, dtype=out_dtype) + t_out = dz_aval.update(shape=t_shape, dtype=out_dtype) + updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype) + return out, t_out, updated_amax_aval + + @staticmethod + def lowering(ctx, dz, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary): + """ + te_dgated_silu_cast_transpose_p lowering rules + """ + dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in + assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + assert x_aval.dtype == dz_aval.dtype + assert amax_aval.dtype == jnp.float32 + assert scale_aval.dtype == jnp.float32 + assert scale_inv_aval.dtype == jnp.float32 + ir_dz_type = ir.RankedTensorType(dz.type) + ir_dz_shape = ir_dz_type.shape + x_type = ir.RankedTensorType(x.type) + x_shape = x_type.shape + dz_batch_szie = reduce(operator.mul, ir_dz_shape[:-1]) + x_batch_size = reduce(operator.mul, x_shape[:-2]) + assert dz_batch_szie == x_batch_size + assert x_shape[-2] == 2 # Linear + GeLU + ir_hidden_szie = ir_dz_shape[-1] + gi_hidden_size = x_shape[-1] + assert ir_hidden_szie == gi_hidden_size + ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype) + ir_amax_type = ir.RankedTensorType(amax.type) + ir_amax_dtype = ir_amax_type.element_type + ir_amax_shape = ir_amax_type.shape + ir_scale_shape = ir_amax_shape + ir_scale_inv_shape = ir_amax_shape + transposed_x_shape = _multidim_transpose(x_shape, static_axis_boundary, -2) + out_types = [ + ir.RankedTensorType.get(x_shape, ir_out_dtype), + ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype), + ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), + ] + operands = [dz, x, amax, scale, scale_inv] + operand_shapes = [ir_dz_shape, x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + contracted_x_shape = (x_batch_size, x_shape[-1]) + opaque = transformer_engine_jax.pack_common_descriptor(contracted_x_shape, + jax_dtype_to_te_dtype(dz_aval.dtype), + jax_dtype_to_te_dtype(out_dtype)) + + out = custom_caller(DgatedSiluCastTransposePrimitive.name, + args, + opaque, + False, + operand_output_aliases={2: 2}) + + return out + + @staticmethod + def impl(dz, x, amax, scale, scale_inv, out_dtype, static_axis_boundary): + """ + to describe implementation + """ + assert DgatedSiluCastTransposePrimitive.inner_primitive is not None + out, t_out, updated_amax = DgatedSiluCastTransposePrimitive.inner_primitive.bind( + dz, + x, + amax, + scale, + scale_inv, + out_dtype=out_dtype, + static_axis_boundary=static_axis_boundary) + return out, t_out, updated_amax + + @staticmethod + def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary): + """ + to describe batch rules for vmap + """ + del static_axis_boundary + _check_valid_batch_dims(batch_dims) + assert DgatedSiluCastTransposePrimitive.outer_primitive is not None + dz, x, amax, scale, scale_inv = batched_args + x_bdim, _, amax_bdim, _, _ = batch_dims + + out_bdims = x_bdim, x_bdim, amax_bdim + return DgatedSiluCastTransposePrimitive.outer_primitive.bind( + dz, x, amax, scale, scale_inv, out_dtype=out_dtype, + static_axis_boundary=x_bdim), out_bdims + + @staticmethod + def infer_sharding_from_operands(out_dtype, static_axis_boundary, mesh, arg_infos, + result_infos): + del out_dtype, result_infos + x_spec = get_padded_spec(arg_infos[1]) + out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) + xt_spec = _multidim_transpose(x_spec, static_axis_boundary, -2) + tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec)) + amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2]))) + return (out_sharding, tranposed_out_sharding, amax_sharding) + + @staticmethod + def partition(out_dtype, static_axis_boundary, mesh, arg_infos, result_infos): + del result_infos + x_spec = get_padded_spec(arg_infos[1]) + casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) + xt_spec = _multidim_transpose(x_spec, static_axis_boundary, -2) + casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec)) + + amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2]))) + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) + out_shardings = (casted_x_sharding, casted_transposed_x_sharding, amax_sharding) + + def sharded_impl(dz, x, amax, scale, scale_inv): + local_out, local_t_out, local_amax = DgatedSiluCastTransposePrimitive.impl( + dz, + x, + amax, + scale, + scale_inv, + out_dtype=out_dtype, + static_axis_boundary=static_axis_boundary) + global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax) + return local_out, local_t_out, global_updated_amax + + return mesh, sharded_impl, out_shardings, arg_shardings + + +register_primitive(DgatedSiluCastTransposePrimitive) + + +def dgated_silu_cast_transpose( + dz: jnp.ndarray, x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, + scale_inv: jnp.ndarray, out_dtype: TEDType, + static_axis_boundary: int) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """ + cast transpose d_gated_silu fusion wrapper + Return FP8(dgeglu(inputs)) + """ + return DgatedSiluCastTransposePrimitive.outer_primitive.bind( + dz, + x, + amax, + scale, + scale_inv, + out_dtype=out_dtype, + static_axis_boundary=static_axis_boundary) diff --git a/transformer_engine/jax/csrc/extensions.cpp b/transformer_engine/jax/csrc/extensions.cpp index 8aa6b492c8..7d3958879a 100644 --- a/transformer_engine/jax/csrc/extensions.cpp +++ b/transformer_engine/jax/csrc/extensions.cpp @@ -34,6 +34,16 @@ pybind11::dict Registrations() { dict["te_gated_gelu_fp8"] = EncapsulateFunction(GatedGeluFP8); dict["te_dgated_gelu"] = EncapsulateFunction(DGatedGelu); dict["te_dgated_gelu_cast_transpose"] = EncapsulateFunction(DGatedGeluCastTranspose); + // TODO + dict["te_silu"] = EncapsulateFunction(Silu); + dict["te_silu_fp8"] = EncapsulateFunction(SiluFP8); + dict["te_dsilu"] = EncapsulateFunction(DSilu); + dict["te_dsilu_dbias_cast_transpose"] = EncapsulateFunction(DSiluDBiasCastTranspose); + dict["te_gated_silu"] = EncapsulateFunction(GatedSilu); + dict["te_gated_silu_fp8"] = EncapsulateFunction(GatedSiluFP8); + dict["te_dgated_silu"] = EncapsulateFunction(DGatedSilu); + dict["te_dgated_silu_cast_transpose"] = EncapsulateFunction(DGatedSiluCastTranspose); + // dict["te_layernorm_forward"] = EncapsulateFunction(LayerNormForward); dict["te_layernorm_forward_fp8"] = EncapsulateFunction(LayerNormForwardFP8); dict["te_layernorm_backward"] = EncapsulateFunction(LayerNormBackward); @@ -66,7 +76,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("get_cuda_version", &GetCudaRuntimeVersion); m.def("get_device_compute_capability", &GetDeviceComputeCapability); m.def("get_cublasLt_version", &cublasLtGetVersion); - m.def("get_dgelu_dbias_ct_workspace_sizes", &GetDGeluDBiasCastTransposeWorkspaceSizes); + m.def("get_dact_dbias_ct_workspace_sizes", &GetDActDBiasCastTransposeWorkspaceSizes); m.def("get_dbias_ct_workspace_sizes", &GetDBiasCastTransposeWorkspaceSizes); m.def("get_layernorm_fwd_workspace_sizes", &GetLayerNormForwardWorkspaceSizes); m.def("get_layernorm_bwd_workspace_sizes", &GetLayerNormBackwardWorkspaceSizes); diff --git a/transformer_engine/jax/csrc/modules.cpp b/transformer_engine/jax/csrc/modules.cpp index 48b02bcaeb..78e9f60e3f 100644 --- a/transformer_engine/jax/csrc/modules.cpp +++ b/transformer_engine/jax/csrc/modules.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include "common/common.h" #include "common/util/logging.h" @@ -234,30 +235,6 @@ void DGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaqu nvte_dgelu(input_tensor.data(), gelu_input_tensor.data(), output_tensor.data(), stream); } -pybind11::tuple GetDGeluDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size, - DType in_dtype, DType out_dtype) { - auto input_shape = std::vector{batch_size, hidden_size}; - auto gelu_input_shape = std::vector{batch_size, hidden_size}; - auto output_shape = std::vector{batch_size, hidden_size}; - auto output_trans_shape = std::vector{hidden_size, batch_size}; - auto dbias_shape = std::vector{hidden_size}; - - auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype); - auto gelu_input_tensor = TensorWrapper(nullptr, gelu_input_shape, in_dtype); - auto output_tensor = TensorWrapper(nullptr, output_shape, out_dtype); - auto output_trans_tensor = TensorWrapper(nullptr, output_trans_shape, out_dtype); - auto dbias_tensor = TensorWrapper(nullptr, dbias_shape, in_dtype); - - TensorWrapper dummy_workspace; - - nvte_cast_transpose_dbias_dgelu(input_tensor.data(), gelu_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), - dbias_tensor.data(), dummy_workspace.data(), nullptr); - - auto work_shape = MakeShapeVector(dummy_workspace.shape()); - return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype())); -} - void DGeluDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { auto *input = buffers[0]; @@ -466,6 +443,241 @@ void DGatedGeluCastTranspose(cudaStream_t stream, void **buffers, const char *op output_trans_tensor.data(), stream); } +void SiluImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype, float *scale, + cudaStream_t stream, float *scale_inverse, float *amax, void *output) { + auto input_shape = std::vector{m, n}; + auto output_shape = std::vector{m, n}; + + auto input_tensor = TensorWrapper(input, input_shape, static_cast(in_dtype)); + + auto output_tensor = TensorWrapper(output, output_shape, static_cast(out_dtype), amax, + scale, scale_inverse); + + nvte_swish(input_tensor.data(), output_tensor.data(), stream); +} + +void Silu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { + auto *input = buffers[0]; + auto *output = buffers[1]; + + const auto &desc = *UnpackOpaque(opaque, opaque_len); + auto m = desc.shape.dims[0]; + auto n = desc.shape.dims[1]; + + SiluImpl(input, m, n, desc.in_dtype, desc.out_dtype, nullptr, stream, nullptr, nullptr, output); +} + +void SiluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { + auto *input = buffers[0]; + float *amax = reinterpret_cast(buffers[1]); + float *scale = reinterpret_cast(buffers[2]); + float *scale_inv = reinterpret_cast(buffers[3]); + auto *output = buffers[4]; + float *amax_out = reinterpret_cast(buffers[5]); + assert(amax == amax_out); + + const auto &desc = *UnpackOpaque(opaque, opaque_len); + if (!use_fp8(desc.out_dtype)) { + scale = nullptr; + scale_inv = nullptr; + amax_out = nullptr; + } + auto m = desc.shape.dims[0]; + auto n = desc.shape.dims[1]; + + SiluImpl(input, m, n, desc.in_dtype, desc.out_dtype, scale, stream, scale_inv, amax_out, + output); +} + +void DSilu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { + auto *input = buffers[0]; + auto *silu_input = buffers[1]; + auto *output = buffers[2]; + + const auto &desc = *UnpackOpaque(opaque, opaque_len); + auto m = desc.shape.dims[0]; + auto n = desc.shape.dims[1]; + auto input_shape = std::vector{m, n}; + auto silu_input_shape = std::vector{m, n}; + auto output_shape = std::vector{m, n}; + + auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype); + auto silu_input_tensor = TensorWrapper(silu_input, silu_input_shape, desc.in_dtype); + auto output_tensor = TensorWrapper(output, output_shape, desc.out_dtype); + + nvte_dswish(input_tensor.data(), silu_input_tensor.data(), output_tensor.data(), stream); +} + +pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size, + DType in_dtype, DType out_dtype) { + auto input_shape = std::vector{batch_size, hidden_size}; + auto dact_input_shape = std::vector{batch_size, hidden_size}; + auto output_shape = std::vector{batch_size, hidden_size}; + auto output_trans_shape = std::vector{hidden_size, batch_size}; + auto dbias_shape = std::vector{hidden_size}; + + auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype); + auto dact_input_tensor = TensorWrapper(nullptr, dact_input_shape, in_dtype); + auto output_tensor = TensorWrapper(nullptr, output_shape, out_dtype); + auto output_trans_tensor = TensorWrapper(nullptr, output_trans_shape, out_dtype); + auto dbias_tensor = TensorWrapper(nullptr, dbias_shape, in_dtype); + + TensorWrapper dummy_workspace; + + // For now, all dbias_dact(-s) have the same workspace size + nvte_cast_transpose_dbias_dgelu(input_tensor.data(), dact_input_tensor.data(), + output_tensor.data(), output_trans_tensor.data(), + dbias_tensor.data(), dummy_workspace.data(), nullptr); + + auto work_shape = MakeShapeVector(dummy_workspace.shape()); + return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype())); +} + +void DSiluDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, + size_t opaque_len) { + auto *input = buffers[0]; + auto *silu_input = buffers[1]; + float *amax = reinterpret_cast(buffers[2]); + float *scale = reinterpret_cast(buffers[3]); + float *scale_inv = reinterpret_cast(buffers[4]); + auto *output = buffers[5]; + auto *output_trans = buffers[6]; + auto *dbias = buffers[7]; + float *amax_out = reinterpret_cast(buffers[8]); + void *workspace_ptr = buffers[9]; + + const auto &desc = *UnpackOpaque(opaque, opaque_len); + assert(amax == amax_out); + if (!use_fp8(desc.out_dtype)) { + scale = nullptr; + scale_inv = nullptr; + amax_out = nullptr; + } + auto m = desc.shape.dims[0]; + auto n = desc.shape.dims[1]; + auto input_shape = std::vector{m, n}; + auto silu_input_shape = std::vector{m, n}; + auto output_shape = std::vector{m, n}; + auto output_trans_shape = std::vector{n, m}; + auto dbias_shape = std::vector{n}; + + auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype); + auto silu_input_tensor = TensorWrapper(silu_input, silu_input_shape, desc.in_dtype); + auto output_tensor = + TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv); + auto output_trans_tensor = + TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv); + auto dbias_tensor = TensorWrapper(dbias, dbias_shape, desc.in_dtype); + + auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype); + + nvte_cast_transpose_dbias_dswish(input_tensor.data(), silu_input_tensor.data(), + output_tensor.data(), output_trans_tensor.data(), + dbias_tensor.data(), workspace.data(), stream); +} + +void GatedSiluImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype, float *scale, + cudaStream_t stream, float *scale_inverse, float *amax, void *output) { + auto input_shape = std::vector{m, n * 2}; + auto output_shape = std::vector{m, n}; + + auto input_tensor = TensorWrapper(input, input_shape, static_cast(in_dtype)); + + auto output_tensor = TensorWrapper(output, output_shape, static_cast(out_dtype), amax, + scale, scale_inverse); + + nvte_swiglu(input_tensor.data(), output_tensor.data(), stream); +} + +void GatedSilu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { + auto *input = buffers[0]; + auto *output = buffers[1]; + + const auto &desc = *UnpackOpaque(opaque, opaque_len); + auto m = desc.shape.dims[0]; + auto n = desc.shape.dims[1]; + + GatedSiluImpl(input, m, n, desc.in_dtype, desc.out_dtype, nullptr, stream, nullptr, nullptr, + output); +} + +void GatedSiluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { + auto *input = buffers[0]; + float *amax = reinterpret_cast(buffers[1]); + float *scale = reinterpret_cast(buffers[2]); + float *scale_inv = reinterpret_cast(buffers[3]); + auto *output = buffers[4]; + float *amax_out = reinterpret_cast(buffers[5]); + assert(amax == amax_out); + + const auto &desc = *UnpackOpaque(opaque, opaque_len); + if (!use_fp8(desc.out_dtype)) { + scale = nullptr; + scale_inv = nullptr; + amax_out = nullptr; + } + auto m = desc.shape.dims[0]; + auto n = desc.shape.dims[1]; + + GatedSiluImpl(input, m, n, desc.in_dtype, desc.out_dtype, scale, stream, scale_inv, amax_out, + output); +} + +void DGatedSilu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { + auto *input = buffers[0]; + auto *silu_input = buffers[1]; + auto *output = buffers[2]; + + const auto &desc = *UnpackOpaque(opaque, opaque_len); + auto m = desc.shape.dims[0]; + auto n = desc.shape.dims[1]; + auto input_shape = std::vector{m, n}; + auto silu_input_shape = std::vector{m, n * 2}; + auto output_shape = std::vector{m, n * 2}; + + auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype); + auto silu_input_tensor = TensorWrapper(silu_input, silu_input_shape, desc.in_dtype); + auto output_tensor = TensorWrapper(output, output_shape, desc.out_dtype); + + nvte_dswiglu(input_tensor.data(), silu_input_tensor.data(), output_tensor.data(), stream); +} + +void DGatedSiluCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, + size_t opaque_len) { + auto *input = buffers[0]; + auto *silu_input = buffers[1]; + float *amax = reinterpret_cast(buffers[2]); + float *scale = reinterpret_cast(buffers[3]); + float *scale_inv = reinterpret_cast(buffers[4]); + auto *output = buffers[5]; + auto *output_trans = buffers[6]; + float *amax_out = reinterpret_cast(buffers[7]); + + const auto &desc = *UnpackOpaque(opaque, opaque_len); + assert(amax == amax_out); + if (!use_fp8(desc.out_dtype)) { + scale = nullptr; + scale_inv = nullptr; + amax_out = nullptr; + } + auto m = desc.shape.dims[0]; + auto n = desc.shape.dims[1]; + auto input_shape = desc.shape.to_vector(); + auto silu_input_shape = std::vector{m, n * 2}; + auto output_shape = std::vector{m, n * 2}; + auto output_trans_shape = std::vector{n * 2, m}; + + auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype); + auto silu_input_tensor = TensorWrapper(silu_input, silu_input_shape, desc.in_dtype); + auto output_tensor = + TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv); + auto output_trans_tensor = + TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv); + + nvte_dswiglu_cast_transpose(input_tensor.data(), silu_input_tensor.data(), output_tensor.data(), + output_trans_tensor.data(), stream); +} + pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, DType out_dtype, bool is_layer_norm, bool zero_centered_gamma, diff --git a/transformer_engine/jax/csrc/modules.h b/transformer_engine/jax/csrc/modules.h index 4285c8228e..ac14a54e90 100644 --- a/transformer_engine/jax/csrc/modules.h +++ b/transformer_engine/jax/csrc/modules.h @@ -140,13 +140,14 @@ void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t o void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); +// TODO (Phuong): Templating these 9x2 rountines before adding ReGLU, QuickGeLU, Squared ReLu void Gelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void GeluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void DGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); -pybind11::tuple GetDGeluDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size, +pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType out_dtype); void DGeluDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, @@ -167,6 +168,24 @@ void DGatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t void DGatedGeluCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); +void Silu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); + +void SiluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); + +void DSilu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); + +void DSiluDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, + size_t opaque_len); + +void GatedSilu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); + +void GatedSiluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); + +void DGatedSilu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); + +void DGatedSiluCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, + size_t opaque_len); + pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, DType out_dtype, bool is_layer_norm, bool zero_centered_gamma, diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 36008cf854..b95689f6b0 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -943,17 +943,18 @@ def __call__(self, inputs: Array, deterministic: bool = False) -> Array: fuse_layernorm = FP8Helper.is_fp8_enabled( ) and not self.return_layernorm_output and self.enable_layernorm - # Make sure each tuple is sorted in alphabet order - gated_act_pool = [('gelu', 'linear')] - #('linear', 'silu')] coming - act_pool = [('gelu',)] - #('silu',)] coming + gated_act_pool = [('gelu', 'linear'), + ('silu', 'linear')] + act_pool = [('gelu',), + ('silu',)] normalize_acts = [] for act in self.activations: if not isinstance(act, str): return False normalize_acts.append(act.lower()) - normalize_acts = tuple(sorted(normalize_acts)) + normalize_acts = tuple(reversed(normalize_acts) + if normalize_acts[0] == 'linear' else normalize_acts) + is_gated = normalize_acts in gated_act_pool is_act_implemented = normalize_acts in (gated_act_pool + act_pool) diff --git a/transformer_engine/jax/mlp.py b/transformer_engine/jax/mlp.py index 30f6d8456b..1900e3f441 100644 --- a/transformer_engine/jax/mlp.py +++ b/transformer_engine/jax/mlp.py @@ -15,9 +15,13 @@ from .cpp_extensions import gelu_fp8, dgelu, dgelu_dbias_cast_transpose from .cpp_extensions import gated_gelu, gated_gelu_fp8 from .cpp_extensions import dgated_gelu, dgated_gelu_cast_transpose +from .cpp_extensions import silu, silu_fp8 +from .cpp_extensions import dsilu, dsilu_dbias_cast_transpose +from .cpp_extensions import gated_silu, gated_silu_fp8 +from .cpp_extensions import dgated_silu, dgated_silu_cast_transpose from .cpp_extensions import rmsnorm_fwd_fp8, rmsnorm_bwd from .cpp_extensions import layernorm_fwd_fp8, layernorm_bwd -from .dot import fp8_dot_impl, get_precision_of_fp8_dot, quantize +from .dot import fp8_dot_impl, get_precision_of_fp8_dot, quantize, dequantize from .layernorm import canonicalize_layernorm_type from .fp8 import FP8Helper, FP8MetaPackage from .sharding import with_sharding_constraint_by_logical_axes @@ -27,14 +31,22 @@ ('gelu',): {'fwd': gelu, "bwd": dgelu}, ('gelu', 'linear'): {'fwd': gated_gelu, - 'bwd': dgated_gelu} + 'bwd': dgated_gelu}, + ('silu',): {'fwd': silu, + "bwd": dsilu }, + ('silu', 'linear'): {'fwd': gated_silu, + 'bwd': dgated_silu} } activation_fp8_dict = { ('gelu',): {'fwd': gelu_fp8, 'bwd': dgelu_dbias_cast_transpose}, ('gelu', 'linear'): {'fwd': gated_gelu_fp8, - 'bwd': dgated_gelu_cast_transpose} + 'bwd': dgated_gelu_cast_transpose}, + ('silu',): { 'fwd': silu_fp8, + 'bwd': dsilu_dbias_cast_transpose }, + ('silu', 'linear'): { 'fwd': gated_silu_fp8, + 'bwd': dgated_silu_cast_transpose } } @@ -47,7 +59,6 @@ def activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable] output = _activation_lu(x, activation_type) return output - @partial(jax.custom_vjp, nondiff_argnums=(1,)) def _activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]): @@ -55,12 +66,10 @@ def _activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable return _output - def _activation_lu_fwd_rule(x, activation_type): fwd_output = activation_dict[activation_type]["fwd"](x) return fwd_output, (x,) - def _activation_lu_bwd_rule(activation_type, ctx, g): x, = ctx assert x.dtype == g.dtype @@ -72,6 +81,67 @@ def _activation_lu_bwd_rule(activation_type, ctx, g): _activation_lu.defvjp(_activation_lu_fwd_rule, _activation_lu_bwd_rule) +def activation_lu_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, + scale_inv: jnp.ndarray, fwd_dtype:jnp.dtype, bwd_dtype: jnp.dtype, + activation_type: Sequence[Union[str, Callable]]): + """ + Activation Unit + """ + transpose_indices = (1, 2, 0) if len(activation_type) > 1 else (2, 0, 1) + dx_trans_no_use = jnp.empty([x.shape[i] for i in transpose_indices], dtype=x.dtype) + dbias_no_use = jnp.empty(x.shape[-1], dtype=x.dtype) + + output = _activation_lu_fp8(x, dx_trans_no_use, dbias_no_use, amax, + scale, scale_inv, fwd_dtype, bwd_dtype, activation_type) + return output + +@partial(jax.custom_vjp, nondiff_argnums=(6,7,8)) +def _activation_lu_fp8(x: jnp.ndarray, + dx_trans_no_use: jnp.ndarray, dbias_no_use: jnp.ndarray, + amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray, + fwd_dtype: jnp.dtype, bwd_dtype: jnp.dtype, + activation_type: Sequence[Union[str, Callable]]): + + output = _activation_lu_fp8_fwd_rule(x, dx_trans_no_use, dbias_no_use, amax, + scale, scale_inv, fwd_dtype, bwd_dtype, + activation_type) + + return output + +def _activation_lu_fp8_fwd_rule(x, + dx_trans_no_use, # pylint: disable=unused-argument + dbias_no_use, # pylint: disable=unused-argument + amax, + scale, scale_inv, + fwd_dtype, bwd_dtype, # pylint: disable=unused-argument + activation_type): + activation_lu_out, _ = activation_fp8_dict[activation_type ]["fwd"]( + x, amax, scale, scale_inv, fwd_dtype) + + activation_lu_out = dequantize(activation_lu_out, x.dtype, scale_inv) + ctx = (x, amax, scale, scale_inv) + return activation_lu_out, ctx + +def _activation_lu_fp8_bwd_rule(fwd_dtype, bwd_dtype, # pylint: disable=unused-argument + activation_type, ctx, g): + x, amax, scale, scale_inv = ctx + + activation_lu_fp8_bwd = activation_fp8_dict[activation_type]["bwd"] + if len(activation_type) > 1: #gated, no bias + dactivation_lu, dactivation_lu_trans, amax_out = \ + activation_lu_fp8_bwd(g, x, amax, scale, scale_inv, bwd_dtype, -1) + dbias = jnp.empty(x.shape[-1], x.dtype) + else: + dactivation_lu, dactivation_lu_trans, dbias, amax_out = \ + activation_lu_fp8_bwd(g, x, amax, scale, scale_inv, bwd_dtype, -1) + dactivation_lu = dequantize(dactivation_lu, x.dtype, scale_inv) + dactivation_lu_trans = dequantize(dactivation_lu_trans, x.dtype, scale_inv) + ctx = (dactivation_lu, dactivation_lu_trans, dbias, amax_out, scale, scale_inv) + return ctx + +_activation_lu_fp8.defvjp(_activation_lu_fp8_fwd_rule, _activation_lu_fp8_bwd_rule) + + def fused_layernorm_fp8_mlp(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, @@ -247,11 +317,11 @@ def _fused_layernorm_fp8_mlp_fwd_rule( activation_lu_out_scale = scale[gemm2_x_idx] activation_lu_out_scale_inv = scale_inv[gemm2_x_idx] - activation_lu_fp8 = activation_fp8_dict[activation_type]["fwd"] + activation_lu_fwd_fp8 = activation_fp8_dict[activation_type]["fwd"] # (batch..., hidden_in) -> (batch..., hidden) - casted_activation_lu_out, updated_activation_lu_amax = activation_lu_fp8(dot_1_output, - activation_lu_out_amax, activation_lu_out_scale, + casted_activation_lu_out, updated_activation_lu_amax = \ + activation_lu_fwd_fp8(dot_1_output, activation_lu_out_amax, activation_lu_out_scale, activation_lu_out_scale_inv, fwd_dtype) casted_activation_lu_out = with_sharding_constraint_by_logical_axes(casted_activation_lu_out, From f339c4282302d2ad40ad57b43063375a04d0e730 Mon Sep 17 00:00:00 2001 From: Xiaowei Ren <103958965+xrennvidia@users.noreply.github.com> Date: Thu, 25 Apr 2024 18:47:11 -0700 Subject: [PATCH 036/244] Add attention bias and qkv format to context parallelism (#726) * make FusedAttn with CP support bias Signed-off-by: Xiaowei Ren * assert Alibi cannot work with CP Signed-off-by: Xiaowei Ren * syntax fix Signed-off-by: Xiaowei Ren * fix variable name Signed-off-by: Xiaowei Ren * fix tensor shapes Signed-off-by: Xiaowei Ren * a typo fix Signed-off-by: Xiaowei Ren * fix bias indexing for CP Signed-off-by: Xiaowei Ren * bug fix Signed-off-by: Xiaowei Ren * add attn bias tests Signed-off-by: Xiaowei Ren * change dbias update location Signed-off-by: Xiaowei Ren * fix CP test model configs Signed-off-by: Xiaowei Ren * change CP test sequence length Signed-off-by: Xiaowei Ren * make AttnFuncWithCP support qkv format of sbhd Signed-off-by: Xiaowei Ren * make sure qkv are contiguous for CP in cuDNN fused attn Signed-off-by: Xiaowei Ren * change assert message Signed-off-by: Xiaowei Ren * fix code format Signed-off-by: Xiaowei Ren --------- Signed-off-by: Xiaowei Ren Co-authored-by: cyanguwa <8636796+cyanguwa@users.noreply.github.com> Signed-off-by: Pawel Gadzinski --- .../fused_attn/run_fused_attn_with_cp.py | 34 +- .../fused_attn/test_fused_attn_with_cp.py | 28 +- transformer_engine/pytorch/attention.py | 423 +++++++++++++----- 3 files changed, 367 insertions(+), 118 deletions(-) diff --git a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py index 87a0b2cd60..1af8391bce 100644 --- a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py @@ -6,7 +6,7 @@ import torch import torch.distributed as dist from transformer_engine.pytorch.attention import DotProductAttention -from test_fused_attn_with_cp import model_configs +from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fused_attn dtypes={'fp16' : torch.float16, 'bf16' : torch.bfloat16} @@ -17,8 +17,10 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend= os.environ["NVTE_FUSED_ATTN"] = "0" if kernel_backend == "FlashAttention": os.environ["NVTE_FLASH_ATTN"] = "1" + config = model_configs_flash_attn[model] if kernel_backend == "FusedAttention": os.environ["NVTE_FUSED_ATTN"] = "1" + config = model_configs_fused_attn[model] rank = int(os.getenv('RANK', '0')) world_size = int(os.getenv('WORLD_SIZE', '1')) @@ -40,8 +42,6 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend= assert(rank in cp_comm_ranks) cp_comm_group = dist.new_group(cp_comm_ranks, backend='nccl') - config = model_configs[model] - assert config.attn_mask_type in ['causal', 'no_mask'], f"{config.attn_mask_type} is an unsupported attention mask type!" # instantiate core attn module @@ -69,18 +69,30 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend= v = torch.randn(kv_input_shape, dtype=dtypes[dtype]).cuda() dout = torch.randn(attn_output_shape, dtype=dtypes[dtype]).cuda() + # create flash attention bias + if config.attn_bias_type not in ["no_bias", "alibi"]: + attn_bias_shape = (1, 1, config.max_seqlen_q, config.max_seqlen_kv) + bias = torch.randn(*attn_bias_shape, dtype=dtypes[dtype]).cuda() + else: + bias = None + # make sure all GPU ranks have same inputs - for x in [q, k, v, dout]: + for x in [q, k, v, dout] + ([] if bias is None else [bias]): dist.broadcast(x, 0, group=cp_comm_group) # run core_attn without CP for x in [q, k, v]: x.requires_grad = True - out = core_attn(q, k, v) + out = core_attn( + q, k, v, + core_attention_bias_type=config.attn_bias_type, + core_attention_bias=bias, + ) out.backward(dout) # run core_attn wit CP - q_, k_, v_, dout_ = [x.clone().detach() for x in [q, k, v, dout]] + q_, k_, v_, dout_, *rest = [x.clone().detach() for x in [q, k, v, dout] + ([] if bias is None else [bias])] + bias_ = rest[0] if len(rest) else None seq_dim = qkv_format.index('s') q_, k_, v_, dout_ = [x.view(*x.shape[:seq_dim], 2*world_size, x.shape[seq_dim]//(2*world_size), *x.shape[(seq_dim+1):]) \ for x in [q_, k_, v_, dout_]] @@ -88,8 +100,16 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend= q_, k_, v_, dout_ = [x.index_select(seq_dim, seq_idx) for x in [q_, k_, v_, dout_]] q_, k_, v_, dout_ = [x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim+2):]) for x in [q_, k_, v_, dout_]] q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]] + if bias_ is not None: + bias_ = bias_.view(*bias_.shape[:-2], 2*world_size, bias_.shape[-2]//(2*world_size), bias_.shape[-1]) + bias_ = bias_.index_select(2, seq_idx) + bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1]) core_attn.set_context_parallel_group(cp_comm_group, cp_comm_ranks, torch.cuda.Stream()) - out_ = core_attn(q_, k_, v_) + out_ = core_attn( + q_, k_, v_, + core_attention_bias_type=config.attn_bias_type, + core_attention_bias=bias_, + ) out_.backward(dout_) for x in [out_, q_.grad, k_.grad, v_.grad]: diff --git a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py index 1e16a5a295..ac571cd0e4 100644 --- a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py @@ -11,12 +11,12 @@ _cudnn_version, ) -model_configs = { - # test: b, h, hg, d, sq, skv, p, mask, bias - "cp_1_0": ModelConfig(1, 12, 12, 128, 16384, 16384, 0.0, "causal", "no_bias"), # MHA - "cp_1_1": ModelConfig(1, 12, 12, 128, 16384, 16384, 0.0, "no_mask", "no_bias"), # MHA - "cp_2_0": ModelConfig(1, 12, 1, 128, 16384, 16384, 0.0, "causal", "no_bias"), # GQA - "cp_2_1": ModelConfig(1, 12, 1, 128, 16384, 16384, 0.0, "no_mask", "no_bias"), # GQA +model_configs_flash_attn = { + # test: b, h, hg, d, sq, skv, p, mask, bias + "cp_1_0": ModelConfig(1, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA + "cp_1_1": ModelConfig(1, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA + "cp_2_0": ModelConfig(1, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA + "cp_2_1": ModelConfig(1, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA } def get_bash_arguments(**kwargs): @@ -30,7 +30,7 @@ def get_bash_arguments(**kwargs): @pytest.mark.skipif(not _is_flash_attention_2_available(), reason="Flash-attn 2.0+ is required.") @pytest.mark.parametrize("dtype", ['bf16', 'fp16']) -@pytest.mark.parametrize("model", model_configs.keys()) +@pytest.mark.parametrize("model", model_configs_flash_attn.keys()) @pytest.mark.parametrize("qkv_format", ['bshd', 'sbhd']) def test_cp_with_flash_attention(dtype, model, qkv_format): subprocess.run( @@ -43,9 +43,21 @@ def test_cp_with_flash_attention(dtype, model, qkv_format): check=True ) +model_configs_fused_attn = { + # test: b, h, hg, d, sq, skv, p, mask, bias + "cp_1_0": ModelConfig(1, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA + "cp_1_1": ModelConfig(1, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA + "cp_1_2": ModelConfig(1, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # MHA + "cp_1_3": ModelConfig(1, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # MHA + "cp_2_0": ModelConfig(1, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA + "cp_2_1": ModelConfig(1, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA + "cp_2_2": ModelConfig(1, 12, 1, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # GQA + "cp_2_3": ModelConfig(1, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # GQA +} + @pytest.mark.skipif(_cudnn_version() < (8,9,7), reason="cuDNN 8.9.7+ is required.") @pytest.mark.parametrize("dtype", ['bf16', 'fp16']) -@pytest.mark.parametrize("model", model_configs.keys()) +@pytest.mark.parametrize("model", model_configs_fused_attn.keys()) @pytest.mark.parametrize("qkv_format", ['bshd', 'sbhd']) def test_cp_with_fused_attention(dtype, model, qkv_format): subprocess.run( diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 4bb39b913f..c4f9bd5301 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -490,9 +490,10 @@ def flash_attn_p2p_communicate(rank, send_tensor, send_dst, @jit_fuser -def flash_attn_fwd_out_correction(out, out_per_step, softmax_lse, softmax_lse_per_step): +def flash_attn_fwd_out_correction(out, out_per_step, seq_dim, + softmax_lse, softmax_lse_per_step): """Merge partial outputs of each step in Attention with context parallelism""" - softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).transpose(1, 2) + softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).movedim(2, seq_dim) softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1) out_corrected = out_per_step*softmax_lse_corrected_exp out.add_(out_corrected) @@ -516,22 +517,44 @@ class AttnFuncWithCP(torch.autograd.Function): @staticmethod def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, - dropout_p, cp_group, cp_global_ranks, cp_stream, softmax_scale, attn_mask_type, - deterministic, use_fused_attention): + dropout_p, cp_group, cp_global_ranks, cp_stream, softmax_scale, qkv_format, + attn_mask_type, attn_bias_type, attn_bias, deterministic, use_fused_attention): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) cp_size = get_distributed_world_size(cp_group) rank = get_distributed_rank(cp_group) send_dst = cp_global_ranks[(rank + 1) % cp_size] - recv_src = cp_global_ranks[(rank + cp_size - 1) % cp_size] + recv_src = cp_global_ranks[(rank - 1) % cp_size] batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2) causal = (attn_mask_type == "causal") + qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format + if causal: - # [b, s, np, hn] -> [b, 2, s//2, np, hn] - q, k, v = [x.view(x.shape[0], 2, x.shape[1]//2, *x.shape[2:]) for x in [q, k, v]] + if qkv_format == "bshd": + # [b, s, np, hn] -> [b, 2, s//2, np, hn] + q, k, v = [x.view(x.shape[0], 2, x.shape[1]//2, *x.shape[2:]) for x in [q, k, v]] + elif qkv_format == "sbhd": + # [s, b, np, hn] -> [2, s//2, b, np, hn] + q, k, v = [x.view(2, x.shape[0]//2, *x.shape[1:]) for x in [q, k, v]] + if attn_bias is not None: + assert (len(attn_bias.shape) == 4), ( + "Only support bias shape of [b, h, sq, sk] for forward, " + "and [1, h, sq, sk] for backward!" + ) + # [b, np, sq, sk] -> [b, np, 2, sq//2, 2*cp, sk//(2*cp)] + attn_bias_ = attn_bias.view( \ + *attn_bias.shape[:-2], \ + 2, attn_bias.shape[-2]//2, \ + 2*cp_size, attn_bias.shape[-1]//(2*cp_size) \ + ) + # [b, np, sq, sk] -> [b, np, sq, 2*cp, sk//(2*cp)] + attn_bias = attn_bias.view( \ + *attn_bias.shape[:-1], \ + 2*cp_size, attn_bias.shape[-1]//(2*cp_size) \ + ) assert(q.shape[-1] % 8 == 0), "hidden size per attention head should be multiple of 8" fa_optional_forward_kwargs = {} if _flash_attn_2_3_plus: @@ -542,10 +565,12 @@ def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, # Flash Attn inputs q_inputs = [None, None] kv_inputs = [None, None] + attn_bias_inputs = [None, None] # Flash Attn outputs out_per_step = [None for _ in range(cp_size)] softmax_lse_per_step = [None for _ in range(cp_size)] rng_states = [None for _ in range(cp_size)] + attn_biases = [None for _ in range(cp_size)] # create two streams to resolve wave quantization issue of Flash Attn in each step flash_attn_streams = [torch.cuda.current_stream(), cp_stream] @@ -577,20 +602,37 @@ def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, if causal: if i == 0: if use_fused_attention: - # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] - q_inputs[i%2] = q.view(q.shape[0], -1, *q.shape[-2:]) - # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn] - kv_inputs[i%2] = kv_inputs[i%2].view( - 2, k.shape[0], -1, *k.shape[-2:]) - out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = \ + if qkv_format == "bshd": + # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] + q_inputs[i%2] = q.view(q.shape[0], -1, *q.shape[-2:]) + # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn] + kv_inputs[i%2] = kv_inputs[i%2].view( + 2, k.shape[0], -1, *k.shape[-2:]) + elif qkv_format == "sbhd": + # [2, sq//2, b, np, hn] -> [sq, b, np, hn] + q_inputs[i%2] = q.view(-1, *q.shape[-3:]) + # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn] + kv_inputs[i%2] = kv_inputs[i%2].view( + 2, -1, *k.shape[-3:]) + if attn_bias is not None: + idx = (rank - i) % cp_size + attn_bias_inputs[i%2] = torch.cat( + (attn_bias[..., idx, :], \ + attn_bias[..., (2*cp_size-idx-1), :]), + dim=-1 + ).contiguous() + out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = \ fused_attn_fwd( is_training, max_seqlen_q, max_seqlen_k, cu_seqlens_q, cu_seqlens_k, q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1], TE_DType[q.dtype], tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, attn_scale=softmax_scale, dropout=dropout_p, - qkv_layout="bshd_bshd_bshd", attn_mask_type="causal", + qkv_layout=qkv_layout, attn_mask_type="causal", + attn_bias_type=attn_bias_type, attn_bias=attn_bias_inputs[i%2], ) + if len(rest) > 0: + attn_biases[i] = rest[0] else: # [b, 2, sq//2, np, hn] -> [b*sq, np, hn] q_inputs[i%2] = q.view(-1, *q.shape[-2:]) @@ -605,19 +647,31 @@ def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, ) elif i <= rank: if use_fused_attention: - # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] - q_inputs[i%2] = q.view(q.shape[0], -1, *q.shape[-2:]) - # [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn] - kv_inputs[i%2] = kv_inputs[i%2][:, :, 0, ...].contiguous() - out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = \ + if qkv_format == "bshd": + # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] + q_inputs[i%2] = q.view(q.shape[0], -1, *q.shape[-2:]) + # [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn] + kv_inputs[i%2] = kv_inputs[i%2][:, :, 0, ...].contiguous() + elif qkv_format == "sbhd": + # [2, sq//2, b, np, hn] -> [sq, b, np, hn] + q_inputs[i%2] = q.view(-1, *q.shape[-3:]) + # [2, 2, sk//2, b, np, hn] -> [2, sk//2, b, np, hn] + kv_inputs[i%2] = kv_inputs[i%2][:, 0, ...].contiguous() + if attn_bias is not None: + idx = (rank - i) % cp_size + attn_bias_inputs[i%2] = attn_bias[..., idx, :].contiguous() + out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = \ fused_attn_fwd( is_training, max_seqlen_q, max_seqlen_k//2, cu_seqlens_q, cu_seqlens_k//2, q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1], TE_DType[q.dtype], tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, attn_scale=softmax_scale, dropout=dropout_p, - qkv_layout="bshd_bshd_bshd", attn_mask_type="no_mask", + qkv_layout=qkv_layout, attn_mask_type="no_mask", + attn_bias_type=attn_bias_type, attn_bias=attn_bias_inputs[i%2], ) + if len(rest) > 0: + attn_biases[i] = rest[0] else: # [b, 2, sq//2, np, hn] -> [b*sq, np, hn] q_inputs[i%2] = q.view(-1, *q.shape[-2:]) @@ -636,20 +690,37 @@ def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, ) else: if use_fused_attention: - # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] - q_inputs[i%2] = q[:, 1, ...].contiguous() - # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn] - kv_inputs[i%2] = kv_inputs[i%2].view( - 2, k.shape[0], -1, *k.shape[-2:]) - out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = \ + if qkv_format == "bshd": + # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] + q_inputs[i%2] = q[:, 1, ...].contiguous() + # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn] + kv_inputs[i%2] = kv_inputs[i%2].view( + 2, k.shape[0], -1, *k.shape[-2:]) + elif qkv_format == "sbhd": + # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] + q_inputs[i%2] = q[1].contiguous() + # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn] + kv_inputs[i%2] = kv_inputs[i%2].view( + 2, -1, *k.shape[-3:]) + if attn_bias is not None: + idx = (rank - i) % cp_size + attn_bias_inputs[i%2] = torch.cat( + (attn_bias_[..., 1, :, idx, :], \ + attn_bias_[..., 1, :, (2*cp_size-idx-1), :]), + dim=-1 + ).contiguous() + out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = \ fused_attn_fwd( is_training, max_seqlen_q//2, max_seqlen_k, cu_seqlens_q//2, cu_seqlens_k, q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1], TE_DType[q.dtype], tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, attn_scale=softmax_scale, dropout=dropout_p, - qkv_layout="bshd_bshd_bshd", attn_mask_type="no_mask", + qkv_layout=qkv_layout, attn_mask_type="no_mask", + attn_bias_type=attn_bias_type, attn_bias=attn_bias_inputs[i%2], ) + if len(rest) > 0: + attn_biases[i] = rest[0] else: # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn] q_inputs[i%2] = q[:, 1, ...].contiguous().view(-1, *q.shape[-2:]) @@ -666,15 +737,24 @@ def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, ) else: if use_fused_attention: - out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = \ + if attn_bias is not None: + idx = (rank - i) % cp_size + attn_bias_inputs[i%2] = torch.cat( + (attn_bias[..., idx, :], attn_bias[..., (2*cp_size-idx-1), :]), + dim=-1 + ).contiguous() + out_per_step[i], [softmax_lse_per_step[i], rng_states[i], *rest] = \ fused_attn_fwd( is_training, max_seqlen_q, max_seqlen_k, cu_seqlens_q, cu_seqlens_k, q, kv_inputs[i%2][0], kv_inputs[i%2][1], TE_DType[q.dtype], tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, attn_scale=softmax_scale, dropout=dropout_p, - qkv_layout="bshd_bshd_bshd", attn_mask_type="no_mask", + qkv_layout=qkv_layout, attn_mask_type="no_mask", + attn_bias_type=attn_bias_type, attn_bias=attn_bias_inputs[i%2], ) + if len(rest) > 0: + attn_biases[i] = rest[0] else: # [b, sq, np, hn] -> [b*sq, np, hn] q_inputs[i%2] = q.view(-1, *q.shape[-2:]) @@ -719,23 +799,33 @@ def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, torch.cuda.current_stream().wait_stream(flash_attn_streams[1]) softmax_lse = softmax_lse.to(torch.float) + seq_dim = qkv_format.index("s") for i in range(cp_size): - # [b*sq, np, hn] -> [b, sq, np, hn] or [b*sq//2, np, hn] -> [b, sq//2, np, hn] - out_ = out_per_step[i].view(out.shape[0], -1, *out.shape[-2:]) + if qkv_format == "bshd": + out_per_step[i] = out_per_step[i].view(out.shape[0], -1, *out.shape[-2:]) + out_ = out[:, 1, ...] + elif qkv_format == "sbhd": + out_per_step[i] = out_per_step[i].view(-1, *out.shape[-3:]) + out_ = out[1] if i <= rank or not causal: - flash_attn_fwd_out_correction(out.view(*out_.shape), - out_, + flash_attn_fwd_out_correction(out.view(*out_per_step[i].shape), + out_per_step[i], + seq_dim, softmax_lse, softmax_lse_per_step[i]) else: - flash_attn_fwd_out_correction(out[:, 1, ...], - out_, + flash_attn_fwd_out_correction(out_, + out_per_step[i], + seq_dim, softmax_lse_[..., 1, :], softmax_lse_per_step[i]) kv = p2p_comm_buffers[-1] if use_fused_attention: - out = out.view(out.shape[0], -1, *out.shape[-2:]) + if qkv_format == "bshd": + out = out.view(out.shape[0], -1, *out.shape[-2:]) + elif qkv_format == "sbhd": + out = out.view(-1, *out.shape[-3:]) else: out = out.view(-1, *out.shape[-2:]) ctx.save_for_backward(q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k) @@ -747,6 +837,10 @@ def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, ctx.max_seqlen_k = max_seqlen_k ctx.softmax_scale = softmax_scale ctx.causal = causal + ctx.qkv_format = qkv_format + ctx.attn_bias_type = attn_bias_type + ctx.attn_bias_shape = None if attn_bias is None else attn_bias.shape + ctx.attn_biases = attn_biases ctx.deterministic = deterministic ctx.use_fused_attention = use_fused_attention return out @@ -757,10 +851,26 @@ def backward(ctx, dout): cp_size = get_distributed_world_size(ctx.cp_group) rank = get_distributed_rank(ctx.cp_group) - send_dst = ctx.cp_global_ranks[(rank + cp_size - 1) % cp_size] + send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size] recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size] batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2) + qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format + + if ctx.attn_biases[0] is not None: + # [b, np, sq, 2*cp, sk//(2*cp)] + attn_dbias = torch.zeros( + *ctx.attn_bias_shape, + dtype=ctx.attn_biases[0].dtype, + device=ctx.attn_biases[0].device + ) + # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, 2, sq//2, 2*cp, sk//(2*cp)] + attn_dbias_ = attn_dbias.view( + *attn_dbias.shape[:-3], 2, attn_dbias.shape[-3]//2, *attn_dbias.shape[-2:] + ) + else: + attn_dbias = None + if ctx.causal: # [b, np, sq] -> [b, np, 2, sq//2] softmax_lse_ = softmax_lse.view(*softmax_lse.shape[:-1], 2, softmax_lse.shape[-1]//2) @@ -814,23 +924,36 @@ def backward(ctx, dout): if ctx.causal: if i == (cp_size-1): if ctx.use_fused_attention: - # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] - q_ = q.view(q.shape[0], -1, *q.shape[-2:]) - # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn] - kv_ = kv.view(*kv.shape[0:2], -1, *kv.shape[-2:]) - # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] - out_ = out.view(out.shape[0], -1, *out.shape[-2:]) - dout_ = dout.view(dout.shape[0], -1, *dout.shape[-2:]) - dq_, dk_, dv_, _ = fused_attn_bwd( + if ctx.qkv_format == "bshd": + # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] + q_ = q.view(q.shape[0], -1, *q.shape[-2:]) + # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn] + kv_ = kv.view(*kv.shape[0:2], -1, *kv.shape[-2:]) + # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] + out_ = out.view(out.shape[0], -1, *out.shape[-2:]) + dout_ = dout.view(dout.shape[0], -1, *dout.shape[-2:]) + elif ctx.qkv_format == "sbhd": + # [2, sq//2, b, np, hn] -> [sq, b, np, hn] + q_ = q.view(-1, *q.shape[-3:]) + # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn] + kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:]) + # [2, sq//2, b, np, hn] -> [sq, b, np, hn] + out_ = out.view(-1, *out.shape[-3:]) + dout_ = dout.view(-1, *dout.shape[-3:]) + aux_ctx_tensors = [softmax_lse, ctx.rng_states[cp_size-i-1]] + if attn_dbias is not None: + aux_ctx_tensors += [ctx.attn_biases[cp_size-i-1]] + dq_, dk_, dv_, dbias_ = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_k, cu_seqlens_q, cu_seqlens_k, - q_, kv_[0], kv_[1], out_, dout_, TE_DType[q.dtype], TE_DType[kv.dtype], - [softmax_lse, ctx.rng_states[cp_size-i-1]], + q_, kv_[0], kv_[1], out_, dout_, + TE_DType[q.dtype], TE_DType[kv.dtype], aux_ctx_tensors, tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, attn_scale=ctx.softmax_scale, dropout=ctx.dropout_p, - qkv_layout="bshd_bshd_bshd", + qkv_layout=qkv_layout, attn_mask_type="causal", + attn_bias_type=ctx.attn_bias_type, ) else: # [b, 2, sq//2, np, hn] -> [b*sq, np, hn] @@ -854,23 +977,36 @@ def backward(ctx, dout): ) elif i >= (cp_size-rank-1): if ctx.use_fused_attention: - # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] - q_ = q.view(q.shape[0], -1, *q.shape[-2:]) - # [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn] - kv_ = kv[:, :, 0, ...].contiguous() - # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] - out_ = out.view(out.shape[0], -1, *out.shape[-2:]) - dout_ = dout.view(dout.shape[0], -1, *dout.shape[-2:]) - dq_, dk_, dv_, _ = fused_attn_bwd( + if ctx.qkv_format == "bshd": + # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] + q_ = q.view(q.shape[0], -1, *q.shape[-2:]) + # [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn] + kv_ = kv[:, :, 0, ...].contiguous() + # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] + out_ = out.view(out.shape[0], -1, *out.shape[-2:]) + dout_ = dout.view(dout.shape[0], -1, *dout.shape[-2:]) + elif ctx.qkv_format == "sbhd": + # [2, sq//2, b, np, hn] -> [sq, b, np, hn] + q_ = q.view(-1, *q.shape[-3:]) + # [2, 2, sk//2, b, np, hn] -> [2, sk//2, b, np, hn] + kv_ = kv[:, 0, ...].contiguous() + # [2, sq//2, b, np, hn] -> [sq, b, np, hn] + out_ = out.view(-1, *out.shape[-3:]) + dout_ = dout.view(-1, *dout.shape[-3:]) + aux_ctx_tensors = [softmax_lse, ctx.rng_states[cp_size-i-1]] + if attn_dbias is not None: + aux_ctx_tensors += [ctx.attn_biases[cp_size-i-1]] + dq_, dk_, dv_, dbias_ = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_k//2, cu_seqlens_q, cu_seqlens_k//2, - q_, kv_[0], kv_[1], out_, dout_, TE_DType[q.dtype], TE_DType[kv.dtype], - [softmax_lse, ctx.rng_states[cp_size-i-1]], + q_, kv_[0], kv_[1], out_, dout_, + TE_DType[q.dtype], TE_DType[kv.dtype], aux_ctx_tensors, tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, attn_scale=ctx.softmax_scale, dropout=ctx.dropout_p, - qkv_layout="bshd_bshd_bshd", + qkv_layout=qkv_layout, attn_mask_type="no_mask", + attn_bias_type=ctx.attn_bias_type, ) else: # [b, 2, sq//2, np, hn] -> [b*sq, np, hn] @@ -894,23 +1030,36 @@ def backward(ctx, dout): ) else: if ctx.use_fused_attention: - # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] - q_ = q[:, 1, ...].contiguous() - # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn] - kv_ = kv.view(*kv.shape[0:2], -1, *kv.shape[-2:]) - # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] - out_ = out[:, 1, ...].contiguous() - dout_ = dout[:, 1, ...].contiguous() - dq_, dk_, dv_, _ = fused_attn_bwd( + if ctx.qkv_format == "bshd": + # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] + q_ = q[:, 1, ...].contiguous() + # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn] + kv_ = kv.view(*kv.shape[0:2], -1, *kv.shape[-2:]) + # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] + out_ = out[:, 1, ...].contiguous() + dout_ = dout[:, 1, ...].contiguous() + elif ctx.qkv_format == "sbhd": + # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] + q_ = q[1].contiguous() + # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn] + kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:]) + # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] + out_ = out[1].contiguous() + dout_ = dout[1].contiguous() + aux_ctx_tensors = [softmax_lse_, ctx.rng_states[cp_size-i-1]] + if attn_dbias is not None: + aux_ctx_tensors += [ctx.attn_biases[cp_size-i-1]] + dq_, dk_, dv_, dbias_ = fused_attn_bwd( ctx.max_seqlen_q//2, ctx.max_seqlen_k, cu_seqlens_q//2, cu_seqlens_k, - q_, kv_[0], kv_[1], out_, dout_, TE_DType[q.dtype], TE_DType[kv.dtype], - [softmax_lse_, ctx.rng_states[cp_size-i-1]], + q_, kv_[0], kv_[1], out_, dout_, + TE_DType[q.dtype], TE_DType[kv.dtype], aux_ctx_tensors, tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, attn_scale=ctx.softmax_scale, dropout=ctx.dropout_p, - qkv_layout="bshd_bshd_bshd", + qkv_layout=qkv_layout, attn_mask_type="no_mask", + attn_bias_type=ctx.attn_bias_type, ) else: # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn] @@ -934,16 +1083,20 @@ def backward(ctx, dout): ) else: if ctx.use_fused_attention: - dq_, dk_, dv_, _ = fused_attn_bwd( + aux_ctx_tensors = [softmax_lse, ctx.rng_states[cp_size-i-1]] + if attn_dbias is not None: + aux_ctx_tensors += [ctx.attn_biases[cp_size-i-1]] + dq_, dk_, dv_, dbias_ = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_k, cu_seqlens_q, cu_seqlens_k, - q, kv[0], kv[1], out, dout, TE_DType[q.dtype], TE_DType[kv.dtype], - [softmax_lse, ctx.rng_states[cp_size-i-1]], + q, kv[0], kv[1], out, dout, + TE_DType[q.dtype], TE_DType[kv.dtype], aux_ctx_tensors, tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, attn_scale=ctx.softmax_scale, dropout=ctx.dropout_p, - qkv_layout="bshd_bshd_bshd", + qkv_layout=qkv_layout, attn_mask_type="no_mask", + attn_bias_type=ctx.attn_bias_type, ) else: # [b, sq, np, hn] -> [b*sq, np, hn] @@ -970,8 +1123,12 @@ def backward(ctx, dout): # [b*sq, np, hn] -> [b, sq, np, hn] if not causal dq_ = dq_.view(*dq.shape) else: - # [b*sq//2, np, hn] -> [b, sq//2, np, hn] - dq_ = dq_.view(dq.shape[0], *dq.shape[2:]) + if ctx.qkv_format == "bshd": + # [b*sq//2, np, hn] -> [b, sq//2, np, hn] + dq_ = dq_.view(dq.shape[0], *dq.shape[2:]) + elif ctx.qkv_format == "sbhd": + # [b*sq//2, np, hn] -> [sq//2, b, np, hn] + dq_ = dq_.view(-1, *dq.shape[-3:]) if ctx.causal: if i > (cp_size-rank-1): @@ -980,18 +1137,44 @@ def backward(ctx, dout): if rank == (cp_size-1): dq.copy_(dq_) else: - dq[:, 0, ...].copy_(dq_[:, 0, ...]) - dq[:, 1, ...].add_(dq_[:, 1, ...]) + if ctx.qkv_format == "bshd": + dq[:, 0, ...].copy_(dq_[:, 0, ...]) + dq[:, 1, ...].add_(dq_[:, 1, ...]) + elif ctx.qkv_format == "sbhd": + dq[0].copy_(dq_[0]) + dq[1].add_(dq_[1]) elif i > 0: - dq[:, 1, ...].add_(dq_) + if ctx.qkv_format == "bshd": + dq[:, 1, ...].add_(dq_) + elif ctx.qkv_format == "sbhd": + dq[1].add_(dq_) else: - dq[:, 1, ...].copy_(dq_) + if ctx.qkv_format == "bshd": + dq[:, 1, ...].copy_(dq_) + elif ctx.qkv_format == "sbhd": + dq[1].copy_(dq_) else: if i == 0: dq.copy_(dq_) else: dq.add_(dq_) + if attn_dbias is not None: + idx = (rank+i+1)%cp_size + if i == (cp_size - 1) or not ctx.causal: + # [b, np, sq, sk//cp] -> [b, np, sq, 2, sk//(2*cp)] + dbias_ = dbias_.view(*dbias_.shape[:-1], 2, dbias_.shape[-1]//2) + attn_dbias[..., idx, :].copy_(dbias_[..., 0, :]) + attn_dbias[..., (2*cp_size-idx-1), :].copy_(dbias_[..., 1, :]) + elif i >= (cp_size-rank-1): + # [b, np, sq, sk//(2*cp)] + attn_dbias[..., idx, :].copy_(dbias_) + else: + # [b, np, sq//2, sk//cp] -> [b, np, sq//2, 2, sk//(2*cp)] + dbias_ = dbias_.view(*dbias_.shape[:-1], 2, dbias_.shape[-1]//2) + attn_dbias_[..., 1, :, idx, :].copy_(dbias_[..., 0, :]) + attn_dbias_[..., 1, :, (2*cp_size-idx-1), :].copy_(dbias_[..., 1, :]) + # wait until dKV is received for req in send_recv_reqs: req.wait() @@ -1000,8 +1183,12 @@ def backward(ctx, dout): if ctx.use_fused_attention: dkv_ = torch.cat((dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0) if ctx.causal and i >= (cp_size-rank-1) and i != (cp_size-1): - # [2, b*sk//2, np, hn] -> [2, b, sk//2, np, hn] - dkv_ = dkv_.view(*dkv.shape[0:2], *dkv.shape[3:]) + if ctx.qkv_format == "bshd": + # [2, b*sk//2, np, hn] -> [2, b, sk//2, np, hn] + dkv_ = dkv_.view(*dkv.shape[0:2], *dkv.shape[3:]) + elif ctx.qkv_format == "sbhd": + # [2, b*sk//2, np, hn] -> [2, sk//2, b, np, hn] + dkv_ = dkv_.view(dkv.shape[0], -1, *dkv.shape[-3:]) else: # [2, b*sk, np, hn] -> [2, b, 2, sk//2, np, hn] if causal # [2, b*sk, np, hn] -> [2, b, sk, np, hn] if not causal @@ -1010,15 +1197,25 @@ def backward(ctx, dout): if ctx.causal: if i == (cp_size-1): if rank == 0: - dkv[:, :, 0, ...].add_(dkv_[:, :, 0, ...]) - dkv[:, :, 1, ...].copy_(dkv_[:, :, 1, ...]) + if ctx.qkv_format == "bshd": + dkv[:, :, 0, ...].add_(dkv_[:, :, 0, ...]) + dkv[:, :, 1, ...].copy_(dkv_[:, :, 1, ...]) + elif ctx.qkv_format == "sbhd": + dkv[:, 0, ...].add_(dkv_[:, 0, ...]) + dkv[:, 1, ...].copy_(dkv_[:, 1, ...]) else: dkv.add_(dkv_) elif i >= (cp_size-rank-1): if i == 0 and rank == (cp_size-1): - dkv[:, :, 0, ...].copy_(dkv_) + if ctx.qkv_format == "bshd": + dkv[:, :, 0, ...].copy_(dkv_) + elif ctx.qkv_format == "sbhd": + dkv[:, 0, ...].copy_(dkv_) else: - dkv[:, :, 0, ...].add_(dkv_) + if ctx.qkv_format == "bshd": + dkv[:, :, 0, ...].add_(dkv_) + elif ctx.qkv_format == "sbhd": + dkv[:, 0, ...].add_(dkv_) elif i > 0: dkv.add_(dkv_) else: @@ -1030,26 +1227,44 @@ def backward(ctx, dout): dkv.add_(dkv_) if ctx.causal: - # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] - dq = dq.view(q.shape[0], -1, *q.shape[-2:]) - # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn] - dkv = dkv.view(*kv.shape[0:2], -1, *kv.shape[-2:]) + if ctx.qkv_format == "bshd": + # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] + dq = dq.view(q.shape[0], -1, *q.shape[-2:]) + # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn] + dkv = dkv.view(*kv.shape[0:2], -1, *kv.shape[-2:]) + elif ctx.qkv_format == "sbhd": + # [2, sq//2, b, np, hn] -> [sq, b, np, hn] + dq = dq.view(-1, *q.shape[-3:]) + # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn] + dkv = dkv.view(kv.shape[0], -1, *kv.shape[-3:]) + + if attn_dbias is not None: + # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, sq, sk] + attn_dbias = attn_dbias.view(*attn_dbias.shape[:-2], -1) + return None, dq, dkv[0], dkv[1], None, None, None, None, None, None, \ - None, None, None, None, None, None + None, None, None, None, None, None, attn_dbias, None, None def attn_forward_func_with_cp( - is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, - cp_group, cp_global_ranks, cp_stream, softmax_scale=None, attn_mask_type="causal", - deterministic=False, use_fused_attention=False + is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p, cp_group, cp_global_ranks, cp_stream, softmax_scale=None, qkv_format="bshd", + attn_mask_type="causal", attn_bias_type="no_bias", attn_bias=None, deterministic=False, + use_fused_attention=False ) -> torch.Tensor: """Attention implementation with context parallelism""" + assert(qkv_format in ["bshd", "sbhd"] + ), f"QKV format of {qkv_format} is not supported with context parallelism!" + assert(qkv_format != "sbhd" or use_fused_attention + ), "FlashAttention does not support sbhd format!" assert (attn_mask_type in ["causal", "no_mask"] ), f"Mask type of {attn_mask_type} is not supported with context parallelism!" + assert (attn_bias is None or use_fused_attention + ), "Attention bias is only supported with FusedAttention!" out = AttnFuncWithCP.apply( is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, - dropout_p, cp_group, cp_global_ranks, cp_stream, softmax_scale, attn_mask_type, - deterministic, use_fused_attention + dropout_p, cp_group, cp_global_ranks, cp_stream, softmax_scale, qkv_format, + attn_mask_type, attn_bias_type, attn_bias, deterministic, use_fused_attention ) return out @@ -1857,6 +2072,7 @@ def forward( self.attention_dropout if self.training else 0.0, cp_group, cp_global_ranks, cp_stream, softmax_scale=1.0/self.norm_factor, + qkv_format="bshd" if qkv_format=="sbhd" else qkv_format, attn_mask_type=attn_mask_type, deterministic=self.deterministic ) @@ -2821,11 +3037,11 @@ def forward( assert (fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen ), f"{fused_attention_backend} does not work with context parallelism!" - assert (core_attention_bias_type == "no_bias"), \ - "Core attention bias has not been supported with context parallelism yet!" - if qkv_format == 'sbhd': - query_layer, key_layer, value_layer = [x.transpose(0,1).contiguous() - for x in (query_layer, key_layer, value_layer)] + assert ( + core_attention_bias_type not in ["alibi"] + ), f"{core_attention_bias_type} is not supported with context parallelism!" + query_layer, key_layer, value_layer = [x.contiguous() + for x in (query_layer, key_layer, value_layer)] with self.attention_dropout_ctx(): output = attn_forward_func_with_cp( self.training, @@ -2835,11 +3051,12 @@ def forward( self.attention_dropout if self.training else 0.0, cp_group, cp_global_ranks, cp_stream, softmax_scale=1.0/self.norm_factor, + qkv_format=qkv_format, attn_mask_type=attn_mask_type, + attn_bias_type=core_attention_bias_type, + attn_bias=core_attention_bias, use_fused_attention=True, ) - if qkv_format == 'sbhd': - output = output.transpose(0,1).contiguous() else: with self.prepare_forward(query_layer, is_first_microbatch, From 36297ef017c5fd411d62393df8d4ca1ea7af86c1 Mon Sep 17 00:00:00 2001 From: Zhenhuan Liu Date: Tue, 30 Apr 2024 04:11:20 +0800 Subject: [PATCH 037/244] FP8 Support for MCore MoE (#648) * Add support for MoE with FP8. Signed-off-by: Dennis Liu * Fix unittest. Signed-off-by: Dennis Liu * Fix error in linear backward. Signed-off-by: Dennis Liu --------- Signed-off-by: Dennis Liu Co-authored-by: Przemyslaw Tredak Signed-off-by: Pawel Gadzinski --- tests/pytorch/test_sanity.py | 46 ++++++++++++++++++- .../pytorch/cpp_extensions/cast.py | 21 +++++---- .../pytorch/cpp_extensions/gemm.py | 4 ++ .../pytorch/cpp_extensions/transpose.py | 21 +++++---- .../pytorch/csrc/extensions/cast.cu | 3 ++ .../pytorch/csrc/extensions/transpose.cu | 5 ++ transformer_engine/pytorch/module/base.py | 8 +++- transformer_engine/pytorch/module/linear.py | 8 +++- 8 files changed, 92 insertions(+), 24 deletions(-) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 9f8c8f73cb..b6904b0c45 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -9,7 +9,11 @@ import torch import pytest -from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager +from transformer_engine.pytorch.fp8 import ( + fp8_autocast, + FP8GlobalStateManager, + fp8_model_init, +) from transformer_engine.pytorch.utils import ( init_method_normal, scaled_init_method_normal, @@ -107,6 +111,7 @@ def is_fp8_supported(self): param_types.append(torch.bfloat16) all_boolean = [True, False] +batch_sizes_with_zero = [0, 1, 2] all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"] all_normalizations = ["LayerNorm", "RMSNorm"] @@ -456,6 +461,45 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad): _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad) +@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("bs", batch_sizes_with_zero) +@pytest.mark.parametrize("model", ["small", "weird"]) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_model_params", all_boolean) +@pytest.mark.parametrize("use_bias", all_boolean) +def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_params, use_bias): + config = model_configs[model] + ffn_hidden_size = 4 * config.hidden_size + num_tokens = bs*config.seq_len + + if fp8_recipe is not None: + if not fp8_available: + pytest.skip(reason_for_no_fp8) + if not config.is_fp8_supported(): + pytest.skip("Model config does not support FP8") + + use_fp8 = fp8_recipe is not None + with fp8_model_init(enabled=use_fp8 and fp8_model_params): + te_linear = ( + Linear( + config.hidden_size, + ffn_hidden_size, + bias=use_bias, + params_dtype=dtype + ) + .cuda() + ) + + inp_hidden_states = torch.randn( + num_tokens, config.hidden_size, dtype=dtype, requires_grad=True + ).cuda() + with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): + out = te_linear(inp_hidden_states) + loss = out.sum() + loss.backward() + assert out.shape == (num_tokens, ffn_hidden_size) + + @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("model", ["small", "weird"]) diff --git a/transformer_engine/pytorch/cpp_extensions/cast.py b/transformer_engine/pytorch/cpp_extensions/cast.py index 3c80beff97..a86222d958 100644 --- a/transformer_engine/pytorch/cpp_extensions/cast.py +++ b/transformer_engine/pytorch/cpp_extensions/cast.py @@ -22,16 +22,18 @@ def cast_to_fp8( """Cast input to FP8""" if out is not None: - torch.ops.tex_ts.cast_to_fp8_noalloc_ts( - inp, - fp8_meta_tensor.scale, - out, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, - fp8_tensor, - otype - ) + if inp.nelement() > 0: + torch.ops.tex_ts.cast_to_fp8_noalloc_ts( + inp, + fp8_meta_tensor.scale, + out, + fp8_meta_tensor.amax_history, + fp8_meta_tensor.scale_inv, + fp8_tensor, + otype + ) return None + return torch.ops.tex_ts.cast_to_fp8_ts( inp, fp8_meta_tensor.scale, @@ -41,7 +43,6 @@ def cast_to_fp8( otype, ) - def cast_from_fp8( inp: torch.Tensor, fp8_meta_tensor: tex.FP8TensorMeta, diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 46ce244ce6..758d933401 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -64,6 +64,8 @@ def fp8_gemm( bias_dtype = TE_DType[bias_dtype] out_dtype = TE_DType[out.dtype] if D_dtype is None else D_dtype + if A.nelement() == 0 or B.nelement() == 0: + return out, gelu_input args = ( A, @@ -191,6 +193,8 @@ def gemm( grad_bias = empty_tensor bias = bias if use_bias else empty_tensor + if A.nelement() == 0 or B.nelement() == 0: + return out, grad_bias, gelu_input assert A.dtype == dtype and B.dtype == dtype, \ f'Expected dtype={dtype}, but found A.dtype={A.dtype} and B.dtype={B.dtype}' diff --git a/transformer_engine/pytorch/cpp_extensions/transpose.py b/transformer_engine/pytorch/cpp_extensions/transpose.py index 3671f2e064..b264259fa5 100644 --- a/transformer_engine/pytorch/cpp_extensions/transpose.py +++ b/transformer_engine/pytorch/cpp_extensions/transpose.py @@ -39,16 +39,17 @@ def fp8_cast_transpose_fused( if noop_flag is None: noop_flag = torch.Tensor() - tex.fused_cast_transpose_noop( - inp, - noop_flag, - fp8_meta_tensor.scale[fp8_tensor], - fp8_meta_tensor.amax_history[0][fp8_tensor], - fp8_meta_tensor.scale_inv[fp8_tensor], - cast_out, - transpose_out, - otype, - ) + if inp.nelement() > 0: + tex.fused_cast_transpose_noop( + inp, + noop_flag, + fp8_meta_tensor.scale[fp8_tensor], + fp8_meta_tensor.amax_history[0][fp8_tensor], + fp8_meta_tensor.scale_inv[fp8_tensor], + cast_out, + transpose_out, + otype, + ) if return_outputs: return cast_out, transpose_out diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cu b/transformer_engine/pytorch/csrc/extensions/cast.cu index 80975069de..c798a39df5 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cu +++ b/transformer_engine/pytorch/csrc/extensions/cast.cu @@ -19,6 +19,9 @@ at::Tensor cast_to_fp8(const at::Tensor &input, auto output = at::empty_like(input, at::CUDA(GetATenDType(otype))); + if (input.numel() == 0) + return output; + auto input_cu = makeTransformerEngineTensor(input); auto output_cu = makeTransformerEngineTensor(output.data_ptr(), shape, otype, amax.data_ptr(), scale.data_ptr(), diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cu b/transformer_engine/pytorch/csrc/extensions/transpose.cu index fc178adeb4..bf87fdb4bc 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cu +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cu @@ -83,6 +83,9 @@ std::vector fused_cast_transpose_bgrad(at::Tensor grad_output, grad_output.size(0), DType::kByte); + if (M == 0 || N == 0) + return {grad_bias, grad_output_cast, grad_output_transpose}; + auto input_cu = makeTransformerEngineTensor(grad_output); auto cast_output_cu = makeTransformerEngineTensor(grad_output_cast.data_ptr(), {M, N}, otype, amax.data_ptr(), scale.data_ptr(), @@ -335,6 +338,8 @@ at::Tensor fp8_transpose(at::Tensor input, size_t M = static_cast(input.size(0)); size_t N = static_cast(input.size(1)); + if (M == 0 || N == 0) + return input; auto output = allocateTorchTensor(input.size(1), diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index e0bf5efbbf..0803b474f6 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -825,8 +825,12 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: if get_rng_state_tracker is None: init_fn(param) else: - with get_rng_state_tracker().fork(): - init_fn(param) + if hasattr(self, "rng_tracker_name") and self.rng_tracker_name: + with get_rng_state_tracker().fork(self.rng_tracker_name): + init_fn(param) + else: + with get_rng_state_tracker().fork(): + init_fn(param) # If primary weights are in fp8, wrap the parameter as Float8Tensor fp8_meta_index = self.param_init_meta[name].fp8_meta_index diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index b48987f34c..ca5345bc69 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -152,7 +152,6 @@ def forward( inputmat_total, _ = gather_along_first_dim(inputmat, tp_group) else: inputmat_total = inputmat - if fp8: if _NVTE_DEBUG: print('[Linear]: using FP8 forward') @@ -664,6 +663,10 @@ class Linear(TransformerEngineBaseModule): init_method : Callable, default = `None` used for initializing weights in the following way: `init_method(weight)`. When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`. + get_rng_state_tracker : Callable, default = `None` + used to get the random number generator state tracker for initilizeing weights. + rng_tracker_name : str, default = `None` + the param passed to get_rng_state_tracker to get the specific rng tracker. parameters_split : Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None Configuration for splitting the weight and bias tensors along dim 0 into multiple PyTorch parameters. If a list or tuple of strings is provided, @@ -723,6 +726,7 @@ def __init__( tp_group: Optional[dist_group_type] = None, tp_size: int = 1, get_rng_state_tracker: Optional[Callable] = None, + rng_tracker_name: Optional[str] = None, init_method: Optional[Callable] = None, bias: bool = True, return_bias: bool = False, @@ -753,6 +757,8 @@ def __init__( ), "Userbuffer communication backend not available." self.ub_name = ub_name self.get_rng_state_tracker = get_rng_state_tracker + self.rng_tracker_name = rng_tracker_name + if device == 'meta': assert parameters_split is None, ("Cannot split module parameters " "on 'meta' device.") From b394dee466076c2e0656aa69a6c27b7a992fb634 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Mon, 29 Apr 2024 13:17:02 -0700 Subject: [PATCH 038/244] Add module level filter for deprecation warning in common (#813) * Add module level filter for deprecation warning in common Signed-off-by: Kirthi Shankar Sivamani * Fix module Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani Signed-off-by: Pawel Gadzinski --- transformer_engine/common/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/transformer_engine/common/utils.py b/transformer_engine/common/utils.py index a54e778171..339fa59f6c 100644 --- a/transformer_engine/common/utils.py +++ b/transformer_engine/common/utils.py @@ -6,7 +6,8 @@ import warnings from enum import Enum -warnings.simplefilter('default') +warnings.filterwarnings( + "module", category=DeprecationWarning, module="transformer_engine.common.utils") class DeprecatedEnum: # pylint: disable=too-few-public-methods From 086df06c1db173e9118c4ec9041bf080d0ea51e8 Mon Sep 17 00:00:00 2001 From: cyanguwa <8636796+cyanguwa@users.noreply.github.com> Date: Mon, 29 Apr 2024 13:22:54 -0700 Subject: [PATCH 039/244] [PyTorch] Fix tp_group_initialized error (#819) remove tp_size/tp_group as amax reduction is handled by fp8_group() Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/attention.py | 23 ++++------------------- 1 file changed, 4 insertions(+), 19 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index c4f9bd5301..1676a728db 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -2153,7 +2153,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): def forward(ctx, is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, rng_gen, fused_attention_backend, use_FAv2_bwd, - fp8, fp8_meta, tp_size, tp_group): + fp8, fp8_meta): if fp8: if _NVTE_DEBUG: print('[DotProductAttention]: using FP8 forward') @@ -2227,8 +2227,6 @@ def forward(ctx, is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype, attn_bias, qkvo_tensors = (qkv, out_save) if not ctx.fp8 else (None, None) ctx.save_for_backward(*qkvo_tensors, cu_seqlens, *fp8_tensors) ctx.fp8_meta = fp8_meta - ctx.tp_size = tp_size - ctx.tp_group = tp_group ctx.aux_ctx_tensors = aux_ctx_tensors ctx.max_seqlen = max_seqlen ctx.qkv_dtype = qkv_dtype @@ -2349,7 +2347,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, q, kv, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, rng_gen, fused_attention_backend, - use_FAv2_bwd, fp8, fp8_meta, tp_size, tp_group): + use_FAv2_bwd, fp8, fp8_meta): if fp8: if _NVTE_DEBUG: print('[DotProductAttention]: using FP8 forward') @@ -2430,8 +2428,6 @@ def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seql qkvo_tensors = (q, kv, out_save) if not ctx.fp8 else (None, None, None) ctx.save_for_backward(*qkvo_tensors, cu_seqlens_q, cu_seqlens_kv, *fp8_tensors) ctx.fp8_meta = fp8_meta - ctx.tp_size = tp_size - ctx.tp_group = tp_group ctx.aux_ctx_tensors = aux_ctx_tensors ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_kv = max_seqlen_kv @@ -2566,7 +2562,7 @@ class FusedAttnFunc(torch.autograd.Function): def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, q, k, v, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, rng_gen, fused_attention_backend, - use_FAv2_bwd, fp8, fp8_meta, tp_size, tp_group): + use_FAv2_bwd, fp8, fp8_meta): if fp8: if _NVTE_DEBUG: print('[DotProductAttention]: using FP8 forward') @@ -2704,8 +2700,6 @@ def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seql qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None) ctx.save_for_backward(*qkvo_tensors, cu_seqlens_q, cu_seqlens_kv, *fp8_tensors) ctx.fp8_meta = fp8_meta - ctx.tp_size = tp_size - ctx.tp_group = tp_group ctx.aux_ctx_tensors = aux_ctx_tensors ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_kv = max_seqlen_kv @@ -2907,8 +2901,6 @@ def __init__( attention_type: str = "self", layer_number: Optional[int] = None, deterministic: bool = False, - tp_size: int = 1, - tp_group: Optional[dist_group_type] = None, ) -> None: super().__init__() @@ -2935,9 +2927,6 @@ def __init__( if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "1": os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1" - self.tp_size = tp_size - self.tp_group = tp_group - def get_fp8_weights_scratchpad( self, is_first_microbatch: Union[bool, None], @@ -3092,8 +3081,6 @@ def forward( use_FAv2_bwd, self.fp8 and self.fp8_meta["recipe"].fp8_dpa, self.fp8_meta, - self.tp_size, - self.tp_group, ) # ...hd -> ...(hd) @@ -3292,9 +3279,7 @@ def __init__( attention_type=attention_type, layer_number=layer_number, deterministic=self.deterministic, - **attn_kwargs, - tp_size=self.tp_size, - tp_group=self.tp_group) + **attn_kwargs) self.unfused_attention = UnfusedDotProductAttention( norm_factor, **attn_kwargs, layer_number=layer_number) From 9ac388a9bdb7ce5b83f9064b34a5e7514ac66d1d Mon Sep 17 00:00:00 2001 From: cyanguwa <8636796+cyanguwa@users.noreply.github.com> Date: Mon, 29 Apr 2024 15:11:55 -0700 Subject: [PATCH 040/244] [PyTorch] Skip context parallel tests on architectures below sm80 (#799) restrict context parallel tests to sm80+ as fused/flash attn backends require sm80+ Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Signed-off-by: Pawel Gadzinski --- tests/pytorch/fused_attn/test_fused_attn_with_cp.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py index ac571cd0e4..43280ecdde 100644 --- a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py @@ -10,6 +10,7 @@ _is_flash_attention_2_available, _cudnn_version, ) +from transformer_engine.pytorch.utils import get_device_compute_capability model_configs_flash_attn = { # test: b, h, hg, d, sq, skv, p, mask, bias @@ -29,6 +30,7 @@ def get_bash_arguments(**kwargs): return args @pytest.mark.skipif(not _is_flash_attention_2_available(), reason="Flash-attn 2.0+ is required.") +@pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.") @pytest.mark.parametrize("dtype", ['bf16', 'fp16']) @pytest.mark.parametrize("model", model_configs_flash_attn.keys()) @pytest.mark.parametrize("qkv_format", ['bshd', 'sbhd']) @@ -56,6 +58,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format): } @pytest.mark.skipif(_cudnn_version() < (8,9,7), reason="cuDNN 8.9.7+ is required.") +@pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.") @pytest.mark.parametrize("dtype", ['bf16', 'fp16']) @pytest.mark.parametrize("model", model_configs_fused_attn.keys()) @pytest.mark.parametrize("qkv_format", ['bshd', 'sbhd']) From f6aca0af357ad9c84a1038b5bf5b0137883d8a4a Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Tue, 30 Apr 2024 11:13:48 -0700 Subject: [PATCH 041/244] [PyTorch] Fix linter warnings from unused args (#816) * Fix linter warnings from unused args Signed-off-by: Tim Moon * Update .gitignore Co-authored-by: Kirthi Shankar Sivamani Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --------- Signed-off-by: Tim Moon Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: Kirthi Shankar Sivamani Signed-off-by: Pawel Gadzinski --- .gitignore | 2 +- transformer_engine/pytorch/attention.py | 20 +++++++++++--------- transformer_engine/pytorch/float8_tensor.py | 16 +++++++++++----- 3 files changed, 23 insertions(+), 15 deletions(-) diff --git a/.gitignore b/.gitignore index 4502c06264..54f5e0b2d7 100644 --- a/.gitignore +++ b/.gitignore @@ -26,4 +26,4 @@ docs/_build .ipynb_checkpoints docs/doxygen *.log -CMakeFiles/CMakeSystem.cmake \ No newline at end of file +CMakeFiles/CMakeSystem.cmake diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 1676a728db..dbc26d538d 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1757,11 +1757,12 @@ class _PrepareQKVForFA(torch.autograd.Function): to separate contiguous q, k, v tensors in (b, s, ...) layout.""" @staticmethod - def forward(ctx, - query_layer: torch.Tensor, - key_layer: torch.Tensor, - value_layer: torch.Tensor - ) -> torch.Tensor: + def forward( + _ctx: torch.autograd.function.FunctionCtx, # unused + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # All inputs received are non-contiguous tensors. # The `query_layer` tensor is used to access the # full memory region of the QKV tensor. @@ -1773,10 +1774,11 @@ def forward(ctx, return query_layer, key_layer, value_layer @staticmethod - def backward(ctx, - dq: torch.Tensor, - dk: torch.Tensor, - dv: torch.Tensor + def backward( + _ctx: torch.autograd.function.FunctionCtx, # unused + dq: torch.Tensor, + dk: torch.Tensor, + dv: torch.Tensor ) -> Tuple[Union[torch.Tensor, None], ...]: dqkv = tex.fa_prepare_bwd(dq, dk, dv) dq, dk, dv = split_tensor_along_dim(dqkv, -1, 3) diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index bbcbc2839c..719cc36739 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -46,7 +46,7 @@ class _FromFloat8Func(torch.autograd.Function): """Cast from FP8 to other dtype""" @staticmethod def forward( - ctx, + _ctx: torch.autograd.function.FunctionCtx, # unused tensor: Float8Tensor, dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: @@ -63,7 +63,10 @@ def forward( return out @staticmethod - def backward(ctx, grad): + def backward( + _ctx: torch.autograd.function.FunctionCtx, # unused + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: # Assume that we want gradients in full precision return grad, None @@ -97,7 +100,7 @@ class _ToFloat8Func(torch.autograd.Function): """Cast to FP8 from other dtype""" @staticmethod def forward( - ctx, + _ctx: torch.autograd.function.FunctionCtx, # unused tensor: torch.Tensor, fp8_meta: Optional[Dict[str, Any]] = None, fp8_meta_forward: bool = True, @@ -106,7 +109,7 @@ def forward( scale: Optional[torch.Tensor] = None, amax: Optional[torch.Tensor] = None, scale_inv: Optional[torch.Tensor] = None, - ): + ) -> Float8Tensor: # Manually compute scale-inverse if needed if scale is not None and scale_inv is None: @@ -189,7 +192,10 @@ def forward( ) @staticmethod - def backward(ctx, grad): + def backward( + _ctx: torch.autograd.function.FunctionCtx, # unused + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: # Assume that we want gradients in full precision return grad, None, None, None, None, None, None, None From 53be6336fcb131f814c189fcfd34f03e8f4deda9 Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Tue, 30 Apr 2024 11:14:04 -0700 Subject: [PATCH 042/244] Added pull request template (#793) * Added pull request template Signed-off-by: Przemek Tredak * Changes from the review Signed-off-by: Przemek Tredak --------- Signed-off-by: Przemek Tredak Signed-off-by: Pawel Gadzinski --- .github/PULL_REQUEST_TEMPLATE.md | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 .github/PULL_REQUEST_TEMPLATE.md diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000000..d00d4adf49 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,29 @@ +# Description + +Please include a brief summary of the changes, relevant motivation and context. + +Fixes # (issue) + +## Type of change + +- [ ] Documentation change (change only to the documentation, either a fix or a new content) +- [ ] Bug fix (non-breaking change which fixes an issue) +- [ ] New feature (non-breaking change which adds functionality) +- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) + +## Changes + +Please list the changes introduced in this PR: + +- Change A +- Change B + +# Checklist: + +- [ ] I have read and followed the [contributing guidelines](https://github.com/NVIDIA/TransformerEngine/blob/main/CONTRIBUTING.rst) +- [ ] The functionality is complete +- [ ] I have commented my code, particularly in hard-to-understand areas +- [ ] I have made corresponding changes to the documentation +- [ ] My changes generate no new warnings +- [ ] I have added tests that prove my fix is effective or that my feature works +- [ ] New and existing unit tests pass locally with my changes From 996ed0d87c33cd757288ac1bb42908c92072ec68 Mon Sep 17 00:00:00 2001 From: vasunvidia <108759426+vasunvidia@users.noreply.github.com> Date: Tue, 30 Apr 2024 14:12:00 -0700 Subject: [PATCH 043/244] Fix ring_exchange RS to support CUDA graph capture (#811) Signed-off-by: Vasudevan Rengasamy Co-authored-by: Kirthi Shankar Sivamani Signed-off-by: Pawel Gadzinski --- .../pytorch/csrc/comm_gemm_overlap.h | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h index dfbcfe3e8a..814655a305 100644 --- a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h +++ b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h @@ -99,7 +99,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { } _ubuf = torch::from_blob(_ubuf_ptr, {sample.size(0), sample.size(1)}, sample.options()); - at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); + at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); for (int i = 0; i < std::min(num_max_streams, num_splits); i++) { cudaStream_t stream; cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, -1); @@ -596,7 +596,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ubuf_byte_ptr += ubuf_chunk_bytes; } - at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); + at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); for (int i = 0; i < std::min(num_max_streams, tp_size); i++) { cudaStream_t stream; cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, -1); @@ -691,7 +691,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { assert(pre_gelu_out.numel() == 0); // Catch up the default torch stream - at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); + at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main)); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0)); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0)); @@ -974,7 +974,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { B_scale_inverse = B_scale_inverse[B_fp8_tensor]; // Catch up the main stream - at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); + at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main)); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0)); @@ -1055,8 +1055,10 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { B_scale_inverse = B_scale_inverse[B_fp8_tensor]; // Catch up the main stream - at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); + at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main)); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0)); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0)); for (int i = 0; i < _stream_compute.size(); i++) { CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t) _stream_compute[i], _start_compute, 0)); } @@ -1113,13 +1115,20 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options()); torch::sum_out(rs_output, reduce_buf, 0); } + for (int i = 0; i < _stream_compute.size(); i++) { + CHECK_CUDA( + cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i])); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); + } + CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send)); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0)); } /* ** Copy input to _ubufs[0] */ void copy_input_to_ubuf(torch::Tensor input, bool chunk) { - at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); + at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); if (chunk) { // Copy input to the target ubuf chunk by rank offset if (input.numel() != _ubufs[0].numel() || input.element_size() != _ubufs[0].element_size()) { From 46fc3b05f85b44bf05620f3bd6837876db2bed09 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 30 Apr 2024 14:12:46 -0700 Subject: [PATCH 044/244] Avoid amax roll for non-run modules (#825) Signed-off-by: Kirthi Shankar Sivamani Signed-off-by: Pawel Gadzinski --- .../common/recipe/delayed_scaling.cu | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/transformer_engine/common/recipe/delayed_scaling.cu b/transformer_engine/common/recipe/delayed_scaling.cu index 38e71b74de..de48a53ebf 100644 --- a/transformer_engine/common/recipe/delayed_scaling.cu +++ b/transformer_engine/common/recipe/delayed_scaling.cu @@ -197,16 +197,18 @@ kernel_bulk( const auto last_amax = ((amax_reduction_buffer != nullptr) && (amax_reduction_buffer[offset_in_buffer+count] != 0.0f)) ? amax_reduction_buffer[offset_in_buffer+count] : amax_history[0]; - for (size_t off = 0; off < length; off += bsize) { - const size_t i = off + tid; - float a = 0; - if (i < length) { - a = (i < length - 1) ? amax_history[(i+1)*stride] : last_amax; - amax = fmaxf(amax, a); - } - __syncthreads(); // Inplace roll - if (i < length) { - amax_history[i*stride] = (i > 0) ? a : 0; + if (last_amax != 0.0f) { + for (size_t off = 0; off < length; off += bsize) { + const size_t i = off + tid; + float a = 0; + if (i < length) { + a = (i < length - 1) ? amax_history[(i+1)*stride] : last_amax; + amax = fmaxf(amax, a); + } + __syncthreads(); // Inplace roll + if (i < length) { + amax_history[i*stride] = (i > 0) ? a : 0; + } } } From 850b79095f655b1557d4f6e79282867baf4fc7a8 Mon Sep 17 00:00:00 2001 From: Jinze Xue <155670984+jinzex@users.noreply.github.com> Date: Wed, 1 May 2024 10:28:35 -0700 Subject: [PATCH 045/244] Handle the scaling factor when amax is too tiny that leads to an infinite scale (#786) * Handle the scaling factor when amax is too tiny that leads to an infinite scale Signed-off-by: Jinze Xue * revert formatting changes Signed-off-by: Jinze Xue * fix comments Signed-off-by: Jinze Xue * Apply review suggestion Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Jinze Xue <155670984+jinzex@users.noreply.github.com> * Apply review suggestion Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Jinze Xue <155670984+jinzex@users.noreply.github.com> * Apply review suggestion Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Jinze Xue <155670984+jinzex@users.noreply.github.com> * apply review suggestion Signed-off-by: Jinze Xue * add test_recipe.py to qa/L0_pytorch_unittest/test.sh; fix unittest for is_first_microbatch=False Signed-off-by: Jinze Xue * revert changes to update_weight_scale_inv Signed-off-by: Jinze Xue * Debug test failures Signed-off-by: Tim Moon --------- Signed-off-by: Jinze Xue Signed-off-by: Jinze Xue <155670984+jinzex@users.noreply.github.com> Signed-off-by: Tim Moon Co-authored-by: Jinze Xue Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: Tim Moon Signed-off-by: Pawel Gadzinski --- qa/L0_pytorch_unittest/test.sh | 1 + tests/pytorch/test_recipe.py | 98 ++++++++++++++++++- .../common/recipe/delayed_scaling.cu | 26 +++++ transformer_engine/pytorch/fp8.py | 15 ++- 4 files changed, 138 insertions(+), 2 deletions(-) diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index ded45dd377..2c14664dce 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -8,6 +8,7 @@ set -e pip install pytest==7.2 onnxruntime==1.13.1 pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py +pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index 6b65960ec6..92c7f26f59 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -9,9 +9,10 @@ import transformer_engine.common.recipe import transformer_engine.pytorch as te +import transformer_engine_extensions as tex from transformer_engine.pytorch.fp8 import ( FP8GlobalStateManager, - amax_and_scale_update, + _amax_and_scale_update, get_default_fp8_recipe, ) @@ -162,3 +163,98 @@ def test_amax_and_scale_update( fp8_meta[backward_key].scale_inv, ref_scale_inv_backward, ) + + @pytest.mark.parametrize("amax_case", ["zero", "tiny", "normal", "inf", "nan"]) + @pytest.mark.parametrize("fused_update", [True, False], ids=["fused", "non-fused"]) + @pytest.mark.parametrize( + "fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2], ids=["E4M3", "E5M2"] + ) + def test_scale_update_numeric_scenarios(self, amax_case, fused_update, fp8_dtype): + + if fp8_dtype == tex.DType.kFloat8E4M3: + fp8_format = transformer_engine.common.recipe.Format.E4M3 + fp8_max = fp8_format.value.max_fwd + elif fp8_dtype == tex.DType.kFloat8E5M2: + fp8_format = transformer_engine.common.recipe.Format.HYBRID + fp8_max = fp8_format.value.max_bwd + else: + raise ValueError(f"{fp8_dtype=} is not supported") + + scaling_factor_compute_algo = None + if fused_update: + scaling_factor_compute_algo = ( + lambda amax, scale, fp8_max, recipe: + te.fp8._default_sf_compute(amax, scale, fp8_max, recipe.margin) + ) + recipe = transformer_engine.common.recipe.DelayedScaling( + fp8_format=fp8_format, scaling_factor_compute_algo=scaling_factor_compute_algo + ) + + # Setup fp8_meta dictionary + def setup_fp8_meta(): + with te.fp8_autocast(enabled=True, fp8_recipe=recipe): + module = te.Linear(16, 16) + y = module(torch.zeros([16, 16], device="cuda")) + y.backward(torch.zeros_like(y)) + return module.fp8_meta + + fp8_meta = setup_fp8_meta() + forward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True) + + # Replace the fp8_meta[forward_key] with a new TensorMeta for test purpose + fp8_meta[forward_key] = tex.FP8TensorMeta() + fp8_meta[forward_key].scale = torch.ones(1, dtype=torch.float32, device="cuda") + fp8_meta[forward_key].scale_inv = torch.ones(1, dtype=torch.float32, device="cuda") + + # test different scenarios + if amax_case == "zero": + fp8_meta[forward_key].amax_history = torch.tensor([[0]], dtype=torch.float32, device="cuda") + expected_scale = torch.tensor([1.0], dtype=torch.float32, device="cuda") + elif amax_case == "tiny": + # calculate the minimum amax value that results in a FP32 maximum scale + fp32_max = torch.tensor(torch.finfo(torch.float32).max) + tiny_amax = fp8_max / fp32_max + # make the amax less than the minimum amax so that the scale will be infinite + amax_value = tiny_amax / 2 + fp8_meta[forward_key].amax_history = torch.tensor( + [[amax_value]], dtype=torch.float32, device="cuda" + ) + # expected scale is FP32_max + expected_scale = fp32_max.view(1).cuda() + elif amax_case == "normal": + # plus a small epsilon to avoid zero amax + amax_value = torch.rand(1, dtype=torch.float32, device="cuda") + 1e-5 + fp8_meta[forward_key].amax_history = amax_value.view(1, 1) + expected_scale = fp8_max / amax_value + elif amax_case == "inf": + fp8_meta[forward_key].amax_history = torch.tensor( + [[torch.inf]], dtype=torch.float32, device="cuda" + ) + expected_scale = torch.tensor([1.0], dtype=torch.float32, device="cuda") + elif amax_case == "nan": + fp8_meta[forward_key].amax_history = torch.tensor( + [[torch.nan]], dtype=torch.float32, device="cuda" + ) + expected_scale = torch.tensor([1.0], dtype=torch.float32, device="cuda") + + if fused_update: + tex.fused_amax_and_scale_update_after_reduction( + fp8_meta[forward_key].amax_history.clone().view(-1), + [fp8_meta[forward_key].amax_history], + [fp8_meta[forward_key].scale], + [fp8_meta[forward_key].scale_inv], + recipe.amax_compute_algo, + fp8_dtype, + recipe.margin, + ) + else: + _amax_and_scale_update( + fp8_meta[forward_key].amax_history, + fp8_meta[forward_key].scale, + fp8_meta[forward_key].scale_inv, + fp8_max, + recipe, + ) + + torch.testing.assert_close(fp8_meta[forward_key].scale, expected_scale) + torch.testing.assert_close(fp8_meta[forward_key].scale_inv, torch.reciprocal(expected_scale)) diff --git a/transformer_engine/common/recipe/delayed_scaling.cu b/transformer_engine/common/recipe/delayed_scaling.cu index de48a53ebf..2e232f50e2 100644 --- a/transformer_engine/common/recipe/delayed_scaling.cu +++ b/transformer_engine/common/recipe/delayed_scaling.cu @@ -8,6 +8,7 @@ #include #include +#include #include "../common.h" #include "../util/logging.h" @@ -151,6 +152,13 @@ kernel(const float* amax_history_ptr, } else { scale = scale_ptr[bid]; } + // When the amax is too tiny that the scale becoming infinite in FP32, + // we set the scale to the max value of FP32. In this case, the tensor’s + // amax won't get mapped to the FP8 max representable, but rather + // something below that, but this is the best thing we can do. + if (isinf(scale)) { + scale = std::numeric_limits::max(); + } updated_scale_ptr[bid] = scale; // Update scale inverse @@ -239,12 +247,30 @@ kernel_bulk( // Update scale and scale inverse if (tid == 0) { + // Computing the scaling factor requires consideration of the following scenarios: + // 1. amax == 0: + // No action is possible, set scale to the previous scale (or 1). + // 2. 0 < amax < tiny_amax + // The amax is too tiny that the scale becomes infinite in FP32. + // Set scale = FP32_max + // 3. tiny_amax <= amax < FP32_max: + // Set scale = FP8_max (or scaled_max) / amax + // 4. When amax == inf or amax == nan: + // No action is possible, set scale to the previous scale (or 1). + float scale; if (isfinite(amax) && amax > 0) { scale = scaled_max / amax; } else { scale = p.param[bid].scale[count]; } + // When the amax is too tiny that the scale becoming infinite in FP32, + // we set the scale to the max value of FP32. In this case, the tensor’s + // amax won't get mapped to the FP8 max representable, but rather + // something below that, but this is the best thing we can do. + if (isinf(scale)) { + scale = std::numeric_limits::max(); + } p.param[bid].scale[count] = scale; p.param[bid].scale_inv[count] = 1 / scale; } diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 1f359d4864..b28e380473 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -598,11 +598,24 @@ def _default_sf_compute( scale: torch.Tensor, fp8_max: float, margin: int, + _fp32_max: float = torch.finfo(torch.float32).max, # finfo not available in jitter ) -> torch.Tensor: - """Default function to convert amax to scaling factor.""" + """Default function to convert amax to scaling factor. + Computing the scaling factor requires consideration of the following scenarios: + 1. amax == 0: + No action is possible, set scale to the previous scale (or 1). + 2. 0 < amax < tiny_amax + The amax is too tiny that the scale becomes infinite in FP32. + Set scale = FP32_max + 3. tiny_amax <= amax < FP32_max: + Set scale = FP8_max (or scaled_max) / amax + 4. When amax == inf or amax == nan: + No action is possible, set scale to the previous scale (or 1). + """ sf = (fp8_max / amax) / (2 ** margin) sf = torch.where(amax > 0.0, sf, scale) sf = torch.where(torch.isfinite(amax), sf, scale) + sf = torch.where(torch.isinf(sf), torch.full_like(sf, _fp32_max), sf) scale.copy_(sf) return scale From cd0f62fd2167c52aca7f420af800081d2c94b1e3 Mon Sep 17 00:00:00 2001 From: Ming-Xu Huang Date: Wed, 1 May 2024 16:56:09 -0400 Subject: [PATCH 046/244] [JAX] Support FP8 training for Pipeline Parallelism when Micro-batch > 1 on Paxml. (#774) * Support FP8 Meta Dtype (FM32) and Align FP8 Scale Update with PyTorch. Signed-off-by: Ming Huang * Modify with the feedback of code review Signed-off-by: Ming Huang * Hiding FlaxFloatMeta32 inside fp8.py Signed-off-by: Ming Huang * Make functions to be JAX tracable objects. Signed-off-by: Ming Huang * Rebased with mian. Signed-off-by: Ming Huang * Update jax images for github workflow. Signed-off-by: Ming Huang --------- Signed-off-by: Ming Huang Signed-off-by: Pawel Gadzinski --- .github/workflows/build.yml | 2 +- transformer_engine/jax/dot.py | 13 +- transformer_engine/jax/fp8.py | 46 ++++++ transformer_engine/jax/layernorm.py | 11 +- transformer_engine/jax/mlp.py | 209 ++++++++++++++++------------ 5 files changed, 184 insertions(+), 97 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 711980fa1c..cc302fbdf5 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -31,7 +31,7 @@ jobs: name: 'JAX' runs-on: ubuntu-latest container: - image: ghcr.io/nvidia/jax:latest + image: ghcr.io/nvidia/jax:jax options: --user root steps: - name: 'Checkout' diff --git a/transformer_engine/jax/dot.py b/transformer_engine/jax/dot.py index 00d0bcb99f..bad0582085 100644 --- a/transformer_engine/jax/dot.py +++ b/transformer_engine/jax/dot.py @@ -103,12 +103,18 @@ def _fp8_dot_fwd_rule( fwd_dtype, bwd_dtype, # pylint: disable=unused-argument contracting_dims): + + maybe_fm32_to_fp32, maybe_fp32_to_fm32 = \ + FP8Helper.generate_fp8_meta_dtype_converter_pair(fp8_max, amax, scale, scale_inv) + fp8_max, amax, scale, scale_inv = maybe_fm32_to_fp32(fp8_max, amax, scale, scale_inv) + lhs_contracting_dims, rhs_contracting_dims = contracting_dims x_shape_suf = x.shape[min(lhs_contracting_dims):] kernel_shape_pre = kernel.shape[:max(rhs_contracting_dims) + 1] assert x_shape_suf == kernel_shape_pre + scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale) amax = FP8Helper.update_amax_history(amax) gemm_x_idx, gemm_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0) @@ -130,7 +136,7 @@ def _fp8_dot_fwd_rule( get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP)) ctx = (casted_x, casted_kernel, fp8_max, amax, scale, scale_inv, updated_x_amax, - updated_kernel_amax, x.shape, kernel.shape) + updated_kernel_amax, x.shape, kernel.shape, maybe_fp32_to_fm32) return output, ctx @@ -138,7 +144,8 @@ def _fp8_dot_bwd_rule(fwd_dtype, bwd_dtype, contracting_dims, ctx, grad): # p lhs_contracting_dims, rhs_contracting_dims = contracting_dims casted_x, casted_kernel, fp8_max, amax, scale, scale_inv, \ - updated_x_amax, updated_kernel_amax, x_shape, kernel_shape = ctx + updated_x_amax, updated_kernel_amax, x_shape, kernel_shape, \ + maybe_fp32_to_fm32 = ctx gemm_x_idx, gemm_kernel_idx, gemm_grad_idx = FP8Helper.get_fp8_meta_indices(0) @@ -170,7 +177,7 @@ def _fp8_dot_bwd_rule(fwd_dtype, bwd_dtype, contracting_dims, ctx, grad): # p amax = amax.at[gemm_kernel_idx, 0].set(updated_kernel_amax) amax = amax.at[gemm_grad_idx, 0].set(updated_grad_amax[0]) - scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale) + fp8_max, amax, scale, scale_inv = maybe_fp32_to_fm32(fp8_max, amax, scale, scale_inv) return dgrad, wgrad, fp8_max, amax, scale, scale_inv diff --git a/transformer_engine/jax/fp8.py b/transformer_engine/jax/fp8.py index 2c57ef426f..cbd357e22e 100644 --- a/transformer_engine/jax/fp8.py +++ b/transformer_engine/jax/fp8.py @@ -11,6 +11,7 @@ import jax import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict +from flax.linen import fp8_ops from transformer_engine_jax import DType from transformer_engine_jax import get_cublasLt_version @@ -67,6 +68,15 @@ def _format2dtypes(format_: Format): return jnp.bfloat16, jnp.bfloat16 +# fm32 is a custom dtype to specify the "add" rules as max operation. +# This is typically used in Pipeline Parallelism + "MiconBatching > 1", +# which is implemented via nn.scan. Without this custom dtype, nn.scan +# would sum gradients from all micro-batches, and this is not the expected +# behavior for FP8 meta. Instead, the summation of FP8 meta gradients should +# be "MAX". +FlaxFloatMeta32 = fp8_ops.fm32 + + class FP8MetaPackage: """ A container that contains all required meta data for FP8 @@ -303,6 +313,42 @@ def _update_fp8_metas_impl(fp8_metas: Collection) -> Collection: return jax.tree_util.tree_unflatten(treedef, fp8_meta_arrays) + @staticmethod + def generate_fp8_meta_dtype_converter_pair(*args): + """ + Generate a pair of conversion fun in-between fm32 and fp32. + """ + + def identical_fun(*metas): + return metas + + def fm32_to_fp32_fun(*metas): + for meta in metas: + assert meta.dtype == FlaxFloatMeta32 + return [jax.lax.convert_element_type(meta, jnp.float32) for meta in metas] + + def fp32_to_fm32_fun(*metas): + for meta in metas: + assert meta.dtype == jnp.float32 + return [jax.lax.convert_element_type(meta, FlaxFloatMeta32) for meta in metas] + + # Make functions to be a vaild JAX type + partial_identical_fun = jax.tree_util.Partial(identical_fun) + partial_fm32_to_fp32_fun = jax.tree_util.Partial(fm32_to_fp32_fun) + partial_fp32_to_fm32_fun = jax.tree_util.Partial(fp32_to_fm32_fun) + + if len(args) < 1: + return partial_identical_fun, partial_identical_fun + + original_dtype = args[0].dtype + for arg in args: + assert arg.dtype == original_dtype + + if original_dtype == FlaxFloatMeta32: + return partial_fm32_to_fp32_fun, partial_fp32_to_fm32_fun + + return partial_identical_fun, partial_identical_fun + @staticmethod def update_amax_history(amax: jnp.ndarray) -> jnp.ndarray: """ diff --git a/transformer_engine/jax/layernorm.py b/transformer_engine/jax/layernorm.py index acf49639d4..707778e2c7 100644 --- a/transformer_engine/jax/layernorm.py +++ b/transformer_engine/jax/layernorm.py @@ -162,6 +162,11 @@ def _layernorm_fp8_dot_fwd_rule( k_contracting_dims = (0,) assert x.shape[-1] == kernel.shape[0] + maybe_fm32_to_fp32, maybe_fp32_to_fm32 = \ + FP8Helper.generate_fp8_meta_dtype_converter_pair(fp8_max, amax, scale, scale_inv) + fp8_max, amax, scale, scale_inv = maybe_fm32_to_fp32(fp8_max, amax, scale, scale_inv) + + scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale) amax = FP8Helper.update_amax_history(amax) gemm_x_idx, gemm_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0) @@ -216,7 +221,7 @@ def _layernorm_fp8_dot_fwd_rule( ctx = (ln_out, casted_kernel, fp8_max, amax, scale, scale_inv, updated_x_amax, updated_kernel_amax, x.shape, kernel.shape, mu, rsigma, x, gamma, x_contracting_dims, - k_contracting_dims) + k_contracting_dims, maybe_fp32_to_fm32) return output, ctx @@ -234,7 +239,7 @@ def _layernorm_fp8_dot_bwd_rule( ln_out_, casted_kernel, fp8_max, amax, scale, scale_inv, \ updated_x_amax, updated_kernel_amax, \ x_shape, kernel_shape, mu, rsigma, x, gamma, \ - x_contracting_dims, k_contracting_dims = ctx + x_contracting_dims, k_contracting_dims, maybe_fp32_to_fm32 = ctx ln_out_t = transpose(ln_out_, static_axis_boundary=-1, transpose_axis_boundary=-1) @@ -282,7 +287,7 @@ def _layernorm_fp8_dot_bwd_rule( amax = amax.at[gemm_kernel_idx, 0].set(updated_kernel_amax[0]) amax = amax.at[gemm_grad_idx, 0].set(updated_grad_amax[0]) - scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale) + fp8_max, amax, scale, scale_inv = maybe_fp32_to_fm32(fp8_max, amax, scale, scale_inv) return dx, wgrad, \ dgamma, dbeta, \ diff --git a/transformer_engine/jax/mlp.py b/transformer_engine/jax/mlp.py index 1900e3f441..a9761499c0 100644 --- a/transformer_engine/jax/mlp.py +++ b/transformer_engine/jax/mlp.py @@ -26,27 +26,42 @@ from .fp8 import FP8Helper, FP8MetaPackage from .sharding import with_sharding_constraint_by_logical_axes - activation_dict = { - ('gelu',): {'fwd': gelu, - "bwd": dgelu}, - ('gelu', 'linear'): {'fwd': gated_gelu, - 'bwd': dgated_gelu}, - ('silu',): {'fwd': silu, - "bwd": dsilu }, - ('silu', 'linear'): {'fwd': gated_silu, - 'bwd': dgated_silu} + ('gelu',): { + 'fwd': gelu, + "bwd": dgelu + }, + ('gelu', 'linear'): { + 'fwd': gated_gelu, + 'bwd': dgated_gelu + }, + ('silu',): { + 'fwd': silu, + "bwd": dsilu + }, + ('silu', 'linear'): { + 'fwd': gated_silu, + 'bwd': dgated_silu + } } activation_fp8_dict = { - ('gelu',): {'fwd': gelu_fp8, - 'bwd': dgelu_dbias_cast_transpose}, - ('gelu', 'linear'): {'fwd': gated_gelu_fp8, - 'bwd': dgated_gelu_cast_transpose}, - ('silu',): { 'fwd': silu_fp8, - 'bwd': dsilu_dbias_cast_transpose }, - ('silu', 'linear'): { 'fwd': gated_silu_fp8, - 'bwd': dgated_silu_cast_transpose } + ('gelu',): { + 'fwd': gelu_fp8, + 'bwd': dgelu_dbias_cast_transpose + }, + ('gelu', 'linear'): { + 'fwd': gated_gelu_fp8, + 'bwd': dgated_gelu_cast_transpose + }, + ('silu',): { + 'fwd': silu_fp8, + 'bwd': dsilu_dbias_cast_transpose + }, + ('silu', 'linear'): { + 'fwd': gated_silu_fp8, + 'bwd': dgated_silu_cast_transpose + } } @@ -55,10 +70,11 @@ def activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable] Activation Unit """ if len(activation_type) > 1: - assert x.shape[-2] == 2 # Linear + GeLU + assert x.shape[-2] == 2 # Linear + GeLU output = _activation_lu(x, activation_type) return output + @partial(jax.custom_vjp, nondiff_argnums=(1,)) def _activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]): @@ -66,10 +82,12 @@ def _activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable return _output + def _activation_lu_fwd_rule(x, activation_type): fwd_output = activation_dict[activation_type]["fwd"](x) return fwd_output, (x,) + def _activation_lu_bwd_rule(activation_type, ctx, g): x, = ctx assert x.dtype == g.dtype @@ -78,11 +96,12 @@ def _activation_lu_bwd_rule(activation_type, ctx, g): dx = jnp.reshape(dx, x.shape) return (dx,) + _activation_lu.defvjp(_activation_lu_fwd_rule, _activation_lu_bwd_rule) -def activation_lu_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, - scale_inv: jnp.ndarray, fwd_dtype:jnp.dtype, bwd_dtype: jnp.dtype, +def activation_lu_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray, + fwd_dtype: jnp.dtype, bwd_dtype: jnp.dtype, activation_type: Sequence[Union[str, Callable]]): """ Activation Unit @@ -91,43 +110,51 @@ def activation_lu_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, dx_trans_no_use = jnp.empty([x.shape[i] for i in transpose_indices], dtype=x.dtype) dbias_no_use = jnp.empty(x.shape[-1], dtype=x.dtype) - output = _activation_lu_fp8(x, dx_trans_no_use, dbias_no_use, amax, - scale, scale_inv, fwd_dtype, bwd_dtype, activation_type) + output = _activation_lu_fp8(x, dx_trans_no_use, dbias_no_use, amax, scale, scale_inv, fwd_dtype, + bwd_dtype, activation_type) return output -@partial(jax.custom_vjp, nondiff_argnums=(6,7,8)) -def _activation_lu_fp8(x: jnp.ndarray, - dx_trans_no_use: jnp.ndarray, dbias_no_use: jnp.ndarray, + +@partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8)) +def _activation_lu_fp8(x: jnp.ndarray, dx_trans_no_use: jnp.ndarray, dbias_no_use: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray, fwd_dtype: jnp.dtype, bwd_dtype: jnp.dtype, activation_type: Sequence[Union[str, Callable]]): - output = _activation_lu_fp8_fwd_rule(x, dx_trans_no_use, dbias_no_use, amax, - scale, scale_inv, fwd_dtype, bwd_dtype, - activation_type) + output = _activation_lu_fp8_fwd_rule(x, dx_trans_no_use, dbias_no_use, amax, scale, scale_inv, + fwd_dtype, bwd_dtype, activation_type) return output -def _activation_lu_fp8_fwd_rule(x, - dx_trans_no_use, # pylint: disable=unused-argument - dbias_no_use, # pylint: disable=unused-argument - amax, - scale, scale_inv, - fwd_dtype, bwd_dtype, # pylint: disable=unused-argument - activation_type): - activation_lu_out, _ = activation_fp8_dict[activation_type ]["fwd"]( - x, amax, scale, scale_inv, fwd_dtype) + +def _activation_lu_fp8_fwd_rule( + x, + dx_trans_no_use, # pylint: disable=unused-argument + dbias_no_use, # pylint: disable=unused-argument + amax, + scale, + scale_inv, + fwd_dtype, + bwd_dtype, # pylint: disable=unused-argument + activation_type): + activation_lu_out, _ = activation_fp8_dict[activation_type]["fwd"](x, amax, scale, scale_inv, + fwd_dtype) activation_lu_out = dequantize(activation_lu_out, x.dtype, scale_inv) ctx = (x, amax, scale, scale_inv) return activation_lu_out, ctx -def _activation_lu_fp8_bwd_rule(fwd_dtype, bwd_dtype, # pylint: disable=unused-argument - activation_type, ctx, g): + +def _activation_lu_fp8_bwd_rule( + fwd_dtype, # pylint: disable=unused-argument + bwd_dtype, + activation_type, + ctx, + g): x, amax, scale, scale_inv = ctx activation_lu_fp8_bwd = activation_fp8_dict[activation_type]["bwd"] - if len(activation_type) > 1: #gated, no bias + if len(activation_type) > 1: #gated, no bias dactivation_lu, dactivation_lu_trans, amax_out = \ activation_lu_fp8_bwd(g, x, amax, scale, scale_inv, bwd_dtype, -1) dbias = jnp.empty(x.shape[-1], x.dtype) @@ -139,25 +166,26 @@ def _activation_lu_fp8_bwd_rule(fwd_dtype, bwd_dtype, # pylint: disable=unused ctx = (dactivation_lu, dactivation_lu_trans, dbias, amax_out, scale, scale_inv) return ctx + _activation_lu_fp8.defvjp(_activation_lu_fp8_fwd_rule, _activation_lu_fp8_bwd_rule) def fused_layernorm_fp8_mlp(x: jnp.ndarray, - gamma: jnp.ndarray, - beta: jnp.ndarray, - kernels: List[jnp.ndarray], - biases: List[jnp.ndarray], - fp8_gemm_pkg: FP8MetaPackage, - layernorm_type: str, - zero_centered_gamma: bool = False, - epsilon: float = 1e-6, - layernorm_input_axes: Tuple[str, ...] = None, - dot_1_input_axes: Tuple[str, ...] = None, - dot_2_input_axes: Tuple[str, ...] = None, - ffn1_ckpt_name: str = 'ffn1', - ffn2_ckpt_name: str = 'ffn2', - activation_type: Sequence[Union[str, Callable]] = ('gelu',), - use_bias: bool = True) -> jnp.ndarray: + gamma: jnp.ndarray, + beta: jnp.ndarray, + kernels: List[jnp.ndarray], + biases: List[jnp.ndarray], + fp8_gemm_pkg: FP8MetaPackage, + layernorm_type: str, + zero_centered_gamma: bool = False, + epsilon: float = 1e-6, + layernorm_input_axes: Tuple[str, ...] = None, + dot_1_input_axes: Tuple[str, ...] = None, + dot_2_input_axes: Tuple[str, ...] = None, + ffn1_ckpt_name: str = 'ffn1', + ffn2_ckpt_name: str = 'ffn2', + activation_type: Sequence[Union[str, Callable]] = ('gelu',), + use_bias: bool = True) -> jnp.ndarray: """ Layernorm + GEMM1 + bias + activation + GEMM2 + bias """ @@ -184,31 +212,28 @@ def fused_layernorm_fp8_mlp(x: jnp.ndarray, "if layernorm_type is 'rmsnorm'" output = _fused_layernorm_fp8_mlp(x, gamma, beta, kernel_1, kernel_2, bias_1, bias_2, fp8_max, - amax, scale, scale_inv, fwd_dtype, bwd_dtype, layernorm_type, - zero_centered_gamma, epsilon, layernorm_input_axes, - dot_1_input_axes, dot_2_input_axes, ffn1_ckpt_name, - ffn2_ckpt_name, activation_type, use_bias) + amax, scale, scale_inv, fwd_dtype, bwd_dtype, layernorm_type, + zero_centered_gamma, epsilon, layernorm_input_axes, + dot_1_input_axes, dot_2_input_axes, ffn1_ckpt_name, + ffn2_ckpt_name, activation_type, use_bias) return output @partial(jax.custom_vjp, nondiff_argnums=(11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22)) def _fused_layernorm_fp8_mlp(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, - kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, bias_1: jnp.ndarray, - bias_2: jnp.ndarray, fp8_max: jnp.ndarray, amax: jnp.ndarray, - scale: jnp.ndarray, scale_inv: jnp.ndarray, fwd_dtype: jnp.dtype, - bwd_dtype: jnp.dtype, layernorm_type: str, zero_centered_gamma: bool, - epsilon: float, layernorm_input_axes: Tuple[str, ...], - dot_1_input_axes: Tuple[str, ...], dot_2_input_axes: Tuple[str, ...], - ffn1_ckpt_name: str, ffn2_ckpt_name: str, - activation_type: Sequence[Union[str, Callable]], - use_bias: bool): - output, _ = _fused_layernorm_fp8_mlp_fwd_rule(x, gamma, beta, kernel_1, kernel_2, bias_1, - bias_2, fp8_max, amax, scale, scale_inv, - fwd_dtype, bwd_dtype, layernorm_type, - zero_centered_gamma, epsilon, - layernorm_input_axes, dot_1_input_axes, - dot_2_input_axes, ffn1_ckpt_name, ffn2_ckpt_name, - activation_type, use_bias) + kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, bias_1: jnp.ndarray, + bias_2: jnp.ndarray, fp8_max: jnp.ndarray, amax: jnp.ndarray, + scale: jnp.ndarray, scale_inv: jnp.ndarray, fwd_dtype: jnp.dtype, + bwd_dtype: jnp.dtype, layernorm_type: str, zero_centered_gamma: bool, + epsilon: float, layernorm_input_axes: Tuple[str, ...], + dot_1_input_axes: Tuple[str, ...], dot_2_input_axes: Tuple[str, ...], + ffn1_ckpt_name: str, ffn2_ckpt_name: str, + activation_type: Sequence[Union[str, Callable]], use_bias: bool): + output, _ = _fused_layernorm_fp8_mlp_fwd_rule( + x, gamma, beta, kernel_1, kernel_2, bias_1, bias_2, fp8_max, amax, scale, scale_inv, + fwd_dtype, bwd_dtype, layernorm_type, zero_centered_gamma, epsilon, layernorm_input_axes, + dot_1_input_axes, dot_2_input_axes, ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + use_bias) return output @@ -256,6 +281,11 @@ def _fused_layernorm_fp8_mlp_fwd_rule( if not is_gated: kernel_1 = jnp.squeeze(kernel_1, axis=-2) + maybe_fm32_to_fp32, maybe_fp32_to_fm32 = \ + FP8Helper.generate_fp8_meta_dtype_converter_pair(fp8_max, amax, scale, scale_inv) + fp8_max, amax, scale, scale_inv = maybe_fm32_to_fp32(fp8_max, amax, scale, scale_inv) + + scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale) amax = FP8Helper.update_amax_history(amax) gemm1_x_idx, gemm1_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0) @@ -324,8 +354,8 @@ def _fused_layernorm_fp8_mlp_fwd_rule( activation_lu_fwd_fp8(dot_1_output, activation_lu_out_amax, activation_lu_out_scale, activation_lu_out_scale_inv, fwd_dtype) - casted_activation_lu_out = with_sharding_constraint_by_logical_axes(casted_activation_lu_out, - dot_2_input_axes) + casted_activation_lu_out = with_sharding_constraint_by_logical_axes( + casted_activation_lu_out, dot_2_input_axes) kernel_2_scale = scale[gemm2_kernel_idx] kernel_2_scale_inv = scale_inv[gemm2_kernel_idx] @@ -335,8 +365,8 @@ def _fused_layernorm_fp8_mlp_fwd_rule( # (batch..., hidden_in) x (hidden_out, hidden_in) dot_2_output = fp8_dot_impl(casted_activation_lu_out, casted_kernel_2, - activation_lu_out_scale_inv, - kernel_2_scale_inv, x.dtype, (x_contracting_dims, (0,)), + activation_lu_out_scale_inv, kernel_2_scale_inv, x.dtype, + (x_contracting_dims, (0,)), get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP)) if use_bias: @@ -348,7 +378,7 @@ def _fused_layernorm_fp8_mlp_fwd_rule( ctx = (x, ln_out, mu, rsigma, gamma, dot_1_output, casted_activation_lu_out, casted_kernel_1, casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, updated_activation_lu_amax, updated_kernel_1_amax, updated_kernel_2_amax, - x_contracting_dims, xt_batch_dims, bias_1.shape, bias_2.shape) + x_contracting_dims, xt_batch_dims, bias_1.shape, bias_2.shape, maybe_fp32_to_fm32) return dot_2_output, ctx @@ -371,7 +401,7 @@ def _fused_layernorm_fp8_mlp_bwd_rule( x, ln_out, mu, rsigma, gamma, dot_1_output, casted_activation_lu_out, \ casted_kernel_1, casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, \ updated_activation_lu_amax, updated_kernel_1_amax, updated_kernel_2_amax, \ - x_contracting_dims, xt_batch_dims, bias_1_shape, bias_2_shape= ctx + x_contracting_dims, xt_batch_dims, bias_1_shape, bias_2_shape, maybe_fp32_to_fm32 = ctx is_gated = len(activation_type) > 1 @@ -481,8 +511,7 @@ def _fused_layernorm_fp8_mlp_bwd_rule( xt_batch_dims_2 = xt_batch_dims if not is_gated \ else tuple(i + 1 for i in xt_batch_dims) wgrad_1 = fp8_dot_impl(ln_out_t, casted_dactivation_lu_t, gemm1_x_scale_inv, - dactivation_lu_scale_inv, grad.dtype, - (xt_batch_dims, xt_batch_dims_2), + dactivation_lu_scale_inv, grad.dtype, (xt_batch_dims, xt_batch_dims_2), get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD)) # Expand act axis to match the shape with the given kernel_1 if not is_gated: @@ -490,14 +519,13 @@ def _fused_layernorm_fp8_mlp_bwd_rule( # (batch..., hidden_out) x (hidden_in, hidden_out) if is_gated: - x_contracting_dims = ((min(x_contracting_dims),) + tuple( - i + 1 for i in x_contracting_dims), (1,2)) + x_contracting_dims = ((min(x_contracting_dims),) + tuple(i + 1 for i in x_contracting_dims), + (1, 2)) else: x_contracting_dims = (x_contracting_dims, (1,)) kernel_1_scale_inv = scale_inv[gemm1_kernel_idx] - dgrad_1 = fp8_dot_impl(casted_dactivation_lu, casted_kernel_1, - dactivation_lu_scale_inv, kernel_1_scale_inv, - grad.dtype, x_contracting_dims, + dgrad_1 = fp8_dot_impl(casted_dactivation_lu, casted_kernel_1, dactivation_lu_scale_inv, + kernel_1_scale_inv, grad.dtype, x_contracting_dims, get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD)) dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, layernorm_input_axes) @@ -523,10 +551,11 @@ def _fused_layernorm_fp8_mlp_bwd_rule( amax = amax.at[gemm2_kernel_idx, 0].set(updated_kernel_2_amax) amax = amax.at[gemm2_grad_idx, 0].set(updated_grad_amax[0]) - scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale) + fp8_max, amax, scale, scale_inv = maybe_fp32_to_fm32(fp8_max, amax, scale, scale_inv) + return dx, dgamma, dbeta, wgrad_1, wgrad_2, dbias_1, dbias_2, \ fp8_max, amax, scale, scale_inv _fused_layernorm_fp8_mlp.defvjp(_fused_layernorm_fp8_mlp_fwd_rule, - _fused_layernorm_fp8_mlp_bwd_rule) + _fused_layernorm_fp8_mlp_bwd_rule) From da9ee4de6e9dce4d68dad132f842d01c702fe707 Mon Sep 17 00:00:00 2001 From: cyanguwa <8636796+cyanguwa@users.noreply.github.com> Date: Wed, 1 May 2024 20:41:59 -0700 Subject: [PATCH 047/244] [PyTorch] Miscellanous fixes for FP8 DPA module (#804) * initialize tp_group for FP8 DPA Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix cuDNN version in unit tests for cuDNN v9 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add hook to ignore missing fused_attn._extra_states if training from old checkpoints Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove test and redundant implementation from last commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove warning message and replace with docstring Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove tp_size/tp_group in FusedAttention; amax reduction is handled with fp8_group Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * move core_attention.fused_attention._extra_state to core_attention._extra_state Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * simplify post_state_dict_hooks between FU and DPA Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add temporary test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove previous attempts to move core_attention.fused_attention to core_attention; keep the test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove the test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * disable pylint self arg for hook which is required by hook Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Signed-off-by: cyanguwa <8636796+cyanguwa@users.noreply.github.com> Signed-off-by: Pawel Gadzinski --- tests/pytorch/fused_attn/test_fused_attn.py | 3 ++- transformer_engine/pytorch/attention.py | 12 ++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index 40cfdd34b7..caba385d46 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -70,7 +70,8 @@ def reset_global_fp8_state(): def _cudnn_version() -> Tuple[int, int, int]: """Runtime cuDNN version (major, minor, patch)""" encoded_version = ext.get_cudnn_version() - major, encoded_version = divmod(encoded_version, 1000) + major_version_magnitude = 1000 if encoded_version < 90000 else 10000 + major, encoded_version = divmod(encoded_version, major_version_magnitude) minor, patch = divmod(encoded_version, 100) return (major, minor, patch) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index dbc26d538d..2f5a6aa671 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -2929,6 +2929,17 @@ def __init__( if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "1": os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1" + def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument + """ + Temporarily remove fused_attention._extra_state as a missing key + when loading older TransformerEngine checkpoints. Will phase out + this hook in TransformerEngine 2.0. + """ + for key in incompatible_keys.missing_keys: + if 'fused_attention._extra_state' in key: + incompatible_keys.missing_keys.remove(key) + self.register_load_state_dict_post_hook(remove_extra_states_check) + def get_fp8_weights_scratchpad( self, is_first_microbatch: Union[bool, None], @@ -3282,6 +3293,7 @@ def __init__( layer_number=layer_number, deterministic=self.deterministic, **attn_kwargs) + self.unfused_attention = UnfusedDotProductAttention( norm_factor, **attn_kwargs, layer_number=layer_number) From 4afb291e09da80fe53de117c5593f53126cc43ad Mon Sep 17 00:00:00 2001 From: Reese Wang Date: Fri, 3 May 2024 00:40:26 +0800 Subject: [PATCH 048/244] [JAX] Enhance JAX unit tests (#796) * Add layernorm_fp8_dot unit test Signed-off-by: Reese Wang * Update the softmax primitives support conditions Signed-off-by: Reese Wang * Add tests for the softmax primitives Signed-off-by: Reese Wang * Round1 refactor of test_layer Signed-off-by: Reese Wang * Split dropout arguments of ref code and add hidden/intermediate dropout elementwise comparison Signed-off-by: Reese Wang * Add dropout_braodcast_dim, self_attn_mask tests and clean a few code Signed-off-by: Reese Wang * Abstract test layer and fix a rope reference code diff Signed-off-by: Reese Wang * Add bias tests Signed-off-by: Reese Wang * Add epsilon and float32 tests Signed-off-by: Reese Wang * Add relpos_bias and attention dropout tests Signed-off-by: Reese Wang * Loose the atol Signed-off-by: Reese Wang * Move common fixtures to conftest.py Signed-off-by: Reese Wang * Add doc string for test_layer Signed-off-by: Reese Wang * Add doc string for test_layer Signed-off-by: Reese Wang * Fix conflicts of test_layer Signed-off-by: Reese Wang * Avoid to left bias parameters in graph when use_bias=False Signed-off-by: Reese Wang --------- Signed-off-by: Reese Wang Signed-off-by: Pawel Gadzinski --- tests/jax/conftest.py | 16 + tests/jax/test_custom_call_compute.py | 236 ++++--- tests/jax/test_fused_attn.py | 8 +- tests/jax/test_layer.py | 759 +++++++-------------- tests/jax/test_praxis_layers.py | 10 - tests/jax/test_softmax.py | 165 +++++ tests/jax/utils.py | 449 +++++++----- transformer_engine/jax/cpp_extensions.py | 13 +- transformer_engine/jax/flax/module.py | 34 +- transformer_engine/jax/flax/transformer.py | 5 +- 10 files changed, 889 insertions(+), 806 deletions(-) create mode 100644 tests/jax/conftest.py create mode 100644 tests/jax/test_softmax.py diff --git a/tests/jax/conftest.py b/tests/jax/conftest.py new file mode 100644 index 0000000000..5f1aaa4c39 --- /dev/null +++ b/tests/jax/conftest.py @@ -0,0 +1,16 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""conftest for tests/jax""" +import jax +import pytest + + +@pytest.fixture(autouse=True, scope='function') +def clear_live_arrays(): + """ + Clear all live arrays to keep the resource clean + """ + yield + for arr in jax.live_arrays(): + arr.delete() diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 2d4c9b7e32..6555aa29ac 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -2,6 +2,7 @@ # # See LICENSE for license information. +from contextlib import nullcontext import functools import operator from typing import Callable, Sequence, Union @@ -10,7 +11,6 @@ import jax.numpy as jnp import numpy as np import pytest -from jax import lax from jax import jit, value_and_grad from flax import linen as nn @@ -18,7 +18,7 @@ from transformer_engine.jax.dot import type_safe_dot_general, dequantize, quantize from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper from transformer_engine.jax.fp8 import is_fp8_available -from transformer_engine.jax.layernorm import layernorm +from transformer_engine.jax.layernorm import layernorm, layernorm_fp8_dot from transformer_engine.jax.mlp import activation_lu, activation_lu_fp8, fused_layernorm_fp8_mlp @@ -45,16 +45,6 @@ def _convert_to_activation_function(fn_or_string): raise ValueError(f"don't know how to convert {fn_or_string} to an activation function") -@pytest.fixture(autouse=True, scope='function') -def clear_live_arrays(): - """ - Clear all live arrays to keep the resource clean - """ - yield - for arr in jax.live_arrays(): - arr.delete() - - class TestFP8Dot: @pytest.mark.skipif(not is_fp8_supported, reason=reason) @@ -416,88 +406,150 @@ def test_activation_lu(self, random_inputs, activation_type): dtype=FP8Helper.BWD_DTYPE) -class TestRMSNorm: - - @pytest.mark.parametrize('n, hidden', LN_CASES) - @pytest.mark.parametrize('dtype', DTYPES) - def test_forward_backward(self, n, hidden, dtype): - key = jax.random.PRNGKey(0) - subkeys = jax.random.split(key, 2) - - x = jax.random.uniform(subkeys[0], (n, hidden), dtype, -2, 1) - scale = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, -2, 1) - scale = jnp.asarray(scale, dtype) - epsilon = 1e-6 - - def reference_rmsnorm(x, scale): - x = jnp.asarray(x, jnp.float32) - mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True) - y = jnp.asarray(x * lax.rsqrt(mean2 + epsilon), dtype) - return y * scale - - jitted_primitive = jit( - value_and_grad(lambda x, scale: jnp.mean(layernorm(x, scale, None, "rmsnorm")), (0, 1))) - - jitted_reference = jit( - value_and_grad(lambda x, scale: jnp.mean(reference_rmsnorm(x, scale)), (0, 1))) - - primitive_out, (primitive_dx, primitive_dgamma) = jitted_primitive(x, scale) - reference_out, (reference_dx, reference_dgamma) = jitted_reference(x, scale) - - assert_allclose(primitive_out, reference_out, dtype=dtype) - assert_allclose(primitive_dx, reference_dx, dtype=dtype) - assert_allclose(primitive_dgamma, reference_dgamma, dtype=dtype) - +class TestNorm: + """ + Test transformer_engine.jax.layernorm APIs + """ -class TestLayerNorm: + def reference_layernorm(self, x, scale, bias, zero_centered_gamma, eps): + """ + JAX native layernorm implementations + - bias is not None: layernorm + - bias is None: rmsnorm + """ + x_ = jnp.asarray(x, jnp.float32) + if bias is None: + mean = 0. + else: + mean = jnp.mean(x_, axis=-1, keepdims=True) + var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True) + normed_input = (x_ - mean) * jax.lax.rsqrt(var + eps) + if zero_centered_gamma: + scale += 1. + if bias is None: + bias = 0. + return jnp.asarray(normed_input * scale + bias).astype(x.dtype) @pytest.mark.parametrize('n, hidden', LN_CASES) @pytest.mark.parametrize('dtype', DTYPES) + @pytest.mark.parametrize('ln_type', ['layernorm', 'rmsnorm']) @pytest.mark.parametrize('zero_centered_gamma', [False, True]) - def test_forward_backward(self, n, hidden, zero_centered_gamma, dtype): - key = jax.random.PRNGKey(0) - subkeys = jax.random.split(key, 3) - - x = jax.random.uniform(subkeys[0], (n, hidden), dtype, -1, 1) - scale_range = (-1, 1) if zero_centered_gamma else (0, 2) - scale = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, *scale_range) - scale = jnp.asarray(scale, dtype) - bias = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1) - bias = jnp.asarray(bias, dtype) - epsilon = 1e-6 - - def reference_layernorm(x, scale, bias, zero_centered_gamma, eps): - x_ = jnp.asarray(x, jnp.float32) - mean = jnp.mean(x_, axis=-1, keepdims=True) - var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True) - normed_input = (x_ - mean) * jax.lax.rsqrt(var + eps) - # Align TE implementation - if zero_centered_gamma: - return jnp.asarray(normed_input * (scale + 1) + bias).astype(x.dtype) - return jnp.asarray(normed_input * scale + bias).astype(x.dtype) - - def compute_loss(x): - # Higher precision to compute the loss - x_ = x.astype(jnp.float32) - return jnp.mean(jnp.square(x_)).astype(x.dtype) - - jitted_primitive = jit( - value_and_grad( - lambda x, scale, bias: compute_loss( - layernorm(x, scale, bias, "layernorm", zero_centered_gamma, epsilon)), - (0, 1, 2))) - - jitted_reference = jit( - value_and_grad( - lambda x, scale, bias: compute_loss( - reference_layernorm(x, scale, bias, zero_centered_gamma, epsilon)), (0, 1, 2))) - - primitive_out, (primitive_dx, primitive_dgamma, - primitive_dbeta) = jitted_primitive(x, scale, bias) - reference_out, (reference_dx, reference_dgamma, - reference_dbeta) = jitted_reference(x, scale, bias) - - assert_allclose(primitive_out, reference_out, dtype=dtype) - assert_allclose(primitive_dx, reference_dx, dtype=dtype) - assert_allclose(primitive_dgamma, reference_dgamma, dtype=dtype) - assert_allclose(primitive_dbeta, reference_dbeta, dtype=dtype) + @pytest.mark.parametrize('epsilon', [1e-2, 1e-6]) + def test_layernorm_forward_backward(self, n, hidden, ln_type, zero_centered_gamma, epsilon, + dtype): + """ + Test transformer_engine.jax.layernorm.layernorm + """ + expect_assert = False + if ln_type == 'rmsnorm' and zero_centered_gamma: + # zero_centered_gamma is not supported for rmsnorm, expect an assertion. + expect_assert = True + + with pytest.raises(AssertionError, match=r".*zero_centered_gamma is not supported.*" + ) if expect_assert else nullcontext(): + key = jax.random.PRNGKey(0) + subkeys = jax.random.split(key, 3) + + x = jax.random.uniform(subkeys[0], (n, hidden), dtype, -1, 1) + gamma_range = (-1, 1) if zero_centered_gamma else (0, 2) + gamma = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, *gamma_range) + gamma = jnp.asarray(gamma, dtype) + if ln_type == 'layernorm': + beta = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1) + beta = jnp.asarray(beta, dtype) + else: + beta = None + + def compute_loss(x): + # Higher precision to compute the loss + x_ = x.astype(jnp.float32) + return jnp.mean(jnp.square(x_)).astype(x.dtype) + + jitted_primitive = jit( + value_and_grad( + lambda x, gamma, beta: compute_loss( + layernorm(x, gamma, beta, ln_type, zero_centered_gamma, epsilon)), + (0, 1, 2))) + + jitted_reference = jit( + value_and_grad( + lambda x, gamma, beta: compute_loss( + self.reference_layernorm(x, gamma, beta, zero_centered_gamma, epsilon)), + (0, 1, 2))) + + primitive_out, (primitive_dx, primitive_dgamma, + primitive_dbeta) = jitted_primitive(x, gamma, beta) + reference_out, (reference_dx, reference_dgamma, + reference_dbeta) = jitted_reference(x, gamma, beta) + + assert_allclose(primitive_out, reference_out, dtype=dtype) + assert_allclose(primitive_dx, reference_dx, dtype=dtype) + assert_allclose(primitive_dgamma, reference_dgamma, dtype=dtype) + if beta is not None: + assert_allclose(primitive_dbeta, reference_dbeta, dtype=dtype) + + @pytest.mark.skipif(not is_fp8_supported, reason=reason) + @pytest.mark.parametrize('m,n,k', GEMM_CASES) + @pytest.mark.parametrize('ln_type', ['layernorm', 'rmsnorm']) + @pytest.mark.parametrize('zero_centered_gamma', [True, False]) + @pytest.mark.parametrize('epsilon', [1e-2, 1e-6]) + def test_ln_fp8_dot_forward_backward(self, m, n, k, ln_type, zero_centered_gamma, epsilon): + """ + Test transformer_engine.jax.layernorm.layernorm_fp8_dot + """ + expect_assert = False + if ln_type == 'rmsnorm' and zero_centered_gamma: + # zero_centered_gamma is not supported for rmsnorm, expect an assertion. + expect_assert = True + + with pytest.raises(AssertionError, match=r".*zero_centered_gamma is not supported.*" + ) if expect_assert else nullcontext(): + key = jax.random.PRNGKey(0) + subkeys = jax.random.split(key, 4) + + a = jax.random.normal(subkeys[0], (m, k)).astype(jnp.bfloat16) + b = jax.random.normal(subkeys[1], (k, n)).astype(jnp.bfloat16) + + gamma = jax.random.normal(subkeys[2], (k,)).astype(jnp.bfloat16) + if ln_type == 'layernorm': + beta = jax.random.normal(subkeys[3], (k,)).astype(jnp.bfloat16) + else: + beta = None + + fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM) + fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_LEN), + jnp.float32) + fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32) + fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32) + + def primitive_func(x, y, gamma, beta, fp8_max, fp8_metas_amax, fp8_metas_scale, + fp8_metas_scale_inv): + fp8_meta_pkg = FP8MetaPackage(1, fp8_max, fp8_metas_amax, fp8_metas_scale, + fp8_metas_scale_inv) + primitive_out = layernorm_fp8_dot(x, y, gamma, beta, fp8_meta_pkg, ln_type, + zero_centered_gamma) + return jnp.mean(primitive_out) + + def ref_func(x, y, gamma, beta, zero_centered_gamma): + x = self.reference_layernorm(x, gamma, beta, zero_centered_gamma, epsilon) + return jnp.mean(jnp.dot(x, y)) + + value_n_grad_primitive_func = value_and_grad(primitive_func, range(8)) + value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2, 3)) + + ref_out, (ref_a_grad, ref_b_grad, ref_gamma_grad, + ref_beta_grad) = value_n_grad_ref_func(a, b, gamma, beta, zero_centered_gamma) + + for _ in range(3): + primitive_out, (primitive_a_grad, primitive_b_grad, primitive_gamma_grad, + primitive_beta_grad, fp8_max, fp8_metas_amax, fp8_metas_scale, + fp8_metas_scale_inv) = value_n_grad_primitive_func( + a, b, gamma, beta, fp8_max, fp8_metas_amax, fp8_metas_scale, + fp8_metas_scale_inv) + + assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE) + assert_allclose(primitive_a_grad, ref_a_grad, dtype=FP8Helper.BWD_DTYPE) + assert_allclose(primitive_b_grad, ref_b_grad, dtype=FP8Helper.BWD_DTYPE) + assert_allclose(primitive_gamma_grad, ref_gamma_grad, dtype=FP8Helper.BWD_DTYPE) + if beta is not None: + assert_allclose(primitive_beta_grad, ref_beta_grad, dtype=FP8Helper.BWD_DTYPE) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 483f070559..bcf69e70cc 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -27,16 +27,14 @@ from utils import assert_allclose -@pytest.fixture(autouse=True, scope='function') -def clear_live_arrays(): +@pytest.fixture(autouse=True, scope='module') +def init(): """ - Clear all live arrays to keep the resource clean + WAR for CUDA uninitialize error """ # Calling customcalls before jax may cause CUDA uninitialize error _ = jnp.zeros(0) yield - for arr in jax.live_arrays(): - arr.delete() def general_dot_product_attention(query: ArrayLike, key: ArrayLike, value: ArrayLike, diff --git a/tests/jax/test_layer.py b/tests/jax/test_layer.py index 70602ccbb8..1493b50cf0 100644 --- a/tests/jax/test_layer.py +++ b/tests/jax/test_layer.py @@ -1,16 +1,17 @@ # Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. - +"""Test transformer_engine.jax.flax.TransformerLayer""" import os from functools import partial +from typing import Dict import flax import jax import jax.numpy as jnp import pytest -from utils import assert_allclose +from utils import assert_allclose, assert_tree_like_allclose, sync_params_values from utils import DecoderLayer as RefDecoderLayer from utils import EncoderLayer as RefEncoderLayer @@ -21,68 +22,18 @@ is_fp8_supported, reason = is_fp8_available() -@pytest.fixture(autouse=True, scope='module') +@pytest.fixture(autouse=True, scope='function') def enable_fused_attn(): - """ - Enable fused attention - """ + """Enable fused attention""" os.environ["NVTE_FUSED_ATTN"] = "1" yield del os.environ["NVTE_FUSED_ATTN"] -@pytest.fixture(autouse=True, scope='function') -def clear_live_arrays(): - """ - Clear all live arrays to keep the resource clean - """ - yield - for arr in jax.live_arrays(): - arr.delete() - - -def loss_fn(diff_xs, no_diff_xs, params, others, model, rngs): - output = model.apply({"params": params, **others}, *diff_xs, *no_diff_xs, rngs=rngs) - return jnp.mean(output) - - -def generate_test_rngs(): - data_rng = jax.random.PRNGKey(0) - init_rng = {'params': jax.random.PRNGKey(1), 'dropout': jax.random.PRNGKey(2)} - apply_rng = {'dropout': jax.random.PRNGKey(3)} - return data_rng, init_rng, apply_rng - - -def generate_layer(layer_cls, init_rng, diff_inputs, no_diff_inputs): - layer = layer_cls() - variables = layer.init(init_rng, *diff_inputs, *no_diff_inputs) - others, params = flax.core.pop(variables, 'params') - del variables - return layer, params, others - - -def compare_dict(ref_fd, test_fd, rtol=1e-05, atol=1e-08): - # To be compatible with both Flax>=0.7.1 or <0.7.1 - # since Flax 0.7.1 removed FrozenDict. - ref_fd = flax.core.unfreeze(ref_fd) - test_fd = flax.core.unfreeze(test_fd) - for key in ref_fd: - assert key in test_fd, \ - f"{key} not found in test dict {test_fd}" - assert isinstance(test_fd[key], type(ref_fd[key])), \ - f"The data type is not match between ref and test " \ - f"dict on {key=}" - if isinstance(ref_fd[key], dict): - compare_dict(ref_fd[key], test_fd[key], rtol, atol) - else: - assert_allclose(ref_fd[key], - test_fd[key], - rtol=rtol, - atol=atol, - err_msg=f"{key=} is not close") - - -DATA_SHAPE = [(32, 128, 1024), (32, 512, 1024)] # (batch, seqlen, emb_dim) +DATA_SHAPE = [ # (batch, seqlen, emb_dim) + pytest.param((32, 128, 1024), id='32-128-1024'), + pytest.param((32, 512, 1024), id='32-512-1024'), +] DTYPE = [jnp.float32, jnp.bfloat16] FP8_FORMATS = [Format.E4M3, Format.HYBRID] @@ -90,31 +41,42 @@ def compare_dict(ref_fd, test_fd, rtol=1e-05, atol=1e-08): _KEY_OF_OUTPUT_LAYERNORM = "output_layernorm" _KEY_OF_DROP_PATH = "drop_path" _KEY_OF_FUSE_QKV_PARAMS = "fuse_qkv_params" -_KEY_OF_DROPOUT_RATE = "dropout_rate" +_KEY_OF_HIDDEN_DROPOUT = "hidden_dropout" +_KEY_OF_ATTENTION_DROPOUT = "attention_dropout" +_KEY_OF_INTERMEDIATE_DROPOUT = "intermediate_dropout" +_KEY_OF_HIDDEN_DROPOUT_DIMS = "hidden_dropout_dims" +_KEY_OF_INTERMEDIATE_DROPOUT_DIMS = "intermediate_dropout_dims" _KEY_OF_MLP_ACTIVATIONS = "mlp_activations" -_KEY_OF_FUSE_MLP_WI = "fuse_mlp_wi" -_KEY_OF_LAYERNORM_TYPE = 'layernorm_type' -_KEY_OF_ZERO_CENTERED_GAMMA = 'zero_centered_gamma' -_KEY_OF_TRANSPOSE_BS = 'transpose_batch_sequence' +_KEY_OF_LAYERNORM_TYPE = "layernorm_type" +_KEY_OF_LAYERNORM_EPS = "layernorm_epsilon" +_KEY_OF_ZERO_CENTERED_GAMMA = "zero_centered_gamma" +_KEY_OF_TRANSPOSE_BS = "transpose_batch_sequence" _KEY_OF_SCALE_ATTN_LOGITS = "scale_attn_logits" -_KEY_OF_NUM_HEADS = 'num_attention_heads' -_KEY_OF_NUM_GQA_GROUPS = 'num_gqa_groups' +_KEY_OF_NUM_HEADS = "num_attention_heads" +_KEY_OF_NUM_GQA_GROUPS = "num_gqa_groups" _KEY_OF_ENABLE_ROPE = "enable_rotary_pos_emb" _KEY_OF_ROPE_GROUP_METHOD = "rotary_pos_emb_group_method" +_KEY_OF_SELF_ATTN_BIAS_TYPE = "self_attn_bias_type" +_KEY_OF_SELF_ATTN_MASK_TYPE = "self_attn_mask_type" +_KEY_OF_FLOAT32_ATTENTION_LOGITS = "float32_attention_logits" +_KEY_OF_USE_BIAS = "use_bias" +_KEY_OF_RELATIVE_EMBEDDING = "enable_relative_embedding" BASE_ATTRS = { _KEY_OF_TRANSPOSE_BS: True, _KEY_OF_NUM_HEADS: 8, - _KEY_OF_DROPOUT_RATE: 0, + _KEY_OF_HIDDEN_DROPOUT: 0, + _KEY_OF_ATTENTION_DROPOUT: 0, + _KEY_OF_INTERMEDIATE_DROPOUT: 0, + _KEY_OF_SELF_ATTN_MASK_TYPE: "padding_causal", + _KEY_OF_LAYERNORM_TYPE: 'layernorm', } -ATTRS = [{ +ATTRS = [{}, { _KEY_OF_LAYERNORM_TYPE: 'rmsnorm', }, { - _KEY_OF_LAYERNORM_TYPE: 'layernorm', -}, { - _KEY_OF_LAYERNORM_TYPE: 'layernorm', - _KEY_OF_ZERO_CENTERED_GAMMA: True + _KEY_OF_ZERO_CENTERED_GAMMA: True, + _KEY_OF_LAYERNORM_EPS: 1e-2, }, { _KEY_OF_LAYERNORM_TYPE: 'rmsnorm', _KEY_OF_RESIDUAL_POST_LAYERNORM: True @@ -133,518 +95,323 @@ def compare_dict(ref_fd, test_fd, rtol=1e-05, atol=1e-08): _KEY_OF_FUSE_QKV_PARAMS: False }, { _KEY_OF_LAYERNORM_TYPE: 'rmsnorm', - _KEY_OF_DROPOUT_RATE: 0.0, - _KEY_OF_MLP_ACTIVATIONS: (('gelu', 'linear')), - _KEY_OF_FUSE_MLP_WI: True + _KEY_OF_MLP_ACTIVATIONS: ('gelu', 'linear'), }, { _KEY_OF_SCALE_ATTN_LOGITS: True, _KEY_OF_LAYERNORM_TYPE: 'rmsnorm', - _KEY_OF_DROPOUT_RATE: 0.8, - _KEY_OF_MLP_ACTIVATIONS: (('gelu', 'linear')), - _KEY_OF_FUSE_MLP_WI: True + _KEY_OF_HIDDEN_DROPOUT: 0.8, + _KEY_OF_INTERMEDIATE_DROPOUT: 0.5, + _KEY_OF_MLP_ACTIVATIONS: ('gelu', 'linear'), + _KEY_OF_USE_BIAS: True, }, { _KEY_OF_TRANSPOSE_BS: False, _KEY_OF_SCALE_ATTN_LOGITS: True, _KEY_OF_LAYERNORM_TYPE: 'rmsnorm', - _KEY_OF_DROPOUT_RATE: 0.0, - _KEY_OF_MLP_ACTIVATIONS: (('gelu', 'linear')), - _KEY_OF_FUSE_MLP_WI: True + _KEY_OF_MLP_ACTIVATIONS: ('gelu', 'linear'), }, { _KEY_OF_NUM_HEADS: 8, _KEY_OF_NUM_GQA_GROUPS: 4, _KEY_OF_TRANSPOSE_BS: False, _KEY_OF_SCALE_ATTN_LOGITS: True, - _KEY_OF_LAYERNORM_TYPE: 'layernorm', - _KEY_OF_DROPOUT_RATE: 0.0, - _KEY_OF_MLP_ACTIVATIONS: (('gelu',)), - _KEY_OF_FUSE_MLP_WI: True + _KEY_OF_MLP_ACTIVATIONS: ('gelu',), + _KEY_OF_USE_BIAS: True, }, { _KEY_OF_LAYERNORM_TYPE: 'rmsnorm', - _KEY_OF_DROPOUT_RATE: 0.0, _KEY_OF_MLP_ACTIVATIONS: (('silu', 'linear')), - _KEY_OF_FUSE_MLP_WI: True }, { _KEY_OF_SCALE_ATTN_LOGITS: True, _KEY_OF_LAYERNORM_TYPE: 'rmsnorm', - _KEY_OF_DROPOUT_RATE: 0.8, + _KEY_OF_HIDDEN_DROPOUT: 0.8, + _KEY_OF_INTERMEDIATE_DROPOUT: 0.5, _KEY_OF_MLP_ACTIVATIONS: (('silu', 'linear')), - _KEY_OF_FUSE_MLP_WI: True + _KEY_OF_USE_BIAS: True, }, { _KEY_OF_TRANSPOSE_BS: False, _KEY_OF_SCALE_ATTN_LOGITS: True, _KEY_OF_LAYERNORM_TYPE: 'rmsnorm', - _KEY_OF_DROPOUT_RATE: 0.0, _KEY_OF_MLP_ACTIVATIONS: (('silu', 'linear')), - _KEY_OF_FUSE_MLP_WI: True }, { _KEY_OF_NUM_HEADS: 8, _KEY_OF_NUM_GQA_GROUPS: 4, _KEY_OF_TRANSPOSE_BS: False, _KEY_OF_SCALE_ATTN_LOGITS: True, _KEY_OF_LAYERNORM_TYPE: 'layernorm', - _KEY_OF_DROPOUT_RATE: 0.0, _KEY_OF_MLP_ACTIVATIONS: (('silu',)), - _KEY_OF_FUSE_MLP_WI: True + _KEY_OF_USE_BIAS: True, }, { _KEY_OF_TRANSPOSE_BS: False, - _KEY_OF_LAYERNORM_TYPE: 'layernorm', - _KEY_OF_DROPOUT_RATE: 0.0, - _KEY_OF_FUSE_MLP_WI: True, + _KEY_OF_LAYERNORM_TYPE: 'rmsnorm', + _KEY_OF_NUM_GQA_GROUPS: 1, _KEY_OF_ENABLE_ROPE: True, - _KEY_OF_ROPE_GROUP_METHOD: "consecutive" + _KEY_OF_ROPE_GROUP_METHOD: "consecutive", + _KEY_OF_FLOAT32_ATTENTION_LOGITS: True, }, { _KEY_OF_TRANSPOSE_BS: True, - _KEY_OF_LAYERNORM_TYPE: 'layernorm', - _KEY_OF_DROPOUT_RATE: 0.0, - _KEY_OF_FUSE_MLP_WI: True, _KEY_OF_ENABLE_ROPE: True, - _KEY_OF_ROPE_GROUP_METHOD: "consecutive" + _KEY_OF_ROPE_GROUP_METHOD: "consecutive", + _KEY_OF_USE_BIAS: True, }, { _KEY_OF_TRANSPOSE_BS: False, _KEY_OF_LAYERNORM_TYPE: 'layernorm', - _KEY_OF_DROPOUT_RATE: 0.0, - _KEY_OF_FUSE_MLP_WI: True, + _KEY_OF_NUM_GQA_GROUPS: 2, _KEY_OF_ENABLE_ROPE: True, - _KEY_OF_ROPE_GROUP_METHOD: "alternate" + _KEY_OF_ROPE_GROUP_METHOD: "alternate", + _KEY_OF_USE_BIAS: True, + _KEY_OF_FLOAT32_ATTENTION_LOGITS: True, }, { _KEY_OF_TRANSPOSE_BS: True, - _KEY_OF_LAYERNORM_TYPE: 'layernorm', - _KEY_OF_DROPOUT_RATE: 0.0, - _KEY_OF_FUSE_MLP_WI: True, + _KEY_OF_LAYERNORM_TYPE: 'rmsnorm', _KEY_OF_ENABLE_ROPE: True, - _KEY_OF_ROPE_GROUP_METHOD: "alternate" + _KEY_OF_ROPE_GROUP_METHOD: "alternate", + _KEY_OF_USE_BIAS: True, +}, { + _KEY_OF_HIDDEN_DROPOUT: 0.3, + _KEY_OF_HIDDEN_DROPOUT_DIMS: (0,), + _KEY_OF_INTERMEDIATE_DROPOUT: 0.5, + _KEY_OF_INTERMEDIATE_DROPOUT_DIMS: (1,), +}, { + _KEY_OF_SELF_ATTN_MASK_TYPE: "padding", + _KEY_OF_USE_BIAS: True, +}, { + _KEY_OF_RELATIVE_EMBEDDING: False, + _KEY_OF_SELF_ATTN_BIAS_TYPE: "no_bias", +}, { + _KEY_OF_ATTENTION_DROPOUT: 0.3, }] ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS] -class TestEncoderLayer: - - @staticmethod - def sync_params(ref, target): - unfreeze_target = flax.core.unfreeze(target) - unfreeze_attn_scope = unfreeze_target['attention'] - ref_attn_scope = ref['attention'] - for key in ref_attn_scope.keys(): - unfreeze_attn_scope[key]['kernel'] = \ - ref_attn_scope[key]['kernel'].reshape(unfreeze_attn_scope[key]['kernel'].shape) - unfreeze_target['mlp']['wi_kernel'] = \ - jnp.reshape(ref['mlp']['wi']['kernel'], unfreeze_target['mlp']['wi_kernel'].shape) - unfreeze_target['mlp']['wo_kernel'] = \ - ref['mlp']['wo']['kernel'] - return ref, unfreeze_target - - def forward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): - transpose_batch_sequence = _KEY_OF_TRANSPOSE_BS in attrs and attrs[_KEY_OF_TRANSPOSE_BS] - batch, seqlen = data_shape[:2] - if transpose_batch_sequence: - data_shape = (data_shape[1], data_shape[0], *data_shape[2:]) - sequence_dim = 0 if transpose_batch_sequence else 1 +class BaseRunner: + """Base runner to define forward and backward tests""" + layer_type: TransformerLayerType = None + reference_layer: flax.linen.Module = None + transformations: Dict[str, str] = None - data_rng, init_rng, apply_rng = generate_test_rngs() - inputs = (jax.random.normal(data_rng, data_shape, dtype),) + def __init__(self, attrs): + self.attrs = attrs + self._generate_test_rngs() + # Disable fused attention for attention dropout because the different dropout impl + if attrs.get(_KEY_OF_ATTENTION_DROPOUT, False) and os.getenv('NVTE_FUSED_ATTN'): + os.environ['NVTE_FUSED_ATTN'] = "0" - padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8) - ref_masks = (1 - padded_mask,) - test_masks = (None, padded_mask) # The second arg of Transformer is encoded tokens. - - te_layer_attrs = {} - for k, v in attrs.items(): - if k == 'dropout_rate': - te_layer_attrs['attention_dropout'] = v - te_layer_attrs['hidden_dropout'] = v - te_layer_attrs['intermediate_dropout'] = v - elif k == 'fuse_mlp_wi': - continue - else: - te_layer_attrs[k] = v - ref_layer_cls = partial(RefEncoderLayer, dtype=dtype, **attrs) - layer_cls = partial(TransformerLayer, - hidden_dropout_dims=(sequence_dim,), - intermediate_dropout_dims=(sequence_dim,), - layer_type=TransformerLayerType.ENCODER, - self_attn_mask_type='padding', - dtype=dtype, - **te_layer_attrs) - - ref_layer, ref_params, ref_others = generate_layer(ref_layer_cls, init_rng, inputs, - ref_masks) - test_layer, test_params, test_others = generate_layer(layer_cls, init_rng, inputs, - test_masks) - - ref_params, test_params = TestEncoderLayer.sync_params(ref_params, test_params) - - ref_out = loss_fn(inputs, ref_masks, ref_params, ref_others, ref_layer, apply_rng) - test_out = loss_fn(inputs, test_masks, test_params, test_others, test_layer, apply_rng) - - if attrs[_KEY_OF_DROPOUT_RATE] == 0.: # Skip elementwise checking for dropout - assert_allclose(ref_out, test_out, rtol=rtol, atol=atol) - - del data_rng, init_rng, apply_rng - - def forward_backward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): - transpose_batch_sequence = _KEY_OF_TRANSPOSE_BS in attrs and attrs[_KEY_OF_TRANSPOSE_BS] - batch, seqlen = data_shape[:2] - if transpose_batch_sequence: - data_shape = (data_shape[1], data_shape[0], *data_shape[2:]) - sequence_dim = 0 if transpose_batch_sequence else 1 + def _generate_test_rngs(self): + root_rng = jax.random.PRNGKey(0) + params_rng, init_dropout_rng, apply_dropout_rng = jax.random.split(root_rng, 3) + self.init_rng = {'params': params_rng, 'dropout': init_dropout_rng} + self.apply_rng = {'dropout': apply_dropout_rng} - data_rng, init_rng, apply_rng = generate_test_rngs() - inputs = (jax.random.normal(data_rng, data_shape, dtype),) + def _generate_layer(self, layer_cls, diff_inputs, no_diff_inputs): + layer = layer_cls() + variables = layer.init(self.init_rng, *diff_inputs, *no_diff_inputs) + others, params = flax.core.pop(variables, 'params') + del variables + return layer, params, others - padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8) - ref_masks = (1 - padded_mask,) - test_masks = (None, padded_mask) # The second arg of Transformer is encoded tokens. - - te_layer_attrs = {} - for k, v in attrs.items(): - if k == 'dropout_rate': - te_layer_attrs['attention_dropout'] = v - te_layer_attrs['hidden_dropout'] = v - te_layer_attrs['intermediate_dropout'] = v - elif k == 'fuse_mlp_wi': - continue - else: - te_layer_attrs[k] = v - ref_layer_cls = partial(RefEncoderLayer, dtype=dtype, **attrs) - layer_cls = partial(TransformerLayer, - hidden_dropout_dims=(sequence_dim,), - intermediate_dropout_dims=(sequence_dim,), - layer_type=TransformerLayerType.ENCODER, - self_attn_mask_type='padding', - dtype=dtype, - **te_layer_attrs) - ref_layer, ref_params, ref_others = generate_layer(ref_layer_cls, init_rng, inputs, - ref_masks) - test_layer, test_params, test_others = generate_layer(layer_cls, init_rng, inputs, - test_masks) - - ref_params, test_params = TestEncoderLayer.sync_params(ref_params, test_params) + def _loss_fn(self, diff_xs, no_diff_xs, params, others, model): + variables = {'params': params, **others} + output = model.apply(variables, *diff_xs, *no_diff_xs, rngs=self.apply_rng) + return jnp.mean(output, dtype=jnp.float32).astype(output.dtype) + + def _sync_params(self, ref, target): + """Copy the reference params to target""" + target = sync_params_values(target, ref, self.transformations) + return ref, target + + def test_forward(self, data_shape, dtype, rtol=1e-05, atol=1e-08): + """Test only the forward""" + inputs, (ref_masks, test_masks) = self.generate_inputs(data_shape, dtype) + + ref_layer_cls = partial(self.reference_layer, dtype=dtype, **self.attrs) + layer_cls = partial(TransformerLayer, layer_type=self.layer_type, dtype=dtype, **self.attrs) + + ref_layer, ref_params, ref_others = self._generate_layer(ref_layer_cls, inputs, ref_masks) + test_layer, test_params, test_others = self._generate_layer(layer_cls, inputs, test_masks) + ref_params, test_params = self._sync_params(ref_params, test_params) + + ref_out = self._loss_fn(inputs, ref_masks, ref_params, ref_others, ref_layer) + test_out = self._loss_fn(inputs, test_masks, test_params, test_others, test_layer) + + assert_allclose(ref_out, test_out, rtol=rtol, atol=atol) + + def test_backward(self, data_shape, dtype, rtol=1e-05, atol=1e-08): + """Test forward and backward through value_and_grad()""" + inputs, (ref_masks, test_masks) = self.generate_inputs(data_shape, dtype) + + ref_layer_cls = partial(self.reference_layer, dtype=dtype, **self.attrs) + layer_cls = partial(TransformerLayer, layer_type=self.layer_type, dtype=dtype, **self.attrs) + + ref_layer, ref_params, ref_others = self._generate_layer(ref_layer_cls, inputs, ref_masks) + test_layer, test_params, test_others = self._generate_layer(layer_cls, inputs, test_masks) + + ref_params, test_params = self._sync_params(ref_params, test_params) if FP8Helper.is_fp8_enabled(): for _ in range(4): - _, tmp_grad = jax.value_and_grad(loss_fn, argnums=(3,), - has_aux=False)(inputs, test_masks, test_params, - test_others, test_layer, apply_rng) + _, tmp_grad = jax.value_and_grad(self._loss_fn, argnums=(3,), has_aux=False)( + inputs, + test_masks, + test_params, + test_others, + test_layer, + ) _, fp8_meta_grad = flax.core.pop(tmp_grad[0], FP8Helper.FP8_COLLECTION_NAME) test_others = FP8Helper.update_collections( {FP8Helper.FP8_COLLECTION_NAME: fp8_meta_grad}, test_others) test_others = FP8Helper.update_fp8_metas(test_others) del tmp_grad, fp8_meta_grad - grad_fn = jax.value_and_grad(loss_fn, argnums=(0, 2), has_aux=False) - - ref_out, ref_grads = grad_fn(inputs, ref_masks, ref_params, ref_others, ref_layer, - apply_rng) - test_out, test_grads = grad_fn(inputs, test_masks, test_params, test_others, test_layer, - apply_rng) - - def reorganize_test_wgrad(test_wgrad, attrs): - num_heads = attrs.get(_KEY_OF_NUM_HEADS) - num_gqa_groups = attrs.get(_KEY_OF_NUM_GQA_GROUPS, num_heads) - fuse_qkv = attrs.get(_KEY_OF_FUSE_QKV_PARAMS, True) and \ - num_heads == num_gqa_groups - - attn_name = 'attention' - unfreeze_test_wgrad = flax.core.unfreeze(test_wgrad) - if "output_layernorm" not in attrs: - unfreeze_test_wgrad['pre_attention_layer_norm'] = {} - pre_attn_layer_key = 'qkv' if fuse_qkv else 'query' - unfreeze_test_wgrad['pre_attention_layer_norm']['scale'] = \ - unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['scale'] - del unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['scale'] - if 'ln_bias' in unfreeze_test_wgrad[attn_name][pre_attn_layer_key]: - unfreeze_test_wgrad['pre_attention_layer_norm']['ln_bias'] = \ - unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['ln_bias'] - del unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['ln_bias'] - - for key in unfreeze_test_wgrad[attn_name].keys(): - unfreeze_test_wgrad[attn_name][key]['kernel'] = \ - jnp.reshape(unfreeze_test_wgrad[attn_name][key]['kernel'], - (unfreeze_test_wgrad[attn_name][key]['kernel'].shape[0], -1)) - - unfreeze_test_wgrad['pre_mlp_layer_norm'] = {} - unfreeze_test_wgrad['pre_mlp_layer_norm']['scale'] = \ - unfreeze_test_wgrad['mlp']['scale'] - del unfreeze_test_wgrad['mlp']['scale'] - if 'ln_bias' in unfreeze_test_wgrad['mlp']: - unfreeze_test_wgrad['pre_mlp_layer_norm']['ln_bias'] = \ - unfreeze_test_wgrad['mlp']['ln_bias'] - del unfreeze_test_wgrad['mlp']['ln_bias'] - unfreeze_test_wgrad['mlp']['wi'] = {} - unfreeze_test_wgrad['mlp']['wi']['kernel'] = \ - jnp.reshape(unfreeze_test_wgrad['mlp']['wi_kernel'], - (unfreeze_test_wgrad['mlp']['wi_kernel'].shape[0], -1)) - del unfreeze_test_wgrad['mlp']['wi_kernel'] - unfreeze_test_wgrad['mlp']['wo'] = {} - unfreeze_test_wgrad['mlp']['wo']['kernel'] = \ - unfreeze_test_wgrad['mlp']['wo_kernel'] - del unfreeze_test_wgrad['mlp']['wo_kernel'] - return unfreeze_test_wgrad - - if attrs[_KEY_OF_DROPOUT_RATE] == 0.: # Skip elementwise checking for dropout - assert_allclose(ref_out, test_out, rtol=rtol, atol=atol) - assert_allclose(ref_grads[0][0], test_grads[0][0], rtol=rtol, atol=atol) # dgrad - - compare_dict(ref_grads[1], - reorganize_test_wgrad(test_grads[1], attrs), - rtol=rtol, - atol=atol) # wgrad - - del data_rng, init_rng, apply_rng - - @pytest.mark.parametrize('data_shape', DATA_SHAPE) - @pytest.mark.parametrize('dtype', DTYPE) - @pytest.mark.parametrize('attrs', ATTRS) - def test_forward(self, data_shape, dtype, attrs): - FP8Helper.finalize() # Ensure FP8 disabled. - self.forward_runner(data_shape, dtype, attrs, rtol=1e-05, atol=2e-04) - - @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize('data_shape', DATA_SHAPE) - @pytest.mark.parametrize('dtype', DTYPE) - @pytest.mark.parametrize('fp8_format', FP8_FORMATS) - @pytest.mark.parametrize('attrs', ATTRS) - def test_forward_with_fp8(self, data_shape, dtype, fp8_format, attrs): - FP8Helper.initialize(fp8_format=fp8_format) - self.forward_runner(data_shape, dtype, attrs, rtol=1e-04, atol=1e-03) - FP8Helper.finalize() - - @pytest.mark.parametrize('data_shape', DATA_SHAPE) - @pytest.mark.parametrize('dtype', DTYPE) - @pytest.mark.parametrize('attrs', ATTRS) - def test_forward_backward(self, data_shape, dtype, attrs): - FP8Helper.finalize() # Ensure FP8 disabled. - self.forward_backward_runner(data_shape, dtype, attrs, rtol=1e-05, atol=2e-04) - - @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize('data_shape', DATA_SHAPE) - @pytest.mark.parametrize('dtype', DTYPE) - @pytest.mark.parametrize('fp8_format', FP8_FORMATS) - @pytest.mark.parametrize('attrs', ATTRS) - def test_forward_backward_with_fp8(self, data_shape, dtype, fp8_format, attrs): - FP8Helper.initialize(fp8_format=fp8_format) - self.forward_backward_runner(data_shape, dtype, attrs, rtol=1e-04, atol=1e-03) - FP8Helper.finalize() - - -class TestDecoderLayer: - - @staticmethod - def sync_params(ref, target): - unfreeze_target = flax.core.unfreeze(target) - for scope in ['self_attention', 'encoder_decoder_attention']: - unfreeze_scope = unfreeze_target[scope] - ref_scope = ref[scope] - for key in unfreeze_scope.keys(): - unfreeze_scope[key]['kernel'] = \ - ref_scope[key]['kernel'].reshape(unfreeze_scope[key]['kernel'].shape) - unfreeze_target['mlp']['wi_kernel'] = \ - jnp.reshape(ref['mlp']['wi']['kernel'], unfreeze_target['mlp']['wi_kernel'].shape) - unfreeze_target['mlp']['wo_kernel'] = \ - ref['mlp']['wo']['kernel'] - return ref, unfreeze_target - - def forward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): - transpose_batch_sequence = _KEY_OF_TRANSPOSE_BS in attrs and attrs[_KEY_OF_TRANSPOSE_BS] + grad_fn = jax.value_and_grad(self._loss_fn, argnums=(0, 2), has_aux=False) + + ref_out, (ref_dgrads, ref_wgrads) = grad_fn(inputs, ref_masks, ref_params, ref_others, + ref_layer) + test_out, (test_dgrads, test_wgrads) = grad_fn(inputs, test_masks, test_params, test_others, + test_layer) + + assert_allclose(ref_out, test_out, rtol=rtol, atol=atol) + assert_tree_like_allclose(ref_dgrads, test_dgrads, rtol=rtol, atol=atol) + + _, restructed_ref_wgrads = self._sync_params(ref_wgrads, test_wgrads) + assert_tree_like_allclose(restructed_ref_wgrads, test_wgrads, rtol=rtol, atol=atol) + + +class EncoderRunner(BaseRunner): + """Encoder runner implementations""" + layer_type = TransformerLayerType.ENCODER + reference_layer = RefEncoderLayer + transformations = { + 'attention/qkv/scale': 'pre_attention_layer_norm/scale', + 'attention/qkv/ln_bias': 'pre_attention_layer_norm/ln_bias', + 'attention/query/scale': 'pre_attention_layer_norm/scale', + 'attention/query/ln_bias': 'pre_attention_layer_norm/ln_bias', + 'mlp/wi_kernel': 'mlp/wi/kernel', + 'mlp/wi_bias': 'mlp/wi/bias', + 'mlp/wo_kernel': 'mlp/wo/kernel', + 'mlp/wo_bias': 'mlp/wo/bias', + 'mlp/scale': 'pre_mlp_layer_norm/scale', + 'mlp/ln_bias': 'pre_mlp_layer_norm/ln_bias', + } + + def generate_inputs(self, data_shape, dtype): + """ + Return inputs, (ref_masks, test_masks) + """ + transpose_batch_sequence = self.attrs[_KEY_OF_TRANSPOSE_BS] batch, seqlen = data_shape[:2] if transpose_batch_sequence: data_shape = (data_shape[1], data_shape[0], *data_shape[2:]) - sequence_dim = 0 if transpose_batch_sequence else 1 - data_rng, init_rng, apply_rng = generate_test_rngs() - inputs = (jax.random.normal(data_rng, data_shape, - dtype), jax.random.normal(data_rng, data_shape, dtype)) + data_rng = jax.random.PRNGKey(2024) + inputs = (jax.random.normal(data_rng, data_shape, dtype),) padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8) causal_mask = jnp.triu(jnp.ones((batch, 1, seqlen, seqlen), dtype=jnp.uint8), k=1) - ref_masks = (1 - causal_mask, 1 - padded_mask) - test_masks = (causal_mask, padded_mask) - - te_layer_attrs = {} - for k, v in attrs.items(): - if k == 'dropout_rate': - te_layer_attrs['attention_dropout'] = v - te_layer_attrs['hidden_dropout'] = v - te_layer_attrs['intermediate_dropout'] = v - elif k == 'fuse_mlp_wi': - continue - else: - te_layer_attrs[k] = v - ref_layer_cls = partial(RefDecoderLayer, dtype=dtype, **attrs) - layer_cls = partial(TransformerLayer, - hidden_dropout_dims=(sequence_dim,), - intermediate_dropout_dims=(sequence_dim,), - layer_type=TransformerLayerType.DECODER, - self_attn_mask_type='padding_causal', - dtype=dtype, - **te_layer_attrs) - ref_layer, ref_params, ref_others = generate_layer(ref_layer_cls, init_rng, inputs, - ref_masks) - test_layer, test_params, test_others = generate_layer(layer_cls, init_rng, inputs, - test_masks) - - ref_params, test_params = TestDecoderLayer.sync_params(ref_params, test_params) - - ref_out = loss_fn(inputs, ref_masks, ref_params, ref_others, ref_layer, apply_rng) - test_out = loss_fn(inputs, test_masks, test_params, test_others, test_layer, apply_rng) - - if attrs[_KEY_OF_DROPOUT_RATE] == 0.: # Skip elementwise checking for dropout - assert_allclose(ref_out, test_out, rtol=rtol, atol=atol) - - del data_rng, init_rng, apply_rng - - def forward_backward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): - transpose_batch_sequence = _KEY_OF_TRANSPOSE_BS in attrs and attrs[_KEY_OF_TRANSPOSE_BS] + if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ['casual', 'padding_causal']: + mask = causal_mask + else: + mask = padded_mask + + ref_masks = (1 - mask,) + test_masks = (None, mask) # The second arg of Transformer is encoded tokens. + + return inputs, (ref_masks, test_masks) + + +class DecoderRunner(BaseRunner): + """ + Decoder runner implementations + """ + layer_type = TransformerLayerType.DECODER + reference_layer = RefDecoderLayer + transformations = { + 'encoder_decoder_attention/qkv/scale': 'pre_cross_attention_layer_norm/scale', + 'encoder_decoder_attention/qkv/ln_bias': 'pre_cross_attention_layer_norm/ln_bias', + 'encoder_decoder_attention/query/scale': 'pre_cross_attention_layer_norm/scale', + 'encoder_decoder_attention/query/ln_bias': 'pre_cross_attention_layer_norm/ln_bias', + 'self_attention/qkv/scale': 'pre_self_attention_layer_norm/scale', + 'self_attention/qkv/ln_bias': 'pre_self_attention_layer_norm/ln_bias', + 'self_attention/query/scale': 'pre_self_attention_layer_norm/scale', + 'self_attention/query/ln_bias': 'pre_self_attention_layer_norm/ln_bias', + 'mlp/wi_kernel': 'mlp/wi/kernel', + 'mlp/wi_bias': 'mlp/wi/bias', + 'mlp/wo_kernel': 'mlp/wo/kernel', + 'mlp/wo_bias': 'mlp/wo/bias', + 'mlp/scale': 'pre_mlp_layer_norm/scale', + 'mlp/ln_bias': 'pre_mlp_layer_norm/ln_bias', + } + + def generate_inputs(self, data_shape, dtype): + """ + Return inputs, (ref_masks, test_masks) + """ + transpose_batch_sequence = self.attrs[_KEY_OF_TRANSPOSE_BS] batch, seqlen = data_shape[:2] if transpose_batch_sequence: data_shape = (data_shape[1], data_shape[0], *data_shape[2:]) - sequence_dim = 0 if transpose_batch_sequence else 1 - data_rng, init_rng, apply_rng = generate_test_rngs() - inputs = (jax.random.normal(data_rng, data_shape, - dtype), jax.random.normal(data_rng, data_shape, dtype)) + data_rng = jax.random.PRNGKey(0) + data_rng_0, data_rng_1 = jax.random.split(data_rng, 2) + inputs = (jax.random.normal(data_rng_0, data_shape, + dtype), jax.random.normal(data_rng_1, data_shape, dtype)) padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8) causal_mask = jnp.triu(jnp.ones((batch, 1, seqlen, seqlen), dtype=jnp.uint8), k=1) - ref_masks = (1 - causal_mask, 1 - padded_mask) - test_masks = (causal_mask, padded_mask) - - te_layer_attrs = {} - for k, v in attrs.items(): - if k == 'dropout_rate': - te_layer_attrs['attention_dropout'] = v - te_layer_attrs['hidden_dropout'] = v - te_layer_attrs['intermediate_dropout'] = v - elif k == 'fuse_mlp_wi': - continue - else: - te_layer_attrs[k] = v - ref_layer_cls = partial(RefDecoderLayer, dtype=dtype, **attrs) - layer_cls = partial(TransformerLayer, - hidden_dropout_dims=(sequence_dim,), - intermediate_dropout_dims=(sequence_dim,), - layer_type=TransformerLayerType.DECODER, - self_attn_mask_type='padding_causal', - dtype=dtype, - **te_layer_attrs) - ref_layer, ref_params, ref_others = generate_layer(ref_layer_cls, init_rng, inputs, - ref_masks) - test_layer, test_params, test_others = generate_layer(layer_cls, init_rng, inputs, - test_masks) - - ref_params, test_params = TestDecoderLayer.sync_params(ref_params, test_params) + if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ['casual', 'padding_causal']: + self_mask = causal_mask + else: + self_mask = padded_mask - if FP8Helper.is_fp8_enabled(): - for _ in range(4): - _, tmp_grad = jax.value_and_grad(loss_fn, argnums=(3,), - has_aux=False)(inputs, test_masks, test_params, - test_others, test_layer, apply_rng) - _, fp8_meta_grad = flax.core.pop(tmp_grad[0], FP8Helper.FP8_COLLECTION_NAME) - test_others = FP8Helper.update_collections( - {FP8Helper.FP8_COLLECTION_NAME: fp8_meta_grad}, test_others) - test_others = FP8Helper.update_fp8_metas(test_others) - del tmp_grad, fp8_meta_grad + ref_masks = (1 - self_mask, 1 - padded_mask) + test_masks = (self_mask, padded_mask) + + return inputs, (ref_masks, test_masks) + + +@pytest.mark.parametrize('data_shape', DATA_SHAPE) +@pytest.mark.parametrize('dtype', DTYPE) +@pytest.mark.parametrize('attrs', ATTRS) +class BaseTester(): + """ + Pytest interface to invoke the runner + """ + runner = BaseRunner - grad_fn = jax.value_and_grad(loss_fn, argnums=(0, 2), has_aux=False) - - ref_out, ref_grads = grad_fn(inputs, ref_masks, ref_params, ref_others, ref_layer, - apply_rng) - test_out, test_grads = grad_fn(inputs, test_masks, test_params, test_others, test_layer, - apply_rng) - - def reorganize_test_wgrad(test_wgrad, attrs): - num_heads = attrs.get(_KEY_OF_NUM_HEADS) - num_gqa_groups = attrs.get(_KEY_OF_NUM_GQA_GROUPS, num_heads) - fuse_qkv = attrs.get(_KEY_OF_FUSE_QKV_PARAMS, True) and \ - num_heads == num_gqa_groups - - unfreeze_test_wgrad = flax.core.unfreeze(test_wgrad) - if "output_layernorm" not in attrs: - attn_name = 'self_attention' - unfreeze_test_wgrad['pre_self_attention_layer_norm'] = {} - pre_attn_layer_key = 'qkv' if fuse_qkv else 'query' - unfreeze_test_wgrad['pre_self_attention_layer_norm']['scale'] = \ - unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['scale'] - del unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['scale'] - if 'ln_bias' in unfreeze_test_wgrad[attn_name][pre_attn_layer_key]: - unfreeze_test_wgrad['pre_self_attention_layer_norm']['ln_bias'] = \ - unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['ln_bias'] - del unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['ln_bias'] - - for scope in ['self_attention', 'encoder_decoder_attention']: - for key in unfreeze_test_wgrad[scope].keys(): - unfreeze_test_wgrad[scope][key]['kernel'] = \ - jnp.reshape(unfreeze_test_wgrad[scope][key]['kernel'], - (unfreeze_test_wgrad[scope][key]['kernel'].shape[0], -1)) - - unfreeze_test_wgrad['pre_cross_attention_layer_norm'] = {} - unfreeze_test_wgrad['pre_cross_attention_layer_norm']['scale'] = \ - unfreeze_test_wgrad['encoder_decoder_attention']['query']['scale'] - del unfreeze_test_wgrad['encoder_decoder_attention']['query']['scale'] - if 'ln_bias' in unfreeze_test_wgrad['encoder_decoder_attention']['query']: - unfreeze_test_wgrad['pre_cross_attention_layer_norm']['ln_bias'] = \ - unfreeze_test_wgrad['encoder_decoder_attention']['query']['ln_bias'] - del unfreeze_test_wgrad['encoder_decoder_attention']['query']['ln_bias'] - unfreeze_test_wgrad['pre_mlp_layer_norm'] = {} - unfreeze_test_wgrad['pre_mlp_layer_norm']['scale'] = \ - unfreeze_test_wgrad['mlp']['scale'] - del unfreeze_test_wgrad['mlp']['scale'] - if 'ln_bias' in unfreeze_test_wgrad['mlp']: - unfreeze_test_wgrad['pre_mlp_layer_norm']['ln_bias'] = \ - unfreeze_test_wgrad['mlp']['ln_bias'] - del unfreeze_test_wgrad['mlp']['ln_bias'] - unfreeze_test_wgrad['mlp']['wi'] = {} - unfreeze_test_wgrad['mlp']['wi']['kernel'] = \ - jnp.reshape(unfreeze_test_wgrad['mlp']['wi_kernel'], - (unfreeze_test_wgrad['mlp']['wi_kernel'].shape[0], -1)) - del unfreeze_test_wgrad['mlp']['wi_kernel'] - unfreeze_test_wgrad['mlp']['wo'] = {} - unfreeze_test_wgrad['mlp']['wo']['kernel'] = \ - unfreeze_test_wgrad['mlp']['wo_kernel'] - del unfreeze_test_wgrad['mlp']['wo_kernel'] - return unfreeze_test_wgrad - - if attrs[_KEY_OF_DROPOUT_RATE] == 0.: # Skip elementwise checking for dropout - assert_allclose(ref_out, test_out, rtol=rtol, atol=atol) - assert_allclose(ref_grads[0][0], test_grads[0][0], rtol=rtol, atol=atol) # dgrad - compare_dict(ref_grads[1], - reorganize_test_wgrad(test_grads[1], attrs), - rtol=rtol, - atol=atol) # wgrad - - del data_rng, init_rng, apply_rng - - @pytest.mark.parametrize('data_shape', DATA_SHAPE) - @pytest.mark.parametrize('dtype', DTYPE) - @pytest.mark.parametrize('attrs', ATTRS) def test_forward(self, data_shape, dtype, attrs): + """Test normal datatype forward""" FP8Helper.finalize() # Ensure FP8 disabled. - self.forward_runner(data_shape, dtype, attrs, rtol=1e-05, atol=2e-04) + self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-5, atol=7e-5) + + def test_backward(self, data_shape, dtype, attrs): + """Test normal datatype backward""" + FP8Helper.finalize() # Ensure FP8 disabled. + self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-5, atol=7e-5) @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize('data_shape', DATA_SHAPE) - @pytest.mark.parametrize('dtype', DTYPE) @pytest.mark.parametrize('fp8_format', FP8_FORMATS) - @pytest.mark.parametrize('attrs', ATTRS) - def test_forward_with_fp8(self, data_shape, dtype, fp8_format, attrs): + def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_format): + """Test forward with fp8 enabled""" FP8Helper.initialize(fp8_format=fp8_format) - self.forward_runner(data_shape, dtype, attrs, rtol=1e-04, atol=3e-02) + self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-4, atol=1e-3) FP8Helper.finalize() - @pytest.mark.parametrize('data_shape', DATA_SHAPE) - @pytest.mark.parametrize('dtype', DTYPE) - @pytest.mark.parametrize('attrs', ATTRS) - def test_forward_backward(self, data_shape, dtype, attrs): - FP8Helper.finalize() # Ensure FP8 disabled. - self.forward_backward_runner(data_shape, dtype, attrs, rtol=1e-05, atol=3e-04) - @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize('data_shape', DATA_SHAPE) - @pytest.mark.parametrize('dtype', DTYPE) @pytest.mark.parametrize('fp8_format', FP8_FORMATS) - @pytest.mark.parametrize('attrs', ATTRS) - def test_forward_backward_with_fp8(self, data_shape, dtype, fp8_format, attrs): + def test_backward_with_fp8(self, data_shape, dtype, attrs, fp8_format): + """Test backward with fp8 enabled""" FP8Helper.initialize(fp8_format=fp8_format) - self.forward_backward_runner(data_shape, dtype, attrs, rtol=1e-04, atol=3e-02) + self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-4, atol=1e-3) FP8Helper.finalize() + + +class TestEncoderLayer(BaseTester): + """ + Test transformer_engine.jax.flax.TransformerLayer(layer_type=Encoder) + """ + runner = EncoderRunner + + +class TestDecoderLayer(BaseTester): + """ + Test transformer_engine.jax.flax.TransformerLayer(layer_type=Decoder) + """ + runner = DecoderRunner diff --git a/tests/jax/test_praxis_layers.py b/tests/jax/test_praxis_layers.py index 1bc32d1251..df2c0d582b 100644 --- a/tests/jax/test_praxis_layers.py +++ b/tests/jax/test_praxis_layers.py @@ -56,16 +56,6 @@ def enable_fused_attn(): del os.environ["NVTE_FUSED_ATTN"] -@pytest.fixture(autouse=True, scope='function') -def clear_live_arrays(): - """ - Clear all live arrays to keep the resource clean - """ - yield - for arr in jax.live_arrays(): - arr.delete() - - def compare_dict(ref_fd, test_fd, rtol=1e-05, atol=1e-08): for key in ref_fd: assert key in test_fd, \ diff --git a/tests/jax/test_softmax.py b/tests/jax/test_softmax.py new file mode 100644 index 0000000000..bb5eecd654 --- /dev/null +++ b/tests/jax/test_softmax.py @@ -0,0 +1,165 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Tests for the softmax primitives""" +from contextlib import nullcontext +from dataclasses import dataclass +from functools import wraps + +import jax +import jax.numpy as jnp +import pytest +from jax import lax +from jax import nn +from jax import value_and_grad, jit +from jax.typing import DTypeLike + +from utils import assert_allclose + +from transformer_engine.jax.softmax import is_softmax_kernel_available +from transformer_engine.jax.softmax import SoftmaxType, softmax + + +def catch_unsupported(method): + """ + The unsupported case should raise error instead of running it incorrectly. + This helper function is to check if the unsupported case raises the assertion error. + """ + + @wraps(method) + def wrapper(self, *args, **kwargs): + if not self._is_support(): + assertion_checker = pytest.raises(AssertionError) + else: + assertion_checker = nullcontext() + with assertion_checker: + return method(self, *args, **kwargs) + + return wrapper + + +@dataclass +class SoftmaxRunner: + """ + Softmax runner + """ + batch_size: int + max_seqlen_q: int + max_seqlen_kv: int + num_heads: int + scale_factor: float + softmax_type: SoftmaxType + dtype: DTypeLike + + @staticmethod + def reference_softmax(logits, mask, scale_factor, **_): + """ + Jax softmax as the reference + """ + if mask is not None: + logits += lax.select(mask > 0, + jnp.full(mask.shape, -1e10).astype(logits.dtype), + jnp.full(mask.shape, 0.).astype(logits.dtype)) + return nn.softmax(logits * scale_factor) + + def _is_support(self): + return is_softmax_kernel_available(self.softmax_type, self.batch_size, self.num_heads, + self.max_seqlen_q, self.max_seqlen_kv, self.dtype) + + def _setup_inputs(self): + key = jax.random.PRNGKey(0) + logits_key, mask_key = jax.random.split(key, 2) + + logits_shape = (self.batch_size, self.num_heads, self.max_seqlen_q, self.max_seqlen_kv) + mask_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv) + + self.logits = jax.random.uniform(logits_key, logits_shape, self.dtype, -1.) + + match self.softmax_type: + case SoftmaxType.SCALED: + self.mask = None + case SoftmaxType.SCALED_MASKED: + self.mask = jax.random.bernoulli(mask_key, shape=mask_shape).astype(jnp.uint8) + case SoftmaxType.SCALED_UPPER_TRIANG_MASKED: + self.mask = (1. - jnp.tril(jnp.ones_like(self.logits))).astype(jnp.uint8) + case _: + raise ValueError(f"Unknown {self.softmax_type=}") + + @catch_unsupported + def test_forward(self): + """ + Test transformer_engine.jax.softmax.softmax fwd rule + """ + self._setup_inputs() + primitive_out = softmax(self.logits, self.mask, self.scale_factor, self.softmax_type) + reference_out = __class__.reference_softmax(self.logits, self.mask, self.scale_factor) + assert_allclose(primitive_out, reference_out, dtype=self.dtype) + + @catch_unsupported + def test_backward(self): + """ + Test transformer_engine.jax.softmax.softmax bwd rule + """ + self._setup_inputs() + + def grad_func(func, *args, **kwargs): + fwd_out = func(*args, **kwargs) + return jnp.mean(fwd_out, dtype=jnp.float32).astype(self.dtype) + + args = [self.logits, self.mask] + kwargs = { + 'scale_factor': self.scale_factor, + 'softmax_type': self.softmax_type, + } + + # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation + jitted_primitive = jit( + value_and_grad(lambda logits, *args: grad_func(softmax, self.logits, *args, **kwargs), + (0,))) + jitted_reference = jit( + value_and_grad( + lambda logits, *args: grad_func(__class__.reference_softmax, self.logits, *args, ** + kwargs), (0,))) + + primitive_out, (primitive_grad_logits,) = jitted_primitive(*args) + reference_out, (reference_grad_logits,) = jitted_reference(*args) + + assert_allclose(primitive_out, reference_out, dtype=self.dtype) + assert_allclose(primitive_grad_logits, reference_grad_logits, dtype=self.dtype) + + +@pytest.mark.parametrize('b, s_q, s_kv, h', [ + pytest.param(8, 16, 16, 16, id='8-16-16-16'), + pytest.param(8, 512, 512, 16, id='8-512-512-16'), + pytest.param(2, 8, 16384, 8, id='2-8-16384-8') +]) +@pytest.mark.parametrize('scale_factor', [0.125]) +@pytest.mark.parametrize('softmax_type', [ + pytest.param(SoftmaxType.SCALED, id='SCALED'), + pytest.param(SoftmaxType.SCALED_MASKED, id='SCALED_MASKED'), + pytest.param(SoftmaxType.SCALED_UPPER_TRIANG_MASKED, id='SCALED_UPPER_TRIANG_MASKED') +]) +@pytest.mark.parametrize('dtype', [ + pytest.param(jnp.bfloat16, id="BF16"), + pytest.param(jnp.float16, id="FP16"), +]) +class TestSoftmax: + """ + Test transformer_engine.jax.softmax.softmax + """ + + @staticmethod + def test_forward(b, s_q, s_kv, h, scale_factor, softmax_type, dtype): + """ + Test forward with parameterized configs + """ + runner = SoftmaxRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype) + runner.test_forward() + + @staticmethod + def test_backward(b, s_q, s_kv, h, scale_factor, softmax_type, dtype): + """ + Test forward with parameterized configs + """ + runner = SoftmaxRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype) + runner.test_backward() diff --git a/tests/jax/utils.py b/tests/jax/utils.py index c8e1b1b183..12b462fb8a 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -13,6 +13,7 @@ import numpy as np from flax import linen as nn from flax.linen import partitioning as nn_partitioning +from flax.linen.attention import combine_masks from jax import lax, vmap from jax import nn as jax_nn from jax import random as jax_random @@ -64,27 +65,6 @@ def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Calla raise ValueError(f"don't know how to convert {fn_or_string} to an activation function") -def combine_masks(*masks: Optional[Array], dtype: DType = jnp.float32): - """Combine attention masks. - - Args: - *masks: set of attention mask arguments to combine, some can be None. - dtype: final mask dtype - - Returns: - Combined mask, reduced by logical and, returns None if no masks given. - """ - masks = [m for m in masks if m is not None] - if not masks: - return None - assert all(map(lambda x: x.ndim == masks[0].ndim, - masks)), (f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}') - mask, *other_masks = masks - for other_mask in other_masks: - mask = jnp.logical_and(mask, other_mask) - return mask.astype(dtype) - - def combine_biases(*masks: Optional[Array]): """Combine attention biases. @@ -105,96 +85,109 @@ def combine_biases(*masks: Optional[Array]): return mask -def dot_product_attention(query: Array, - key: Array, - value: Array, - transpose_batch_sequence: bool, - bias: Optional[Array] = None, - dropout_rng: Optional[PRNGKey] = None, - dropout_rate: float = 0., - deterministic: bool = False, - dtype: DType = jnp.float32, - float32_logits: bool = False): +class DotProductAttention(nn.Module): + transpose_batch_sequence: bool = True + scale_attn_logits: bool = True + dropout_rate: float = 0. + dtype: DType = jnp.float32 + float32_logits: bool = False """Computes dot-product attention given query, key, and value. - This is the core function for applying attention based on - https://arxiv.org/abs/1706.03762. It calculates the attention weights given - query and key and combines the values using the attention weights. + This is the core function for applying attention based on + https://arxiv.org/abs/1706.03762. It calculates the attention weights given + query and key and combines the values using the attention weights. - Args: - query: queries for calculating attention with shape of `[batch, q_length, - num_heads, qk_depth_per_head]`. - key: keys for calculating attention with shape of `[batch, kv_length, - num_gqa_groups, qk_depth_per_head]`. - value: values to be used in attention with shape of `[batch, kv_length, - num_gqa_groups, v_depth_per_head]`. - bias: bias for the attention weights. This should be broadcastable to the - shape `[batch, num_heads, q_length, kv_length]` This can be used for - incorporating causal masks, padding masks, proximity bias, etc. - dropout_rng: JAX PRNGKey: to be used for dropout - dropout_rate: dropout rate - deterministic: bool, deterministic or not (to apply dropout) - dtype: the dtype of the computation (default: float32) - float32_logits: bool, if True then compute logits in float32 to avoid - numerical issues with bfloat16. + Args: + dropout_rate: dropout rate + dtype: the dtype of the computation (default: float32) + float32_logits: bool, if True then compute logits in float32 to avoid + numerical issues with bfloat16. + """ - Returns: - Output of shape `[batch, length, num_heads, v_depth_per_head]`. - """ - assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.' - batch_dim = 1 if transpose_batch_sequence else 0 - assert query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim], ( - 'q, k, v batch dims must match.') - sequence_dim = 0 if transpose_batch_sequence else 1 - assert key.shape[sequence_dim] == value.shape[sequence_dim], 'k, v lengths must match.' - assert key.shape[-2] == value.shape[-2], 'k, v num_heads must match.' - assert query.shape[-1] == key.shape[-1], 'q, k head_dim must match.' - - # Casting logits and softmax computation for float32 for model stability. - if float32_logits: - query = query.astype(jnp.float32) - key = key.astype(jnp.float32) - - # `attn_weights`: [batch, num_heads, groups, q_length, kv_length] - h_q, h_kv = query.shape[-2], key.shape[-2] - assert (h_q % h_kv == 0) and (h_q >= h_kv) - group_size = h_q // h_kv - grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1])) - - if transpose_batch_sequence: - attn_weights = jnp.einsum('qbhgd,kbhd->bhgqk', grouped_query, key) - else: - attn_weights = jnp.einsum('bqhgd,bkhd->bhgqk', grouped_query, key) - - # reshape back to normal DPA shape for bias/softmax/dropout - b, h, g, q, k = attn_weights_with_groups_shape = attn_weights.shape - attn_weights_without_groups_shape = (b, h * g, q, k) - attn_weights = attn_weights.reshape(attn_weights_without_groups_shape) - - # Apply attention bias: masking, dropout, proximity bias, etc. - if bias is not None: - attn_weights = attn_weights + bias.astype(attn_weights.dtype) - - # Normalize the attention weights across `kv_length` dimension. - attn_weights = jax_nn.softmax(attn_weights).astype(dtype) - - # Apply attention dropout. - if not deterministic and dropout_rate > 0.: - keep_prob = 1.0 - dropout_rate - # T5 broadcasts along the "length" dim, but unclear which one that - # corresponds to in positional dimensions here, assuming query dim. - dropout_shape = list(attn_weights.shape) - keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape) - multiplier = (keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=dtype)) - attn_weights = attn_weights * multiplier - - attn_weights = attn_weights.reshape(attn_weights_with_groups_shape) - - # Take the linear combination of `value`. - if transpose_batch_sequence: - return jnp.einsum('bhgqk,kbhd->qbhgd', attn_weights, value).reshape(query.shape) - - return jnp.einsum('bhgqk,bkhd->bqhgd', attn_weights, value).reshape(query.shape) + @nn.compact + def __call__(self, + query: Array, + key: Array, + value: Array, + bias: Optional[Array] = None, + deterministic: bool = False): + """ + Args: + query: queries for calculating attention with shape of `[batch, q_length, + num_heads, qk_depth_per_head]`. + key: keys for calculating attention with shape of `[batch, kv_length, + num_gqa_groups, qk_depth_per_head]`. + value: values to be used in attention with shape of `[batch, kv_length, + num_gqa_groups, v_depth_per_head]`. + bias: bias for the attention weights. This should be broadcastable to the + shape `[batch, num_heads, q_length, kv_length]` This can be used for + incorporating causal masks, padding masks, proximity bias, etc. + dropout_rng: JAX PRNGKey: to be used for dropout + deterministic: bool, deterministic or not (to apply dropout) + Returns: + Output of shape `[batch, length, num_heads, v_depth_per_head]`. + """ + assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.' + batch_dim = 1 if self.transpose_batch_sequence else 0 + assert query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim], ( + 'q, k, v batch dims must match.') + sequence_dim = 0 if self.transpose_batch_sequence else 1 + assert key.shape[sequence_dim] == value.shape[sequence_dim], 'k, v lengths must match.' + assert key.shape[-2] == value.shape[-2], 'k, v num_heads must match.' + assert query.shape[-1] == key.shape[-1], 'q, k head_dim must match.' + + if self.scale_attn_logits: + head_dim = query.shape[-1] + depth_scaling = jnp.sqrt(head_dim).astype(self.dtype) + query = query / depth_scaling + + # Casting logits and softmax computation for float32 for model stability. + if self.float32_logits: + query = query.astype(jnp.float32) + key = key.astype(jnp.float32) + + # `attn_weights`: [batch, num_heads, groups, q_length, kv_length] + h_q, h_kv = query.shape[-2], key.shape[-2] + assert (h_q % h_kv == 0) and (h_q >= h_kv) + group_size = h_q // h_kv + grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1])) + + if self.transpose_batch_sequence: + attn_weights = jnp.einsum('qbhgd,kbhd->bhgqk', grouped_query, key) + else: + attn_weights = jnp.einsum('bqhgd,bkhd->bhgqk', grouped_query, key) + + # reshape back to normal DPA shape for bias/softmax/dropout + b, h, g, q, k = attn_weights_with_groups_shape = attn_weights.shape + attn_weights_without_groups_shape = (b, h * g, q, k) + attn_weights = attn_weights.reshape(attn_weights_without_groups_shape) + + # Apply attention bias: masking, dropout, proximity bias, etc. + if bias is not None: + attn_weights = attn_weights + bias.astype(attn_weights.dtype) + + # Normalize the attention weights across `kv_length` dimension. + attn_weights = jax_nn.softmax(attn_weights).astype(self.dtype) + + # Apply attention dropout. + if not deterministic and self.dropout_rate > 0.: + keep_prob = 1.0 - self.dropout_rate + # T5 broadcasts along the "length" dim, but unclear which one that + # corresponds to in positional dimensions here, assuming query dim. + dropout_shape = list(attn_weights.shape) + dropout_rng = self.make_rng('dropout') + keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape) + multiplier = (keep.astype(attn_weights.dtype) / + jnp.asarray(keep_prob, dtype=self.dtype)) + attn_weights = attn_weights * multiplier + + attn_weights = attn_weights.reshape(attn_weights_with_groups_shape) + + # Take the linear combination of `value`. + if self.transpose_batch_sequence: + return jnp.einsum('bhgqk,kbhd->qbhgd', attn_weights, value).reshape(query.shape) + + return jnp.einsum('bhgqk,bkhd->bqhgd', attn_weights, value).reshape(query.shape) class DenseGeneral(nn.Module): @@ -253,8 +246,9 @@ def __call__(self, inputs: Array) -> Array: bias = nn_partitioning.param_with_axes('bias', self.bias_init, self.features, - self.dtype, + jnp.float32, axes=self.bias_axes) + bias = bias.astype(self.dtype) else: bias = None @@ -284,8 +278,10 @@ class MlpBlock(nn.Module): activations: Sequence[Union[str, Callable]] = ('relu',) kernel_init: Initializer = None intermediate_dropout_rate: float = 0.1 + intermediate_dropout_dims: Sequence[int] = () + use_bias: bool = False dtype: Any = jnp.float32 - fuse_wi: bool = False + fuse_wi: bool = True def __post_init__(self): if self.kernel_init is None: @@ -306,6 +302,8 @@ def __call__(self, inputs, deterministic: bool = False): dtype=self.dtype, kernel_init=self.kernel_init, kernel_axes=('embed', 'mlp'), + use_bias=self.use_bias, + bias_axes=('mlp'), name=dense_name)(inputs) x = jnp.split(x, num_activations, axis=-1) for idx, act_fn in enumerate(self.activations): @@ -318,16 +316,18 @@ def __call__(self, inputs, deterministic: bool = False): dtype=self.dtype, kernel_init=self.kernel_init, kernel_axes=('embed', 'mlp'), + use_bias=self.use_bias, + bias_axes=('mlp'), name=dense_name)(inputs) x = _convert_to_activation_function(act_fn)(x) activations.append(x) # Take elementwise product of above intermediate activations. x = functools.reduce(operator.mul, activations) - dropout_broadcast_dims = (0,) if self.transpose_batch_sequence else (1,) # Apply dropout and final dense output projection. - x = nn.Dropout(rate=self.intermediate_dropout_rate, broadcast_dims=dropout_broadcast_dims)( - x, deterministic=deterministic) # Broadcast along length. + x = nn.Dropout(rate=self.intermediate_dropout_rate, + broadcast_dims=self.intermediate_dropout_dims)( + x, deterministic=deterministic) # Broadcast along length. if self.transpose_batch_sequence: x = nn_partitioning.with_sharding_constraint(x, ('length', 'batch', 'mlp')) else: @@ -336,6 +336,8 @@ def __call__(self, inputs, deterministic: bool = False): dtype=self.dtype, kernel_init=self.kernel_init, kernel_axes=('mlp', 'embed'), + use_bias=self.use_bias, + bias_axes=('embed'), name='wo')(x) return output @@ -369,7 +371,6 @@ def apply_rotary_pos_emb_consecutive( min_timescale: int = 1, max_timescale: int = 10000, ): - embedding_dim = inputs.shape[-1] half_embedding_dim = embedding_dim // 2 fraction = 2 * jnp.arange(0, half_embedding_dim) / embedding_dim @@ -429,6 +430,7 @@ class MultiHeadAttention(nn.Module): enable_rotary_pos_emb: bool = False rotary_pos_emb_group_method: str = 'consecutive' fuse_qkv: bool = True + use_bias: bool = False def __post_init__(self): if self.kernel_init is None: @@ -478,12 +480,16 @@ def __call__(self, axis=-1, features=self.num_heads * self.head_dim, kernel_axes=('embed', 'joined_kv'), + use_bias=self.use_bias, + bias_axes=('joined_kv'), dtype=self.dtype) kv_projection = functools.partial(DenseGeneral, axis=-1, features=self.num_gqa_groups * self.head_dim, kernel_axes=('embed', 'joined_kv'), + use_bias=self.use_bias, + bias_axes=('joined_kv'), dtype=self.dtype) # NOTE: T5 does not explicitly rescale the attention logits by @@ -519,26 +525,27 @@ def qkv_init(key, shape, dtype): features=self.num_heads * self.head_dim * 3, kernel_axes=('embed', 'joined_kv'), kernel_init=qkv_init, + use_bias=self.use_bias, + bias_axes=('joined_kv'), name='qkv', dtype=self.dtype)(inputs_kv) query, key, value = jnp.split( qkv_proj, [self.num_heads * self.head_dim, self.num_heads * self.head_dim * 2], axis=-1) - if self.scale_attn_logits: - query = query / depth_scaling else: - query = q_projection(kernel_init=query_init, name='query')( \ - (inputs_q / depth_scaling) if self.scale_attn_logits else inputs_q) + query = q_projection(kernel_init=query_init, name='query')(inputs_q) + kv_proj = DenseGeneral(axis=-1, features=self.num_gqa_groups * self.head_dim * 2, kernel_axes=('embed', 'joined_kv'), kernel_init=self.kernel_init, + use_bias=self.use_bias, + bias_axes=('joined_kv'), name='kv', dtype=self.dtype)(inputs_kv) key, value = jnp.split(kv_proj, [self.num_gqa_groups * self.head_dim], axis=-1) else: - query = q_projection(kernel_init=query_init, name='query')( \ - (inputs_q / depth_scaling) if self.scale_attn_logits else inputs_q) + query = q_projection(kernel_init=query_init, name='query')(inputs_q) key = kv_projection(kernel_init=self.kernel_init, name='key')(inputs_kv) value = kv_projection(kernel_init=self.kernel_init, name='value')(inputs_kv) @@ -546,15 +553,18 @@ def qkv_init(key, shape, dtype): batch_dim = 1 if self.transpose_batch_sequence else 0 seq_dim = 1 - batch_dim - position = jnp.expand_dims(jnp.arange(query.shape[seq_dim]), axis=batch_dim) + q_position = jnp.expand_dims(jnp.arange(query.shape[seq_dim]), axis=batch_dim) + k_position = jnp.expand_dims(jnp.arange(query.shape[seq_dim]), axis=batch_dim) if self.rotary_pos_emb_group_method == 'alternate': apply_rotary_pos_emb = apply_rotary_pos_emb_alternate else: apply_rotary_pos_emb = apply_rotary_pos_emb_consecutive - query = apply_rotary_pos_emb(query, position) - key = apply_rotary_pos_emb(key, position) + query = query.reshape((*query.shape[:2], self.num_heads, self.head_dim)) + key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim)) + query = apply_rotary_pos_emb(query, q_position) + key = apply_rotary_pos_emb(key, k_position) query = query.reshape((*query.shape[:2], self.num_heads, self.head_dim)) key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim)) @@ -656,21 +666,16 @@ def qkv_init(key, shape, dtype): if bias is not None: attention_bias = combine_biases(attention_bias, bias) - dropout_rng = None - if not deterministic and self.dropout_rate > 0.: - dropout_rng = self.make_rng('dropout') - # Apply attention. - x = dot_product_attention(query, - key, - value, - transpose_batch_sequence=self.transpose_batch_sequence, - bias=attention_bias, - dropout_rng=dropout_rng, - dropout_rate=self.dropout_rate, - deterministic=deterministic, - dtype=self.dtype, - float32_logits=self.float32_logits) + x = DotProductAttention(transpose_batch_sequence=self.transpose_batch_sequence, + scale_attn_logits=self.scale_attn_logits, + dropout_rate=self.dropout_rate, + dtype=self.dtype, + float32_logits=self.float32_logits)(query, + key, + value, + bias=attention_bias, + deterministic=deterministic) x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3])) @@ -685,6 +690,8 @@ def qkv_init(key, shape, dtype): axis=-1, kernel_init=self.kernel_init, kernel_axes=('joined_kv', 'embed'), + use_bias=self.use_bias, + bias_axes=('embed'), dtype=self.dtype, name='out')(x) return out @@ -858,27 +865,36 @@ def __call__(self, qlen, klen, bidirectional=True): class EncoderLayer(nn.Module): """Transformer encoder layer.""" + enable_relative_embedding: bool = True relative_embedding: nn.Module = None num_attention_heads: int = 8 num_gqa_groups: int | None = None head_dim: int = 64 - dropout_rate: float = 0.1 + hidden_dropout: float = 0.1 + hidden_dropout_dims: Sequence[int] = () + attention_dropout: float = 0.1 + intermediate_dropout: float = 0.1 + intermediate_dropout_dims: Sequence[int] = () transpose_batch_sequence: bool = True float32_attention_logits: bool = False scale_attn_logits: bool = False scaled_query_init: bool = True mlp_dim: int = 2048 mlp_activations: Sequence[str] = ('relu',) + use_bias: bool = False dtype: Any = jnp.float32 apply_residual_connection_post_layernorm: bool = False layernorm_type: str = 'layernorm' + layernorm_epsilon: float = 1e-6 zero_centered_gamma: bool = False output_layernorm: bool = False drop_path: float = 0.0 enable_rotary_pos_emb: bool = False rotary_pos_emb_group_method: str = 'consecutive' fuse_qkv_params: bool = True - fuse_mlp_wi: bool = False + fuse_mlp_wi: bool = True + self_attn_bias_type: Any = None + self_attn_mask_type: Any = None def __post_init__(self): if self.num_gqa_groups is None: @@ -887,21 +903,25 @@ def __post_init__(self): @nn.compact def __call__(self, inputs, encoder_mask=None, deterministic=False): + del self.self_attn_mask_type # dummy, just align to TE's impl # Relative position embedding as attention biases. sequence_dim = 0 if self.transpose_batch_sequence else 1 batch_dim = 1 - sequence_dim - if self.relative_embedding is None: - rel_emb = RelativePositionBiases(num_buckets=32, - max_distance=128, - num_heads=self.num_attention_heads, - dtype=self.dtype, - embedding_init=nn.initializers.variance_scaling( - 1.0, 'fan_avg', 'uniform'), - name='relpos_bias') + if self.enable_relative_embedding: + if self.relative_embedding is None: + rel_emb = RelativePositionBiases(num_buckets=32, + max_distance=128, + num_heads=self.num_attention_heads, + dtype=self.dtype, + embedding_init=nn.initializers.variance_scaling( + 1.0, 'fan_avg', 'uniform'), + name='relpos_bias') + else: + rel_emb = self.relative_embedding + encoder_bias = rel_emb(inputs.shape[sequence_dim], inputs.shape[sequence_dim], True) else: - rel_emb = self.relative_embedding - encoder_bias = rel_emb(inputs.shape[sequence_dim], inputs.shape[sequence_dim], True) + encoder_bias = None # Attention block. residual = inputs @@ -909,6 +929,7 @@ def __call__(self, inputs, encoder_mask=None, deterministic=False): if not self.output_layernorm: # Attention block. x = LayerNorm(layernorm_type=self.layernorm_type, + epsilon=self.layernorm_epsilon, zero_centered_gamma=self.zero_centered_gamma, dtype=self.dtype, name="pre_attention_layer_norm")(inputs) @@ -924,20 +945,21 @@ def __call__(self, inputs, encoder_mask=None, deterministic=False): dtype=self.dtype, head_dim=self.head_dim, transpose_batch_sequence=self.transpose_batch_sequence, - dropout_rate=self.dropout_rate, + dropout_rate=self.attention_dropout, float32_logits=self.float32_attention_logits, scale_attn_logits=self.scale_attn_logits, scaled_query_init=self.scaled_query_init, fuse_qkv=self.fuse_qkv_params, enable_rotary_pos_emb=self.enable_rotary_pos_emb, rotary_pos_emb_group_method=self.rotary_pos_emb_group_method, + use_bias=self.use_bias, name='attention')(x, x, encoder_mask, encoder_bias, deterministic=deterministic) - x = nn.Dropout(rate=self.dropout_rate, - broadcast_dims=(sequence_dim,))(x, deterministic=deterministic) + x = nn.Dropout(rate=self.hidden_dropout, + broadcast_dims=self.hidden_dropout_dims)(x, deterministic=deterministic) if self.drop_path > 0.0: drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim) x = nn.Dropout(rate=self.drop_path, @@ -947,6 +969,7 @@ def __call__(self, inputs, encoder_mask=None, deterministic=False): # MLP block. residual = x y = LayerNorm(layernorm_type=self.layernorm_type, + epsilon=self.layernorm_epsilon, zero_centered_gamma=self.zero_centered_gamma, dtype=self.dtype, name='pre_mlp_layer_norm')(x) @@ -959,13 +982,15 @@ def __call__(self, inputs, encoder_mask=None, deterministic=False): transpose_batch_sequence=self.transpose_batch_sequence, intermediate_dim=self.mlp_dim, activations=self.mlp_activations, - intermediate_dropout_rate=self.dropout_rate, + intermediate_dropout_rate=self.intermediate_dropout, + intermediate_dropout_dims=self.intermediate_dropout_dims, + use_bias=self.use_bias, dtype=self.dtype, fuse_wi=self.fuse_mlp_wi, name='mlp', )(y, deterministic=deterministic) - y = nn.Dropout(rate=self.dropout_rate, - broadcast_dims=(sequence_dim,))(y, deterministic=deterministic) + y = nn.Dropout(rate=self.hidden_dropout, + broadcast_dims=self.hidden_dropout_dims)(y, deterministic=deterministic) if self.drop_path > 0.0: drop_path_shape = _generate_drop_path_shape(y.shape, batch_dim) y = nn.Dropout(rate=self.drop_path, @@ -974,6 +999,7 @@ def __call__(self, inputs, encoder_mask=None, deterministic=False): if self.output_layernorm: y = LayerNorm(layernorm_type=self.layernorm_type, + epsilon=self.layernorm_epsilon, zero_centered_gamma=self.zero_centered_gamma, dtype=self.dtype, name="output_layernorm")(y) @@ -982,27 +1008,36 @@ def __call__(self, inputs, encoder_mask=None, deterministic=False): class DecoderLayer(nn.Module): """Transformer decoder layer that attends to the encoder.""" + enable_relative_embedding: bool = True relative_embedding: nn.Module = None num_attention_heads: int = 8 num_gqa_groups: int | None = None head_dim: int = 64 - dropout_rate: float = 0.1 + hidden_dropout: float = 0.1 + hidden_dropout_dims: Sequence[int] = () + attention_dropout: float = 0.1 + intermediate_dropout: float = 0.1 + intermediate_dropout_dims: Sequence[int] = () transpose_batch_sequence: bool = True float32_attention_logits: bool = False scale_attn_logits: bool = False scaled_query_init: bool = True mlp_dim: int = 2048 mlp_activations: Sequence[str] = ('relu',) + use_bias: bool = False dtype: Any = jnp.float32 apply_residual_connection_post_layernorm: bool = False output_layernorm: bool = False layernorm_type: str = 'layernorm' + layernorm_epsilon: float = 1e-6 zero_centered_gamma: bool = False drop_path: float = 0.0 enable_rotary_pos_emb: bool = False rotary_pos_emb_group_method: str = 'consecutive' fuse_qkv_params: bool = True - fuse_mlp_wi: bool = False + fuse_mlp_wi: bool = True + self_attn_bias_type: Any = None + self_attn_mask_type: Any = None def __post_init__(self): if self.num_gqa_groups is None: @@ -1018,22 +1053,26 @@ def __call__(self, deterministic=False, decode=False, max_decode_length=None): - + del self.self_attn_mask_type # dummy, just align to TE's impl # Relative position embedding as attention biases. sequence_dim = 0 if self.transpose_batch_sequence else 1 batch_dim = 1 - sequence_dim - l = max_decode_length if decode and max_decode_length else inputs.shape[sequence_dim] - if self.relative_embedding is None: - rel_emb = RelativePositionBiases(num_buckets=32, - max_distance=128, - num_heads=self.num_attention_heads, - dtype=self.dtype, - embedding_init=nn.initializers.variance_scaling( - 1.0, 'fan_avg', 'uniform'), - name='relpos_bias') + + if self.enable_relative_embedding: + l = max_decode_length if decode and max_decode_length else inputs.shape[sequence_dim] + if self.relative_embedding is None: + rel_emb = RelativePositionBiases(num_buckets=32, + max_distance=128, + num_heads=self.num_attention_heads, + dtype=self.dtype, + embedding_init=nn.initializers.variance_scaling( + 1.0, 'fan_avg', 'uniform'), + name='relpos_bias') + else: + rel_emb = self.relative_embedding + decoder_bias = rel_emb(l, l, False) else: - rel_emb = self.relative_embedding - decoder_bias = rel_emb(l, l, False) + decoder_bias = None # inputs: embedded inputs to the decoder with shape [batch, length, emb_dim] residual = inputs @@ -1041,6 +1080,7 @@ def __call__(self, if not self.output_layernorm: # Attention block. x = LayerNorm(layernorm_type=self.layernorm_type, + epsilon=self.layernorm_epsilon, zero_centered_gamma=self.zero_centered_gamma, dtype=self.dtype, name="pre_self_attention_layer_norm")(inputs) @@ -1056,21 +1096,22 @@ def __call__(self, dtype=self.dtype, head_dim=self.head_dim, transpose_batch_sequence=self.transpose_batch_sequence, - dropout_rate=self.dropout_rate, + dropout_rate=self.attention_dropout, float32_logits=self.float32_attention_logits, scale_attn_logits=self.scale_attn_logits, scaled_query_init=self.scaled_query_init, enable_rotary_pos_emb=self.enable_rotary_pos_emb, rotary_pos_emb_group_method=self.rotary_pos_emb_group_method, fuse_qkv=self.fuse_qkv_params, + use_bias=self.use_bias, name='self_attention')(x, x, decoder_mask, decoder_bias, deterministic=deterministic, decode=decode) - x = nn.Dropout(rate=self.dropout_rate, - broadcast_dims=(sequence_dim,))(x, deterministic=deterministic) + x = nn.Dropout(rate=self.hidden_dropout, + broadcast_dims=self.hidden_dropout_dims)(x, deterministic=deterministic) if self.drop_path > 0.0: drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim) x = nn.Dropout(rate=self.drop_path, @@ -1080,6 +1121,7 @@ def __call__(self, # Encoder-Decoder block. residual = x y = LayerNorm(layernorm_type=self.layernorm_type, + epsilon=self.layernorm_epsilon, zero_centered_gamma=self.zero_centered_gamma, dtype=self.dtype, name='pre_cross_attention_layer_norm')(x) @@ -1091,24 +1133,26 @@ def __call__(self, dtype=self.dtype, head_dim=self.head_dim, transpose_batch_sequence=self.transpose_batch_sequence, - dropout_rate=self.dropout_rate, + dropout_rate=self.attention_dropout, float32_logits=self.float32_attention_logits, scale_attn_logits=self.scale_attn_logits, scaled_query_init=self.scaled_query_init, enable_rotary_pos_emb=self.enable_rotary_pos_emb, rotary_pos_emb_group_method=self.rotary_pos_emb_group_method, fuse_qkv=self.fuse_qkv_params, + use_bias=self.use_bias, name='encoder_decoder_attention')(y, encoded, encoder_decoder_mask, deterministic=deterministic) - y = nn.Dropout(rate=self.dropout_rate, - broadcast_dims=(sequence_dim,))(y, deterministic=deterministic) + y = nn.Dropout(rate=self.hidden_dropout, + broadcast_dims=self.hidden_dropout_dims)(y, deterministic=deterministic) y = y + residual # MLP block. residual = y z = LayerNorm(layernorm_type=self.layernorm_type, + epsilon=self.layernorm_epsilon, zero_centered_gamma=self.zero_centered_gamma, dtype=self.dtype, name='pre_mlp_layer_norm')(y) @@ -1118,13 +1162,15 @@ def __call__(self, transpose_batch_sequence=self.transpose_batch_sequence, intermediate_dim=self.mlp_dim, activations=self.mlp_activations, - intermediate_dropout_rate=self.dropout_rate, + intermediate_dropout_rate=self.intermediate_dropout, + intermediate_dropout_dims=self.intermediate_dropout_dims, + use_bias=self.use_bias, dtype=self.dtype, fuse_wi=self.fuse_mlp_wi, name='mlp', )(z, deterministic=deterministic) - z = nn.Dropout(rate=self.dropout_rate, - broadcast_dims=(sequence_dim,))(z, deterministic=deterministic) + z = nn.Dropout(rate=self.hidden_dropout, + broadcast_dims=self.hidden_dropout_dims)(z, deterministic=deterministic) if self.drop_path > 0.0: drop_path_shape = _generate_drop_path_shape(z.shape, batch_dim) z = nn.Dropout(rate=self.drop_path, @@ -1133,6 +1179,7 @@ def __call__(self, if self.output_layernorm: z = LayerNorm(layernorm_type=self.layernorm_type, + epsilon=self.layernorm_epsilon, zero_centered_gamma=self.zero_centered_gamma, dtype=self.dtype, name="output_layernorm")(z) @@ -1210,6 +1257,21 @@ def assert_allclose( np.testing.assert_allclose(actual, desired, **tols, **kwargs) +def assert_tree_like_allclose(expected, actual, rtol=1e-05, atol=1e-08): + flatten_expected, _ = jax.tree_util.tree_flatten_with_path(expected) + flatten_actual, _ = jax.tree_util.tree_flatten_with_path(actual) + + for (expected_path, expected_value), (actual_path, + actual_value) in zip(flatten_expected, flatten_actual): + assert expected_path == actual_path + key_str = jax.tree_util.keystr(expected_path) + assert_allclose(expected_value, + actual_value, + rtol=rtol, + atol=atol, + err_msg=f'Value of expected{key_str} and actual{key_str} is not close') + + def dtype_tols( dtype: Union[DType, TEDType, np.dtype], reference_value: float = 1.0, @@ -1259,3 +1321,36 @@ def dtype_tols( rtol=eps_relaxed, atol=max(ulp, eps_relaxed), ) + + +def sync_params_values(dst, src, transformations, sep='/'): + """ + This function will reconstuct a tree with dst's tree_def/shape and src's value. + transformations is a map that records the key mappings between dst and src. + If no dst key found in the transformerations, it will fall back to src key = dst key. + transformations = { + dst key map 0: src key map 0, + dst key map 1: src key map 1, + ... + # if dst key = src key, we don't need to add it + } + """ + src_values = {} + for key, value in jax.tree_util.tree_leaves_with_path(src): + normalized_key = sep.join(x.key for x in key) + src_values[normalized_key] = value + + flatten_dst, dst_tree_def = jax.tree_util.tree_flatten_with_path(dst) + synced_dst_values = [] + + for key, value in flatten_dst: + normalized_key = sep.join(x.key for x in key) + if normalized_key in transformations: + corresponding_src_key = transformations[normalized_key] + else: + corresponding_src_key = normalized_key + synced_dst_values.append(src_values[corresponding_src_key]) + + synced_dst = jax.tree_util.tree_unflatten(dst_tree_def, synced_dst_values) + + return jax.tree_util.tree_map(lambda x, y: x.reshape(y.shape), synced_dst, dst) diff --git a/transformer_engine/jax/cpp_extensions.py b/transformer_engine/jax/cpp_extensions.py index 87c5e5fe29..00e3d81481 100644 --- a/transformer_engine/jax/cpp_extensions.py +++ b/transformer_engine/jax/cpp_extensions.py @@ -1069,7 +1069,7 @@ class SoftmaxPrimitive(BasePrimitive): """ Softmax Primitive """ - max_k_seqlen_supported = 4096 + max_k_seqlen_supported = 16384 name = "te_softmax_internal_placeholder" @staticmethod @@ -1324,8 +1324,7 @@ def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype = dtypes.canonicalize_dtype(dtype) if (dtype in [jnp.float16, jnp.bfloat16] - and 16 < k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported - # k_seqlen must be 16 ~ 4096 + and 16 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported and q_seqlen % 4 == 0 # q_seqlen must be divisor of 4 and attn_batches % 4 == 0 # batch * heads must be divisor of 4 ): @@ -1483,8 +1482,7 @@ def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype = dtypes.canonicalize_dtype(dtype) if (dtype in [jnp.float16, jnp.bfloat16] - and 16 < k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported - # k_seqlen must be 16 ~ 4096 + and 16 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported and q_seqlen % 4 == 0 # q_seqlen must be divisor of 4 and attn_batches % 4 == 0 # batch * heads must be divisor of 4 ): @@ -1695,11 +1693,10 @@ def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype = dtypes.canonicalize_dtype(dtype) if (dtype in [jnp.float16, jnp.bfloat16] - and 16 < k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported - # k_seqlen must be 16 ~ 4096 + and 16 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported and q_seqlen % 4 == 0 # q_seqlen must be divisor of 4 and attn_batches % 4 == 0 # batch * heads must be divisor of 4 - ): + and k_seqlen == q_seqlen): if 0 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported: batch_per_block = SoftmaxPrimitive.get_batch_per_block(k_seqlen) return attn_batches % batch_per_block == 0 diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index b95689f6b0..66cf91c3de 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -1035,21 +1035,25 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): if use_fused_layernorm_mlp: assert self.axis == -1 # Only support axis = =-1 at this moment - bias_1_shape = intermediate_dim if self.use_bias else 0 - bias_1 = nn_partitioning.param_with_axes('wi_bias', - self.bias_init, - bias_1_shape, - jnp.float32, - axes=self.bias_axes_1) - bias_1 = bias_1.astype(self.dtype) - - bias_2_shape = (hidden_size,) if self.use_bias else (0,) - bias_2 = nn_partitioning.param_with_axes('wo_bias', - self.bias_init, - bias_2_shape, - jnp.float32, - axes=self.bias_axes_2) - bias_2 = bias_2.astype(self.dtype) + if self.use_bias: + bias_1_shape = intermediate_dim + bias_1 = nn_partitioning.param_with_axes('wi_bias', + self.bias_init, + bias_1_shape, + jnp.float32, + axes=self.bias_axes_1) + bias_1 = bias_1.astype(self.dtype) + + bias_2_shape = (hidden_size,) + bias_2 = nn_partitioning.param_with_axes('wo_bias', + self.bias_init, + bias_2_shape, + jnp.float32, + axes=self.bias_axes_2) + bias_2 = bias_2.astype(self.dtype) + else: + bias_1 = jnp.empty(0, self.dtype) + bias_2 = jnp.empty(0, self.dtype) out = fused_layernorm_fp8_mlp(y, scale, diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index cacb360a27..6898d1989a 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -1103,7 +1103,7 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): else: assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD - # No changes to memory layout, should trigger bicast only (Ideally no Perf impact) + # No changes to memory layout, should trigger bitcast only (Ideally no Perf impact) query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim)) key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim)) @@ -1161,8 +1161,6 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): bias = dynamic_vector_slice_in_dim(jnp.squeeze(bias, axis=0), jnp.reshape(cur_index, (-1)), 1, -2) - scale_factor = 1.0 / sqrt(self.head_dim) if self.scale_attn_logits else 1.0 - LEADING_AXES = (BATCH_AXES, SEQLEN_AXES) if self.transpose_batch_sequence: LEADING_AXES = (SEQLEN_AXES, BATCH_AXES) @@ -1192,6 +1190,7 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): value = with_sharding_constraint_by_logical_axes(value, qkv_sharding_constraint) dpa_args = [query, key, value] + scale_factor = 1.0 / sqrt(self.head_dim) if self.scale_attn_logits else 1.0 x = DotProductAttention(head_dim=self.head_dim, num_attention_heads=self.num_attention_heads, num_gqa_groups=self.num_gqa_groups, From 5db9ed957a768f538792d4c393d7f3b68919702c Mon Sep 17 00:00:00 2001 From: Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com> Date: Fri, 3 May 2024 09:20:18 -0700 Subject: [PATCH 049/244] [JAX] Generalizing Activation Primitives (#810) * templated primitives and respective C++ functions Signed-off-by: Phuong Nguyen * fixes for LayerNormMLP, tests in test_custom_compute all passed Signed-off-by: Phuong Nguyen * added default arg for pybind get_workspace_size funcs Signed-off-by: Phuong Nguyen * fixes for TestTransFormer with non-gated act tests Signed-off-by: Phuong Nguyen * renamed gelu to act Signed-off-by: Phuong Nguyen * improved enum implementation, avoid using magic numbers Signed-off-by: Phuong Nguyen * Exposed C++ ActivationEnum to python side Signed-off-by: Phuong Nguyen * Changed error messages Signed-off-by: Phuong Nguyen * changed conditional check on input shape for dbias_cast_transpose Signed-off-by: Phuong Nguyen * changed dtype (tol) for bias grad tests Signed-off-by: Phuong Nguyen * fixes so that layer_norm_fp8_mlp can take bias = None Signed-off-by: Phuong Nguyen * Set bias = None in flax modules Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen Signed-off-by: Pawel Gadzinski --- tests/jax/test_custom_call_compute.py | 26 +- .../common/transpose/cast_transpose_fusion.cu | 11 +- transformer_engine/jax/cpp_extensions.py | 2012 +++-------------- transformer_engine/jax/csrc/extensions.cpp | 38 +- transformer_engine/jax/csrc/modules.cpp | 489 ++-- transformer_engine/jax/csrc/modules.h | 53 +- transformer_engine/jax/flax/module.py | 10 +- transformer_engine/jax/mlp.py | 158 +- 8 files changed, 531 insertions(+), 2266 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 6555aa29ac..8779058080 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -194,8 +194,8 @@ def test_grad_fused_layernorm_fp8_mlp(self, m, n, k, b1 = jax.random.normal(subkeys[3], (len(activation_type), n), jnp.bfloat16) b2 = jax.random.normal(subkeys[4], (k,), jnp.bfloat16) else: - b1 = jax.random.normal(subkeys[3], (0,), jnp.bfloat16) - b2 = jax.random.normal(subkeys[4], (0,), jnp.bfloat16) + b1 = None + b2 = None init_fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM * 2) init_fp8_metas_amax = jnp.zeros( @@ -300,19 +300,19 @@ def ref_func(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale, assert_allclose(jnp.asarray(primitive_k1_grad, np.float32), jnp.asarray(ref_k1_grad, np.float32), dtype=FP8Helper.BWD_DTYPE) - assert_allclose(jnp.asarray(primitive_k2_grad, np.float32), - jnp.asarray(ref_k2_grad, np.float32), - dtype=FP8Helper.BWD_DTYPE) assert_allclose(jnp.asarray(primitive_s_grad, np.float32), jnp.asarray(ref_s_grad, np.float32), dtype=FP8Helper.BWD_DTYPE) + assert_allclose(jnp.asarray(primitive_k2_grad, np.float32), + jnp.asarray(ref_k2_grad, np.float32), + dtype=FP8Helper.BWD_DTYPE) if use_bias: - assert_allclose(jnp.asarray(primitive_b1_grad, np.float32), - jnp.asarray(ref_b1_grad, np.float32), - dtype=jnp.bfloat16) assert_allclose(jnp.asarray(primitive_b2_grad, np.float32), jnp.asarray(ref_b2_grad, np.float32), - dtype=jnp.bfloat16) + dtype=FP8Helper.BWD_DTYPE) + assert_allclose(jnp.asarray(primitive_b1_grad, np.float32), + jnp.asarray(ref_b1_grad, np.float32), + dtype=FP8Helper.BWD_DTYPE) @pytest.fixture(name="random_inputs") @@ -341,13 +341,14 @@ def ref_act_lu(inputs): def primitive_func(self, inputs): return jnp.mean(activation_lu(inputs, activation_type = self.activation_type)) - @pytest.mark.parametrize('shape', [(32, 2, 64), (64, 2, 256)]) + @pytest.mark.parametrize('shape', [(32, 1, 64), (64, 1, 256)]) @pytest.mark.parametrize('activation_type', [('gelu',), ('gelu', 'linear'), ('silu',), ('silu', 'linear')]) def test_activation_lu(self, random_inputs, activation_type): x = random_inputs + x = jnp.repeat(x, len(activation_type), axis=1) self.activation_type = activation_type value_n_grad_primitive_func = jit( @@ -355,8 +356,6 @@ def test_activation_lu(self, random_inputs, activation_type): prim_out, (prim_grad,) = value_n_grad_primitive_func(x) ref_out, (ref_grad,) = self.ref_func(x, activation_type) - """ prim_grad, = prim_grad """ - """ ref_grad, = ref_grad """ assert_allclose(prim_out, ref_out, dtype=x.dtype) assert_allclose(prim_grad, ref_grad, dtype=x.dtype) @@ -372,7 +371,7 @@ def primitive_func(self, inputs, dx_trans_no_use, dbias_no_use, amax, scale, sca activation_type = self.activation_type)) @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize('shape', [(32, 2, 64), (64, 2, 256)]) + @pytest.mark.parametrize('shape', [(32, 1, 64), (64, 1, 256)]) @pytest.mark.parametrize('activation_type', [('gelu',), ('gelu', 'linear'), ('silu',), @@ -384,6 +383,7 @@ def test_activation_lu(self, random_inputs, activation_type): self.activation_type = activation_type x = random_inputs + x = jnp.repeat(x, len(activation_type), axis=1) value_n_grad_primitive_func = jit( value_and_grad(self.primitive_func, (0, 1, 2, 3, 4, 5,))) diff --git a/transformer_engine/common/transpose/cast_transpose_fusion.cu b/transformer_engine/common/transpose/cast_transpose_fusion.cu index 0a0560d470..66bed83aa0 100644 --- a/transformer_engine/common/transpose/cast_transpose_fusion.cu +++ b/transformer_engine/common/transpose/cast_transpose_fusion.cu @@ -529,11 +529,12 @@ void cast_transpose_dbias(const Tensor &input, Tensor *dbias, Tensor *workspace, cudaStream_t stream) { - // TODO - // CheckInputTensor(input, "cast_transpose_dbias_input"); - // CheckOutputTensor(*cast_output, "cast_output"); - // CheckOutputTensor(*transposed_output, "transposed_output"); - // CheckOutputTensor(*dbias, "dbias"); + if (workspace->data.dptr != nullptr) { + CheckInputTensor(input, "cast_transpose_dbias_input"); + CheckOutputTensor(*cast_output, "cast_output"); + CheckOutputTensor(*transposed_output, "transposed_output"); + CheckOutputTensor(*dbias, "dbias"); + } NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions."); diff --git a/transformer_engine/jax/cpp_extensions.py b/transformer_engine/jax/cpp_extensions.py index 00e3d81481..8f4ed045d0 100644 --- a/transformer_engine/jax/cpp_extensions.py +++ b/transformer_engine/jax/cpp_extensions.py @@ -4,7 +4,7 @@ """JAX te custom call""" from abc import ABCMeta, abstractmethod from dataclasses import dataclass -from typing import Tuple +from typing import Tuple, Sequence, Union, Callable from functools import partial, reduce import operator import os @@ -27,6 +27,7 @@ from transformer_engine_jax import NVTE_Mask_Type from transformer_engine_jax import NVTE_QKV_Layout from transformer_engine_jax import NVTE_Fused_Attn_Backend +from transformer_engine_jax import NVTE_Activation_Enum from .sharding import all_reduce_max_along_all_axes_except_PP from .sharding import all_reduce_sum_along_dp_fsdp @@ -124,6 +125,14 @@ def _check_valid_batch_dims(bdims): f"but got {dim=}" +ActivationEnum = { + ('gelu',): NVTE_Activation_Enum.GELU, + ('gelu', 'linear'): NVTE_Activation_Enum.GEGLU, + ('silu',): NVTE_Activation_Enum.SILU, + ('silu', 'linear'): NVTE_Activation_Enum.SWIGLU +} + + class BasePrimitive(metaclass=ABCMeta): """ jax primitive @@ -2556,244 +2565,28 @@ def fused_attn_bwd(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.nda is_training=is_training) -class GeluPrimitive(BasePrimitive): +class ActLuPrimitive(BasePrimitive): """ - Gelu Froward Primitive + Activation Forward Primitive """ - name = "te_gelu" + name = "te_act_lu" multiple_results = False inner_primitive = None outer_primitive = None - impl_static_args = () + impl_static_args = (1,) @staticmethod - def abstract(x_aval): + def abstract(x_aval, *, act_enum): # pylint: disable=unused-argument """ - gated_gelu abstract + act_lu abstract """ dtype = dtypes.canonicalize_dtype(x_aval.dtype) assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - out_aval = core.raise_to_shaped(x_aval) - return out_aval - - @staticmethod - def lowering(ctx, x): - """ - gated_gelu lowering rules - """ - (x_aval,) = ctx.avals_in - assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - ir_x_type = ir.RankedTensorType(x.type) - ir_x_shape = ir_x_type.shape - out_shape = ir_x_shape - - out_types = [ - ir.RankedTensorType.get(out_shape, ir_x_type.element_type), - ] - operands = [x] - operand_shapes = [ir_x_shape] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - - hidden_size = ir_x_shape[-1] - batch_size = reduce(operator.mul, ir_x_shape[:-1]) - in_dtype = jax_dtype_to_te_dtype(x_aval.dtype) - opaque = transformer_engine_jax.pack_common_descriptor((batch_size, hidden_size), in_dtype, - in_dtype) - - out = custom_caller(GeluPrimitive.name, args, opaque, False) - - return [out] - - @staticmethod - def impl(x): - assert GeluPrimitive.inner_primitive is not None - out = GeluPrimitive.inner_primitive.bind(x) - return out - - @staticmethod - def batcher(batched_args, batch_dims): - """ - gated_gelu batcher - """ - _check_valid_batch_dims(batch_dims) - assert GeluPrimitive.outer_primitive is not None - inputs, = batched_args - inputs_bdim, = batch_dims - - out_bdims = inputs_bdim - return GeluPrimitive.outer_primitive.bind(inputs), out_bdims - - @staticmethod - def infer_sharding_from_operands(mesh, arg_infos, result_infos): - """ - gated_gelu infer_sharding_from_operands - """ - del result_infos # Unused. - x_spec = get_padded_spec(arg_infos[0]) - out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) - return out_sharding - - @staticmethod - def partition(mesh, arg_infos, result_infos): - """ - gated_gelu partitioning - """ - del result_infos - x_spec = get_padded_spec(arg_infos[0]) - arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) - out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) - impl = GeluPrimitive.impl - return mesh, impl, out_sharding, arg_shardings - - -register_primitive(GeluPrimitive) - - -def gelu(inputs: jnp.ndarray) -> jnp.ndarray: - """ - gelu wrapper - Return geglu(inputs) - Assume inputs has two dimensions shape and the memory layout is (N..., H) - """ - return GeluPrimitive.outer_primitive.bind(inputs) - - -class DGeluPrimitive(BasePrimitive): - """ - Dgated Gelu Primitive - """ - name = "te_dgelu" - multiple_results = False - inner_primitive = None - outer_primitive = None - impl_static_args = () - - @staticmethod - def abstract(dz_aval, x_aval): - """ - dgelu abstract - """ - dtype = dtypes.canonicalize_dtype(dz_aval.dtype) - assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - assert x_aval.dtype == dtype - assert dz_aval.shape == x_aval.shape - - out_aval = core.raise_to_shaped(x_aval) - return out_aval - - @staticmethod - def lowering(ctx, dz, x): - """ - dgelu lowering rules - """ - in_aval, gi_aval = ctx.avals_in - assert in_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - assert gi_aval.dtype == in_aval.dtype - ir_in_type = ir.RankedTensorType(dz.type) - ir_in_shape = ir_in_type.shape - gi_type = ir.RankedTensorType(x.type) - gi_shape = gi_type.shape - assert ir_in_shape == gi_shape - - ir_batch_size = reduce(operator.mul, ir_in_shape[:-1]) - i_hidden_size = ir_in_shape[-1] - out_dtype = ir_in_type.element_type - out_shape = gi_shape - - out_types = [ - ir.RankedTensorType.get(out_shape, out_dtype), - ] - operands = [dz, x] - operand_shapes = [ir_in_shape, gi_shape] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - - in_dtype = jax_dtype_to_te_dtype(in_aval.dtype) - opaque = transformer_engine_jax.pack_common_descriptor((ir_batch_size, i_hidden_size), - in_dtype, in_dtype) - - out = custom_caller(DGeluPrimitive.name, args, opaque, False) - - return [out] - - @staticmethod - def impl(dz, x): - """ - dgelu implementation - """ - assert DGeluPrimitive.inner_primitive is not None - dx = DGeluPrimitive.inner_primitive.bind(dz, x) - return dx - - @staticmethod - def batcher(batched_args, batch_dims): - """ - dgelu batcher - """ - _check_valid_batch_dims(batch_dims) - assert DGeluPrimitive.outer_primitive is not None - dz, x = batched_args - _, x_bdim = batch_dims - - out_bdims = x_bdim - return DGeluPrimitive.outer_primitive.bind(dz, x), out_bdims - - @staticmethod - def infer_sharding_from_operands(mesh, arg_infos, result_infos): - """ - dgelu infer_sharding_from_operands - """ - del result_infos # Unused. - gelu_out_spec = get_padded_spec(arg_infos[1]) - dx_sharding = NamedSharding(mesh, PartitionSpec(*gelu_out_spec)) - return dx_sharding - - @staticmethod - def partition(mesh, arg_infos, result_infos): - """ - dgelu partition - """ - del result_infos - dx_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) - arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) - out_shardings = dx_sharding - impl = DGeluPrimitive.impl - return mesh, impl, out_shardings, arg_shardings - - -register_primitive(DGeluPrimitive) - - -def dgelu(inputs: jnp.ndarray, gelu_inputs: jnp.ndarray) -> jnp.ndarray: - """ - dgelu fusion wrapper - Return dgeglu(inputs) - """ - return DGeluPrimitive.outer_primitive.bind(inputs, gelu_inputs) - - -class GatedGeluPrimitive(BasePrimitive): - """ - Gated Gelu Froward Primitive - """ - name = "te_gated_gelu" - multiple_results = False - inner_primitive = None - outer_primitive = None - impl_static_args = () - - @staticmethod - def abstract(x_aval): - """ - gated_gelu abstract - """ - dtype = dtypes.canonicalize_dtype(x_aval.dtype) - assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] x_shape = x_aval.shape - assert x_shape[-2] == 2 # Assume x in (....., 2, hidden) + assert (x_shape[-2] == 2 or x_shape[-2] == 1) hidden_size = x_shape[-1] batch_shapes = x_shape[:-2] - x_shape = x_aval.shape out_aval = core.raise_to_shaped(x_aval) out_shape = (batch_shapes) + (hidden_size,) out_aval = out_aval.update(shape=out_shape, dtype=dtype) @@ -2801,9 +2594,9 @@ def abstract(x_aval): return out_aval @staticmethod - def lowering(ctx, x): + def lowering(ctx, x, *, act_enum): """ - gated_gelu lowering rules + act_lu lowering rules """ (x_aval,) = ctx.avals_in assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] @@ -2821,100 +2614,101 @@ def lowering(ctx, x): hidden_size = ir_x_shape[-1] batch_size = reduce(operator.mul, ir_x_shape[:-2]) in_dtype = jax_dtype_to_te_dtype(x_aval.dtype) - opaque = transformer_engine_jax.pack_common_descriptor((batch_size, hidden_size), in_dtype, - in_dtype) + opaque = transformer_engine_jax.pack_common_descriptor( + (batch_size, hidden_size), in_dtype, in_dtype, act_enum) - out = custom_caller(GatedGeluPrimitive.name, args, opaque, False) + out = custom_caller(ActLuPrimitive.name, args, opaque, False) return [out] @staticmethod - def impl(x): - assert GatedGeluPrimitive.inner_primitive is not None - out = GatedGeluPrimitive.inner_primitive.bind(x) + def impl(x, act_enum): + assert ActLuPrimitive.inner_primitive is not None + out = ActLuPrimitive.inner_primitive.bind(x, act_enum=act_enum) return out @staticmethod - def batcher(batched_args, batch_dims): + def batcher(batched_args, batch_dims, *, act_enum): """ - gated_gelu batcher + act_lu batcher """ _check_valid_batch_dims(batch_dims) - assert GatedGeluPrimitive.outer_primitive is not None + assert ActLuPrimitive.outer_primitive is not None inputs, = batched_args inputs_bdim, = batch_dims out_bdims = inputs_bdim - return GatedGeluPrimitive.outer_primitive.bind(inputs), out_bdims + return ActLuPrimitive.outer_primitive.bind(inputs, act_enum=act_enum), out_bdims @staticmethod - def infer_sharding_from_operands(mesh, arg_infos, result_infos): + def infer_sharding_from_operands(act_enum, mesh, arg_infos, result_infos): """ - gated_gelu infer_sharding_from_operands + act_lu infer_sharding_from_operands """ - del result_infos # Unused. + del result_infos, act_enum # Unused. x_spec = get_padded_spec(arg_infos[0]) out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1])) return out_sharding @staticmethod - def partition(mesh, arg_infos, result_infos): + def partition(act_enum, mesh, arg_infos, result_infos): """ - gated_gelu partitioning + act_lu partitioning """ - del result_infos + del result_infos, act_enum x_spec = get_padded_spec(arg_infos[0]) arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1])) - impl = GatedGeluPrimitive.impl + impl = ActLuPrimitive.impl return mesh, impl, out_sharding, arg_shardings -register_primitive(GatedGeluPrimitive) - +register_primitive(ActLuPrimitive) -def gated_gelu(inputs: jnp.ndarray) -> jnp.ndarray: +def act_lu(inputs: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]) -> jnp.ndarray: """ - gated gelu wrapper - Return FP8(geglu(inputs)) - Assume inputs has two dimensions shape and the memory layout is (N, 2, H) + act_lu wrapper + Return act_lu(inputs) + Input shape: (N, 1, H) for non-gated activations + (N, 2, H) for gated activations """ - return GatedGeluPrimitive.outer_primitive.bind(inputs) + act_type_id = ActivationEnum[activation_type] + return ActLuPrimitive.outer_primitive.bind(inputs, act_enum=act_type_id) -class DgatedGeluPrimitive(BasePrimitive): +class DActLuPrimitive(BasePrimitive): """ - Dgated Gelu Primitive + Dgated ActLu Primitive """ - name = "te_dgated_gelu" + name = "te_dact_lu" multiple_results = False inner_primitive = None outer_primitive = None - impl_static_args = () + impl_static_args = (2,) @staticmethod - def abstract(dz_aval, x_aval): + def abstract(dz_aval, x_aval, *, act_enum): # pylint: disable=unused-argument """ - dgated_gelu abstract + dact_lu abstract """ dtype = dtypes.canonicalize_dtype(dz_aval.dtype) assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert x_aval.dtype == dtype for axis in range(len(dz_aval.shape) - 1): assert dz_aval.shape[axis] == x_aval.shape[axis] - - assert x_aval.shape[-2] == 2 # Assume x in (....., 2, hidden) + assert (x_aval.shape[-2] == 2 or x_aval.shape[-2] == 1) i_hidden_size = dz_aval.shape[-1] g_hidden_size = x_aval.shape[-1] assert i_hidden_size == g_hidden_size out_aval = core.raise_to_shaped(x_aval) + return out_aval @staticmethod - def lowering(ctx, dz, x): + def lowering(ctx, dz, x, *, act_enum): """ - dgated_gelu lowering rules + dact_lu lowering rules """ in_aval, gi_aval = ctx.avals_in assert in_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] @@ -2942,66 +2736,68 @@ def lowering(ctx, dz, x): in_dtype = jax_dtype_to_te_dtype(in_aval.dtype) opaque = transformer_engine_jax.pack_common_descriptor((ir_batch_size, i_hidden_size), - in_dtype, in_dtype) + in_dtype, in_dtype, act_enum) - out = custom_caller(DgatedGeluPrimitive.name, args, opaque, False) + out = custom_caller(DActLuPrimitive.name, args, opaque, False) return [out] @staticmethod - def impl(dz, x): + def impl(dz, x, act_enum): """ - dgated_gelu implementation + dact_lu implementation """ - assert DgatedGeluPrimitive.inner_primitive is not None - dx = DgatedGeluPrimitive.inner_primitive.bind(dz, x) + assert DActLuPrimitive.inner_primitive is not None + dx = DActLuPrimitive.inner_primitive.bind(dz, x, act_enum=act_enum) return dx @staticmethod - def batcher(batched_args, batch_dims): + def batcher(batched_args, batch_dims, *, act_enum): """ - dgated_gelu batcher + dact_lu batcher """ _check_valid_batch_dims(batch_dims) - assert DgatedGeluPrimitive.outer_primitive is not None + assert DActLuPrimitive.outer_primitive is not None dz, x = batched_args _, x_bdim = batch_dims out_bdims = x_bdim - return DgatedGeluPrimitive.outer_primitive.bind(dz, x), out_bdims + return DActLuPrimitive.outer_primitive.bind(dz, x, act_enum=act_enum), out_bdims @staticmethod - def infer_sharding_from_operands(mesh, arg_infos, result_infos): + def infer_sharding_from_operands(act_enum, mesh, arg_infos, result_infos): """ - dgated_gelu infer_sharding_from_operands + dact_lu infer_sharding_from_operands """ - del result_infos # Unused. - gelu_out_spec = get_padded_spec(arg_infos[1]) - dx_sharding = NamedSharding(mesh, PartitionSpec(*gelu_out_spec)) + del result_infos, act_enum # Unused. + act_lu_out_spec = get_padded_spec(arg_infos[1]) + dx_sharding = NamedSharding(mesh, PartitionSpec(*act_lu_out_spec)) return dx_sharding @staticmethod - def partition(mesh, arg_infos, result_infos): + def partition(act_enum, mesh, arg_infos, result_infos): """ - dgated_gelu partition + dact_lu partition """ - del result_infos + del result_infos, act_enum dx_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) out_shardings = dx_sharding - impl = DgatedGeluPrimitive.impl + impl = DActLuPrimitive.impl return mesh, impl, out_shardings, arg_shardings -register_primitive(DgatedGeluPrimitive) +register_primitive(DActLuPrimitive) -def dgated_gelu(inputs: jnp.ndarray, gelu_inputs: jnp.ndarray) -> jnp.ndarray: +def dact_lu(inputs: jnp.ndarray, act_lu_inputs: jnp.ndarray, + activation_type: Sequence[Union[str, Callable]]) -> jnp.ndarray: """ - dgated_gelu fusion wrapper - Return dgeglu(inputs) + dact_lu fusion wrapper + Return dgated_act_lu(inputs) """ - return DgatedGeluPrimitive.outer_primitive.bind(inputs, gelu_inputs) + act_type_id = ActivationEnum[activation_type] + return DActLuPrimitive.outer_primitive.bind(inputs, act_lu_inputs, act_enum=act_type_id) def _normalize_axis_boundary(axis, ndim): @@ -3958,20 +3754,21 @@ def rmsnorm_fwd_fp8(x: jnp.ndarray, gamma: jnp.ndarray, amax: jnp.ndarray, scale epsilon=epsilon) -class GeluFp8Primitive(BasePrimitive): +class ActLuFp8Primitive(BasePrimitive): """ - Gelu FP8 Primitive + ActLu FP8 Primitive """ - name = "te_gelu_fp8" + name = "te_act_lu_fp8" multiple_results = True - impl_static_args = (4,) #out_dtype + impl_static_args = (4, 5) #out_dtype, act_enum inner_primitive = None outer_primitive = None @staticmethod - def abstract(x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype): + def abstract(x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype, + act_enum): # pylint: disable=unused-argument """ - te_gelu_p abstract + te_act_lu_p abstract """ dtype = dtypes.canonicalize_dtype(x_aval.dtype) # Currently only support casting to E4M3 only in C side. @@ -3981,15 +3778,19 @@ def abstract(x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype): assert scale_aval.dtype == jnp.float32 assert scale_inv_aval.dtype == jnp.float32 - out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype) + assert (x_aval.shape[-2] == 1 or x_aval.shape[-2] == 2) + hidden_size = x_aval.shape[-1] + batch_shape = x_aval.shape[:-2] + out_shape = (batch_shape) + (hidden_size,) + out_aval = x_aval.update(shape=out_shape, dtype=out_dtype) updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype) return out_aval, updated_amax_aval @staticmethod - def lowering(ctx, x, amax, scale, scale_inv, *, out_dtype): + def lowering(ctx, x, amax, scale, scale_inv, *, out_dtype, act_enum): """ - te_gated_gelu_p lowering rules + te_gated_act_lu_p lowering rules """ x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] @@ -4006,8 +3807,9 @@ def lowering(ctx, x, amax, scale, scale_inv, *, out_dtype): ir_scale_inv_shape = ir_amax_shape hidden_size = ir_x_shape[-1] - batch_size = reduce(operator.mul, ir_x_shape[:-1]) - out_shape = ir_x_shape + batch_shape = ir_x_shape[:-2] + batch_size = reduce(operator.mul, batch_shape) + out_shape = batch_shape + [hidden_size] out_types = [ ir.RankedTensorType.get(out_shape, ir_out_dtype), ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), @@ -4016,11 +3818,13 @@ def lowering(ctx, x, amax, scale, scale_inv, *, out_dtype): operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - opaque = transformer_engine_jax.pack_common_descriptor((batch_size, hidden_size), - jax_dtype_to_te_dtype(x_aval.dtype), - jax_dtype_to_te_dtype(out_dtype)) + opaque = transformer_engine_jax.pack_common_descriptor(( + batch_size, hidden_size), + jax_dtype_to_te_dtype(x_aval.dtype), + jax_dtype_to_te_dtype(out_dtype), + act_enum) - out = custom_caller(GeluFp8Primitive.name, + out = custom_caller(ActLuFp8Primitive.name, args, opaque, False, @@ -4029,55 +3833,58 @@ def lowering(ctx, x, amax, scale, scale_inv, *, out_dtype): return out @staticmethod - def impl(x, amax, scale, scale_inv, out_dtype): + def impl(x, amax, scale, scale_inv, out_dtype, act_enum): """ to describe implementation """ - assert GeluFp8Primitive.inner_primitive is not None - out, updated_amax = GeluFp8Primitive.inner_primitive.bind(x, - amax, - scale, - scale_inv, - out_dtype=out_dtype) + assert ActLuFp8Primitive.inner_primitive is not None + out, updated_amax = ActLuFp8Primitive.inner_primitive.bind(x, + amax, + scale, + scale_inv, + out_dtype=out_dtype, + act_enum=act_enum) return out, updated_amax @staticmethod - def batcher(batched_args, batch_dims, *, out_dtype): + def batcher(batched_args, batch_dims, *, out_dtype, act_enum): """ to describe batch rules for vmap """ _check_valid_batch_dims(batch_dims) - assert GeluFp8Primitive.outer_primitive is not None + assert ActLuFp8Primitive.outer_primitive is not None x, amax, scale, scale_inv = batched_args x_bdim, amax_bdim, _, _ = batch_dims out_bdims = x_bdim, amax_bdim - return GeluFp8Primitive.outer_primitive.bind(x, amax, scale, scale_inv, - out_dtype=out_dtype), out_bdims + return ActLuFp8Primitive.outer_primitive.bind(x, amax, scale, scale_inv, + out_dtype=out_dtype, + act_enum=act_enum), out_bdims @staticmethod - def infer_sharding_from_operands(out_dtype, mesh, arg_infos, result_infos): - del out_dtype, result_infos + def infer_sharding_from_operands(out_dtype, act_enum, mesh, arg_infos, result_infos): + del out_dtype, result_infos, act_enum x_spec = get_padded_spec(arg_infos[0]) - out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) + out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1])) amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) return (out_sharding, amax_sharding) @staticmethod - def partition(out_dtype, mesh, arg_infos, result_infos): + def partition(out_dtype, act_enum, mesh, arg_infos, result_infos): del result_infos x_spec = get_padded_spec(arg_infos[0]) - out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) + out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1])) amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) out_shardings = (out_sharding, amax_sharding) def sharded_impl(x, amax, scale, scale_inv): - local_x, local_amax = GeluFp8Primitive.impl(x, - amax, - scale, - scale_inv, - out_dtype=out_dtype) + local_x, local_amax = ActLuFp8Primitive.impl(x, + amax, + scale, + scale_inv, + out_dtype=out_dtype, + act_enum=act_enum) global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax) return local_x, global_updated_amax @@ -4085,34 +3892,40 @@ def sharded_impl(x, amax, scale, scale_inv): return mesh, sharded_impl, out_shardings, arg_shardings -register_primitive(GeluFp8Primitive) +register_primitive(ActLuFp8Primitive) -def gelu_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray, - out_dtype: jnp.dtype) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: +def act_lu_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray, + out_dtype: jnp.dtype, activation_type: Sequence[Union[str, Callable]] + ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """ - gated gelu wrapper - Return FP8(geglu(x)) + act wrapper + Return FP8(act_lu(x)) + Input shape: (N, 1, H) for non-gated activations + (N, 2, H) for gated activations """ - return GeluFp8Primitive.outer_primitive.bind(x, amax, scale, scale_inv, out_dtype=out_dtype) + act_type_id = ActivationEnum[activation_type] + return ActLuFp8Primitive.outer_primitive.bind(x, amax, scale, scale_inv, out_dtype=out_dtype, + act_enum = act_type_id) -class DGeluDBiasCastTransposePrimitive(BasePrimitive): +class DActLuDBiasCastTransposePrimitive(BasePrimitive): """ - DGelu DBias Cast Transpose Primitive + DActLu DBias Cast Transpose Primitive """ - name = "te_dgelu_dbias_cast_transpose" + name = "te_dact_lu_dbias_cast_transpose" multiple_results = True - # out_dtype, static_axis_boundary, transpose_axis_boundary - impl_static_args = (5, 6, 7) + # out_dtype, static_axis_boundary, transpose_axis_boundary, act_enum + impl_static_args = (5, 6, 7, 8) inner_primitive = None outer_primitive = None @staticmethod def abstract(dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype, - static_axis_boundary, transpose_axis_boundary): + static_axis_boundary, transpose_axis_boundary, + act_enum): # pylint: disable=unused-argument """ - te_dgelu_dbais_cast_transpose_p abstract + te_dact_lu_dbais_cast_transpose_p abstract """ dtype = dtypes.canonicalize_dtype(dz_aval.dtype) assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] @@ -4123,7 +3936,8 @@ def abstract(dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtyp ir_hidden_szie = dz_aval.shape[-1] gi_hidden_size = x_aval.shape[-1] assert ir_hidden_szie == gi_hidden_size - t_shape = _multidim_transpose(x_aval.shape, static_axis_boundary, transpose_axis_boundary) + t_shape = _multidim_transpose(x_aval.shape, + static_axis_boundary, transpose_axis_boundary) out = dz_aval.update(shape=x_aval.shape, dtype=out_dtype) t_out = dz_aval.update(shape=t_shape, dtype=out_dtype) @@ -4146,18 +3960,18 @@ def abstract(dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtyp @staticmethod def outer_abstract(*args, **kwargs): """ - te_dgelu_dbais_cast_transpose_p outer abstract + te_dact_lu_dbais_cast_transpose_p outer abstract """ out, t_out, dbias, updated_amax_aval, _ = \ - DGeluDBiasCastTransposePrimitive.abstract(*args, **kwargs) + DActLuDBiasCastTransposePrimitive.abstract(*args, **kwargs) return out, t_out, dbias, updated_amax_aval @staticmethod def lowering(ctx, dz, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary, - transpose_axis_boundary): + transpose_axis_boundary, act_enum): """ - te_dgated_gelu_cast_transpose_p lowering rules + te_dgated_act_lu_cast_transpose_p lowering rules """ dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] @@ -4169,11 +3983,11 @@ def lowering(ctx, dz, x, amax, scale, scale_inv, *, out_dtype, static_axis_bound ir_dz_shape = ir_dz_type.shape x_type = ir.RankedTensorType(x.type) x_shape = x_type.shape - assert ir_dz_shape == x_shape - - batch_szie = reduce(operator.mul, ir_dz_shape[:-1]) + dz_batch_szie = reduce(operator.mul, ir_dz_shape[:-1]) + x_batch_size = reduce(operator.mul, x_shape[:-2]) + assert dz_batch_szie == x_batch_size ir_hidden_szie = ir_dz_shape[-1] - contracted_x_shape = (batch_szie, ir_hidden_szie) + contracted_x_shape = (x_batch_size, ir_hidden_szie) ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype) ir_amax_type = ir.RankedTensorType(amax.type) @@ -4199,9 +4013,10 @@ def lowering(ctx, dz, x, amax, scale, scale_inv, *, out_dtype, static_axis_bound args = CustomCallArgsWrapper(out_types, operands, operand_shapes) opaque = transformer_engine_jax.pack_common_wk_descriptor( contracted_x_shape, wkspace_aval.shape, jax_dtype_to_te_dtype(dz_aval.dtype), - jax_dtype_to_te_dtype(out_dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype)) + jax_dtype_to_te_dtype(out_dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), + act_enum) - out = custom_caller(DGeluDBiasCastTransposePrimitive.name, + out = custom_caller(DActLuDBiasCastTransposePrimitive.name, args, opaque, False, @@ -4211,12 +4026,12 @@ def lowering(ctx, dz, x, amax, scale, scale_inv, *, out_dtype, static_axis_bound @staticmethod def impl(dz, x, amax, scale, scale_inv, out_dtype, static_axis_boundary, - transpose_axis_boundary): + transpose_axis_boundary, act_enum): """ to describe implementation """ - assert DGeluDBiasCastTransposePrimitive.inner_primitive is not None - out, t_out, dbias, updated_amax, _ = DGeluDBiasCastTransposePrimitive.inner_primitive.bind( + assert DActLuDBiasCastTransposePrimitive.inner_primitive is not None + out, t_out, dbias, updated_amax, _ = DActLuDBiasCastTransposePrimitive.inner_primitive.bind( dz, x, amax, @@ -4224,18 +4039,19 @@ def impl(dz, x, amax, scale, scale_inv, out_dtype, static_axis_boundary, scale_inv, out_dtype=out_dtype, static_axis_boundary=static_axis_boundary, - transpose_axis_boundary=transpose_axis_boundary) + transpose_axis_boundary=transpose_axis_boundary, + act_enum=act_enum) return out, t_out, dbias, updated_amax @staticmethod def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary, - transpose_axis_boundary): + transpose_axis_boundary, act_enum): """ to describe batch rules for vmap """ del static_axis_boundary _check_valid_batch_dims(batch_dims) - assert DGeluDBiasCastTransposePrimitive.outer_primitive is not None + assert DActLuDBiasCastTransposePrimitive.outer_primitive is not None dz, x, amax, scale, scale_inv = batched_args x_bdim, _, amax_bdim, _, _ = batch_dims @@ -4244,7 +4060,7 @@ def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary, transpose_axis_boundary += 1 # Plus batch dim out_bdims = x_bdim, x_bdim, x_bdim, amax_bdim - return DGeluDBiasCastTransposePrimitive.outer_primitive.bind( + return DActLuDBiasCastTransposePrimitive.outer_primitive.bind( dz, x, amax, @@ -4252,12 +4068,13 @@ def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary, scale_inv, out_dtype=out_dtype, static_axis_boundary=x_bdim, - transpose_axis_boundary=transpose_axis_boundary), out_bdims + transpose_axis_boundary=transpose_axis_boundary, + act_enum=act_enum), out_bdims @staticmethod - def infer_sharding_from_operands(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, - arg_infos, result_infos): - del out_dtype, result_infos + def infer_sharding_from_operands(out_dtype, static_axis_boundary, transpose_axis_boundary, + act_enum, mesh, arg_infos, result_infos): + del out_dtype, result_infos, act_enum x_spec = get_padded_spec(arg_infos[1]) out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary) @@ -4268,8 +4085,8 @@ def infer_sharding_from_operands(out_dtype, static_axis_boundary, transpose_axis return (out_sharding, tranposed_out_sharding, dbias_shaprding, amax_sharding) @staticmethod - def partition(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, - result_infos): + def partition(out_dtype, static_axis_boundary, transpose_axis_boundary, + act_enum, mesh, arg_infos, result_infos): del result_infos x_spec = get_padded_spec(arg_infos[1]) casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) @@ -4285,7 +4102,8 @@ def partition(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, ar amax_sharding) def sharded_impl(dz, x, amax, scale, scale_inv): - local_out, local_t_out, local_dbias, local_amax = DGeluDBiasCastTransposePrimitive.impl( + local_out, local_t_out, local_dbias, local_amax =\ + DActLuDBiasCastTransposePrimitive.impl( dz, x, amax, @@ -4293,7 +4111,8 @@ def sharded_impl(dz, x, amax, scale, scale_inv): scale_inv, out_dtype=out_dtype, static_axis_boundary=static_axis_boundary, - transpose_axis_boundary=transpose_axis_boundary) + transpose_axis_boundary=transpose_axis_boundary, + act_enum=act_enum) global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias) global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax) return local_out, local_t_out, global_dbias, global_updated_amax @@ -4301,26 +4120,30 @@ def sharded_impl(dz, x, amax, scale, scale_inv): return mesh, sharded_impl, out_shardings, arg_shardings -register_primitive(DGeluDBiasCastTransposePrimitive) +register_primitive(DActLuDBiasCastTransposePrimitive) -def dgelu_dbias_cast_transpose( - dz: jnp.ndarray, - x: jnp.ndarray, - amax: jnp.ndarray, - scale: jnp.ndarray, - scale_inv: jnp.ndarray, - out_dtype: TEDType, - static_axis_boundary: int, - transpose_axis_boundary: int = -1) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: +def dact_lu_dbias_cast_transpose( + dz: jnp.ndarray, + x: jnp.ndarray, + amax: jnp.ndarray, + scale: jnp.ndarray, + scale_inv: jnp.ndarray, + out_dtype: TEDType, + static_axis_boundary: int, + transpose_axis_boundary: int = -1, + activation_type: Sequence[Union[str, Callable]] = ('gelu',) + ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """ - cast transpose dgelu and dbias fusion wrapper - Return FP8(dgeglu(inputs)), dbias + cast transpose dact_lu and dbias fusion wrapper + Return FP8(dact_lu(inputs)), dbias + ONLY support non-gated activation type """ if static_axis_boundary < 0: static_axis_boundary = -1 # means no static axes - return DGeluDBiasCastTransposePrimitive.outer_primitive.bind( + act_type_id = ActivationEnum[activation_type] + return DActLuDBiasCastTransposePrimitive.outer_primitive.bind( dz, x, amax, @@ -4328,7 +4151,8 @@ def dgelu_dbias_cast_transpose( scale_inv, out_dtype=out_dtype, static_axis_boundary=static_axis_boundary, - transpose_axis_boundary=transpose_axis_boundary) + transpose_axis_boundary=transpose_axis_boundary, + act_enum=act_type_id) class DBiasCastTransposePrimitive(BasePrimitive): @@ -4353,13 +4177,11 @@ def abstract(dz_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype, assert amax_aval.dtype == jnp.float32 assert scale_aval.dtype == jnp.float32 assert scale_inv_aval.dtype == jnp.float32 - gi_hidden_size = dz_aval.shape[-1] + gi_hidden_size = reduce(operator.mul, dz_aval.shape[transpose_axis_boundary:]) t_shape = _multidim_transpose(dz_aval.shape, static_axis_boundary, transpose_axis_boundary) out = dz_aval.update(shape=dz_aval.shape, dtype=out_dtype) t_out = dz_aval.update(shape=t_shape, dtype=out_dtype) - if dz_aval.shape[-2] == 2: - gi_hidden_size *= 2 dbias_shape = (*dz_aval.shape[:static_axis_boundary + 1], gi_hidden_size) dbias = dz_aval.update(shape=dbias_shape, dtype=dtype) @@ -4398,13 +4220,9 @@ def lowering(ctx, dz, amax, scale, scale_inv, *, out_dtype, static_axis_boundary assert scale_inv_aval.dtype == jnp.float32 ir_dz_type = ir.RankedTensorType(dz.type) ir_dz_shape = ir_dz_type.shape - ir_hidden_szie = ir_dz_shape[-1] - if dz_aval.shape[-2] == 2: - batch_szie = reduce(operator.mul, ir_dz_shape[:-2]) - ir_hidden_szie *= 2 - else: - batch_szie = reduce(operator.mul, ir_dz_shape[:-1]) - contracted_dz_shape = (batch_szie, ir_hidden_szie) + batch_size = reduce(operator.mul, ir_dz_shape[:transpose_axis_boundary]) + ir_hidden_size = reduce(operator.mul, ir_dz_shape[transpose_axis_boundary:]) + contracted_dz_shape = (batch_size, ir_hidden_size) ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype) ir_amax_type = ir.RankedTensorType(amax.type) ir_amax_dtype = ir_amax_type.element_type @@ -4413,7 +4231,7 @@ def lowering(ctx, dz, amax, scale, scale_inv, *, out_dtype, static_axis_boundary ir_scale_inv_shape = ir_amax_shape transposed_dz_shape = _multidim_transpose(ir_dz_shape, static_axis_boundary, transpose_axis_boundary) - dbias_shape = (*ir_dz_shape[:static_axis_boundary + 1], ir_hidden_szie) + dbias_shape = (*ir_dz_shape[:static_axis_boundary + 1], ir_hidden_size) wkspace_aval = ctx.avals_out[-1] @@ -4556,1356 +4374,60 @@ def dbias_cast_transpose( transpose_axis_boundary=transpose_axis_boundary) -class GatedGeluFp8Primitive(BasePrimitive): +class DgatedActLuCastTransposePrimitive(BasePrimitive): """ - Gated Gelu FP8 Primitive + Dgated ActLu Cast Transpose Primitive """ - name = "te_gated_gelu_fp8" + name = "te_dgated_act_lu_cast_transpose" multiple_results = True - impl_static_args = (4,) #out_dtype + impl_static_args = (5, 6, 7) # out_dtype, static_axis_boundary, act_enum inner_primitive = None outer_primitive = None @staticmethod - def abstract(x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype): + def abstract(dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype, + static_axis_boundary, act_enum): # pylint: disable=unused-argument """ - te_gated_gelu_p abstract + te_dgated_act_lu_cast_transpose_p abstract """ - dtype = dtypes.canonicalize_dtype(x_aval.dtype) - # Currently only support casting to E4M3 only in C side. - assert out_dtype == jnp.float8_e4m3fn + dtype = dtypes.canonicalize_dtype(dz_aval.dtype) assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + assert x_aval.dtype == dtype + assert x_aval.shape[-2] == 2 # Linear + GeLU assert amax_aval.dtype == jnp.float32 assert scale_aval.dtype == jnp.float32 assert scale_inv_aval.dtype == jnp.float32 - - assert x_aval.shape[-2] == 2 # Assume x in (....., 2, hidden) - hidden_size = x_aval.shape[-1] - batch_shape = x_aval.shape[:-2] - out_shape = (batch_shape) + (hidden_size,) - out_aval = x_aval.update(shape=out_shape, dtype=out_dtype) + ir_hidden_szie = dz_aval.shape[-1] + gi_hidden_size = x_aval.shape[-1] + assert ir_hidden_szie == gi_hidden_size + t_shape = _multidim_transpose(x_aval.shape, static_axis_boundary, -2) + out = dz_aval.update(shape=x_aval.shape, dtype=out_dtype) + t_out = dz_aval.update(shape=t_shape, dtype=out_dtype) updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype) - - return out_aval, updated_amax_aval + return out, t_out, updated_amax_aval @staticmethod - def lowering(ctx, x, amax, scale, scale_inv, *, out_dtype): + def lowering(ctx, dz, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary, act_enum): """ - te_gated_gelu_p lowering rules + te_dgated_act_lu_cast_transpose_p lowering rules """ - x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in - assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in + assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + assert x_aval.dtype == dz_aval.dtype assert amax_aval.dtype == jnp.float32 assert scale_aval.dtype == jnp.float32 assert scale_inv_aval.dtype == jnp.float32 - ir_x_type = ir.RankedTensorType(x.type) - ir_x_shape = ir_x_type.shape - ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype) - ir_amax_type = ir.RankedTensorType(amax.type) - ir_amax_dtype = ir_amax_type.element_type - ir_amax_shape = ir_amax_type.shape - ir_scale_shape = ir_amax_shape - ir_scale_inv_shape = ir_amax_shape - - hidden_size = ir_x_shape[-1] - batch_shape = ir_x_shape[:-2] - batch_size = reduce(operator.mul, batch_shape) - out_shape = batch_shape + [hidden_size] - out_types = [ - ir.RankedTensorType.get(out_shape, ir_out_dtype), - ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), - ] - operands = [x, amax, scale, scale_inv] - operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - - opaque = transformer_engine_jax.pack_common_descriptor((batch_size, out_shape[-1]), - jax_dtype_to_te_dtype(x_aval.dtype), - jax_dtype_to_te_dtype(out_dtype)) - - out = custom_caller(GatedGeluFp8Primitive.name, - args, - opaque, - False, - operand_output_aliases={1: 1}) - - return out - - @staticmethod - def impl(x, amax, scale, scale_inv, out_dtype): - """ - to describe implementation - """ - assert GatedGeluFp8Primitive.inner_primitive is not None - out, updated_amax = GatedGeluFp8Primitive.inner_primitive.bind(x, - amax, - scale, - scale_inv, - out_dtype=out_dtype) - return out, updated_amax - - @staticmethod - def batcher(batched_args, batch_dims, *, out_dtype): - """ - to describe batch rules for vmap - """ - _check_valid_batch_dims(batch_dims) - assert GatedGeluFp8Primitive.outer_primitive is not None - x, amax, scale, scale_inv = batched_args - x_bdim, amax_bdim, _, _ = batch_dims - - out_bdims = x_bdim, amax_bdim - return GatedGeluFp8Primitive.outer_primitive.bind(x, - amax, - scale, - scale_inv, - out_dtype=out_dtype), out_bdims - - @staticmethod - def infer_sharding_from_operands(out_dtype, mesh, arg_infos, result_infos): - del out_dtype, result_infos - x_spec = get_padded_spec(arg_infos[0]) - out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1])) - amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) - return (out_sharding, amax_sharding) - - @staticmethod - def partition(out_dtype, mesh, arg_infos, result_infos): - del result_infos - x_spec = get_padded_spec(arg_infos[0]) - out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1])) - amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) - arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) - out_shardings = (out_sharding, amax_sharding) - - def sharded_impl(x, amax, scale, scale_inv): - local_x, local_amax = GatedGeluFp8Primitive.impl(x, - amax, - scale, - scale_inv, - out_dtype=out_dtype) - global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax) - - return local_x, global_updated_amax - - return mesh, sharded_impl, out_shardings, arg_shardings - - -register_primitive(GatedGeluFp8Primitive) - - -def gated_gelu_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray, - out_dtype: jnp.dtype) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: - """ - gated gelu wrapper - Return FP8(geglu(x)) - """ - return GatedGeluFp8Primitive.outer_primitive.bind(x, - amax, - scale, - scale_inv, - out_dtype=out_dtype) - - -class DgatedGeluCastTransposePrimitive(BasePrimitive): - """ - Dgated Gelu Cast Transpose Primitive - """ - name = "te_dgated_gelu_cast_transpose" - multiple_results = True - impl_static_args = (5, 6) # out_dtype, static_axis_boundary - inner_primitive = None - outer_primitive = None - - @staticmethod - def abstract(dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype, - static_axis_boundary): - """ - te_dgated_gelu_cast_transpose_p abstract - """ - dtype = dtypes.canonicalize_dtype(dz_aval.dtype) - assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - assert x_aval.dtype == dtype - assert x_aval.shape[-2] == 2 # Linear + GeLU - assert amax_aval.dtype == jnp.float32 - assert scale_aval.dtype == jnp.float32 - assert scale_inv_aval.dtype == jnp.float32 - ir_hidden_szie = dz_aval.shape[-1] - gi_hidden_size = x_aval.shape[-1] - assert ir_hidden_szie == gi_hidden_size - t_shape = _multidim_transpose(x_aval.shape, static_axis_boundary, -2) - out = dz_aval.update(shape=x_aval.shape, dtype=out_dtype) - t_out = dz_aval.update(shape=t_shape, dtype=out_dtype) - updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype) - return out, t_out, updated_amax_aval - - @staticmethod - def lowering(ctx, dz, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary): - """ - te_dgated_gelu_cast_transpose_p lowering rules - """ - dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in - assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - assert x_aval.dtype == dz_aval.dtype - assert amax_aval.dtype == jnp.float32 - assert scale_aval.dtype == jnp.float32 - assert scale_inv_aval.dtype == jnp.float32 - ir_dz_type = ir.RankedTensorType(dz.type) - ir_dz_shape = ir_dz_type.shape - x_type = ir.RankedTensorType(x.type) - x_shape = x_type.shape - dz_batch_szie = reduce(operator.mul, ir_dz_shape[:-1]) - x_batch_size = reduce(operator.mul, x_shape[:-2]) - assert dz_batch_szie == x_batch_size - assert x_shape[-2] == 2 # Linear + GeLU - ir_hidden_szie = ir_dz_shape[-1] - gi_hidden_size = x_shape[-1] - assert ir_hidden_szie == gi_hidden_size - ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype) - ir_amax_type = ir.RankedTensorType(amax.type) - ir_amax_dtype = ir_amax_type.element_type - ir_amax_shape = ir_amax_type.shape - ir_scale_shape = ir_amax_shape - ir_scale_inv_shape = ir_amax_shape - transposed_x_shape = _multidim_transpose(x_shape, static_axis_boundary, -2) - out_types = [ - ir.RankedTensorType.get(x_shape, ir_out_dtype), - ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype), - ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), - ] - operands = [dz, x, amax, scale, scale_inv] - operand_shapes = [ir_dz_shape, x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - contracted_x_shape = (x_batch_size, x_shape[-1]) - opaque = transformer_engine_jax.pack_common_descriptor(contracted_x_shape, - jax_dtype_to_te_dtype(dz_aval.dtype), - jax_dtype_to_te_dtype(out_dtype)) - - out = custom_caller(DgatedGeluCastTransposePrimitive.name, - args, - opaque, - False, - operand_output_aliases={2: 2}) - - return out - - @staticmethod - def impl(dz, x, amax, scale, scale_inv, out_dtype, static_axis_boundary): - """ - to describe implementation - """ - assert DgatedGeluCastTransposePrimitive.inner_primitive is not None - out, t_out, updated_amax = DgatedGeluCastTransposePrimitive.inner_primitive.bind( - dz, - x, - amax, - scale, - scale_inv, - out_dtype=out_dtype, - static_axis_boundary=static_axis_boundary) - return out, t_out, updated_amax - - @staticmethod - def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary): - """ - to describe batch rules for vmap - """ - del static_axis_boundary - _check_valid_batch_dims(batch_dims) - assert DgatedGeluCastTransposePrimitive.outer_primitive is not None - dz, x, amax, scale, scale_inv = batched_args - x_bdim, _, amax_bdim, _, _ = batch_dims - - out_bdims = x_bdim, x_bdim, amax_bdim - return DgatedGeluCastTransposePrimitive.outer_primitive.bind( - dz, x, amax, scale, scale_inv, out_dtype=out_dtype, - static_axis_boundary=x_bdim), out_bdims - - @staticmethod - def infer_sharding_from_operands(out_dtype, static_axis_boundary, mesh, arg_infos, - result_infos): - del out_dtype, result_infos - x_spec = get_padded_spec(arg_infos[1]) - out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) - xt_spec = _multidim_transpose(x_spec, static_axis_boundary, -2) - tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec)) - amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2]))) - return (out_sharding, tranposed_out_sharding, amax_sharding) - - @staticmethod - def partition(out_dtype, static_axis_boundary, mesh, arg_infos, result_infos): - del result_infos - x_spec = get_padded_spec(arg_infos[1]) - casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) - xt_spec = _multidim_transpose(x_spec, static_axis_boundary, -2) - casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec)) - - amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2]))) - arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) - out_shardings = (casted_x_sharding, casted_transposed_x_sharding, amax_sharding) - - def sharded_impl(dz, x, amax, scale, scale_inv): - local_out, local_t_out, local_amax = DgatedGeluCastTransposePrimitive.impl( - dz, - x, - amax, - scale, - scale_inv, - out_dtype=out_dtype, - static_axis_boundary=static_axis_boundary) - global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax) - return local_out, local_t_out, global_updated_amax - - return mesh, sharded_impl, out_shardings, arg_shardings - - -register_primitive(DgatedGeluCastTransposePrimitive) - - -def dgated_gelu_cast_transpose( - dz: jnp.ndarray, x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, - scale_inv: jnp.ndarray, out_dtype: TEDType, - static_axis_boundary: int) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: - """ - cast transpose d_gated_gelu fusion wrapper - Return FP8(dgeglu(inputs)) - """ - return DgatedGeluCastTransposePrimitive.outer_primitive.bind( - dz, - x, - amax, - scale, - scale_inv, - out_dtype=out_dtype, - static_axis_boundary=static_axis_boundary) - -# Primitives for SwiGLU and SiLU -class SiluPrimitive(BasePrimitive): - """ - Silu Froward Primitive - """ - name = "te_silu" - multiple_results = False - inner_primitive = None - outer_primitive = None - impl_static_args = () - - @staticmethod - def abstract(x_aval): - """ - gated_silu abstract - """ - dtype = dtypes.canonicalize_dtype(x_aval.dtype) - assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - - out_aval = core.raise_to_shaped(x_aval) - return out_aval - - @staticmethod - def lowering(ctx, x): - """ - gated_silu lowering rules - """ - (x_aval,) = ctx.avals_in - assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - ir_x_type = ir.RankedTensorType(x.type) - ir_x_shape = ir_x_type.shape - out_shape = ir_x_shape - - out_types = [ - ir.RankedTensorType.get(out_shape, ir_x_type.element_type), - ] - operands = [x] - operand_shapes = [ir_x_shape] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - - hidden_size = ir_x_shape[-1] - batch_size = reduce(operator.mul, ir_x_shape[:-1]) - in_dtype = jax_dtype_to_te_dtype(x_aval.dtype) - opaque = transformer_engine_jax.pack_common_descriptor((batch_size, hidden_size), in_dtype, - in_dtype) - - out = custom_caller(SiluPrimitive.name, args, opaque, False) - - return [out] - - @staticmethod - def impl(x): - assert SiluPrimitive.inner_primitive is not None - out = SiluPrimitive.inner_primitive.bind(x) - return out - - @staticmethod - def batcher(batched_args, batch_dims): - """ - gated_silu batcher - """ - _check_valid_batch_dims(batch_dims) - assert SiluPrimitive.outer_primitive is not None - inputs, = batched_args - inputs_bdim, = batch_dims - - out_bdims = inputs_bdim - return SiluPrimitive.outer_primitive.bind(inputs), out_bdims - - @staticmethod - def infer_sharding_from_operands(mesh, arg_infos, result_infos): - """ - gated_silu infer_sharding_from_operands - """ - del result_infos # Unused. - x_spec = get_padded_spec(arg_infos[0]) - out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) - return out_sharding - - @staticmethod - def partition(mesh, arg_infos, result_infos): - """ - gated_silu partitioning - """ - del result_infos - x_spec = get_padded_spec(arg_infos[0]) - arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) - out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) - impl = SiluPrimitive.impl - return mesh, impl, out_sharding, arg_shardings - - -register_primitive(SiluPrimitive) - - -def silu(inputs: jnp.ndarray) -> jnp.ndarray: - """ - silu wrapper - Return geglu(inputs) - Assume inputs has two dimensions shape and the memory layout is (N..., H) - """ - return SiluPrimitive.outer_primitive.bind(inputs) - - -class DSiluPrimitive(BasePrimitive): - """ - Dgated Silu Primitive - """ - name = "te_dsilu" - multiple_results = False - inner_primitive = None - outer_primitive = None - impl_static_args = () - - @staticmethod - def abstract(dz_aval, x_aval): - """ - dsilu abstract - """ - dtype = dtypes.canonicalize_dtype(dz_aval.dtype) - assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - assert x_aval.dtype == dtype - assert dz_aval.shape == x_aval.shape - - out_aval = core.raise_to_shaped(x_aval) - return out_aval - - @staticmethod - def lowering(ctx, dz, x): - """ - dsilu lowering rules - """ - in_aval, gi_aval = ctx.avals_in - assert in_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - assert gi_aval.dtype == in_aval.dtype - ir_in_type = ir.RankedTensorType(dz.type) - ir_in_shape = ir_in_type.shape - gi_type = ir.RankedTensorType(x.type) - gi_shape = gi_type.shape - assert ir_in_shape == gi_shape - - ir_batch_size = reduce(operator.mul, ir_in_shape[:-1]) - i_hidden_size = ir_in_shape[-1] - out_dtype = ir_in_type.element_type - out_shape = gi_shape - - out_types = [ - ir.RankedTensorType.get(out_shape, out_dtype), - ] - operands = [dz, x] - operand_shapes = [ir_in_shape, gi_shape] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - - in_dtype = jax_dtype_to_te_dtype(in_aval.dtype) - opaque = transformer_engine_jax.pack_common_descriptor((ir_batch_size, i_hidden_size), - in_dtype, in_dtype) - - out = custom_caller(DSiluPrimitive.name, args, opaque, False) - - return [out] - - @staticmethod - def impl(dz, x): - """ - dsilu implementation - """ - assert DSiluPrimitive.inner_primitive is not None - dx = DSiluPrimitive.inner_primitive.bind(dz, x) - return dx - - @staticmethod - def batcher(batched_args, batch_dims): - """ - dsilu batcher - """ - _check_valid_batch_dims(batch_dims) - assert DSiluPrimitive.outer_primitive is not None - dz, x = batched_args - _, x_bdim = batch_dims - - out_bdims = x_bdim - return DSiluPrimitive.outer_primitive.bind(dz, x), out_bdims - - @staticmethod - def infer_sharding_from_operands(mesh, arg_infos, result_infos): - """ - dsilu infer_sharding_from_operands - """ - del result_infos # Unused. - silu_out_spec = get_padded_spec(arg_infos[1]) - dx_sharding = NamedSharding(mesh, PartitionSpec(*silu_out_spec)) - return dx_sharding - - @staticmethod - def partition(mesh, arg_infos, result_infos): - """ - dsilu partition - """ - del result_infos - dx_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) - arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) - out_shardings = dx_sharding - impl = DSiluPrimitive.impl - return mesh, impl, out_shardings, arg_shardings - - -register_primitive(DSiluPrimitive) - - -def dsilu(inputs: jnp.ndarray, silu_inputs: jnp.ndarray) -> jnp.ndarray: - """ - dsilu fusion wrapper - Return dgeglu(inputs) - """ - return DSiluPrimitive.outer_primitive.bind(inputs, silu_inputs) - - -class GatedSiluPrimitive(BasePrimitive): - """ - Gated Silu Froward Primitive - """ - name = "te_gated_silu" - multiple_results = False - inner_primitive = None - outer_primitive = None - impl_static_args = () - - @staticmethod - def abstract(x_aval): - """ - gated_silu abstract - """ - dtype = dtypes.canonicalize_dtype(x_aval.dtype) - assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - x_shape = x_aval.shape - assert x_shape[-2] == 2 # Assume x in (....., 2, hidden) - hidden_size = x_shape[-1] - batch_shapes = x_shape[:-2] - x_shape = x_aval.shape - out_aval = core.raise_to_shaped(x_aval) - out_shape = (batch_shapes) + (hidden_size,) - out_aval = out_aval.update(shape=out_shape, dtype=dtype) - - return out_aval - - @staticmethod - def lowering(ctx, x): - """ - gated_silu lowering rules - """ - (x_aval,) = ctx.avals_in - assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - ir_x_type = ir.RankedTensorType(x.type) - ir_x_shape = ir_x_type.shape - out_shape = ir_x_shape[:-2] + [ir_x_shape[-1]] - - out_types = [ - ir.RankedTensorType.get(out_shape, ir_x_type.element_type), - ] - operands = [x] - operand_shapes = [ir_x_shape] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - - hidden_size = ir_x_shape[-1] - batch_size = reduce(operator.mul, ir_x_shape[:-2]) - in_dtype = jax_dtype_to_te_dtype(x_aval.dtype) - opaque = transformer_engine_jax.pack_common_descriptor((batch_size, hidden_size), in_dtype, - in_dtype) - - out = custom_caller(GatedSiluPrimitive.name, args, opaque, False) - - return [out] - - @staticmethod - def impl(x): - assert GatedSiluPrimitive.inner_primitive is not None - out = GatedSiluPrimitive.inner_primitive.bind(x) - return out - - @staticmethod - def batcher(batched_args, batch_dims): - """ - gated_silu batcher - """ - _check_valid_batch_dims(batch_dims) - assert GatedSiluPrimitive.outer_primitive is not None - inputs, = batched_args - inputs_bdim, = batch_dims - - out_bdims = inputs_bdim - return GatedSiluPrimitive.outer_primitive.bind(inputs), out_bdims - - @staticmethod - def infer_sharding_from_operands(mesh, arg_infos, result_infos): - """ - gated_silu infer_sharding_from_operands - """ - del result_infos # Unused. - x_spec = get_padded_spec(arg_infos[0]) - out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1])) - return out_sharding - - @staticmethod - def partition(mesh, arg_infos, result_infos): - """ - gated_silu partitioning - """ - del result_infos - x_spec = get_padded_spec(arg_infos[0]) - arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) - out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1])) - impl = GatedSiluPrimitive.impl - return mesh, impl, out_sharding, arg_shardings - - -register_primitive(GatedSiluPrimitive) - - -def gated_silu(inputs: jnp.ndarray) -> jnp.ndarray: - """ - gated silu wrapper - Return FP8(geglu(inputs)) - Assume inputs has two dimensions shape and the memory layout is (N, 2, H) - """ - return GatedSiluPrimitive.outer_primitive.bind(inputs) - - -class DgatedSiluPrimitive(BasePrimitive): - """ - Dgated Silu Primitive - """ - name = "te_dgated_silu" - multiple_results = False - inner_primitive = None - outer_primitive = None - impl_static_args = () - - @staticmethod - def abstract(dz_aval, x_aval): - """ - dgated_silu abstract - """ - dtype = dtypes.canonicalize_dtype(dz_aval.dtype) - assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - assert x_aval.dtype == dtype - for axis in range(len(dz_aval.shape) - 1): - assert dz_aval.shape[axis] == x_aval.shape[axis] - - assert x_aval.shape[-2] == 2 # Assume x in (....., 2, hidden) - - i_hidden_size = dz_aval.shape[-1] - g_hidden_size = x_aval.shape[-1] - assert i_hidden_size == g_hidden_size - out_aval = core.raise_to_shaped(x_aval) - return out_aval - - @staticmethod - def lowering(ctx, dz, x): - """ - dgated_silu lowering rules - """ - in_aval, gi_aval = ctx.avals_in - assert in_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - assert gi_aval.dtype == in_aval.dtype - ir_in_type = ir.RankedTensorType(dz.type) - ir_in_shape = ir_in_type.shape - gi_type = ir.RankedTensorType(x.type) - gi_shape = gi_type.shape - for axis in range(len(ir_in_shape) - 1): - assert ir_in_shape[axis] == gi_shape[axis] - - ir_batch_size = reduce(operator.mul, ir_in_shape[:-1]) - i_hidden_size = ir_in_shape[-1] - g_hidden_size = gi_shape[-1] - assert i_hidden_size == g_hidden_size - out_dtype = ir_in_type.element_type - out_shape = gi_shape - - out_types = [ - ir.RankedTensorType.get(out_shape, out_dtype), - ] - operands = [dz, x] - operand_shapes = [ir_in_shape, gi_shape] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - - in_dtype = jax_dtype_to_te_dtype(in_aval.dtype) - opaque = transformer_engine_jax.pack_common_descriptor((ir_batch_size, i_hidden_size), - in_dtype, in_dtype) - - out = custom_caller(DgatedSiluPrimitive.name, args, opaque, False) - - return [out] - - @staticmethod - def impl(dz, x): - """ - dgated_silu implementation - """ - assert DgatedSiluPrimitive.inner_primitive is not None - dx = DgatedSiluPrimitive.inner_primitive.bind(dz, x) - return dx - - @staticmethod - def batcher(batched_args, batch_dims): - """ - dgated_silu batcher - """ - _check_valid_batch_dims(batch_dims) - assert DgatedSiluPrimitive.outer_primitive is not None - dz, x = batched_args - _, x_bdim = batch_dims - - out_bdims = x_bdim - return DgatedSiluPrimitive.outer_primitive.bind(dz, x), out_bdims - - @staticmethod - def infer_sharding_from_operands(mesh, arg_infos, result_infos): - """ - dgated_silu infer_sharding_from_operands - """ - del result_infos # Unused. - silu_out_spec = get_padded_spec(arg_infos[1]) - dx_sharding = NamedSharding(mesh, PartitionSpec(*silu_out_spec)) - return dx_sharding - - @staticmethod - def partition(mesh, arg_infos, result_infos): - """ - dgated_silu partition - """ - del result_infos - dx_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) - arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) - out_shardings = dx_sharding - impl = DgatedSiluPrimitive.impl - return mesh, impl, out_shardings, arg_shardings - - -register_primitive(DgatedSiluPrimitive) - - -def dgated_silu(inputs: jnp.ndarray, silu_inputs: jnp.ndarray) -> jnp.ndarray: - """ - dgated_silu fusion wrapper - Return dgeglu(inputs) - """ - return DgatedSiluPrimitive.outer_primitive.bind(inputs, silu_inputs) - - -class SiluFp8Primitive(BasePrimitive): - """ - Silu FP8 Primitive - """ - name = "te_silu_fp8" - multiple_results = True - impl_static_args = (4,) #out_dtype - inner_primitive = None - outer_primitive = None - - @staticmethod - def abstract(x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype): - """ - te_silu_p abstract - """ - dtype = dtypes.canonicalize_dtype(x_aval.dtype) - # Currently only support casting to E4M3 only in C side. - assert out_dtype == jnp.float8_e4m3fn - assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - assert amax_aval.dtype == jnp.float32 - assert scale_aval.dtype == jnp.float32 - assert scale_inv_aval.dtype == jnp.float32 - - out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype) - updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype) - - return out_aval, updated_amax_aval - - @staticmethod - def lowering(ctx, x, amax, scale, scale_inv, *, out_dtype): - """ - te_gated_silu_p lowering rules - """ - x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in - assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - assert amax_aval.dtype == jnp.float32 - assert scale_aval.dtype == jnp.float32 - assert scale_inv_aval.dtype == jnp.float32 - ir_x_type = ir.RankedTensorType(x.type) - ir_x_shape = ir_x_type.shape - ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype) - ir_amax_type = ir.RankedTensorType(amax.type) - ir_amax_dtype = ir_amax_type.element_type - ir_amax_shape = ir_amax_type.shape - ir_scale_shape = ir_amax_shape - ir_scale_inv_shape = ir_amax_shape - - hidden_size = ir_x_shape[-1] - batch_size = reduce(operator.mul, ir_x_shape[:-1]) - out_shape = ir_x_shape - out_types = [ - ir.RankedTensorType.get(out_shape, ir_out_dtype), - ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), - ] - operands = [x, amax, scale, scale_inv] - operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - - opaque = transformer_engine_jax.pack_common_descriptor((batch_size, hidden_size), - jax_dtype_to_te_dtype(x_aval.dtype), - jax_dtype_to_te_dtype(out_dtype)) - - out = custom_caller(SiluFp8Primitive.name, - args, - opaque, - False, - operand_output_aliases={1: 1}) - - return out - - @staticmethod - def impl(x, amax, scale, scale_inv, out_dtype): - """ - to describe implementation - """ - assert SiluFp8Primitive.inner_primitive is not None - out, updated_amax = SiluFp8Primitive.inner_primitive.bind(x, - amax, - scale, - scale_inv, - out_dtype=out_dtype) - return out, updated_amax - - @staticmethod - def batcher(batched_args, batch_dims, *, out_dtype): - """ - to describe batch rules for vmap - """ - _check_valid_batch_dims(batch_dims) - assert SiluFp8Primitive.outer_primitive is not None - x, amax, scale, scale_inv = batched_args - x_bdim, amax_bdim, _, _ = batch_dims - - out_bdims = x_bdim, amax_bdim - return SiluFp8Primitive.outer_primitive.bind(x, amax, scale, scale_inv, - out_dtype=out_dtype), out_bdims - - @staticmethod - def infer_sharding_from_operands(out_dtype, mesh, arg_infos, result_infos): - del out_dtype, result_infos - x_spec = get_padded_spec(arg_infos[0]) - out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) - amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) - return (out_sharding, amax_sharding) - - @staticmethod - def partition(out_dtype, mesh, arg_infos, result_infos): - del result_infos - x_spec = get_padded_spec(arg_infos[0]) - out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) - amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) - arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) - out_shardings = (out_sharding, amax_sharding) - - def sharded_impl(x, amax, scale, scale_inv): - local_x, local_amax = SiluFp8Primitive.impl(x, - amax, - scale, - scale_inv, - out_dtype=out_dtype) - global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax) - - return local_x, global_updated_amax - - return mesh, sharded_impl, out_shardings, arg_shardings - - -register_primitive(SiluFp8Primitive) - - -def silu_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray, - out_dtype: jnp.dtype) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: - """ - gated silu wrapper - Return FP8(geglu(x)) - """ - return SiluFp8Primitive.outer_primitive.bind(x, amax, scale, scale_inv, out_dtype=out_dtype) - - -class DSiluDBiasCastTransposePrimitive(BasePrimitive): - """ - DSilu DBias Cast Transpose Primitive - """ - name = "te_dsilu_dbias_cast_transpose" - multiple_results = True - # out_dtype, static_axis_boundary, transpose_axis_boundary - impl_static_args = (5, 6, 7) - inner_primitive = None - outer_primitive = None - - @staticmethod - def abstract(dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype, - static_axis_boundary, transpose_axis_boundary): - """ - te_dsilu_dbais_cast_transpose_p abstract - """ - dtype = dtypes.canonicalize_dtype(dz_aval.dtype) - assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - assert x_aval.dtype == dtype - assert amax_aval.dtype == jnp.float32 - assert scale_aval.dtype == jnp.float32 - assert scale_inv_aval.dtype == jnp.float32 - ir_hidden_szie = dz_aval.shape[-1] - gi_hidden_size = x_aval.shape[-1] - assert ir_hidden_szie == gi_hidden_size - t_shape = _multidim_transpose(x_aval.shape, static_axis_boundary, transpose_axis_boundary) - out = dz_aval.update(shape=x_aval.shape, dtype=out_dtype) - t_out = dz_aval.update(shape=t_shape, dtype=out_dtype) - - dbias_shape = (*x_aval.shape[:static_axis_boundary + 1], gi_hidden_size) - dbias = dz_aval.update(shape=dbias_shape, dtype=dtype) - - updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype) - - wkspace_info, = transformer_engine_jax.get_dact_dbias_ct_workspace_sizes( - x_aval.size // gi_hidden_size, - gi_hidden_size, - jax_dtype_to_te_dtype(x_aval.dtype), - jax_dtype_to_te_dtype(out_dtype), - ) - wkspace_aval = x_aval.update(shape=wkspace_info[0], - dtype=te_dtype_to_jax_dtype(wkspace_info[1])) - - return out, t_out, dbias, updated_amax_aval, wkspace_aval - - @staticmethod - def outer_abstract(*args, **kwargs): - """ - te_dsilu_dbais_cast_transpose_p outer abstract - """ - - out, t_out, dbias, updated_amax_aval, _ = \ - DSiluDBiasCastTransposePrimitive.abstract(*args, **kwargs) - return out, t_out, dbias, updated_amax_aval - - @staticmethod - def lowering(ctx, dz, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary, - transpose_axis_boundary): - """ - te_dgated_silu_cast_transpose_p lowering rules - """ - dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in - assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - assert x_aval.dtype == dz_aval.dtype - assert amax_aval.dtype == jnp.float32 - assert scale_aval.dtype == jnp.float32 - assert scale_inv_aval.dtype == jnp.float32 - ir_dz_type = ir.RankedTensorType(dz.type) - ir_dz_shape = ir_dz_type.shape - x_type = ir.RankedTensorType(x.type) - x_shape = x_type.shape - assert ir_dz_shape == x_shape - - batch_szie = reduce(operator.mul, ir_dz_shape[:-1]) - ir_hidden_szie = ir_dz_shape[-1] - contracted_x_shape = (batch_szie, ir_hidden_szie) - - ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype) - ir_amax_type = ir.RankedTensorType(amax.type) - ir_amax_dtype = ir_amax_type.element_type - ir_amax_shape = ir_amax_type.shape - ir_scale_shape = ir_amax_shape - ir_scale_inv_shape = ir_amax_shape - transposed_x_shape = _multidim_transpose(x_shape, static_axis_boundary, - transpose_axis_boundary) - dbias_shape = (*x_shape[:static_axis_boundary + 1], ir_hidden_szie) - - wkspace_aval = ctx.avals_out[-1] - - out_types = [ - ir.RankedTensorType.get(x_shape, ir_out_dtype), - ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype), - ir.RankedTensorType.get(dbias_shape, ir_dz_type.element_type), - ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), - ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)), - ] - operands = [dz, x, amax, scale, scale_inv] - operand_shapes = [ir_dz_shape, x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - opaque = transformer_engine_jax.pack_common_wk_descriptor( - contracted_x_shape, wkspace_aval.shape, jax_dtype_to_te_dtype(dz_aval.dtype), - jax_dtype_to_te_dtype(out_dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype)) - - out = custom_caller(DSiluDBiasCastTransposePrimitive.name, - args, - opaque, - False, - operand_output_aliases={2: 3}) - - return out - - @staticmethod - def impl(dz, x, amax, scale, scale_inv, out_dtype, static_axis_boundary, - transpose_axis_boundary): - """ - to describe implementation - """ - assert DSiluDBiasCastTransposePrimitive.inner_primitive is not None - out, t_out, dbias, updated_amax, _ = DSiluDBiasCastTransposePrimitive.inner_primitive.bind( - dz, - x, - amax, - scale, - scale_inv, - out_dtype=out_dtype, - static_axis_boundary=static_axis_boundary, - transpose_axis_boundary=transpose_axis_boundary) - return out, t_out, dbias, updated_amax - - @staticmethod - def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary, - transpose_axis_boundary): - """ - to describe batch rules for vmap - """ - del static_axis_boundary - _check_valid_batch_dims(batch_dims) - assert DSiluDBiasCastTransposePrimitive.outer_primitive is not None - dz, x, amax, scale, scale_inv = batched_args - x_bdim, _, amax_bdim, _, _ = batch_dims - - # Minus batch dim. - transpose_axis_boundary = _normalize_axis_boundary(transpose_axis_boundary, x.ndim - 1) - transpose_axis_boundary += 1 # Plus batch dim - - out_bdims = x_bdim, x_bdim, x_bdim, amax_bdim - return DSiluDBiasCastTransposePrimitive.outer_primitive.bind( - dz, - x, - amax, - scale, - scale_inv, - out_dtype=out_dtype, - static_axis_boundary=x_bdim, - transpose_axis_boundary=transpose_axis_boundary), out_bdims - - @staticmethod - def infer_sharding_from_operands(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, - arg_infos, result_infos): - del out_dtype, result_infos - x_spec = get_padded_spec(arg_infos[1]) - out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) - xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary) - tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec)) - dbias_shaprding = NamedSharding( - mesh, PartitionSpec(*x_spec[:static_axis_boundary + 1], x_spec[-1])) - amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2]))) - return (out_sharding, tranposed_out_sharding, dbias_shaprding, amax_sharding) - - @staticmethod - def partition(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, - result_infos): - del result_infos - x_spec = get_padded_spec(arg_infos[1]) - casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) - xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary) - casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec)) - - dbias_shaprding = NamedSharding( - mesh, PartitionSpec(*x_spec[:static_axis_boundary + 1], x_spec[-1])) - - amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2]))) - arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) - out_shardings = (casted_x_sharding, casted_transposed_x_sharding, dbias_shaprding, - amax_sharding) - - def sharded_impl(dz, x, amax, scale, scale_inv): - local_out, local_t_out, local_dbias, local_amax = DSiluDBiasCastTransposePrimitive.impl( - dz, - x, - amax, - scale, - scale_inv, - out_dtype=out_dtype, - static_axis_boundary=static_axis_boundary, - transpose_axis_boundary=transpose_axis_boundary) - global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias) - global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax) - return local_out, local_t_out, global_dbias, global_updated_amax - - return mesh, sharded_impl, out_shardings, arg_shardings - - -register_primitive(DSiluDBiasCastTransposePrimitive) - - -def dsilu_dbias_cast_transpose( - dz: jnp.ndarray, - x: jnp.ndarray, - amax: jnp.ndarray, - scale: jnp.ndarray, - scale_inv: jnp.ndarray, - out_dtype: TEDType, - static_axis_boundary: int, - transpose_axis_boundary: int = -1) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: - """ - cast transpose dsilu and dbias fusion wrapper - Return FP8(dgeglu(inputs)), dbias - """ - if static_axis_boundary < 0: - static_axis_boundary = -1 # means no static axes - - return DSiluDBiasCastTransposePrimitive.outer_primitive.bind( - dz, - x, - amax, - scale, - scale_inv, - out_dtype=out_dtype, - static_axis_boundary=static_axis_boundary, - transpose_axis_boundary=transpose_axis_boundary) - - -class GatedSiluFp8Primitive(BasePrimitive): - """ - Gated Silu FP8 Primitive - """ - name = "te_gated_silu_fp8" - multiple_results = True - impl_static_args = (4,) #out_dtype - inner_primitive = None - outer_primitive = None - - @staticmethod - def abstract(x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype): - """ - te_gated_silu_p abstract - """ - dtype = dtypes.canonicalize_dtype(x_aval.dtype) - # Currently only support casting to E4M3 only in C side. - assert out_dtype == jnp.float8_e4m3fn - assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - assert amax_aval.dtype == jnp.float32 - assert scale_aval.dtype == jnp.float32 - assert scale_inv_aval.dtype == jnp.float32 - - assert x_aval.shape[-2] == 2 # Assume x in (....., 2, hidden) - hidden_size = x_aval.shape[-1] - batch_shape = x_aval.shape[:-2] - out_shape = (batch_shape) + (hidden_size,) - out_aval = x_aval.update(shape=out_shape, dtype=out_dtype) - updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype) - - return out_aval, updated_amax_aval - - @staticmethod - def lowering(ctx, x, amax, scale, scale_inv, *, out_dtype): - """ - te_gated_silu_p lowering rules - """ - x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in - assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - assert amax_aval.dtype == jnp.float32 - assert scale_aval.dtype == jnp.float32 - assert scale_inv_aval.dtype == jnp.float32 - ir_x_type = ir.RankedTensorType(x.type) - ir_x_shape = ir_x_type.shape - ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype) - ir_amax_type = ir.RankedTensorType(amax.type) - ir_amax_dtype = ir_amax_type.element_type - ir_amax_shape = ir_amax_type.shape - ir_scale_shape = ir_amax_shape - ir_scale_inv_shape = ir_amax_shape - - hidden_size = ir_x_shape[-1] - batch_shape = ir_x_shape[:-2] - batch_size = reduce(operator.mul, batch_shape) - out_shape = batch_shape + [hidden_size] - out_types = [ - ir.RankedTensorType.get(out_shape, ir_out_dtype), - ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), - ] - operands = [x, amax, scale, scale_inv] - operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - - opaque = transformer_engine_jax.pack_common_descriptor((batch_size, out_shape[-1]), - jax_dtype_to_te_dtype(x_aval.dtype), - jax_dtype_to_te_dtype(out_dtype)) - - out = custom_caller(GatedSiluFp8Primitive.name, - args, - opaque, - False, - operand_output_aliases={1: 1}) - - return out - - @staticmethod - def impl(x, amax, scale, scale_inv, out_dtype): - """ - to describe implementation - """ - assert GatedSiluFp8Primitive.inner_primitive is not None - out, updated_amax = GatedSiluFp8Primitive.inner_primitive.bind(x, - amax, - scale, - scale_inv, - out_dtype=out_dtype) - return out, updated_amax - - @staticmethod - def batcher(batched_args, batch_dims, *, out_dtype): - """ - to describe batch rules for vmap - """ - _check_valid_batch_dims(batch_dims) - assert GatedSiluFp8Primitive.outer_primitive is not None - x, amax, scale, scale_inv = batched_args - x_bdim, amax_bdim, _, _ = batch_dims - - out_bdims = x_bdim, amax_bdim - return GatedSiluFp8Primitive.outer_primitive.bind(x, - amax, - scale, - scale_inv, - out_dtype=out_dtype), out_bdims - - @staticmethod - def infer_sharding_from_operands(out_dtype, mesh, arg_infos, result_infos): - del out_dtype, result_infos - x_spec = get_padded_spec(arg_infos[0]) - out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1])) - amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) - return (out_sharding, amax_sharding) - - @staticmethod - def partition(out_dtype, mesh, arg_infos, result_infos): - del result_infos - x_spec = get_padded_spec(arg_infos[0]) - out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1])) - amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) - arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) - out_shardings = (out_sharding, amax_sharding) - - def sharded_impl(x, amax, scale, scale_inv): - local_x, local_amax = GatedSiluFp8Primitive.impl(x, - amax, - scale, - scale_inv, - out_dtype=out_dtype) - global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax) - - return local_x, global_updated_amax - - return mesh, sharded_impl, out_shardings, arg_shardings - - -register_primitive(GatedSiluFp8Primitive) - - -def gated_silu_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray, - out_dtype: jnp.dtype) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: - """ - gated silu wrapper - Return FP8(geglu(x)) - """ - return GatedSiluFp8Primitive.outer_primitive.bind(x, - amax, - scale, - scale_inv, - out_dtype=out_dtype) - - -class DgatedSiluCastTransposePrimitive(BasePrimitive): - """ - Dgated Silu Cast Transpose Primitive - """ - name = "te_dgated_silu_cast_transpose" - multiple_results = True - impl_static_args = (5, 6) # out_dtype, static_axis_boundary - inner_primitive = None - outer_primitive = None - - @staticmethod - def abstract(dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype, - static_axis_boundary): - """ - te_dgated_silu_cast_transpose_p abstract - """ - dtype = dtypes.canonicalize_dtype(dz_aval.dtype) - assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - assert x_aval.dtype == dtype - assert x_aval.shape[-2] == 2 # Linear + GeLU - assert amax_aval.dtype == jnp.float32 - assert scale_aval.dtype == jnp.float32 - assert scale_inv_aval.dtype == jnp.float32 - ir_hidden_szie = dz_aval.shape[-1] - gi_hidden_size = x_aval.shape[-1] - assert ir_hidden_szie == gi_hidden_size - t_shape = _multidim_transpose(x_aval.shape, static_axis_boundary, -2) - out = dz_aval.update(shape=x_aval.shape, dtype=out_dtype) - t_out = dz_aval.update(shape=t_shape, dtype=out_dtype) - updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype) - return out, t_out, updated_amax_aval - - @staticmethod - def lowering(ctx, dz, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary): - """ - te_dgated_silu_cast_transpose_p lowering rules - """ - dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in - assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - assert x_aval.dtype == dz_aval.dtype - assert amax_aval.dtype == jnp.float32 - assert scale_aval.dtype == jnp.float32 - assert scale_inv_aval.dtype == jnp.float32 - ir_dz_type = ir.RankedTensorType(dz.type) - ir_dz_shape = ir_dz_type.shape - x_type = ir.RankedTensorType(x.type) - x_shape = x_type.shape - dz_batch_szie = reduce(operator.mul, ir_dz_shape[:-1]) - x_batch_size = reduce(operator.mul, x_shape[:-2]) - assert dz_batch_szie == x_batch_size - assert x_shape[-2] == 2 # Linear + GeLU - ir_hidden_szie = ir_dz_shape[-1] - gi_hidden_size = x_shape[-1] - assert ir_hidden_szie == gi_hidden_size + ir_dz_type = ir.RankedTensorType(dz.type) + ir_dz_shape = ir_dz_type.shape + x_type = ir.RankedTensorType(x.type) + x_shape = x_type.shape + dz_batch_szie = reduce(operator.mul, ir_dz_shape[:-1]) + x_batch_size = reduce(operator.mul, x_shape[:-2]) + assert dz_batch_szie == x_batch_size + assert x_shape[-2] == 2 # Linear + GeLU + ir_hidden_szie = ir_dz_shape[-1] + gi_hidden_size = x_shape[-1] + assert ir_hidden_szie == gi_hidden_size ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype) ir_amax_type = ir.RankedTensorType(amax.type) ir_amax_dtype = ir_amax_type.element_type @@ -5922,11 +4444,13 @@ def lowering(ctx, dz, x, amax, scale, scale_inv, *, out_dtype, static_axis_bound operand_shapes = [ir_dz_shape, x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] args = CustomCallArgsWrapper(out_types, operands, operand_shapes) contracted_x_shape = (x_batch_size, x_shape[-1]) - opaque = transformer_engine_jax.pack_common_descriptor(contracted_x_shape, - jax_dtype_to_te_dtype(dz_aval.dtype), - jax_dtype_to_te_dtype(out_dtype)) + opaque = transformer_engine_jax.pack_common_descriptor( + contracted_x_shape, + jax_dtype_to_te_dtype(dz_aval.dtype), + jax_dtype_to_te_dtype(out_dtype), + act_enum) - out = custom_caller(DgatedSiluCastTransposePrimitive.name, + out = custom_caller(DgatedActLuCastTransposePrimitive.name, args, opaque, False, @@ -5935,41 +4459,43 @@ def lowering(ctx, dz, x, amax, scale, scale_inv, *, out_dtype, static_axis_bound return out @staticmethod - def impl(dz, x, amax, scale, scale_inv, out_dtype, static_axis_boundary): + def impl(dz, x, amax, scale, scale_inv, out_dtype, static_axis_boundary, act_enum): """ to describe implementation """ - assert DgatedSiluCastTransposePrimitive.inner_primitive is not None - out, t_out, updated_amax = DgatedSiluCastTransposePrimitive.inner_primitive.bind( + assert DgatedActLuCastTransposePrimitive.inner_primitive is not None + out, t_out, updated_amax = DgatedActLuCastTransposePrimitive.inner_primitive.bind( dz, x, amax, scale, scale_inv, out_dtype=out_dtype, - static_axis_boundary=static_axis_boundary) + static_axis_boundary=static_axis_boundary, + act_enum=act_enum) return out, t_out, updated_amax @staticmethod - def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary): + def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary, act_enum): """ to describe batch rules for vmap """ del static_axis_boundary _check_valid_batch_dims(batch_dims) - assert DgatedSiluCastTransposePrimitive.outer_primitive is not None + assert DgatedActLuCastTransposePrimitive.outer_primitive is not None dz, x, amax, scale, scale_inv = batched_args x_bdim, _, amax_bdim, _, _ = batch_dims out_bdims = x_bdim, x_bdim, amax_bdim - return DgatedSiluCastTransposePrimitive.outer_primitive.bind( + return DgatedActLuCastTransposePrimitive.outer_primitive.bind( dz, x, amax, scale, scale_inv, out_dtype=out_dtype, - static_axis_boundary=x_bdim), out_bdims + static_axis_boundary=x_bdim, + act_enum=act_enum), out_bdims @staticmethod - def infer_sharding_from_operands(out_dtype, static_axis_boundary, mesh, arg_infos, - result_infos): - del out_dtype, result_infos + def infer_sharding_from_operands(out_dtype, static_axis_boundary, act_enum, + mesh, arg_infos, result_infos): + del out_dtype, result_infos, act_enum x_spec = get_padded_spec(arg_infos[1]) out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) xt_spec = _multidim_transpose(x_spec, static_axis_boundary, -2) @@ -5978,7 +4504,8 @@ def infer_sharding_from_operands(out_dtype, static_axis_boundary, mesh, arg_info return (out_sharding, tranposed_out_sharding, amax_sharding) @staticmethod - def partition(out_dtype, static_axis_boundary, mesh, arg_infos, result_infos): + def partition(out_dtype, static_axis_boundary, act_enum, + mesh, arg_infos, result_infos): del result_infos x_spec = get_padded_spec(arg_infos[1]) casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) @@ -5990,36 +4517,41 @@ def partition(out_dtype, static_axis_boundary, mesh, arg_infos, result_infos): out_shardings = (casted_x_sharding, casted_transposed_x_sharding, amax_sharding) def sharded_impl(dz, x, amax, scale, scale_inv): - local_out, local_t_out, local_amax = DgatedSiluCastTransposePrimitive.impl( + local_out, local_t_out, local_amax = DgatedActLuCastTransposePrimitive.impl( dz, x, amax, scale, scale_inv, out_dtype=out_dtype, - static_axis_boundary=static_axis_boundary) + static_axis_boundary=static_axis_boundary, + act_enum=act_enum) global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax) return local_out, local_t_out, global_updated_amax return mesh, sharded_impl, out_shardings, arg_shardings -register_primitive(DgatedSiluCastTransposePrimitive) +register_primitive(DgatedActLuCastTransposePrimitive) -def dgated_silu_cast_transpose( - dz: jnp.ndarray, x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, - scale_inv: jnp.ndarray, out_dtype: TEDType, - static_axis_boundary: int) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: +def dgated_act_lu_cast_transpose( + dz: jnp.ndarray, x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, + scale_inv: jnp.ndarray, out_dtype: TEDType, + static_axis_boundary: int, + activation_type: Sequence[Union[str, Callable]] + ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """ - cast transpose d_gated_silu fusion wrapper - Return FP8(dgeglu(inputs)) + cast transpose d_gated_act_lu fusion wrapper + Return FP8(dgated_act_lu(inputs)) """ - return DgatedSiluCastTransposePrimitive.outer_primitive.bind( + act_type_id = ActivationEnum[activation_type] + return DgatedActLuCastTransposePrimitive.outer_primitive.bind( dz, x, amax, scale, scale_inv, out_dtype=out_dtype, - static_axis_boundary=static_axis_boundary) + static_axis_boundary=static_axis_boundary, + act_enum=act_type_id) diff --git a/transformer_engine/jax/csrc/extensions.cpp b/transformer_engine/jax/csrc/extensions.cpp index 7d3958879a..195665f9b8 100644 --- a/transformer_engine/jax/csrc/extensions.cpp +++ b/transformer_engine/jax/csrc/extensions.cpp @@ -25,25 +25,14 @@ pybind11::dict Registrations() { pybind11::dict dict; dict["te_transpose"] = EncapsulateFunction(Transpose); dict["te_cast_transpose"] = EncapsulateFunction(CastTranspose); - dict["te_gelu"] = EncapsulateFunction(Gelu); - dict["te_gelu_fp8"] = EncapsulateFunction(GeluFP8); - dict["te_dgelu"] = EncapsulateFunction(DGelu); - dict["te_dgelu_dbias_cast_transpose"] = EncapsulateFunction(DGeluDBiasCastTranspose); + + dict["te_act_lu"] = EncapsulateFunction(ActLu); + dict["te_act_lu_fp8"] = EncapsulateFunction(ActLuFP8); + dict["te_dact_lu"] = EncapsulateFunction(DActLu); dict["te_dbias_cast_transpose"] = EncapsulateFunction(DBiasCastTranspose); - dict["te_gated_gelu"] = EncapsulateFunction(GatedGelu); - dict["te_gated_gelu_fp8"] = EncapsulateFunction(GatedGeluFP8); - dict["te_dgated_gelu"] = EncapsulateFunction(DGatedGelu); - dict["te_dgated_gelu_cast_transpose"] = EncapsulateFunction(DGatedGeluCastTranspose); - // TODO - dict["te_silu"] = EncapsulateFunction(Silu); - dict["te_silu_fp8"] = EncapsulateFunction(SiluFP8); - dict["te_dsilu"] = EncapsulateFunction(DSilu); - dict["te_dsilu_dbias_cast_transpose"] = EncapsulateFunction(DSiluDBiasCastTranspose); - dict["te_gated_silu"] = EncapsulateFunction(GatedSilu); - dict["te_gated_silu_fp8"] = EncapsulateFunction(GatedSiluFP8); - dict["te_dgated_silu"] = EncapsulateFunction(DGatedSilu); - dict["te_dgated_silu_cast_transpose"] = EncapsulateFunction(DGatedSiluCastTranspose); - // + dict["te_dact_lu_dbias_cast_transpose"] = EncapsulateFunction(DActLuDBiasCastTranspose); + dict["te_dgated_act_lu_cast_transpose"] = EncapsulateFunction(DGatedActLuCastTranspose); + dict["te_layernorm_forward"] = EncapsulateFunction(LayerNormForward); dict["te_layernorm_forward_fp8"] = EncapsulateFunction(LayerNormForwardFP8); dict["te_layernorm_backward"] = EncapsulateFunction(LayerNormBackward); @@ -67,8 +56,11 @@ pybind11::dict Registrations() { PYBIND11_MODULE(transformer_engine_jax, m) { m.def("registrations", &Registrations); - m.def("pack_common_descriptor", &PackCustomCallCommonDescriptor); - m.def("pack_common_wk_descriptor", &PackCustomCallCommonWkDescriptor); + m.def("pack_common_descriptor", &PackCustomCallCommonDescriptor, + pybind11::arg(), pybind11::arg(), pybind11::arg(), pybind11::arg("act_num") = 0); + m.def("pack_common_wk_descriptor", &PackCustomCallCommonWkDescriptor, + pybind11::arg(), pybind11::arg(), pybind11::arg(), + pybind11::arg(), pybind11::arg(), pybind11::arg("act_num") = 0); m.def("pack_norm_descriptor", &PackCustomCallNormDescriptor); m.def("pack_softmax_descriptor", &PackCustomCallSoftmaxDescriptor); m.def("pack_fused_attn_descriptor", &PackCustomCallFusedAttnDescriptor); @@ -109,6 +101,12 @@ PYBIND11_MODULE(transformer_engine_jax, m) { .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD) .value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD); + pybind11::enum_(m, "NVTE_Activation_Enum", pybind11::module_local()) + .value("GELU", NVTE_Activation_Enum::GELU) + .value("GEGLU", NVTE_Activation_Enum::GEGLU) + .value("SILU", NVTE_Activation_Enum::SILU) + .value("SWIGLU", NVTE_Activation_Enum::SWIGLU); + pybind11::enum_(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend) .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) diff --git a/transformer_engine/jax/csrc/modules.cpp b/transformer_engine/jax/csrc/modules.cpp index 78e9f60e3f..fb3d21a124 100644 --- a/transformer_engine/jax/csrc/modules.cpp +++ b/transformer_engine/jax/csrc/modules.cpp @@ -37,6 +37,19 @@ std::vector MakeShapeVector(NVTEShape shape) { return std::vector(shape.data, shape.data + shape.ndim); } +size_t get_activation_len(NVTE_Activation_Enum act_enum) { + switch (act_enum) { + case NVTE_Activation_Enum::GELU: return 1; + case NVTE_Activation_Enum::GEGLU: return 2; + case NVTE_Activation_Enum::SILU: return 1; + case NVTE_Activation_Enum::SWIGLU: return 2; + default: + NVTE_ERROR("Unsupported ActivationEnum"); + break; + return -1; + } +} + template pybind11::bytes PackOpaque(const T &descriptor) { auto str = std::string(reinterpret_cast(&descriptor), sizeof(T)); @@ -52,23 +65,26 @@ const T *UnpackOpaque(const char *opaque, size_t opaque_len) { } pybind11::bytes PackCustomCallCommonDescriptor(const std::vector &shape, DType in_dtype, - DType out_dtype) { + DType out_dtype, size_t act_enum) { CustomCallCommonDescriptor desc; desc.shape.from_vector(shape); desc.in_dtype = in_dtype; desc.out_dtype = out_dtype; + desc.act_enum = act_enum; return PackOpaque(desc); } pybind11::bytes PackCustomCallCommonWkDescriptor(const std::vector &shape, const std::vector &wkshape, DType in_dtype, - DType out_dtype, DType wk_dtype) { + DType out_dtype, DType wk_dtype, + size_t act_enum) { CustomCallCommonWkDescriptor desc; desc.shape.from_vector(shape); desc.wkshape.from_vector(wkshape); desc.in_dtype = in_dtype; desc.out_dtype = out_dtype; desc.wk_dtype = wk_dtype; + desc.act_enum = act_enum; return PackOpaque(desc); } @@ -170,31 +186,50 @@ void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size input_cast_trans_tensor.data(), stream); } -void GeluImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype, float *scale, - cudaStream_t stream, float *scale_inverse, float *amax, void *output) { - auto input_shape = std::vector{m, n}; +void ActLuImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype, float *scale, + cudaStream_t stream, float *scale_inverse, float *amax, void *output, + NVTE_Activation_Enum act_enum) { + auto act_len = get_activation_len(act_enum); + auto input_shape = std::vector{m, n * act_len}; auto output_shape = std::vector{m, n}; - - auto input_tensor = TensorWrapper(input, input_shape, static_cast(in_dtype)); - - auto output_tensor = TensorWrapper(output, output_shape, static_cast(out_dtype), amax, + auto input_tensor = TensorWrapper(input, input_shape, + static_cast(in_dtype)); + auto output_tensor = TensorWrapper(output, output_shape, + static_cast(out_dtype), amax, scale, scale_inverse); - - nvte_gelu(input_tensor.data(), output_tensor.data(), stream); + switch (act_enum) { + case NVTE_Activation_Enum::GELU: + nvte_gelu(input_tensor.data(), output_tensor.data(), stream); + break; + case NVTE_Activation_Enum::GEGLU: + nvte_geglu(input_tensor.data(), output_tensor.data(), stream); + break; + case NVTE_Activation_Enum::SILU: + nvte_swish(input_tensor.data(), output_tensor.data(), stream); + break; + case NVTE_Activation_Enum::SWIGLU: + nvte_swiglu(input_tensor.data(), output_tensor.data(), stream); + break; + default: + NVTE_ERROR("Unsupported ActivationEnum"); + break; + } } -void Gelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { +void ActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { auto *input = buffers[0]; auto *output = buffers[1]; const auto &desc = *UnpackOpaque(opaque, opaque_len); auto m = desc.shape.dims[0]; auto n = desc.shape.dims[1]; + auto act_enum = static_cast(desc.act_enum);; - GeluImpl(input, m, n, desc.in_dtype, desc.out_dtype, nullptr, stream, nullptr, nullptr, output); + ActLuImpl(input, m, n, desc.in_dtype, desc.out_dtype, nullptr, stream, + nullptr, nullptr, output, act_enum); } -void GeluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { +void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { auto *input = buffers[0]; float *amax = reinterpret_cast(buffers[1]); float *scale = reinterpret_cast(buffers[2]); @@ -211,107 +246,91 @@ void GeluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opa } auto m = desc.shape.dims[0]; auto n = desc.shape.dims[1]; + auto act_enum = static_cast(desc.act_enum);; - GeluImpl(input, m, n, desc.in_dtype, desc.out_dtype, scale, stream, scale_inv, amax_out, - output); + ActLuImpl(input, m, n, desc.in_dtype, desc.out_dtype, scale, stream, + scale_inv, amax_out, output, act_enum); } -void DGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { +void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { auto *input = buffers[0]; - auto *gelu_input = buffers[1]; + auto *act_input = buffers[1]; auto *output = buffers[2]; const auto &desc = *UnpackOpaque(opaque, opaque_len); auto m = desc.shape.dims[0]; auto n = desc.shape.dims[1]; + auto act_enum = static_cast(desc.act_enum);; + + auto act_len = get_activation_len(act_enum); auto input_shape = std::vector{m, n}; - auto gelu_input_shape = std::vector{m, n}; - auto output_shape = std::vector{m, n}; + auto act_input_shape = std::vector{m, n * act_len}; + auto output_shape = std::vector{m, n * act_len}; auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype); - auto gelu_input_tensor = TensorWrapper(gelu_input, gelu_input_shape, desc.in_dtype); + auto act_input_tensor = TensorWrapper(act_input, act_input_shape, desc.in_dtype); auto output_tensor = TensorWrapper(output, output_shape, desc.out_dtype); - nvte_dgelu(input_tensor.data(), gelu_input_tensor.data(), output_tensor.data(), stream); -} - -void DGeluDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, - size_t opaque_len) { - auto *input = buffers[0]; - auto *gelu_input = buffers[1]; - float *amax = reinterpret_cast(buffers[2]); - float *scale = reinterpret_cast(buffers[3]); - float *scale_inv = reinterpret_cast(buffers[4]); - auto *output = buffers[5]; - auto *output_trans = buffers[6]; - auto *dbias = buffers[7]; - float *amax_out = reinterpret_cast(buffers[8]); - void *workspace_ptr = buffers[9]; - - const auto &desc = *UnpackOpaque(opaque, opaque_len); - assert(amax == amax_out); - if (!use_fp8(desc.out_dtype)) { - scale = nullptr; - scale_inv = nullptr; - amax_out = nullptr; + switch (act_enum) { + case NVTE_Activation_Enum::GELU: + nvte_dgelu(input_tensor.data(), act_input_tensor.data(), + output_tensor.data(), stream); + break; + case NVTE_Activation_Enum::GEGLU: + nvte_dgeglu(input_tensor.data(), act_input_tensor.data(), + output_tensor.data(), stream); + break; + case NVTE_Activation_Enum::SILU: + nvte_dswish(input_tensor.data(), act_input_tensor.data(), + output_tensor.data(), stream); + break; + case NVTE_Activation_Enum::SWIGLU: + nvte_dswiglu(input_tensor.data(), act_input_tensor.data(), + output_tensor.data(), stream); + break; + default: + NVTE_ERROR("Unsupported ActivationEnum"); + break; } - auto m = desc.shape.dims[0]; - auto n = desc.shape.dims[1]; - auto input_shape = std::vector{m, n}; - auto gelu_input_shape = std::vector{m, n}; - auto output_shape = std::vector{m, n}; - auto output_trans_shape = std::vector{n, m}; - auto dbias_shape = std::vector{n}; - - auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype); - auto gelu_input_tensor = TensorWrapper(gelu_input, gelu_input_shape, desc.in_dtype); - auto output_tensor = - TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv); - auto output_trans_tensor = - TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv); - auto dbias_tensor = TensorWrapper(dbias, dbias_shape, desc.in_dtype); - - auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype); - - nvte_cast_transpose_dbias_dgelu(input_tensor.data(), gelu_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), - dbias_tensor.data(), workspace.data(), stream); } -// HERE -pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size, +pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType out_dtype) { auto input_shape = std::vector{batch_size, hidden_size}; + auto dact_input_shape = std::vector{batch_size, hidden_size}; auto output_shape = std::vector{batch_size, hidden_size}; auto output_trans_shape = std::vector{hidden_size, batch_size}; auto dbias_shape = std::vector{hidden_size}; auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype); + auto dact_input_tensor = TensorWrapper(nullptr, dact_input_shape, in_dtype); auto output_tensor = TensorWrapper(nullptr, output_shape, out_dtype); auto output_trans_tensor = TensorWrapper(nullptr, output_trans_shape, out_dtype); auto dbias_tensor = TensorWrapper(nullptr, dbias_shape, in_dtype); TensorWrapper dummy_workspace; - nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(), - output_trans_tensor.data(), dbias_tensor.data(), - dummy_workspace.data(), nullptr); + // For now, all dbias_dact(-s) have the same workspace size + nvte_cast_transpose_dbias_dgelu(input_tensor.data(), dact_input_tensor.data(), + output_tensor.data(), output_trans_tensor.data(), + dbias_tensor.data(), dummy_workspace.data(), nullptr); auto work_shape = MakeShapeVector(dummy_workspace.shape()); return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype())); } -void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, +void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { auto *input = buffers[0]; - float *amax = reinterpret_cast(buffers[1]); - float *scale = reinterpret_cast(buffers[2]); - float *scale_inv = reinterpret_cast(buffers[3]); - auto *output = buffers[4]; - auto *output_trans = buffers[5]; - auto *dbias = buffers[6]; - float *amax_out = reinterpret_cast(buffers[7]); - void *workspace_ptr = buffers[8]; + auto *act_input = buffers[1]; + float *amax = reinterpret_cast(buffers[2]); + float *scale = reinterpret_cast(buffers[3]); + float *scale_inv = reinterpret_cast(buffers[4]); + auto *output = buffers[5]; + auto *output_trans = buffers[6]; + auto *dbias = buffers[7]; + float *amax_out = reinterpret_cast(buffers[8]); + void *workspace_ptr = buffers[9]; const auto &desc = *UnpackOpaque(opaque, opaque_len); assert(amax == amax_out); @@ -322,12 +341,15 @@ void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, } auto m = desc.shape.dims[0]; auto n = desc.shape.dims[1]; + auto act_enum = static_cast(desc.act_enum);; auto input_shape = std::vector{m, n}; + auto act_input_shape = std::vector{m, n}; auto output_shape = std::vector{m, n}; auto output_trans_shape = std::vector{n, m}; auto dbias_shape = std::vector{n}; auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype); + auto act_input_tensor = TensorWrapper(act_input, act_input_shape, desc.in_dtype); auto output_tensor = TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv); auto output_trans_tensor = @@ -336,81 +358,27 @@ void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype); - nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(), - output_trans_tensor.data(), dbias_tensor.data(), - workspace.data(), stream); -} - -void GatedGeluImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype, float *scale, - cudaStream_t stream, float *scale_inverse, float *amax, void *output) { - auto input_shape = std::vector{m, n * 2}; - auto output_shape = std::vector{m, n}; - - auto input_tensor = TensorWrapper(input, input_shape, static_cast(in_dtype)); - - auto output_tensor = TensorWrapper(output, output_shape, static_cast(out_dtype), amax, - scale, scale_inverse); - - nvte_geglu(input_tensor.data(), output_tensor.data(), stream); -} - -void GatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { - auto *input = buffers[0]; - auto *output = buffers[1]; - - const auto &desc = *UnpackOpaque(opaque, opaque_len); - auto m = desc.shape.dims[0]; - auto n = desc.shape.dims[1]; - - GatedGeluImpl(input, m, n, desc.in_dtype, desc.out_dtype, nullptr, stream, nullptr, nullptr, - output); -} - -void GatedGeluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { - auto *input = buffers[0]; - float *amax = reinterpret_cast(buffers[1]); - float *scale = reinterpret_cast(buffers[2]); - float *scale_inv = reinterpret_cast(buffers[3]); - auto *output = buffers[4]; - float *amax_out = reinterpret_cast(buffers[5]); - assert(amax == amax_out); - - const auto &desc = *UnpackOpaque(opaque, opaque_len); - if (!use_fp8(desc.out_dtype)) { - scale = nullptr; - scale_inv = nullptr; - amax_out = nullptr; + switch (act_enum) { + case NVTE_Activation_Enum::GELU: + nvte_cast_transpose_dbias_dgelu(input_tensor.data(), act_input_tensor.data(), + output_tensor.data(), output_trans_tensor.data(), + dbias_tensor.data(), workspace.data(), stream); + break; + case NVTE_Activation_Enum::SILU: + nvte_cast_transpose_dbias_dswish(input_tensor.data(), act_input_tensor.data(), + output_tensor.data(), output_trans_tensor.data(), + dbias_tensor.data(), workspace.data(), stream); + break; + default: + throw std::runtime_error("Activation Type is not Implemented in DActLuDBiasCastTranspose"); + break; } - auto m = desc.shape.dims[0]; - auto n = desc.shape.dims[1]; - - GatedGeluImpl(input, m, n, desc.in_dtype, desc.out_dtype, scale, stream, scale_inv, amax_out, - output); -} - -void DGatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { - auto *input = buffers[0]; - auto *gelu_input = buffers[1]; - auto *output = buffers[2]; - - const auto &desc = *UnpackOpaque(opaque, opaque_len); - auto m = desc.shape.dims[0]; - auto n = desc.shape.dims[1]; - auto input_shape = std::vector{m, n}; - auto gelu_input_shape = std::vector{m, n * 2}; - auto output_shape = std::vector{m, n * 2}; - - auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype); - auto gelu_input_tensor = TensorWrapper(gelu_input, gelu_input_shape, desc.in_dtype); - auto output_tensor = TensorWrapper(output, output_shape, desc.out_dtype); - - nvte_dgeglu(input_tensor.data(), gelu_input_tensor.data(), output_tensor.data(), stream); } -void DGatedGeluCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, +void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { auto *input = buffers[0]; - auto *gelu_input = buffers[1]; + auto *act_input = buffers[1]; float *amax = reinterpret_cast(buffers[2]); float *scale = reinterpret_cast(buffers[3]); float *scale_inv = reinterpret_cast(buffers[4]); @@ -427,124 +395,69 @@ void DGatedGeluCastTranspose(cudaStream_t stream, void **buffers, const char *op } auto m = desc.shape.dims[0]; auto n = desc.shape.dims[1]; + auto act_enum = static_cast(desc.act_enum);; auto input_shape = desc.shape.to_vector(); - auto gelu_input_shape = std::vector{m, n * 2}; + auto act_input_shape = std::vector{m, n * 2}; auto output_shape = std::vector{m, n * 2}; auto output_trans_shape = std::vector{n * 2, m}; auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype); - auto gelu_input_tensor = TensorWrapper(gelu_input, gelu_input_shape, desc.in_dtype); + auto act_input_tensor = TensorWrapper(act_input, act_input_shape, desc.in_dtype); auto output_tensor = TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv); auto output_trans_tensor = TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv); - nvte_dgeglu_cast_transpose(input_tensor.data(), gelu_input_tensor.data(), output_tensor.data(), - output_trans_tensor.data(), stream); -} - -void SiluImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype, float *scale, - cudaStream_t stream, float *scale_inverse, float *amax, void *output) { - auto input_shape = std::vector{m, n}; - auto output_shape = std::vector{m, n}; - - auto input_tensor = TensorWrapper(input, input_shape, static_cast(in_dtype)); - - auto output_tensor = TensorWrapper(output, output_shape, static_cast(out_dtype), amax, - scale, scale_inverse); - - nvte_swish(input_tensor.data(), output_tensor.data(), stream); -} - -void Silu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { - auto *input = buffers[0]; - auto *output = buffers[1]; - - const auto &desc = *UnpackOpaque(opaque, opaque_len); - auto m = desc.shape.dims[0]; - auto n = desc.shape.dims[1]; - - SiluImpl(input, m, n, desc.in_dtype, desc.out_dtype, nullptr, stream, nullptr, nullptr, output); -} - -void SiluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { - auto *input = buffers[0]; - float *amax = reinterpret_cast(buffers[1]); - float *scale = reinterpret_cast(buffers[2]); - float *scale_inv = reinterpret_cast(buffers[3]); - auto *output = buffers[4]; - float *amax_out = reinterpret_cast(buffers[5]); - assert(amax == amax_out); - - const auto &desc = *UnpackOpaque(opaque, opaque_len); - if (!use_fp8(desc.out_dtype)) { - scale = nullptr; - scale_inv = nullptr; - amax_out = nullptr; + switch (act_enum) { + case NVTE_Activation_Enum::GEGLU: + nvte_dgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), + output_tensor.data(), output_trans_tensor.data(), + stream); + break; + case NVTE_Activation_Enum::SWIGLU: + nvte_dswiglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), + output_tensor.data(), output_trans_tensor.data(), + stream); + break; + default: + NVTE_ERROR("Unsupported ActivationEnum"); + break; } - auto m = desc.shape.dims[0]; - auto n = desc.shape.dims[1]; - - SiluImpl(input, m, n, desc.in_dtype, desc.out_dtype, scale, stream, scale_inv, amax_out, - output); -} - -void DSilu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { - auto *input = buffers[0]; - auto *silu_input = buffers[1]; - auto *output = buffers[2]; - - const auto &desc = *UnpackOpaque(opaque, opaque_len); - auto m = desc.shape.dims[0]; - auto n = desc.shape.dims[1]; - auto input_shape = std::vector{m, n}; - auto silu_input_shape = std::vector{m, n}; - auto output_shape = std::vector{m, n}; - - auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype); - auto silu_input_tensor = TensorWrapper(silu_input, silu_input_shape, desc.in_dtype); - auto output_tensor = TensorWrapper(output, output_shape, desc.out_dtype); - - nvte_dswish(input_tensor.data(), silu_input_tensor.data(), output_tensor.data(), stream); } -pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size, +pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType out_dtype) { auto input_shape = std::vector{batch_size, hidden_size}; - auto dact_input_shape = std::vector{batch_size, hidden_size}; auto output_shape = std::vector{batch_size, hidden_size}; auto output_trans_shape = std::vector{hidden_size, batch_size}; auto dbias_shape = std::vector{hidden_size}; auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype); - auto dact_input_tensor = TensorWrapper(nullptr, dact_input_shape, in_dtype); auto output_tensor = TensorWrapper(nullptr, output_shape, out_dtype); auto output_trans_tensor = TensorWrapper(nullptr, output_trans_shape, out_dtype); auto dbias_tensor = TensorWrapper(nullptr, dbias_shape, in_dtype); TensorWrapper dummy_workspace; - // For now, all dbias_dact(-s) have the same workspace size - nvte_cast_transpose_dbias_dgelu(input_tensor.data(), dact_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), - dbias_tensor.data(), dummy_workspace.data(), nullptr); + nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(), + output_trans_tensor.data(), dbias_tensor.data(), + dummy_workspace.data(), nullptr); auto work_shape = MakeShapeVector(dummy_workspace.shape()); return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype())); } -void DSiluDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, +void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { auto *input = buffers[0]; - auto *silu_input = buffers[1]; - float *amax = reinterpret_cast(buffers[2]); - float *scale = reinterpret_cast(buffers[3]); - float *scale_inv = reinterpret_cast(buffers[4]); - auto *output = buffers[5]; - auto *output_trans = buffers[6]; - auto *dbias = buffers[7]; - float *amax_out = reinterpret_cast(buffers[8]); - void *workspace_ptr = buffers[9]; + float *amax = reinterpret_cast(buffers[1]); + float *scale = reinterpret_cast(buffers[2]); + float *scale_inv = reinterpret_cast(buffers[3]); + auto *output = buffers[4]; + auto *output_trans = buffers[5]; + auto *dbias = buffers[6]; + float *amax_out = reinterpret_cast(buffers[7]); + void *workspace_ptr = buffers[8]; const auto &desc = *UnpackOpaque(opaque, opaque_len); assert(amax == amax_out); @@ -556,13 +469,11 @@ void DSiluDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *op auto m = desc.shape.dims[0]; auto n = desc.shape.dims[1]; auto input_shape = std::vector{m, n}; - auto silu_input_shape = std::vector{m, n}; auto output_shape = std::vector{m, n}; auto output_trans_shape = std::vector{n, m}; auto dbias_shape = std::vector{n}; auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype); - auto silu_input_tensor = TensorWrapper(silu_input, silu_input_shape, desc.in_dtype); auto output_tensor = TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv); auto output_trans_tensor = @@ -571,111 +482,9 @@ void DSiluDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *op auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype); - nvte_cast_transpose_dbias_dswish(input_tensor.data(), silu_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), - dbias_tensor.data(), workspace.data(), stream); -} - -void GatedSiluImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype, float *scale, - cudaStream_t stream, float *scale_inverse, float *amax, void *output) { - auto input_shape = std::vector{m, n * 2}; - auto output_shape = std::vector{m, n}; - - auto input_tensor = TensorWrapper(input, input_shape, static_cast(in_dtype)); - - auto output_tensor = TensorWrapper(output, output_shape, static_cast(out_dtype), amax, - scale, scale_inverse); - - nvte_swiglu(input_tensor.data(), output_tensor.data(), stream); -} - -void GatedSilu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { - auto *input = buffers[0]; - auto *output = buffers[1]; - - const auto &desc = *UnpackOpaque(opaque, opaque_len); - auto m = desc.shape.dims[0]; - auto n = desc.shape.dims[1]; - - GatedSiluImpl(input, m, n, desc.in_dtype, desc.out_dtype, nullptr, stream, nullptr, nullptr, - output); -} - -void GatedSiluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { - auto *input = buffers[0]; - float *amax = reinterpret_cast(buffers[1]); - float *scale = reinterpret_cast(buffers[2]); - float *scale_inv = reinterpret_cast(buffers[3]); - auto *output = buffers[4]; - float *amax_out = reinterpret_cast(buffers[5]); - assert(amax == amax_out); - - const auto &desc = *UnpackOpaque(opaque, opaque_len); - if (!use_fp8(desc.out_dtype)) { - scale = nullptr; - scale_inv = nullptr; - amax_out = nullptr; - } - auto m = desc.shape.dims[0]; - auto n = desc.shape.dims[1]; - - GatedSiluImpl(input, m, n, desc.in_dtype, desc.out_dtype, scale, stream, scale_inv, amax_out, - output); -} - -void DGatedSilu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { - auto *input = buffers[0]; - auto *silu_input = buffers[1]; - auto *output = buffers[2]; - - const auto &desc = *UnpackOpaque(opaque, opaque_len); - auto m = desc.shape.dims[0]; - auto n = desc.shape.dims[1]; - auto input_shape = std::vector{m, n}; - auto silu_input_shape = std::vector{m, n * 2}; - auto output_shape = std::vector{m, n * 2}; - - auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype); - auto silu_input_tensor = TensorWrapper(silu_input, silu_input_shape, desc.in_dtype); - auto output_tensor = TensorWrapper(output, output_shape, desc.out_dtype); - - nvte_dswiglu(input_tensor.data(), silu_input_tensor.data(), output_tensor.data(), stream); -} - -void DGatedSiluCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, - size_t opaque_len) { - auto *input = buffers[0]; - auto *silu_input = buffers[1]; - float *amax = reinterpret_cast(buffers[2]); - float *scale = reinterpret_cast(buffers[3]); - float *scale_inv = reinterpret_cast(buffers[4]); - auto *output = buffers[5]; - auto *output_trans = buffers[6]; - float *amax_out = reinterpret_cast(buffers[7]); - - const auto &desc = *UnpackOpaque(opaque, opaque_len); - assert(amax == amax_out); - if (!use_fp8(desc.out_dtype)) { - scale = nullptr; - scale_inv = nullptr; - amax_out = nullptr; - } - auto m = desc.shape.dims[0]; - auto n = desc.shape.dims[1]; - auto input_shape = desc.shape.to_vector(); - auto silu_input_shape = std::vector{m, n * 2}; - auto output_shape = std::vector{m, n * 2}; - auto output_trans_shape = std::vector{n * 2, m}; - - auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype); - auto silu_input_tensor = TensorWrapper(silu_input, silu_input_shape, desc.in_dtype); - auto output_tensor = - TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv); - auto output_trans_tensor = - TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv); - - nvte_dswiglu_cast_transpose(input_tensor.data(), silu_input_tensor.data(), output_tensor.data(), - output_trans_tensor.data(), stream); + nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(), + output_trans_tensor.data(), dbias_tensor.data(), + workspace.data(), stream); } pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, diff --git a/transformer_engine/jax/csrc/modules.h b/transformer_engine/jax/csrc/modules.h index ac14a54e90..c3b950fbda 100644 --- a/transformer_engine/jax/csrc/modules.h +++ b/transformer_engine/jax/csrc/modules.h @@ -43,14 +43,24 @@ struct Shape { } }; +enum class NVTE_Activation_Enum { + GELU, + GEGLU, + SILU, + SWIGLU, +}; + +size_t get_activation_len(NVTE_Activation_Enum act_enum); + struct CustomCallCommonDescriptor { Shape shape; DType in_dtype; DType out_dtype; + size_t act_enum; }; pybind11::bytes PackCustomCallCommonDescriptor(const std::vector &shape, DType in_dtype, - DType out_dtype); + DType out_dtype, size_t act_enum = 0); struct CustomCallCommonWkDescriptor { Shape shape; @@ -58,11 +68,13 @@ struct CustomCallCommonWkDescriptor { DType in_dtype; DType out_dtype; DType wk_dtype; + size_t act_enum; }; pybind11::bytes PackCustomCallCommonWkDescriptor(const std::vector &shape, - const std::vector &wkshape, DType in_dtype, - DType out_dtype, DType wk_dtype); + const std::vector &wkshape, + DType in_dtype, DType out_dtype, DType wk_dtype, + size_t act_enum = 0); struct CustomCallNormDescriptor { size_t batch_size; @@ -140,17 +152,16 @@ void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t o void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); -// TODO (Phuong): Templating these 9x2 rountines before adding ReGLU, QuickGeLU, Squared ReLu -void Gelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); +void ActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); -void GeluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); +void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); -void DGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); +void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType out_dtype); -void DGeluDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, +void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size, @@ -159,31 +170,7 @@ pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hi void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); -void GatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); - -void GatedGeluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); - -void DGatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); - -void DGatedGeluCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, - size_t opaque_len); - -void Silu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); - -void SiluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); - -void DSilu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); - -void DSiluDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, - size_t opaque_len); - -void GatedSilu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); - -void GatedSiluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); - -void DGatedSilu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); - -void DGatedSiluCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, +void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 66cf91c3de..19424b9f58 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -955,7 +955,6 @@ def __call__(self, inputs: Array, deterministic: bool = False) -> Array: normalize_acts = tuple(reversed(normalize_acts) if normalize_acts[0] == 'linear' else normalize_acts) - is_gated = normalize_acts in gated_act_pool is_act_implemented = normalize_acts in (gated_act_pool + act_pool) use_fused_layernorm_mlp = fuse_layernorm and is_act_implemented and\ @@ -1052,8 +1051,8 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): axes=self.bias_axes_2) bias_2 = bias_2.astype(self.dtype) else: - bias_1 = jnp.empty(0, self.dtype) - bias_2 = jnp.empty(0, self.dtype) + bias_1 = None + bias_2 = None out = fused_layernorm_fp8_mlp(y, scale, @@ -1134,7 +1133,6 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): x += jnp.reshape(bias_1, bias_1_shape) x = checkpoint_name(x, ffn1_ckpt_name) - activations = [] if is_act_implemented: z = activation_lu(x, normalize_acts) @@ -1144,8 +1142,8 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): x_i = _convert_to_activation_function(act_fn)(x[idx]) activations.append(x_i) z = functools.reduce(operator.mul, activations) - if not is_gated: - z = jnp.reshape(z, (*z.shape[:-2], -1)) + if num_activations == 1: + z = jnp.reshape(z, (*z.shape[:-2], -1)) z = nn.Dropout(rate=self.intermediate_dropout_rate, broadcast_dims=self.intermediate_hidden_dropout_dims, diff --git a/transformer_engine/jax/mlp.py b/transformer_engine/jax/mlp.py index a9761499c0..468e51dc79 100644 --- a/transformer_engine/jax/mlp.py +++ b/transformer_engine/jax/mlp.py @@ -11,14 +11,8 @@ from jax.ad_checkpoint import checkpoint_name from .cpp_extensions import cast_fp8, transpose, cast_transpose, dbias_cast_transpose -from .cpp_extensions import gelu -from .cpp_extensions import gelu_fp8, dgelu, dgelu_dbias_cast_transpose -from .cpp_extensions import gated_gelu, gated_gelu_fp8 -from .cpp_extensions import dgated_gelu, dgated_gelu_cast_transpose -from .cpp_extensions import silu, silu_fp8 -from .cpp_extensions import dsilu, dsilu_dbias_cast_transpose -from .cpp_extensions import gated_silu, gated_silu_fp8 -from .cpp_extensions import dgated_silu, dgated_silu_cast_transpose +from .cpp_extensions import act_lu, act_lu_fp8, dact_lu +from .cpp_extensions import dact_lu_dbias_cast_transpose, dgated_act_lu_cast_transpose from .cpp_extensions import rmsnorm_fwd_fp8, rmsnorm_bwd from .cpp_extensions import layernorm_fwd_fp8, layernorm_bwd from .dot import fp8_dot_impl, get_precision_of_fp8_dot, quantize, dequantize @@ -26,44 +20,6 @@ from .fp8 import FP8Helper, FP8MetaPackage from .sharding import with_sharding_constraint_by_logical_axes -activation_dict = { - ('gelu',): { - 'fwd': gelu, - "bwd": dgelu - }, - ('gelu', 'linear'): { - 'fwd': gated_gelu, - 'bwd': dgated_gelu - }, - ('silu',): { - 'fwd': silu, - "bwd": dsilu - }, - ('silu', 'linear'): { - 'fwd': gated_silu, - 'bwd': dgated_silu - } -} - -activation_fp8_dict = { - ('gelu',): { - 'fwd': gelu_fp8, - 'bwd': dgelu_dbias_cast_transpose - }, - ('gelu', 'linear'): { - 'fwd': gated_gelu_fp8, - 'bwd': dgated_gelu_cast_transpose - }, - ('silu',): { - 'fwd': silu_fp8, - 'bwd': dsilu_dbias_cast_transpose - }, - ('silu', 'linear'): { - 'fwd': gated_silu_fp8, - 'bwd': dgated_silu_cast_transpose - } -} - def activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]): """ @@ -84,7 +40,7 @@ def _activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable def _activation_lu_fwd_rule(x, activation_type): - fwd_output = activation_dict[activation_type]["fwd"](x) + fwd_output = act_lu(x, activation_type) return fwd_output, (x,) @@ -92,7 +48,7 @@ def _activation_lu_bwd_rule(activation_type, ctx, g): x, = ctx assert x.dtype == g.dtype - dx = activation_dict[activation_type]["bwd"](g, x) + dx = dact_lu(g, x, activation_type) dx = jnp.reshape(dx, x.shape) return (dx,) @@ -106,7 +62,7 @@ def activation_lu_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, sca """ Activation Unit """ - transpose_indices = (1, 2, 0) if len(activation_type) > 1 else (2, 0, 1) + transpose_indices = (1, 2, 0) dx_trans_no_use = jnp.empty([x.shape[i] for i in transpose_indices], dtype=x.dtype) dbias_no_use = jnp.empty(x.shape[-1], dtype=x.dtype) @@ -127,19 +83,15 @@ def _activation_lu_fp8(x: jnp.ndarray, dx_trans_no_use: jnp.ndarray, dbias_no_us return output -def _activation_lu_fp8_fwd_rule( - x, - dx_trans_no_use, # pylint: disable=unused-argument - dbias_no_use, # pylint: disable=unused-argument - amax, - scale, - scale_inv, - fwd_dtype, - bwd_dtype, # pylint: disable=unused-argument - activation_type): - activation_lu_out, _ = activation_fp8_dict[activation_type]["fwd"](x, amax, scale, scale_inv, - fwd_dtype) - +def _activation_lu_fp8_fwd_rule(x, + dx_trans_no_use, # pylint: disable=unused-argument + dbias_no_use, # pylint: disable=unused-argument + amax, + scale, scale_inv, + fwd_dtype, bwd_dtype, # pylint: disable=unused-argument + activation_type): + activation_lu_out, _ = act_lu_fp8(x, amax, scale, scale_inv, fwd_dtype, + activation_type) activation_lu_out = dequantize(activation_lu_out, x.dtype, scale_inv) ctx = (x, amax, scale, scale_inv) return activation_lu_out, ctx @@ -153,14 +105,14 @@ def _activation_lu_fp8_bwd_rule( g): x, amax, scale, scale_inv = ctx - activation_lu_fp8_bwd = activation_fp8_dict[activation_type]["bwd"] - if len(activation_type) > 1: #gated, no bias + if len(activation_type) > 1: #gated, no bias dactivation_lu, dactivation_lu_trans, amax_out = \ - activation_lu_fp8_bwd(g, x, amax, scale, scale_inv, bwd_dtype, -1) + dgated_act_lu_cast_transpose(g, x, amax, scale, scale_inv, bwd_dtype, -1, activation_type) dbias = jnp.empty(x.shape[-1], x.dtype) - else: + else: #not gated, with bias dactivation_lu, dactivation_lu_trans, dbias, amax_out = \ - activation_lu_fp8_bwd(g, x, amax, scale, scale_inv, bwd_dtype, -1) + dact_lu_dbias_cast_transpose(g, x, amax, scale, scale_inv, bwd_dtype, + -1, -2, activation_type) dactivation_lu = dequantize(dactivation_lu, x.dtype, scale_inv) dactivation_lu_trans = dequantize(dactivation_lu_trans, x.dtype, scale_inv) ctx = (dactivation_lu, dactivation_lu_trans, dbias, amax_out, scale, scale_inv) @@ -262,7 +214,6 @@ def _fused_layernorm_fp8_mlp_fwd_rule( activation_type, use_bias): - is_gated = len(activation_type) > 1 # x should be in shape of (batch..., hidden) # Kernel_1 should be in shape of (Hidden_in, 1, Hidden_out) # Kernel_2 should be in shape of (Hidden_in, Hidden_out) @@ -276,15 +227,9 @@ def _fused_layernorm_fp8_mlp_fwd_rule( assert x.shape[x_contracting_dims[0]] == kernel_1.shape[0] assert kernel_1.shape[-1] == kernel_2.shape[0] - # Squeeze act axis - # (hidden_in, 1, hidden_out) -> (hidden_in, hidden_out) - if not is_gated: - kernel_1 = jnp.squeeze(kernel_1, axis=-2) - maybe_fm32_to_fp32, maybe_fp32_to_fm32 = \ FP8Helper.generate_fp8_meta_dtype_converter_pair(fp8_max, amax, scale, scale_inv) fp8_max, amax, scale, scale_inv = maybe_fm32_to_fp32(fp8_max, amax, scale, scale_inv) - scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale) amax = FP8Helper.update_amax_history(amax) @@ -337,8 +282,11 @@ def _fused_layernorm_fp8_mlp_fwd_rule( (x_contracting_dims, (0,)), get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP)) if use_bias: - bias_1_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1.shape - dot_1_output += jnp.reshape(bias_1, bias_1_shape) + bias_1_shape = bias_1.shape + bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape + dot_1_output += jnp.reshape(bias_1, bias_1_new_shape) + else: + bias_1_shape = None dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name) gemm2_x_idx, gemm2_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(1) @@ -347,12 +295,11 @@ def _fused_layernorm_fp8_mlp_fwd_rule( activation_lu_out_scale = scale[gemm2_x_idx] activation_lu_out_scale_inv = scale_inv[gemm2_x_idx] - activation_lu_fwd_fp8 = activation_fp8_dict[activation_type]["fwd"] # (batch..., hidden_in) -> (batch..., hidden) casted_activation_lu_out, updated_activation_lu_amax = \ - activation_lu_fwd_fp8(dot_1_output, activation_lu_out_amax, activation_lu_out_scale, - activation_lu_out_scale_inv, fwd_dtype) + act_lu_fp8(dot_1_output, activation_lu_out_amax, activation_lu_out_scale, + activation_lu_out_scale_inv, fwd_dtype, activation_type) casted_activation_lu_out = with_sharding_constraint_by_logical_axes( casted_activation_lu_out, dot_2_input_axes) @@ -370,15 +317,18 @@ def _fused_layernorm_fp8_mlp_fwd_rule( get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP)) if use_bias: - bias_2_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2.shape - dot_2_output += jnp.reshape(bias_2, bias_2_shape) + bias_2_shape = bias_2.shape + bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape + dot_2_output += jnp.reshape(bias_2, bias_2_new_shape) + else: + bias_2_shape = None dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name) ctx = (x, ln_out, mu, rsigma, gamma, dot_1_output, casted_activation_lu_out, casted_kernel_1, casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, updated_activation_lu_amax, updated_kernel_1_amax, updated_kernel_2_amax, - x_contracting_dims, xt_batch_dims, bias_1.shape, bias_2.shape, maybe_fp32_to_fm32) + x_contracting_dims, xt_batch_dims, bias_1_shape, bias_2_shape, maybe_fp32_to_fm32) return dot_2_output, ctx @@ -403,8 +353,6 @@ def _fused_layernorm_fp8_mlp_bwd_rule( updated_activation_lu_amax, updated_kernel_1_amax, updated_kernel_2_amax, \ x_contracting_dims, xt_batch_dims, bias_1_shape, bias_2_shape, maybe_fp32_to_fm32 = ctx - is_gated = len(activation_type) > 1 - gemm2_x_idx, gemm2_kernel_idx, gemm2_grad_idx = FP8Helper.get_fp8_meta_indices(1) grad_amax = amax[gemm2_grad_idx, 0:1] @@ -413,7 +361,6 @@ def _fused_layernorm_fp8_mlp_bwd_rule( # Since the sharding of outputs should be the same as dot_1's input grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes) - if use_bias: casted_grad, casted_grad_t, dbias_2, updated_grad_amax = \ dbias_cast_transpose(grad, grad_amax, grad_scale, @@ -427,7 +374,7 @@ def _fused_layernorm_fp8_mlp_bwd_rule( grad_scale_inv, bwd_dtype, static_axis_boundary=-1, transpose_axis_boundary=-1) - dbias_2 = jnp.empty(bias_2_shape, grad.dtype) + dbias_2 = None casted_activation_lu_out_t = transpose(casted_activation_lu_out, static_axis_boundary=-1, @@ -453,11 +400,9 @@ def _fused_layernorm_fp8_mlp_bwd_rule( dactivation_lu_scale = scale[gemm1_grad_idx] dactivation_lu_scale_inv = scale_inv[gemm1_grad_idx] - dactivation_lu_cast_transpose = activation_fp8_dict[activation_type]["bwd"] - dactivation_lu = activation_dict[activation_type]["bwd"](dgrad_2, dot_1_output) - - if is_gated: + if len(activation_type) > 1: # if gated if use_bias: + dactivation_lu = dact_lu(dgrad_2, dot_1_output, activation_type) casted_dactivation_lu, casted_dactivation_lu_t, dbias_1, updated_dactivation_lu_amax = \ dbias_cast_transpose( dactivation_lu, @@ -470,19 +415,20 @@ def _fused_layernorm_fp8_mlp_bwd_rule( dbias_1 = jnp.reshape(dbias_1, bias_1_shape) else: casted_dactivation_lu, casted_dactivation_lu_t, updated_dactivation_lu_amax = \ - dactivation_lu_cast_transpose( + dgated_act_lu_cast_transpose( dgrad_2, dot_1_output, dactivation_lu_amax, dactivation_lu_scale, dactivation_lu_scale_inv, bwd_dtype, - static_axis_boundary=-1) - dbias_1 = jnp.empty(bias_1_shape, bwd_dtype) + static_axis_boundary=-1, + activation_type=activation_type) + dbias_1 = None else: if use_bias: - casted_dactivation_lu, casted_dactivation_lu_t, dbias_1, updated_dactivation_lu_amax = \ - dactivation_lu_cast_transpose( + casted_dactivation_lu, casted_dactivation_lu_t, dbias_1, updated_dactivation_lu_amax=\ + dact_lu_dbias_cast_transpose( dgrad_2, dot_1_output, dactivation_lu_amax, @@ -490,9 +436,11 @@ def _fused_layernorm_fp8_mlp_bwd_rule( dactivation_lu_scale_inv, bwd_dtype, static_axis_boundary=-1, - transpose_axis_boundary=-1) + transpose_axis_boundary=-2, + activation_type=activation_type) dbias_1 = jnp.reshape(dbias_1, bias_1_shape) else: + dactivation_lu = dact_lu(dgrad_2, dot_1_output, activation_type) casted_dactivation_lu, casted_dactivation_lu_t, updated_dactivation_lu_amax = \ cast_transpose( dactivation_lu, @@ -501,28 +449,20 @@ def _fused_layernorm_fp8_mlp_bwd_rule( dactivation_lu_scale_inv, bwd_dtype, static_axis_boundary=-1, - transpose_axis_boundary=-1) - dbias_1 = jnp.empty(bias_1_shape, bwd_dtype) + transpose_axis_boundary=-2) + dbias_1 = None ln_out_t = transpose(ln_out, static_axis_boundary=-1, transpose_axis_boundary=-1) # (hidden, batch...) x (hidden, batch...) gemm1_x_scale_inv = scale_inv[gemm1_x_idx] - xt_batch_dims_2 = xt_batch_dims if not is_gated \ - else tuple(i + 1 for i in xt_batch_dims) + xt_batch_dims_2 = tuple(i + 1 for i in xt_batch_dims) wgrad_1 = fp8_dot_impl(ln_out_t, casted_dactivation_lu_t, gemm1_x_scale_inv, dactivation_lu_scale_inv, grad.dtype, (xt_batch_dims, xt_batch_dims_2), get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD)) - # Expand act axis to match the shape with the given kernel_1 - if not is_gated: - wgrad_1 = jnp.expand_dims(wgrad_1, axis=-2) - # (batch..., hidden_out) x (hidden_in, hidden_out) - if is_gated: - x_contracting_dims = ((min(x_contracting_dims),) + tuple(i + 1 for i in x_contracting_dims), - (1, 2)) - else: - x_contracting_dims = (x_contracting_dims, (1,)) + x_contracting_dims = ((min(x_contracting_dims),) + tuple( + i + 1 for i in x_contracting_dims), (1,2)) kernel_1_scale_inv = scale_inv[gemm1_kernel_idx] dgrad_1 = fp8_dot_impl(casted_dactivation_lu, casted_kernel_1, dactivation_lu_scale_inv, kernel_1_scale_inv, grad.dtype, x_contracting_dims, From 8e75d91368b1ec15f2b8ebd7153de11712e8f522 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Tue, 7 May 2024 10:44:29 -0700 Subject: [PATCH 050/244] [PyTorch] Update FP8 recipe test to handle recipe changes (#834) Update FP8 recipe test to handle recipe changes Signed-off-by: Tim Moon Signed-off-by: Pawel Gadzinski --- tests/pytorch/test_recipe.py | 128 +++++++++++++++++------------------ 1 file changed, 63 insertions(+), 65 deletions(-) diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index 92c7f26f59..2de849fdf2 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -29,7 +29,7 @@ def setup_class(cls) -> None: torch.manual_seed(seed) torch.cuda.manual_seed(seed) - @pytest.mark.parametrize("amax_history_len", [1, 31, 1024]) + @pytest.mark.parametrize("amax_history_len", [31, 1024]) @pytest.mark.parametrize("amax_compute_algo", ["max", "most_recent"]) @pytest.mark.parametrize("is_first_microbatch", [None, True, False]) def test_amax_and_scale_update( @@ -51,7 +51,10 @@ def test_amax_and_scale_update( ) with te.fp8_autocast(enabled=True, fp8_recipe=recipe): module = te.Linear(16, 16) - y = module(torch.zeros([16, 16], device="cuda")) + y = module( + torch.randn([16, 16], device="cuda"), + is_first_microbatch=True, + ) y.backward(torch.zeros_like(y)) # Get amax history and scaling factors @@ -67,101 +70,96 @@ def test_amax_and_scale_update( # Tweak amax history and scaling factors amax_history_forward.copy_(2 * torch.rand_like(amax_history_forward) + 0.5) - if amax_history_len > 1: - amax_history_forward[1, 0].fill_(3) + amax_history_forward[0, :].zero_() scale_forward.copy_(2 * torch.rand_like(scale_forward) + 0.5) scale_inv_forward.copy_(torch.reciprocal(scale_forward)) - amax_history_backward.copy_(2 * torch.rand_like(amax_history_backward) + 0.5) - scale_backward.copy_(2 * torch.rand_like(scale_backward) + 0.5) - scale_inv_backward.copy_(torch.reciprocal(scale_backward)) + amax_history_backward[0, :].zero_() # Expected amax history after update - ref_amax_history_forward = torch.roll(amax_history_forward, -1, dims=0) - ref_amax_history_forward[0].zero_() - ref_amax_history_backward = torch.roll(amax_history_backward, -1, dims=0) - ref_amax_history_backward[0].zero_() + # Note: amax history is only updated when amax is updated + update_weight_amax = is_first_microbatch is None or is_first_microbatch + ref_amax_history_forward = amax_history_forward.clone() + ref_amax_history_forward[:, 0].copy_(torch.roll(amax_history_forward[:, 0], -1)) + if update_weight_amax: + ref_amax_history_forward[:, 1].copy_(torch.roll(amax_history_forward[:, 1], -1)) + ref_amax_history_forward[0, :].zero_() + ref_amax_history_backward = amax_history_backward.clone() + ref_amax_history_backward[:, 0].copy_(torch.roll(amax_history_backward[:, 0], -1)) + ref_amax_history_backward[0, :].zero_() # Expected scale and scale inverse if amax_compute_algo == "max": ref_amax_forward = amax_history_forward.max(dim=0).values ref_amax_backward = amax_history_backward.max(dim=0).values elif amax_compute_algo == "most_recent": - ref_amax_forward = amax_history_forward[0] - ref_amax_backward = amax_history_backward[0] + ref_amax_forward = amax_history_forward[-1] + ref_amax_backward = amax_history_backward[-1] else: raise ValueError(f"{amax_compute_algo=} is not supported") ref_scale_forward = (fp8_format.value.max_fwd / ref_amax_forward) / (2 ** margin) ref_scale_backward = (fp8_format.value.max_bwd / ref_amax_backward) / (2 ** margin) ref_scale_inv_forward = torch.reciprocal(ref_scale_forward) - update_weight_scale_inv = is_first_microbatch is None or is_first_microbatch - if not update_weight_scale_inv: + update_weight_amax = is_first_microbatch is None or is_first_microbatch + if not update_weight_amax: ref_scale_inv_forward[1].copy_(scale_inv_forward[1]) ref_scale_inv_backward = torch.reciprocal(ref_scale_backward) - # Make sure we are not trivially passing tests - if amax_history_len > 1: - with pytest.raises(AssertionError): - torch.testing.assert_close( - amax_history_forward[1:], - ref_amax_history_forward[1:], - ) - with pytest.raises(AssertionError): - torch.testing.assert_close( - scale_forward, - ref_scale_forward, - ) - with pytest.raises(AssertionError): - torch.testing.assert_close( - scale_inv_forward, - ref_scale_inv_forward, - ) - if amax_history_len > 1: - with pytest.raises(AssertionError): - torch.testing.assert_close( - fp8_meta[backward_key].amax_history[1:], - ref_amax_history_backward[1:], - ) - with pytest.raises(AssertionError): - torch.testing.assert_close( - fp8_meta[backward_key].scale, - ref_scale_backward, - ) - with pytest.raises(AssertionError): - torch.testing.assert_close( - fp8_meta[backward_key].scale_inv, - ref_scale_inv_backward, - ) - - # Perform forward and backward pass to update fp8_meta + # Perform forward, backward, and optimizer steps to update fp8_meta with te.fp8_autocast(enabled=True, fp8_recipe=recipe): - x = torch.zeros([16, 16], device="cuda") + x = torch.randn([16, 16], device="cuda") y = module(x, is_first_microbatch=is_first_microbatch) - y.backward(torch.zeros_like(y)) + y.backward(torch.randn_like(y)) - # Check that fp8_meta matches expected values + # Check that amax history matches expected values torch.testing.assert_close( - fp8_meta[forward_key].amax_history[1:], - ref_amax_history_forward[1:], + amax_history_forward[:-1], + ref_amax_history_forward[:-1], ) torch.testing.assert_close( - fp8_meta[forward_key].scale, - ref_scale_forward, + amax_history_backward[:-1], + ref_amax_history_backward[:-1], ) + + # Expected scale and scale inverse + if amax_compute_algo == "max": + ref_amax_forward = amax_history_forward.max(dim=0).values + ref_amax_backward = amax_history_backward.max(dim=0).values + elif amax_compute_algo == "most_recent": + ref_amax_forward = amax_history_forward[-1] + ref_amax_backward = amax_history_backward[-1] + else: + raise ValueError(f"{amax_compute_algo=} is not supported") + ref_scale_forward = (fp8_format.value.max_fwd / ref_amax_forward) / (2 ** margin) + ref_scale_backward = (fp8_format.value.max_bwd / ref_amax_backward) / (2 ** margin) + ref_scale_inv_forward = torch.reciprocal(ref_scale_forward) + ref_scale_inv_backward = torch.reciprocal(ref_scale_backward) + + # Check that scale and scale inverse match expected values + # Note: scale and scale inverse are only updated when amax is updated torch.testing.assert_close( - fp8_meta[forward_key].scale_inv, - ref_scale_inv_forward, + scale_forward[0], + ref_scale_forward[0], ) torch.testing.assert_close( - fp8_meta[backward_key].amax_history[1:], - ref_amax_history_backward[1:], + scale_inv_forward[0], + ref_scale_inv_forward[0], ) + if update_weight_amax: + torch.testing.assert_close( + scale_forward[1], + ref_scale_forward[1], + ) + torch.testing.assert_close( + scale_inv_forward[1], + ref_scale_inv_forward[1], + ) torch.testing.assert_close( - fp8_meta[backward_key].scale, - ref_scale_backward, + scale_backward[0], + ref_scale_backward[0], ) torch.testing.assert_close( - fp8_meta[backward_key].scale_inv, - ref_scale_inv_backward, + scale_inv_backward[0], + ref_scale_inv_backward[0], ) @pytest.mark.parametrize("amax_case", ["zero", "tiny", "normal", "inf", "nan"]) From 4af821b31c8d1612c648a97da4c420a04a42c948 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 9 May 2024 09:37:02 -0700 Subject: [PATCH 051/244] Update FA version (#838) Bump FA version to 2.5.8 Signed-off-by: Kirthi Shankar Sivamani Signed-off-by: Pawel Gadzinski --- setup.py | 2 +- transformer_engine/pytorch/attention.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index d442aec872..769d62a25b 100644 --- a/setup.py +++ b/setup.py @@ -265,7 +265,7 @@ def add_unique(l: List[str], vals: Union[str, List[str]]) -> None: # Framework-specific requirements if "pytorch" in frameworks(): - add_unique(install_reqs, ["torch", "flash-attn>=2.0.6,<=2.4.2,!=2.0.9,!=2.1.0"]) + add_unique(install_reqs, ["torch", "flash-attn>=2.0.6,<=2.5.8,!=2.0.9,!=2.1.0"]) add_unique(test_reqs, ["numpy", "onnxruntime", "torchvision"]) if "jax" in frameworks(): if not found_pybind11(): diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 2f5a6aa671..af1797ee54 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -69,6 +69,7 @@ _flash_attn_version = packaging.version.Version(version("flash-attn")) _flash_attn_version_required = packaging.version.Version("2.0.6") +_flash_attn_max_version = packaging.version.Version("2.5.8") _flash_attn_2_1_plus = _flash_attn_version >= packaging.version.Version("2.1") _flash_attn_2_3_plus = _flash_attn_version >= packaging.version.Version("2.3") _flash_attn_2_4_plus = _flash_attn_version >= packaging.version.Version("2.4") @@ -1931,6 +1932,9 @@ def __init__( assert ( _flash_attn_version >= _flash_attn_version_required ), f"FlashAttention minimum version {_flash_attn_version_required} is required." + assert ( + _flash_attn_version <= _flash_attn_max_version + ), f"FlashAttention maximum version {_flash_attn_max_version} is supported." self.norm_factor = norm_factor self.attention_dropout_ctx = attention_dropout_ctx From 9607e9565606d8605b0c385e3dc0929ca8587826 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com> Date: Thu, 9 May 2024 10:59:20 -0700 Subject: [PATCH 052/244] [JAX] Fixes for the issue with ActLuPrimitive in PAXML (#837) * fixes for ActLuPrimitive in PAXML * changed indices for arg_infos in sharding func in dbias_cast_transpose primitive --------- Signed-off-by: Phuong Nguyen Signed-off-by: Pawel Gadzinski --- transformer_engine/jax/cpp_extensions.py | 27 +++++++++++++++--------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions.py b/transformer_engine/jax/cpp_extensions.py index 8f4ed045d0..1cdfb6f930 100644 --- a/transformer_engine/jax/cpp_extensions.py +++ b/transformer_engine/jax/cpp_extensions.py @@ -2655,16 +2655,20 @@ def partition(act_enum, mesh, arg_infos, result_infos): """ act_lu partitioning """ - del result_infos, act_enum + del result_infos x_spec = get_padded_spec(arg_infos[0]) arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1])) - impl = ActLuPrimitive.impl - return mesh, impl, out_sharding, arg_shardings + + def sharded_impl(x): + return ActLuPrimitive.impl(x, act_enum=act_enum) + + return mesh, sharded_impl, out_sharding, arg_shardings register_primitive(ActLuPrimitive) + def act_lu(inputs: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]) -> jnp.ndarray: """ act_lu wrapper @@ -2779,12 +2783,15 @@ def partition(act_enum, mesh, arg_infos, result_infos): """ dact_lu partition """ - del result_infos, act_enum + del result_infos dx_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) out_shardings = dx_sharding - impl = DActLuPrimitive.impl - return mesh, impl, out_shardings, arg_shardings + + def sharded_impl(dz, x): + return DActLuPrimitive.impl(dz, x, act_enum=act_enum) + + return mesh, sharded_impl, out_shardings, arg_shardings register_primitive(DActLuPrimitive) @@ -4304,20 +4311,20 @@ def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary, def infer_sharding_from_operands(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos): del out_dtype, result_infos - x_spec = get_padded_spec(arg_infos[1]) + x_spec = get_padded_spec(arg_infos[0]) out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary) tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec)) dbias_shaprding = NamedSharding( mesh, PartitionSpec(*x_spec[:static_axis_boundary + 1], x_spec[-1])) - amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2]))) + amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) return (out_sharding, tranposed_out_sharding, dbias_shaprding, amax_sharding) @staticmethod def partition(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos): del result_infos - x_spec = get_padded_spec(arg_infos[1]) + x_spec = get_padded_spec(arg_infos[0]) casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary) casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec)) @@ -4325,7 +4332,7 @@ def partition(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, ar dbias_shaprding = NamedSharding( mesh, PartitionSpec(*x_spec[:static_axis_boundary + 1], x_spec[-1])) - amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2]))) + amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) out_shardings = (casted_x_sharding, casted_transposed_x_sharding, dbias_shaprding, amax_sharding) From e0f3157a3132cc3861930071f26de188530dd3bc Mon Sep 17 00:00:00 2001 From: root Date: Thu, 21 Mar 2024 22:54:35 +0000 Subject: [PATCH 053/244] Not completely done gemma Signed-off-by: Pawel Gadzinski --- docs/examples/te_gemma/te_gemma.py | 172 +++++++++++++++++++++++++ docs/examples/te_gemma/utils.py | 194 +++++++++++++++++++++++++++++ 2 files changed, 366 insertions(+) create mode 100755 docs/examples/te_gemma/te_gemma.py create mode 100755 docs/examples/te_gemma/utils.py diff --git a/docs/examples/te_gemma/te_gemma.py b/docs/examples/te_gemma/te_gemma.py new file mode 100755 index 0000000000..c8551570d4 --- /dev/null +++ b/docs/examples/te_gemma/te_gemma.py @@ -0,0 +1,172 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import os +import re +import gc +from contextlib import contextmanager + +import torch +from torch import nn + +import transformer_engine as te +from transformer_engine.pytorch.attention import RotaryPositionEmbedding +from transformer_engine.pytorch.fp8 import fp8_model_init + +import transformers +from transformers.models.gemma.modeling_gemma import GemmaModel, GemmaForCausalLM, GemmaRMSNorm, GemmaConfig +from transformers.modeling_utils import _add_variant, load_state_dict, _load_state_dict_into_model +from transformers.utils import WEIGHTS_INDEX_NAME +from transformers.utils.hub import get_checkpoint_shard_files + +@contextmanager +def replace_decoder(te_decoder_cls): + """ + Replace `GemmaDecoderLayer` with custom `TEGemmaDecoderLayer`. + """ + original_gemma_decoder_cls = transformers.models.gemma.modeling_gemma.GemmaDecoderLayer + transformers.models.gemma.modeling_gemma.GemmaDecoderLayer = te_decoder_cls + try: + yield + finally: + transformers.models.gemma.modeling_gemma.GemmaDecoderLayer = original_gemma_decoder_cls + + +class TEGemmaDecoderLayer(te.pytorch.TransformerLayer): + """ + Wrapper class over TE's `TransformerLayer`. This makes the wrapper very + similar to HF's `GemmaDecoderLayer` and easier to replace it in the code. + + Args: + config: GemmaConfig + args: positional args (for compatibility with `GemmaDecoderLayer`) + kwargs: keyword args (for compatibility with `GemmaDecoderLayer`) + """ + def __init__(self, config, *args, **kwargs): + super().__init__( + hidden_size=config.hidden_size, + ffn_hidden_size=config.intermediate_size, + num_attention_heads=config.num_attention_heads, + bias=False, + layernorm_epsilon=config.rms_norm_eps, + hidden_dropout=0, + attention_dropout=0, + fuse_qkv_params=False, + normalization="RMSNorm", + activation="geglu", + attn_input_format="bshd", + num_gqa_groups=16, + kv_channels=1000000000000000 + ) + te_rope = RotaryPositionEmbedding(256) + self.te_rope_emb = te_rope(max_seq_len=config.max_position_embeddings).cuda() + + def forward(self, + hidden_states, + *args, + attention_mask, + **kwargs): + """ + Custom forward to make sure we only pass relevant arguments to the + forward pass of the `TransformerLayer`. Also, make sure the output + format matches the output of the HF's `GemmaDecoderLayer`. + """ + return (super().forward(hidden_states, attention_mask=attention_mask, rotary_pos_emb=self.te_rope_emb),) + + +class TEGemmaForCausalLM: + """ + Causal LM created with `GemmaModel`. The underlying `GemmaDecoderLayer` + class is monkey-patched with `TEGemmaDecoderLayer` class before + initializing the causal LM with `GemmaForCausalLM`. + + Args: + config: GemmaConfig + """ + + def __new__(cls, config: GemmaConfig): + with replace_decoder(te_decoder_cls=TEGemmaDecoderLayer): + gemma_for_causal_lm = GemmaForCausalLM(config) + return gemma_for_causal_lm + + @classmethod + def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, **kwargs): + """ + Custom method adapted from `from_pretrained` method in HuggingFace + Transformers repo: https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579 + """ + vanilla_model = cls(config).to(kwargs['torch_dtype']) + is_local = os.path.isdir(pretrained_model_name_or_path) + subfolder = "" + variant = None + if os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, _add_variant("model.safetensors.index.json", variant)) + ): + # Load from a sharded PyTorch checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, subfolder, _add_variant("model.safetensors.index.json", variant) + ) + is_sharded = True + print(archive_file) + + resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( + pretrained_model_name_or_path, + archive_file, + ) + + # If the checkpoint is not sharded, it's a trivial sharding case + if not is_sharded: + assert not isinstance(resolved_archive_file, list) + resolved_archive_file = [resolved_archive_file] + + for shard_file in resolved_archive_file: + state_dict = load_state_dict(shard_file) + replaces_params = replace_params(state_dict, vanilla_model.state_dict()) + #_load_state_dict_into_model(vanilla_model, state_dict, start_prefix="") + + # Force mem release. Taken from huggingface code + del state_dict + gc.collect() + + return vanilla_model + +def replace_params(hf_state_dict, te_state_dict): + # collect all layer prefixes to update + all_layer_prefixes = set() + for param_key in hf_state_dict.keys(): + layer_prefix_pat = 'model.layers.\d+.' + m = re.match(layer_prefix_pat, param_key) + if m is not None: + all_layer_prefixes.add(m.group()) + + [(print(x, " ", te_state_dict[x].shape if type(te_state_dict[x]) is torch.Tensor else " ") if x.startswith(list(all_layer_prefixes)[0]) else "") for x in te_state_dict.keys()] + + for layer_prefix in all_layer_prefixes: + # When loading weights into models with less number of layers, skip the + # copy if the corresponding layer doesn't exist in HF model + if layer_prefix + 'input_layernorm.weight' in hf_state_dict: + te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.layer_norm_weight'].data[:] = hf_state_dict[layer_prefix + 'input_layernorm.weight'].data[:] + + if layer_prefix + 'self_attn.q_proj.weight' in hf_state_dict: + te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.query_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.q_proj.weight'].data[:] + + if layer_prefix + 'self_attn.k_proj.weight' in hf_state_dict: + te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.key_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.k_proj.weight'].data[:] + + if layer_prefix + 'self_attn.v_proj.weight' in hf_state_dict: + te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.value_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.v_proj.weight'].data[:] + + if layer_prefix + 'self_attn.o_proj.weight' in hf_state_dict: + te_state_dict[layer_prefix + 'self_attention.proj.weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.o_proj.weight'].data[:] + + if layer_prefix + 'post_attention_layernorm.weight' in hf_state_dict: + te_state_dict[layer_prefix + 'layernorm_mlp.layer_norm_weight'].data[:] = hf_state_dict[layer_prefix + 'post_attention_layernorm.weight'].data[:] + + if layer_prefix + 'mlp.gate_proj.weight' in hf_state_dict and 'mlp.up_proj.weight' in hf_state_dict: + te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[:] = torch.cat((hf_state_dict[layer_prefix + 'mlp.gate_proj.weight'].data[:], hf_state_dict[layer_prefix + 'mlp.up_proj.weight'].data[:]), dim=0) + + if layer_prefix + 'mlp.down_proj.weight' in hf_state_dict: + te_state_dict[layer_prefix + 'layernorm_mlp.fc2_weight'].data[:] = hf_state_dict[layer_prefix + 'mlp.down_proj.weight'].data[:] + + return all_layer_prefixes \ No newline at end of file diff --git a/docs/examples/te_gemma/utils.py b/docs/examples/te_gemma/utils.py new file mode 100755 index 0000000000..d29b094821 --- /dev/null +++ b/docs/examples/te_gemma/utils.py @@ -0,0 +1,194 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import time +import sys +import IPython + +import torch +from torch.optim import AdamW +from torch.utils.data import DataLoader + +from transformers import AutoModelForCausalLM, AutoTokenizer, get_linear_schedule_with_warmup, AutoConfig +from transformers import DataCollatorForLanguageModeling +from datasets import load_dataset +from accelerate import Accelerator +from accelerate.utils.dataclasses import FP8RecipeKwargs + +class HyperParameters: + def __init__(self): + self.mixed_precision = "bf16" + #self.model_name = "" # <== Add model weight location here + self.dataset_name = "timdettmers/openassistant-guanaco" + self.dataset_text_field = "text" + self.learning_rate = 1.41e-5 + self.batch_size = 16 + self.max_seq_length = 256 + self.gradient_accumulation_steps = 1 + self.num_warmup_steps=5 + self.num_training_steps=10 + + +hyperparams = HyperParameters() + +def get_dataloaders(accelerator:Accelerator, hyperparams): + dataset = load_dataset(hyperparams.dataset_name, split="train") + tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name) + if getattr(tokenizer, "pad_token", None) is None: + tokenizer.pad_token = tokenizer.eos_token + + def tokenize(element): + outputs = tokenizer( + element["text"], + truncation=True, + padding=False, + max_length=hyperparams.max_seq_length, + return_overflowing_tokens=False, + return_length=False + ) + return {"input_ids": outputs["input_ids"], "attention_mask": outputs["attention_mask"]} + + with accelerator.main_process_first(): + dataset = dataset.map( + tokenize, + batched=True, + remove_columns=dataset.column_names + ) + + # Simply pad to the multiple of 16 for both FP8 and BF16 precision + pad_to_multiple_of = 16 + data_collator = DataCollatorForLanguageModeling( + tokenizer=tokenizer, + mlm=False, + pad_to_multiple_of=pad_to_multiple_of, + ) + + dataloader_params = { + "batch_size": hyperparams.batch_size, + "collate_fn": data_collator, + "drop_last": True, + } + train_dataloader = DataLoader(dataset, **dataloader_params) + return train_dataloader + +def init_baseline_model(hyperparams): + # Init the model + config = AutoConfig.from_pretrained(hyperparams.model_name) + # make sure to use flash_attention to do iso comparison with TEGemmaModel + config._attn_implementation = "flash_attention_2" + model = AutoModelForCausalLM.from_pretrained( + hyperparams.model_name, + config=config, + torch_dtype=torch.bfloat16, + ) + # Needed for the cases when using TEGemmaForCausalLM. So adding here for 1:1 comparison + model.config.use_cache=False + + return model + +def init_te_gemma_model(hyperparams): + # Init the model + from te_gemma import TEGemmaForCausalLM + config = AutoConfig.from_pretrained(hyperparams.model_name) + config._attn_implementation = "flash_attention_2" + model = TEGemmaForCausalLM.from_pretrained_local( + hyperparams.model_name, + config=config, + torch_dtype=torch.bfloat16, + ) + # Needed for the cases when using TEGemmaForCausalLM + model.config.use_cache=False + + return model + +def wrap_with_accelerator(model, hyperparams): + # Create FP8 kwarg handler if required + fp8_kwarg_handler = [FP8RecipeKwargs(backend="te")] if hyperparams.mixed_precision == "fp8" else None + + # Init HF accelerator that's used for training + accelerator = Accelerator( + log_with="wandb", + gradient_accumulation_steps=hyperparams.gradient_accumulation_steps, + mixed_precision=hyperparams.mixed_precision, + kwargs_handlers=fp8_kwarg_handler + ) + #accelerator.print(f'State: {accelerator.state}') + train_dataloader = get_dataloaders(accelerator, hyperparams) + + # Wrap model, optimizer/scheduler, dataloaders in accelerate + optimizer = AdamW(params = model.parameters(), lr=hyperparams.learning_rate, fused=True) + lr_scheduler = get_linear_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=100, + num_training_steps=hyperparams.num_training_steps, + ) + model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, lr_scheduler + ) + + return accelerator, model, optimizer, train_dataloader, lr_scheduler + +def finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler): + model.train() + total_loss = 0 + optimizer.zero_grad() + train_dataloader = enumerate(train_dataloader) + + # Warmup iters + for _ in range(hyperparams.num_warmup_steps): + step, batch = next(train_dataloader) + with accelerator.accumulate(model): + outputs = model(**batch) + loss = outputs.loss + total_loss += loss.detach().float() + accelerator.backward(loss) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Get the timers ready + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + + start.record() + # Training iters + for _ in range(hyperparams.num_training_steps): + step, batch = next(train_dataloader) + with accelerator.accumulate(model): + outputs = model(**batch) + loss = outputs.loss + total_loss += loss.detach().float() + accelerator.backward(loss) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + torch.cuda.synchronize() + end.record() + accelerator.end_training() + + print(f"{hyperparams.num_training_steps} finetuning steps complete!\nAverage time taken per step: {(start.elapsed_time(end)/hyperparams.num_training_steps):.0f} milliseconds") + +def restart_jupyter_notebook(): + # Try restarting the Jupyter kernel + IPython.Application.instance().kernel.do_shutdown(True) + + # Check whether the device memory has been flushed + if torch.cuda.memory_allocated() != 0: + import warnings + warnings.warn("The device memory hasn't been flushed, trying with a second method!") + + # Try restarting the Jupyter kernel another way + # Restart the kernel + from IPython.core.display import HTML + HTML("") + + if torch.cuda.memory_allocated() != 0: + print("The device memory hasn't been flushed, try manually restarting the Jupyter kernel!") + + # Suppress the warnings + if not sys.warnoptions: + import warnings + warnings.simplefilter("ignore") + torch.set_warn_always(False) From 746deba9a7e10cc06ea44715c06d21c6e5403968 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 21 Mar 2024 23:36:21 +0000 Subject: [PATCH 054/244] something Signed-off-by: Pawel Gadzinski --- docs/examples/te_gemma/te_gemma.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/docs/examples/te_gemma/te_gemma.py b/docs/examples/te_gemma/te_gemma.py index c8551570d4..06eedb33c0 100755 --- a/docs/examples/te_gemma/te_gemma.py +++ b/docs/examples/te_gemma/te_gemma.py @@ -56,7 +56,7 @@ def __init__(self, config, *args, **kwargs): normalization="RMSNorm", activation="geglu", attn_input_format="bshd", - num_gqa_groups=16, + num_gqa_groups=2, kv_channels=1000000000000000 ) te_rope = RotaryPositionEmbedding(256) @@ -140,8 +140,12 @@ def replace_params(hf_state_dict, te_state_dict): if m is not None: all_layer_prefixes.add(m.group()) - [(print(x, " ", te_state_dict[x].shape if type(te_state_dict[x]) is torch.Tensor else " ") if x.startswith(list(all_layer_prefixes)[0]) else "") for x in te_state_dict.keys()] + print('-' * 50) + [(print(x, " ", te_state_dict[x].shape if type(te_state_dict[x]) is torch.Tensor else " ") if x.startswith(list(all_layer_prefixes)[1]) else "") for x in te_state_dict.keys()] + print('-' * 50) + [(print(x, " ", hf_state_dict[x].shape if type(hf_state_dict[x]) is torch.Tensor else " ") if x.startswith(list(all_layer_prefixes)[1]) else "") for x in hf_state_dict.keys()] + for layer_prefix in all_layer_prefixes: # When loading weights into models with less number of layers, skip the # copy if the corresponding layer doesn't exist in HF model From e582840e448b1a065367d787a2bb2c95e1a1fc2b Mon Sep 17 00:00:00 2001 From: root Date: Fri, 22 Mar 2024 22:15:37 +0000 Subject: [PATCH 055/244] Version which works Signed-off-by: Pawel Gadzinski --- docs/examples/te_gemma/te_gemma.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/docs/examples/te_gemma/te_gemma.py b/docs/examples/te_gemma/te_gemma.py index 06eedb33c0..616be8fbec 100755 --- a/docs/examples/te_gemma/te_gemma.py +++ b/docs/examples/te_gemma/te_gemma.py @@ -56,8 +56,8 @@ def __init__(self, config, *args, **kwargs): normalization="RMSNorm", activation="geglu", attn_input_format="bshd", - num_gqa_groups=2, - kv_channels=1000000000000000 + num_gqa_groups=config.num_key_value_heads, + attention_hidden_size=4096 ) te_rope = RotaryPositionEmbedding(256) self.te_rope_emb = te_rope(max_seq_len=config.max_position_embeddings).cuda() @@ -140,12 +140,6 @@ def replace_params(hf_state_dict, te_state_dict): if m is not None: all_layer_prefixes.add(m.group()) - print('-' * 50) - [(print(x, " ", te_state_dict[x].shape if type(te_state_dict[x]) is torch.Tensor else " ") if x.startswith(list(all_layer_prefixes)[1]) else "") for x in te_state_dict.keys()] - - print('-' * 50) - [(print(x, " ", hf_state_dict[x].shape if type(hf_state_dict[x]) is torch.Tensor else " ") if x.startswith(list(all_layer_prefixes)[1]) else "") for x in hf_state_dict.keys()] - for layer_prefix in all_layer_prefixes: # When loading weights into models with less number of layers, skip the # copy if the corresponding layer doesn't exist in HF model From 59eb22d0f3379820b0e34b7def4c66649d9f2e73 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 22 Mar 2024 22:41:55 +0000 Subject: [PATCH 056/244] Fixed kv_channels Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/attention.py | 43 ++++++++++++----------- transformer_engine/pytorch/transformer.py | 12 +++---- 2 files changed, 26 insertions(+), 29 deletions(-) mode change 100644 => 100755 transformer_engine/pytorch/attention.py mode change 100644 => 100755 transformer_engine/pytorch/transformer.py diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py old mode 100644 new mode 100755 index af1797ee54..a1f97a4fe0 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -3125,8 +3125,8 @@ class DotProductAttention(torch.nn.Module): ---------- num_attention_heads : int number of attention heads in the transformer layer. - kv_channels : int - number of key-value channels. + channels : int + number of key-query-value channels. num_gqa_groups : Optional[int] = None number of GQA groups in the transformer layer. Grouped Query Attention is described in @@ -3196,7 +3196,7 @@ class DotProductAttention(torch.nn.Module): def __init__( self, num_attention_heads: int, - kv_channels: int, + channels: int, num_gqa_groups: Optional[int] = None, attention_dropout: float = 0.0, qkv_format: str = "sbhd", @@ -3230,7 +3230,8 @@ def __init__( self.cp_global_ranks = cp_global_ranks self.cp_stream = cp_stream - self.hidden_size_per_attention_head = kv_channels + + self.hidden_size_per_attention_head = channels // num_attention_heads self.num_gqa_groups = ( num_attention_heads if num_gqa_groups is None else num_gqa_groups ) @@ -3380,9 +3381,9 @@ def forward( Input tensors :attr:`query_layer`, :attr:`key_layer`, and :attr:`value_layer` must each be of shape (:attr:`sequence_length`, :attr:`batch_size`, - :attr:`num_attention_heads`, :attr:`kv_channels`). Output of shape + :attr:`num_attention_heads`, :attr:`channels`). Output of shape (:attr:`sequence_length`, :attr:`batch_size`, :attr:`num_attention_heads` - * :attr:`kv_channels`) is returned. + * :attr:`channels`) is returned. .. note:: @@ -3903,8 +3904,8 @@ class MultiheadAttention(torch.nn.Module): size of each input sample. num_attention_heads : int number of attention heads in the transformer layer. - kv_channels: int, default = `None` - number of key-value channels. defaults to + attention_hidden_size: int, default = `None` + number of key-query-value channels. defaults to :attr:`hidden_size` / :attr:`num_attention_heads` if `None`. attention_dropout: float, default = 0.1 dropout probability for the dropout op during multi-head attention. @@ -4027,7 +4028,7 @@ def __init__( self, hidden_size: int, num_attention_heads: int, - kv_channels: Optional[int] = None, + attention_hidden_size: Optional[int] = None, attention_dropout: float = 0.1, layernorm_epsilon: float = 1e-5, init_method: Optional[Callable] = None, @@ -4076,7 +4077,7 @@ def __init__( self.num_attention_heads = num_attention_heads self.return_bias = return_bias - kv_channels = kv_channels if kv_channels else (hidden_size // num_attention_heads) + self.attention_hidden_size = attention_hidden_size if attention_hidden_size else (hidden_size // num_attention_heads) if init_method is None: init_method = get_default_init_method() @@ -4095,7 +4096,7 @@ def __init__( self.tp_size = tp_size self.sequence_parallel = (tp_size > 1) and sequence_parallel - self.hidden_size_per_attention_head = kv_channels + self.hidden_size_per_attention_head = attention_hidden_size // num_attention_heads self.num_attention_heads_per_partition = divide(num_attention_heads, tp_size) self.num_gqa_groups = ( num_attention_heads if num_gqa_groups is None else num_gqa_groups @@ -4123,14 +4124,14 @@ def __init__( parameters_split = None if not fuse_qkv_params: parameters_split = collections.OrderedDict([ - ("query", hidden_size), - ("key", self.hidden_size_kv), - ("value", self.hidden_size_kv), + ("query", attention_hidden_size), + ("key", attention_hidden_size), + ("value", attention_hidden_size), ]) if self.input_layernorm: self.layernorm_qkv = LayerNormLinear( hidden_size, - hidden_size + 2 * self.hidden_size_kv, + 3 * attention_hidden_size, eps=layernorm_epsilon, init_method=init_method, bias=bias, @@ -4150,7 +4151,7 @@ def __init__( else: self.qkv = Linear( hidden_size, - hidden_size + 2 * self.hidden_size_kv, + 3 * attention_hidden_size, init_method=init_method, bias=bias, return_bias=False, @@ -4162,7 +4163,7 @@ def __init__( if self.input_layernorm: self.layernorm_query = LayerNormLinear( hidden_size, - hidden_size, + attention_hidden_size, eps=layernorm_epsilon, init_method=init_method, bias=bias, @@ -4182,7 +4183,7 @@ def __init__( else: self.query_layer = Linear( hidden_size, - hidden_size, + attention_hidden_size, init_method=init_method, bias=bias, return_bias=False, @@ -4191,7 +4192,7 @@ def __init__( ) self.key_value = Linear( hidden_size, - 2 * self.hidden_size_kv, + 2 * attention_hidden_size, init_method=init_method, bias=bias, return_bias=False, @@ -4203,7 +4204,7 @@ def __init__( # Attention. self.core_attention = DotProductAttention( num_attention_heads, - kv_channels, + attention_hidden_size, num_gqa_groups=self.num_gqa_groups, attention_dropout=attention_dropout, qkv_format=self.qkv_format, @@ -4217,7 +4218,7 @@ def __init__( # Linear self.proj = Linear( - hidden_size, + attention_hidden_size, hidden_size, init_method=output_layer_init_method, bias=bias, diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py old mode 100644 new mode 100755 index 5b6fc1e5c3..b59c1ce346 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -128,8 +128,8 @@ class TransformerLayer(torch.nn.Module): if set to `decoder`, an additional cross-attn block is added after self-attn. This can be used for structures like `T5` Transformer in conjunction with the `encoder` option. - kv_channels: int, default = `None` - number of key-value channels. defaults to + attention_hidden_size: int, default = `None` + number of channels of queue/key/value. defaults to :attr:`hidden_size` / :attr:`num_attention_heads` if `None`. self_attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'arbitrary'}, default = `causal` @@ -236,7 +236,7 @@ def __init__( init_method: Optional[Callable] = None, output_layer_init_method: Optional[Callable] = None, layer_number: Optional[int] = None, - kv_channels: Optional[int] = None, + attention_hidden_size: Optional[int] = None, self_attn_mask_type: str = "causal", window_size: Optional[Tuple[int, int]] = None, tp_group: Optional[dist_group_type] = None, @@ -315,10 +315,6 @@ def __init__( if not fuse_qkv_params: qkv_weight_interleaved = False - self.kv_channels = ( - kv_channels if kv_channels else (hidden_size // num_attention_heads) - ) - if init_method is None: init_method = get_default_init_method() if output_layer_init_method is None: @@ -335,7 +331,7 @@ def __init__( attention_args = ( hidden_size, num_attention_heads, - self.kv_channels, + attention_hidden_size, attention_dropout, layernorm_epsilon, init_method, From 7a7fe6f5b0f3f4651af3f1bdb38d199c2353a345 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 27 Mar 2024 20:51:29 +0000 Subject: [PATCH 057/244] Fixed potential bug with fc1 loading Signed-off-by: Pawel Gadzinski --- docs/examples/te_gemma/te_gemma.py | 38 +++++++++++++++++++----------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/docs/examples/te_gemma/te_gemma.py b/docs/examples/te_gemma/te_gemma.py index 616be8fbec..54a81c05b9 100755 --- a/docs/examples/te_gemma/te_gemma.py +++ b/docs/examples/te_gemma/te_gemma.py @@ -7,6 +7,11 @@ import gc from contextlib import contextmanager +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union + +from transformers.generation import * +from transformers.generation.utils import * + import torch from torch import nn @@ -96,7 +101,7 @@ def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, **k Custom method adapted from `from_pretrained` method in HuggingFace Transformers repo: https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579 """ - vanilla_model = cls(config).to(kwargs['torch_dtype']) + vanilla_model = cls(config) is_local = os.path.isdir(pretrained_model_name_or_path) subfolder = "" variant = None @@ -108,7 +113,6 @@ def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, **k pretrained_model_name_or_path, subfolder, _add_variant("model.safetensors.index.json", variant) ) is_sharded = True - print(archive_file) resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( pretrained_model_name_or_path, @@ -122,8 +126,8 @@ def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, **k for shard_file in resolved_archive_file: state_dict = load_state_dict(shard_file) - replaces_params = replace_params(state_dict, vanilla_model.state_dict()) - #_load_state_dict_into_model(vanilla_model, state_dict, start_prefix="") + replace_params(state_dict, vanilla_model.state_dict()) + _load_state_dict_into_model(vanilla_model, state_dict, start_prefix="") # Force mem release. Taken from huggingface code del state_dict @@ -131,6 +135,7 @@ def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, **k return vanilla_model + def replace_params(hf_state_dict, te_state_dict): # collect all layer prefixes to update all_layer_prefixes = set() @@ -139,32 +144,37 @@ def replace_params(hf_state_dict, te_state_dict): m = re.match(layer_prefix_pat, param_key) if m is not None: all_layer_prefixes.add(m.group()) + + GATE_PROJ_SIZE=24576 for layer_prefix in all_layer_prefixes: # When loading weights into models with less number of layers, skip the # copy if the corresponding layer doesn't exist in HF model if layer_prefix + 'input_layernorm.weight' in hf_state_dict: - te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.layer_norm_weight'].data[:] = hf_state_dict[layer_prefix + 'input_layernorm.weight'].data[:] - + te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.layer_norm_weight'].copy_(1 + hf_state_dict[layer_prefix + 'input_layernorm.weight']) + if layer_prefix + 'self_attn.q_proj.weight' in hf_state_dict: - te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.query_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.q_proj.weight'].data[:] + te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.query_weight'].copy_(hf_state_dict[layer_prefix + 'self_attn.q_proj.weight']) if layer_prefix + 'self_attn.k_proj.weight' in hf_state_dict: - te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.key_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.k_proj.weight'].data[:] + te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.key_weight'].copy_(hf_state_dict[layer_prefix + 'self_attn.k_proj.weight']) if layer_prefix + 'self_attn.v_proj.weight' in hf_state_dict: - te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.value_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.v_proj.weight'].data[:] + te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.value_weight'].copy_(hf_state_dict[layer_prefix + 'self_attn.v_proj.weight']) if layer_prefix + 'self_attn.o_proj.weight' in hf_state_dict: - te_state_dict[layer_prefix + 'self_attention.proj.weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.o_proj.weight'].data[:] + te_state_dict[layer_prefix + 'self_attention.proj.weight'].copy_(hf_state_dict[layer_prefix + 'self_attn.o_proj.weight']) if layer_prefix + 'post_attention_layernorm.weight' in hf_state_dict: - te_state_dict[layer_prefix + 'layernorm_mlp.layer_norm_weight'].data[:] = hf_state_dict[layer_prefix + 'post_attention_layernorm.weight'].data[:] + te_state_dict[layer_prefix + 'layernorm_mlp.layer_norm_weight'].data[:] = hf_state_dict[layer_prefix + 'post_attention_layernorm.weight'].data[:] + 1 - if layer_prefix + 'mlp.gate_proj.weight' in hf_state_dict and 'mlp.up_proj.weight' in hf_state_dict: - te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[:] = torch.cat((hf_state_dict[layer_prefix + 'mlp.gate_proj.weight'].data[:], hf_state_dict[layer_prefix + 'mlp.up_proj.weight'].data[:]), dim=0) + if layer_prefix + 'mlp.gate_proj.weight' in hf_state_dict: + te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[:GATE_PROJ_SIZE] = hf_state_dict[layer_prefix + 'mlp.gate_proj.weight'].data[:] + + if layer_prefix + 'mlp.up_proj.weight' in hf_state_dict: + te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[GATE_PROJ_SIZE:] = hf_state_dict[layer_prefix + 'mlp.up_proj.weight'].data[:] if layer_prefix + 'mlp.down_proj.weight' in hf_state_dict: - te_state_dict[layer_prefix + 'layernorm_mlp.fc2_weight'].data[:] = hf_state_dict[layer_prefix + 'mlp.down_proj.weight'].data[:] + te_state_dict[layer_prefix + 'layernorm_mlp.fc2_weight'].copy_(hf_state_dict[layer_prefix + 'mlp.down_proj.weight'].data[:]) return all_layer_prefixes \ No newline at end of file From 64718a16468391557c582a488ce679a9b24e32f6 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 4 Apr 2024 15:38:22 -0700 Subject: [PATCH 058/244] Gemma generation Signed-off-by: Pawel Gadzinski --- docs/examples/te_gemma/te_gemma.py | 75 +++- .../tutorial_generation_gemma_with_te.ipynb | 372 ++++++++++++++++++ docs/examples/te_gemma/utils.py | 4 +- 3 files changed, 445 insertions(+), 6 deletions(-) create mode 100755 docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb diff --git a/docs/examples/te_gemma/te_gemma.py b/docs/examples/te_gemma/te_gemma.py index 54a81c05b9..113534828c 100755 --- a/docs/examples/te_gemma/te_gemma.py +++ b/docs/examples/te_gemma/te_gemma.py @@ -25,6 +25,7 @@ from transformers.utils import WEIGHTS_INDEX_NAME from transformers.utils.hub import get_checkpoint_shard_files + @contextmanager def replace_decoder(te_decoder_cls): """ @@ -48,7 +49,7 @@ class TEGemmaDecoderLayer(te.pytorch.TransformerLayer): args: positional args (for compatibility with `GemmaDecoderLayer`) kwargs: keyword args (for compatibility with `GemmaDecoderLayer`) """ - def __init__(self, config, *args, **kwargs): + def __init__(self, config, layer_idx, *args, **kwargs): super().__init__( hidden_size=config.hidden_size, ffn_hidden_size=config.intermediate_size, @@ -62,7 +63,8 @@ def __init__(self, config, *args, **kwargs): activation="geglu", attn_input_format="bshd", num_gqa_groups=config.num_key_value_heads, - attention_hidden_size=4096 + attention_hidden_size=4096, + layer_number=(layer_idx+1) ) te_rope = RotaryPositionEmbedding(256) self.te_rope_emb = te_rope(max_seq_len=config.max_position_embeddings).cuda() @@ -71,13 +73,15 @@ def forward(self, hidden_states, *args, attention_mask, + inference_params, + self_attn_mask_type='causal', **kwargs): """ Custom forward to make sure we only pass relevant arguments to the forward pass of the `TransformerLayer`. Also, make sure the output format matches the output of the HF's `GemmaDecoderLayer`. """ - return (super().forward(hidden_states, attention_mask=attention_mask, rotary_pos_emb=self.te_rope_emb),) + return (super().forward(hidden_states, attention_mask=attention_mask, rotary_pos_emb=self.te_rope_emb, inference_params=inference_params, self_attn_mask_type=self_attn_mask_type),) class TEGemmaForCausalLM: @@ -92,7 +96,11 @@ class is monkey-patched with `TEGemmaDecoderLayer` class before def __new__(cls, config: GemmaConfig): with replace_decoder(te_decoder_cls=TEGemmaDecoderLayer): + # trzeba wstawis layer number do tego czegos w jakis sposob gemma_for_causal_lm = GemmaForCausalLM(config) + + gemma_for_causal_lm.generate = TEGemmaForCausalLM.generate.__get__(gemma_for_causal_lm, GemmaForCausalLM) + return gemma_for_causal_lm @classmethod @@ -101,6 +109,7 @@ def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, **k Custom method adapted from `from_pretrained` method in HuggingFace Transformers repo: https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579 """ + vanilla_model = cls(config) is_local = os.path.isdir(pretrained_model_name_or_path) subfolder = "" @@ -134,6 +143,66 @@ def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, **k gc.collect() return vanilla_model + + @torch.no_grad() + def generate( + self, + input_ids: Optional[torch.Tensor] = None, + generation_config: Optional[GenerationConfig] = None, + max_new_tokens = 0, + **kwargs, + ): + num_heads = self.model.config.num_attention_heads + batch_size, seq_len = input_ids.shape + max_seq_len = seq_len + max_new_tokens + generation_config, _ = self._prepare_generation_config(generation_config, **kwargs) + unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) + + # inference_params object is a cache, where keys and values of previous tokens are stored + inference_params = te.pytorch.InferenceParams( + max_batch_size=batch_size, + max_sequence_length=seq_len+max_new_tokens+1) + + # mask has shape [batch_size, num_heads, 1, max_seq_len] and contains False + # when coressponding token is padding and True otherwise. + pad_attention_mask = input_ids.ne(generation_config.pad_token_id) + mask = torch.ones((batch_size, num_heads, 1, max_seq_len), dtype=torch.bool).cuda() + mask[..., :seq_len] = mask[..., :seq_len] & pad_attention_mask.unsqueeze(1).unsqueeze(2).expand(-1, num_heads, -1, -1) + + hidden_states = self.model.embed_tokens(input_ids) + output_tokens = [] + for i in range(max_new_tokens): + normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) + hidden_states = hidden_states * normalizer + for decoder_layer in self.model.layers: + hidden_states = decoder_layer( + hidden_states, + # In the case of arbiutrary mask, the meaning of True and False is switched, so negation is needed. + attention_mask=pad_attention_mask if i == 0 else ~mask[..., :seq_len], + self_attn_mask_type="padding_causal" if i == 0 else "arbitrary", + inference_params=inference_params + )[0] + + # inference_params.sequence_len_offset should contain position of the current token in the sequence. + inference_params.sequence_len_offset += hidden_states.shape[1] + + hidden_states = self.model.norm(hidden_states) + logits = self.lm_head(hidden_states) + logits = logits.float() + logits = logits[:, -1, :] + next_tokens = torch.argmax(logits, dim=-1) + + # Sequences, which are finished should contain padding - taken from huggingface transformers. + next_tokens = next_tokens * unfinished_sequences + generation_config.pad_token_id * (1 - unfinished_sequences) + output_tokens.append(next_tokens) + + unfinished_sequences = unfinished_sequences & ~(next_tokens == generation_config.eos_token_id) + + hidden_states = self.model.embed_tokens(next_tokens).unsqueeze(1) + seq_len += 1 + + result = torch.cat((input_ids, torch.stack(output_tokens).permute([1, 0])), dim=1) + return result def replace_params(hf_state_dict, te_state_dict): diff --git a/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb new file mode 100755 index 0000000000..9fb353b8ea --- /dev/null +++ b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb @@ -0,0 +1,372 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "2cac9d39", + "metadata": {}, + "source": [ + "# Accelerating a Hugging Face Gemma model generation with Transformer Engine\n", + "\n", + "
\n", + "\n", + "Goal\n", + "\n", + "This tutorial showcases how to accelerate generation done by a full Gemma model from [Hugging Face](https://huggingface.co/google/gemma-7b-it) by using `TransformerLayer` from the [Transformer Engine library](https://github.com/NVIDIA/TransformerEngine) in `BF16` precision.\n", + "\n", + "
\n" + ] + }, + { + "cell_type": "markdown", + "id": "401f7fb1", + "metadata": {}, + "source": [ + "## Dependencies for this tutorial\n", + "\n", + "Following files and media are necessary to effectively run this tutorial:\n", + "\n", + "1. `te_gemma.py`\n", + " - This file contains the code to load a Hugging Face Gemma checkpoint in Transformer Engine's `TransformerLayer` instead of Hugging Face's `LlamaDecoderLayer`. Also it contains the logic of the generation using TransformerEngine. \n", + "2. `utils.py`\n", + " - This file contains the code related to dataloading, hyperparameters, setting up model/optimizers/accelerator, model training and other miscellaneous tasks like restarting the jupyter notebook from within the cell. \n", + "3. `media/`\n", + " - This directory contains the images used in the following tutorial." + ] + }, + { + "cell_type": "markdown", + "id": "b564503c", + "metadata": {}, + "source": [ + "## Baseline HuggingFace Gemma generation" + ] + }, + { + "cell_type": "markdown", + "id": "24a8d0a5", + "metadata": {}, + "source": [ + "
\n", + "\n", + "Note\n", + " \n", + "This tutorial loads and trains a Llama 2 7B model which takes up most of the GPU memory and therefore, we need to restart the jupyter notebook each time before running the following sections. A small utility method `restart_jupyter_notebook` is defined in the accompanying `utils.py` file. This function restarts the jupyter notebook so that the GPU memory is flushed before the model is loaded again from the checkpoint in order to avoid running into OOM (Out Of Memory) errors.\n", + "\n", + "If the utility doesn't work, comment this line `restart_jupyter_notebook()` in the following cell and manually restart the jupyter notebook before running the cell. Repeat the same for other sections in this tutorial.\n", + "\n", + "
\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "e36ff380", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 2.60it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Generation time: 26.482454538345337 seconds\n", + "I like the new look of the app. I like the new features. I like the new look of \n", + "==============================\n", + "I do not like the way the new version of the app is set up. I do not like the fa\n" + ] + } + ], + "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "#restart_jupyter_notebook()\n", + "\n", + "\n", + "# Import necessary packages and methods\n", + "from utils import *\n", + "import torch\n", + "\n", + "\n", + "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", + "## !!! `model_name` attr must point to the location of the model weights !!!\n", + "## Weights can be downloaded from: https://llama.meta.com/llama-downloads/\n", + "hyperparams.model_name = \"../../../../gemma-weights\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", + "hyperparams.mixed_precision = \"no\"\n", + "\n", + "\n", + "# Init the model and accelerator wrapper\n", + "model = init_baseline_model(hyperparams).cuda()\n", + "model = model.to(torch.bfloat16)\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name)\n", + "inputs = tokenizer([\"I like\", \"I do not like\"] * 32, return_tensors=\"pt\", padding=True)\n", + "\n", + "inputs['input_ids'] = inputs['input_ids'].cuda()\n", + "inputs['attention_mask'] = inputs['attention_mask'].cuda()\n", + "\n", + "\n", + "start_time = time.time()\n", + "\n", + "outputs = model.generate(\n", + " **inputs,\n", + " max_new_tokens=400\n", + ")\n", + "\n", + "end_time = time.time()\n", + "duration = end_time - start_time\n", + "print(f\"Generation time: {duration} seconds\")\n", + "\n", + "\n", + "# Decode the output tensor to text\n", + "generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n", + "\n", + "# Display the first two samples of the generated text\n", + "print(generated_texts[0][:80])\n", + "print(30 * \"=\")\n", + "print(generated_texts[1][:80])" + ] + }, + { + "cell_type": "markdown", + "id": "a64f0f33", + "metadata": {}, + "source": [ + "Let's add this information in a table and keep comparing it with a few possible improvements in future sections:\n", + "\n", + "| Models | Precision | Generation time | Speedup (over baseline) |\n", + "|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n", + "| HF (baseline) | BF16 | 26.48 | 1 |" + ] + }, + { + "cell_type": "markdown", + "id": "e2fb88e9", + "metadata": {}, + "source": [ + "## [Improvement] Replace HF's `LlamaDecoderLayer` with TE's `TransformerLayer` and use generation within TE\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "6f7fefac", + "metadata": {}, + "source": [ + "```\n", + "@torch.no_grad()\n", + " def generate(\n", + " self,\n", + " input_ids: Optional[torch.Tensor] = None,\n", + " generation_config: Optional[GenerationConfig] = None,\n", + " max_new_tokens = 0,\n", + " **kwargs,\n", + " ):\n", + " num_heads = self.model.config.num_attention_heads\n", + " batch_size, seq_len = input_ids.shape\n", + " max_seq_len = seq_len + max_new_tokens\n", + " generation_config, _ = self._prepare_generation_config(generation_config, **kwargs)\n", + " unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)\n", + "\n", + " # inference_params object is a cache, where keys and values of previous tokens are stored\n", + " inference_params = te.pytorch.InferenceParams(\n", + " max_batch_size=batch_size, \n", + " max_sequence_length=seq_len+max_new_tokens+1) \n", + "\n", + " # mask has shape [batch_size, num_heads, 1, max_seq_len] and contains False \n", + " # when coressponding token is padding and True otherwise.\n", + " pad_attention_mask = input_ids.ne(generation_config.pad_token_id)\n", + " mask = torch.ones((batch_size, num_heads, 1, max_seq_len), dtype=torch.bool).cuda()\n", + " mask[..., :seq_len] = mask[..., :seq_len] & pad_attention_mask.unsqueeze(1).unsqueeze(2).expand(-1, num_heads, -1, -1)\n", + "\n", + " hidden_states = self.model.embed_tokens(input_ids)\n", + " output_tokens = []\n", + " for i in range(max_new_tokens):\n", + " normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)\n", + " hidden_states = hidden_states * normalizer\n", + " for decoder_layer in self.model.layers:\n", + " hidden_states = decoder_layer(\n", + " hidden_states,\n", + " # In the case of arbiutrary mask, the meaning of True and False is switched, so negation is needed.\n", + " attention_mask=pad_attention_mask if i == 0 else ~mask[..., :seq_len],\n", + " self_attn_mask_type=\"padding_causal\" if i == 0 else \"arbitrary\",\n", + " inference_params=inference_params\n", + " )[0]\n", + "\n", + " # inference_params.sequence_len_offset should contain position of the current token in the sequence.\n", + " inference_params.sequence_len_offset += hidden_states.shape[1]\n", + "\n", + " hidden_states = self.model.norm(hidden_states)\n", + " logits = self.lm_head(hidden_states)\n", + " logits = logits.float()\n", + " logits = logits[:, -1, :]\n", + " next_tokens = torch.argmax(logits, dim=-1)\n", + "\n", + " # Sequences, which are finished should contain padding - taken from huggingface transformers.\n", + " next_tokens = next_tokens * unfinished_sequences + generation_config.pad_token_id * (1 - unfinished_sequences)\n", + " output_tokens.append(next_tokens)\n", + "\n", + " unfinished_sequences = unfinished_sequences & ~(next_tokens == generation_config.eos_token_id)\n", + "\n", + " hidden_states = self.model.embed_tokens(next_tokens).unsqueeze(1)\n", + " seq_len += 1\n", + "\n", + " result = torch.cat((input_ids, torch.stack(output_tokens).permute([1, 0])), dim=1)\n", + " return result\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "8f2b752e", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Generation time: 16.87099289894104 seconds\n", + "I like the idea of a \"re-do\" of the original \"The Man from U.N.C.L.E.\" movie. I \n", + "==============================\n", + "I do not like the way the \"new\" (2011) version of the 1099-MISC is set up. I ha\n" + ] + } + ], + "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "#restart_jupyter_notebook()\n", + "\n", + "\n", + "# Import necessary packages and methods\n", + "from utils import *\n", + "import accelerate\n", + "\n", + "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", + "## !!! `model_name` attr must point to the location of the model weights !!!\n", + "## Weights can be downloaded from: https://llama.meta.com/llama-downloads/\n", + "hyperparams.model_name = \"../../../../gemma-weights\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", + "\n", + "# Init the model and accelerator wrapper\n", + "model = init_te_gemma_model(hyperparams)\n", + "#accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)\n", + "\n", + "model = model.to(torch.bfloat16).cuda()\n", + "\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name)\n", + "inputs = tokenizer([\"I like\", \"I do not like\"] * 32, return_tensors=\"pt\", padding=True)\n", + "\n", + "inputs['input_ids'] = inputs['input_ids'].cuda()\n", + "inputs['attention_mask'] = inputs['attention_mask'].cuda()\n", + "\n", + "import time\n", + "\n", + "start_time = time.time()\n", + "\n", + "outputs = model.generate(\n", + " **inputs,\n", + " max_new_tokens=400\n", + ")\n", + "\n", + "end_time = time.time()\n", + "duration = end_time - start_time\n", + "print(f\"Generation time: {duration} seconds\")\n", + "\n", + "\n", + "# Decode the output tensor to text\n", + "generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n", + "\n", + "# Display the first two samples of the generated text\n", + "print(generated_texts[0][:80])\n", + "print(30 * \"=\")\n", + "print(generated_texts[1][:80])" + ] + }, + { + "cell_type": "markdown", + "id": "67ec126c", + "metadata": {}, + "source": [ + "| Models | Precision | Generation time | Speedup (over baseline) |\n", + "|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n", + "| HF (baseline) | BF16 | 26.48 | 1 |\n", + "| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | BF16 | 16.87 | 1.56 |\n", + "\n", + "\n", + "\n", + "After converting to TE Transformer Layers, we obtained the speedup of **56%**!" + ] + }, + { + "cell_type": "markdown", + "id": "41b80b0f", + "metadata": {}, + "source": [ + "## Conclusion\n", + "\n", + "Using `TransformerLayer` module from Transformer Engine as a substitute for Hugging Face's `LlamaDecoderLayer` provides a speedup over Hugging Face's native Gemma generation implementation. `TransformerLayer` provides a speedup over the baseline implementation" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/examples/te_gemma/utils.py b/docs/examples/te_gemma/utils.py index d29b094821..35bd0421d9 100755 --- a/docs/examples/te_gemma/utils.py +++ b/docs/examples/te_gemma/utils.py @@ -83,7 +83,6 @@ def init_baseline_model(hyperparams): torch_dtype=torch.bfloat16, ) # Needed for the cases when using TEGemmaForCausalLM. So adding here for 1:1 comparison - model.config.use_cache=False return model @@ -98,7 +97,6 @@ def init_te_gemma_model(hyperparams): torch_dtype=torch.bfloat16, ) # Needed for the cases when using TEGemmaForCausalLM - model.config.use_cache=False return model @@ -117,7 +115,7 @@ def wrap_with_accelerator(model, hyperparams): train_dataloader = get_dataloaders(accelerator, hyperparams) # Wrap model, optimizer/scheduler, dataloaders in accelerate - optimizer = AdamW(params = model.parameters(), lr=hyperparams.learning_rate, fused=True) + optimizer = AdamW(params = model.parameters(), lr=hyperparams.learning_rate, fused=False) lr_scheduler = get_linear_schedule_with_warmup( optimizer=optimizer, num_warmup_steps=100, From 219bb077be1fe983cd9a0e7042bab32d7a6fcf19 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 12 Apr 2024 20:43:51 +0000 Subject: [PATCH 059/244] Fp8 generation and evaluation Signed-off-by: root Signed-off-by: Pawel Gadzinski --- docs/examples/te_gemma/eval_bf16.py | 42 ++++++++++++++ docs/examples/te_gemma/eval_fp8.py | 64 ++++++++++++++++++++++ docs/examples/te_gemma/generate_convert.py | 59 ++++++++++++++++++++ docs/examples/te_gemma/generate_fp8.py | 54 ++++++++++++++++++ 4 files changed, 219 insertions(+) create mode 100644 docs/examples/te_gemma/eval_bf16.py create mode 100644 docs/examples/te_gemma/eval_fp8.py create mode 100644 docs/examples/te_gemma/generate_convert.py create mode 100755 docs/examples/te_gemma/generate_fp8.py diff --git a/docs/examples/te_gemma/eval_bf16.py b/docs/examples/te_gemma/eval_bf16.py new file mode 100644 index 0000000000..bfeeb8fa45 --- /dev/null +++ b/docs/examples/te_gemma/eval_bf16.py @@ -0,0 +1,42 @@ +from utils import * +import torch +from tqdm import tqdm # For progress bar + +# Default hyperparams, also defined in `utils.py` in class `Hyperparameters` +## !!! `model_name` attr must point to the location of the model weights !!! +## Weights can be downloaded from: https://llama.meta.com/llama-downloads/ +hyperparams.model_name = "../../../../gemma-weights" # <== Add model weight location here e.g. "/path/to/downloaded/llama/weights" +hyperparams.fuse_qkv_params = True + +# Init the model and accelerator wrapper +model = init_te_gemma_model(hyperparams, fp8_model_init=False).cuda() + +dataset = load_dataset(hyperparams.dataset_name, split="train") +tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name) +accelerator = Accelerator( + log_with="wandb", + gradient_accumulation_steps=hyperparams.gradient_accumulation_steps, + mixed_precision=hyperparams.mixed_precision, + kwargs_handlers=[FP8RecipeKwargs(backend="te")] + ) +train_dataloader = enumerate(get_dataloaders(accelerator, hyperparams)) + +model.eval() # Set the model to evaluation mode +total_correct = 0 +total_samples = 0 + +with torch.no_grad(): # No need to compute gradients during evaluation + for _, batch in tqdm(train_dataloader, desc="Evaluating"): + input_ids = batch["input_ids"].cuda() + + labels = input_ids[:, 1:].contiguous() + input_ids = input_ids[:, :-1].contiguous() + outputs = model(input_ids=input_ids, labels=labels, use_cache=False) + + predictions = torch.argmax(outputs.logits, dim=-1) + + total_correct += (predictions == labels).sum().item() + total_samples += labels.numel() + +accuracy = total_correct / total_samples +print(f"Accuraccy = {accuracy}") \ No newline at end of file diff --git a/docs/examples/te_gemma/eval_fp8.py b/docs/examples/te_gemma/eval_fp8.py new file mode 100644 index 0000000000..99948c2be9 --- /dev/null +++ b/docs/examples/te_gemma/eval_fp8.py @@ -0,0 +1,64 @@ +from utils import * +import torch +from tqdm import tqdm # For progress bar +import transformer_engine.pytorch as te + + +# Import necessary packages and methods +from utils import * +import accelerate + +from transformer_engine.pytorch import fp8_model_init +from transformer_engine.common.recipe import Format, DelayedScaling + +# Default hyperparams, also defined in `utils.py` in class `Hyperparameters` +## !!! `model_name` attr must point to the location of the model weights !!! +## Weights can be downloaded from: https://llama.meta.com/llama-downloads/ + +hyperparams.model_name = "../../../../gemma-weights" +hyperparams.fuse_qkv_params = True +model = init_te_gemma_model(hyperparams, fp8_model_init=True).cuda() + + +print("Loading model") +model_state_dict = torch.load('model_fp8_state_dict.pth') +model.load_state_dict(model_state_dict) +print("Model loaded") + + +dataset = load_dataset(hyperparams.dataset_name, split="train") +tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name) + +accelerator = Accelerator( + log_with="wandb", + gradient_accumulation_steps=hyperparams.gradient_accumulation_steps, + mixed_precision=hyperparams.mixed_precision, + kwargs_handlers=[FP8RecipeKwargs(backend="te")] + ) +train_dataloader = enumerate(get_dataloaders(accelerator, hyperparams)) + + +model.eval() # Set the model to evaluation mode +total_correct = 0 +total_samples = 0 + +fp8_format = Format.HYBRID +fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max") +with torch.no_grad(): # No need to compute gradients during evaluation + with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): + for _, batch in tqdm(train_dataloader, desc="Evaluating"): + input_ids = batch["input_ids"].cuda() + + labels = input_ids[:, 1:].contiguous() + input_ids = input_ids[:, :-1].contiguous() + outputs = model(input_ids=input_ids, labels=labels, use_cache=False) + + predictions = torch.argmax(outputs.logits, dim=-1) + + total_correct += (predictions == labels).sum().item() + total_samples += labels.numel() + +accuracy = total_correct / total_samples +print(f"Accuraccy = {accuracy}") + + diff --git a/docs/examples/te_gemma/generate_convert.py b/docs/examples/te_gemma/generate_convert.py new file mode 100644 index 0000000000..66338c64a0 --- /dev/null +++ b/docs/examples/te_gemma/generate_convert.py @@ -0,0 +1,59 @@ +# Import necessary packages and methods +import transformer_engine.pytorch as te +from utils import * +import accelerate +from transformer_engine.pytorch import fp8_model_init +from transformer_engine.common.recipe import Format, DelayedScaling +import torch + + +hyperparams.model_name = "../../../../gemma-weights" +hyperparams.fuse_qkv_params = True +model = init_te_gemma_model(hyperparams, fp8_model_init=False).cuda() +model = model.to(torch.bfloat16) + + +accelerator = Accelerator( + log_with="wandb", + gradient_accumulation_steps=hyperparams.gradient_accumulation_steps, + mixed_precision=hyperparams.mixed_precision, + kwargs_handlers=[FP8RecipeKwargs(backend="te")] + ) +train_dataloader = get_dataloaders(accelerator, hyperparams) + +tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name) + +print("Calibration started") +with te.fp8_autocast(enabled=False, calibrating=True): + model.train() + train_dataloader = enumerate(train_dataloader) + + for i in range(100): + step, batch = next(train_dataloader) + batch["input_ids"] = batch["input_ids"].cuda() + outputs = model.generate( + **batch, + max_new_tokens=1 + ) +print("calibration_finished") + +print("scale_fwd computation started") +with te.fp8_autocast(enabled=True): + for i in range(10): + step, batch = next(train_dataloader) + batch["input_ids"] = batch["input_ids"].cuda() + outputs = model.generate( + **batch, + max_new_tokens=1 + ) +print("scale_fwd_computation ended") + +print("Casting weights...") +model_fp8 = init_te_gemma_model(hyperparams, fp8_model_init=True).cuda() +model_fp8.load_state_dict(model.state_dict()) +print("Weights casted") + + +print("Saving model...") +torch.save(model_fp8.state_dict(), 'model_fp8_state_dict.pth') +print("Model saved!") \ No newline at end of file diff --git a/docs/examples/te_gemma/generate_fp8.py b/docs/examples/te_gemma/generate_fp8.py new file mode 100755 index 0000000000..4a6bc1853e --- /dev/null +++ b/docs/examples/te_gemma/generate_fp8.py @@ -0,0 +1,54 @@ +# Restart the notebook (to flush the GPU memory) +from utils import restart_jupyter_notebook +#restart_jupyter_notebook() +import transformer_engine.pytorch as te + + +# Import necessary packages and methods +from utils import * + +from transformer_engine.pytorch import fp8_model_init +from transformer_engine.common.recipe import Format, DelayedScaling + +hyperparams.model_name = "../../../../gemma-weights" +hyperparams.fuse_qkv_params = True +model = init_te_gemma_model(hyperparams, fp8_model_init=True).cuda() + +print("Loading model") +model_state_dict = torch.load('model_fp8_state_dict.pth') +model.load_state_dict(model_state_dict) +print("Model loaded") + +tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name) +inputs = tokenizer(["I love when", "I love when"] * 32, return_tensors="pt", padding=True) + +inputs['input_ids'] = inputs['input_ids'].cuda() +inputs['attention_mask'] = inputs['attention_mask'].cuda() + +import time + + + +start_time = time.time() + +fp8_format = Format.HYBRID +fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") +torch.manual_seed(1234) +with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): + with torch.no_grad(): + model.eval() + outputs = model.generate( + **inputs, + max_new_tokens=40 + ) + + +end_time = time.time() +duration = end_time - start_time + +generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True) +for text in generated_texts[:2]: + print("-" * 50) + print(text) + +print(f"Duration = {duration}") From 8a5ba9b3b8c4552a3ef8c21fbacbe8e1574b63e3 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 12 Apr 2024 20:44:51 +0000 Subject: [PATCH 060/244] Fp8 generation and evaluation Signed-off-by: Pawel Gadzinski --- docs/examples/te_gemma/te_gemma.py | 77 +++++++++++++++++++++--------- 1 file changed, 55 insertions(+), 22 deletions(-) diff --git a/docs/examples/te_gemma/te_gemma.py b/docs/examples/te_gemma/te_gemma.py index 113534828c..27c079338d 100755 --- a/docs/examples/te_gemma/te_gemma.py +++ b/docs/examples/te_gemma/te_gemma.py @@ -58,7 +58,7 @@ def __init__(self, config, layer_idx, *args, **kwargs): layernorm_epsilon=config.rms_norm_eps, hidden_dropout=0, attention_dropout=0, - fuse_qkv_params=False, + fuse_qkv_params=config.fuse_qkv_params, normalization="RMSNorm", activation="geglu", attn_input_format="bshd", @@ -73,7 +73,7 @@ def forward(self, hidden_states, *args, attention_mask, - inference_params, + inference_params=None, self_attn_mask_type='causal', **kwargs): """ @@ -104,13 +104,14 @@ def __new__(cls, config: GemmaConfig): return gemma_for_causal_lm @classmethod - def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, **kwargs): + def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, fp8_init=False, **kwargs): """ Custom method adapted from `from_pretrained` method in HuggingFace Transformers repo: https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579 """ - vanilla_model = cls(config) + with fp8_model_init(fp8_init): + vanilla_model = cls(config) is_local = os.path.isdir(pretrained_model_name_or_path) subfolder = "" variant = None @@ -135,13 +136,15 @@ def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, **k for shard_file in resolved_archive_file: state_dict = load_state_dict(shard_file) - replace_params(state_dict, vanilla_model.state_dict()) + replace_params(state_dict, vanilla_model.state_dict(), config, fp8_init=config.fuse_qkv_params) _load_state_dict_into_model(vanilla_model, state_dict, start_prefix="") # Force mem release. Taken from huggingface code del state_dict gc.collect() + + return vanilla_model @torch.no_grad() @@ -152,7 +155,6 @@ def generate( max_new_tokens = 0, **kwargs, ): - num_heads = self.model.config.num_attention_heads batch_size, seq_len = input_ids.shape max_seq_len = seq_len + max_new_tokens generation_config, _ = self._prepare_generation_config(generation_config, **kwargs) @@ -165,9 +167,10 @@ def generate( # mask has shape [batch_size, num_heads, 1, max_seq_len] and contains False # when coressponding token is padding and True otherwise. - pad_attention_mask = input_ids.ne(generation_config.pad_token_id) - mask = torch.ones((batch_size, num_heads, 1, max_seq_len), dtype=torch.bool).cuda() - mask[..., :seq_len] = mask[..., :seq_len] & pad_attention_mask.unsqueeze(1).unsqueeze(2).expand(-1, num_heads, -1, -1) + pad_attention_mask = input_ids.ne(generation_config.pad_token_id).unsqueeze(1).unsqueeze(2) + mask = torch.ones((batch_size, 1, 1, max_seq_len), dtype=torch.bool).cuda() + mask[..., :seq_len] = mask[..., :seq_len] & pad_attention_mask.expand(-1, 1, -1, -1) + hidden_states = self.model.embed_tokens(input_ids) output_tokens = [] @@ -179,10 +182,10 @@ def generate( hidden_states, # In the case of arbiutrary mask, the meaning of True and False is switched, so negation is needed. attention_mask=pad_attention_mask if i == 0 else ~mask[..., :seq_len], - self_attn_mask_type="padding_causal" if i == 0 else "arbitrary", + self_attn_mask_type="causal" if i == 0 else "arbitrary", inference_params=inference_params )[0] - + # inference_params.sequence_len_offset should contain position of the current token in the sequence. inference_params.sequence_len_offset += hidden_states.shape[1] @@ -205,7 +208,7 @@ def generate( return result -def replace_params(hf_state_dict, te_state_dict): +def replace_params(hf_state_dict, te_state_dict, config, fp8_init=False): # collect all layer prefixes to update all_layer_prefixes = set() for param_key in hf_state_dict.keys(): @@ -215,35 +218,65 @@ def replace_params(hf_state_dict, te_state_dict): all_layer_prefixes.add(m.group()) GATE_PROJ_SIZE=24576 - + + for layer_prefix in all_layer_prefixes: # When loading weights into models with less number of layers, skip the # copy if the corresponding layer doesn't exist in HF model if layer_prefix + 'input_layernorm.weight' in hf_state_dict: te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.layer_norm_weight'].copy_(1 + hf_state_dict[layer_prefix + 'input_layernorm.weight']) + + if fp8_init: + dst = te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.weight'] + + if layer_prefix + 'self_attn.q_proj.weight' in hf_state_dict: + q = hf_state_dict[layer_prefix + 'self_attn.q_proj.weight'] + for head_nr in range(config.num_attention_heads): + dst_offset = head_nr * config.head_dim * 3 + # copy query + dst[dst_offset:(dst_offset + config.head_dim), :] = \ + q[(head_nr * config.head_dim):(head_nr * config.head_dim + config.head_dim), :] + + if layer_prefix + 'self_attn.k_proj.weight' in hf_state_dict: + k = hf_state_dict[layer_prefix + 'self_attn.k_proj.weight'] + for head_nr in range(config.num_attention_heads): + dst_offset = head_nr * config.head_dim * 3 + # copy query + dst[( dst_offset + config.head_dim):(dst_offset + 2 * config.head_dim), :] = \ + k[(head_nr * config.head_dim):(head_nr * config.head_dim + config.head_dim), :] - if layer_prefix + 'self_attn.q_proj.weight' in hf_state_dict: - te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.query_weight'].copy_(hf_state_dict[layer_prefix + 'self_attn.q_proj.weight']) + if layer_prefix + 'self_attn.v_proj.weight' in hf_state_dict: + v = hf_state_dict[layer_prefix + 'self_attn.v_proj.weight'] + for head_nr in range(config.num_attention_heads): + dst_offset = head_nr * config.head_dim * 3 + dst[(dst_offset + 2 * config.head_dim):(dst_offset + 3 * config.head_dim), :] = \ + v[(head_nr * config.head_dim):(head_nr * config.head_dim + config.head_dim), :] + else: + + if layer_prefix + 'self_attn.q_proj.weight' in hf_state_dict: + te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.query_weight'].copy_(hf_state_dict[layer_prefix + 'self_attn.q_proj.weight']) - if layer_prefix + 'self_attn.k_proj.weight' in hf_state_dict: - te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.key_weight'].copy_(hf_state_dict[layer_prefix + 'self_attn.k_proj.weight']) + if layer_prefix + 'self_attn.k_proj.weight' in hf_state_dict: + te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.key_weight'].copy_(hf_state_dict[layer_prefix + 'self_attn.k_proj.weight']) - if layer_prefix + 'self_attn.v_proj.weight' in hf_state_dict: - te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.value_weight'].copy_(hf_state_dict[layer_prefix + 'self_attn.v_proj.weight']) + if layer_prefix + 'self_attn.v_proj.weight' in hf_state_dict: + te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.value_weight'].copy_(hf_state_dict[layer_prefix + 'self_attn.v_proj.weight']) if layer_prefix + 'self_attn.o_proj.weight' in hf_state_dict: te_state_dict[layer_prefix + 'self_attention.proj.weight'].copy_(hf_state_dict[layer_prefix + 'self_attn.o_proj.weight']) if layer_prefix + 'post_attention_layernorm.weight' in hf_state_dict: - te_state_dict[layer_prefix + 'layernorm_mlp.layer_norm_weight'].data[:] = hf_state_dict[layer_prefix + 'post_attention_layernorm.weight'].data[:] + 1 + te_state_dict[layer_prefix + 'layernorm_mlp.layer_norm_weight'].copy_(1 + hf_state_dict[layer_prefix + 'post_attention_layernorm.weight']) if layer_prefix + 'mlp.gate_proj.weight' in hf_state_dict: - te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[:GATE_PROJ_SIZE] = hf_state_dict[layer_prefix + 'mlp.gate_proj.weight'].data[:] + te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[:GATE_PROJ_SIZE].copy_(hf_state_dict[layer_prefix + 'mlp.gate_proj.weight']) if layer_prefix + 'mlp.up_proj.weight' in hf_state_dict: - te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[GATE_PROJ_SIZE:] = hf_state_dict[layer_prefix + 'mlp.up_proj.weight'].data[:] + te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[GATE_PROJ_SIZE:].copy_(hf_state_dict[layer_prefix + 'mlp.up_proj.weight']) if layer_prefix + 'mlp.down_proj.weight' in hf_state_dict: te_state_dict[layer_prefix + 'layernorm_mlp.fc2_weight'].copy_(hf_state_dict[layer_prefix + 'mlp.down_proj.weight'].data[:]) + + return all_layer_prefixes \ No newline at end of file From 39de0e89a1e7700d24841fdd12ef48c5df38d7aa Mon Sep 17 00:00:00 2001 From: root Date: Wed, 17 Apr 2024 19:45:03 +0000 Subject: [PATCH 061/244] changes Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/attention.py | 63 ++++++++++++++----- .../pytorch/cpp_extensions/fused_attn.py | 18 ++---- .../pytorch/cpp_extensions/normalization.py | 0 .../pytorch/csrc/comm_gemm_overlap.h | 0 transformer_engine/pytorch/csrc/extensions.h | 0 .../pytorch/csrc/extensions/attention.cu | 0 .../pytorch/csrc/extensions/normalization.cu | 0 .../pytorch/csrc/userbuffers/CMakeLists.txt | 0 .../csrc/userbuffers/userbuffers-host.cpp | 0 .../pytorch/csrc/userbuffers/userbuffers.cu | 0 .../pytorch/csrc/userbuffers/userbuffers.h | 0 transformer_engine/pytorch/distributed.py | 0 transformer_engine/pytorch/float8_tensor.py | 0 transformer_engine/pytorch/fp8.py | 0 transformer_engine/pytorch/module/_common.py | 0 transformer_engine/pytorch/module/base.py | 0 .../pytorch/module/layernorm.py | 0 .../pytorch/module/layernorm_linear.py | 0 .../pytorch/module/layernorm_mlp.py | 0 transformer_engine/pytorch/module/linear.py | 0 transformer_engine/pytorch/module/rmsnorm.py | 0 transformer_engine/pytorch/utils.py | 0 22 files changed, 53 insertions(+), 28 deletions(-) mode change 100644 => 100755 transformer_engine/pytorch/cpp_extensions/fused_attn.py mode change 100644 => 100755 transformer_engine/pytorch/cpp_extensions/normalization.py mode change 100644 => 100755 transformer_engine/pytorch/csrc/comm_gemm_overlap.h mode change 100644 => 100755 transformer_engine/pytorch/csrc/extensions.h mode change 100644 => 100755 transformer_engine/pytorch/csrc/extensions/attention.cu mode change 100644 => 100755 transformer_engine/pytorch/csrc/extensions/normalization.cu mode change 100644 => 100755 transformer_engine/pytorch/csrc/userbuffers/CMakeLists.txt mode change 100644 => 100755 transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp mode change 100644 => 100755 transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu mode change 100644 => 100755 transformer_engine/pytorch/csrc/userbuffers/userbuffers.h mode change 100644 => 100755 transformer_engine/pytorch/distributed.py mode change 100644 => 100755 transformer_engine/pytorch/float8_tensor.py mode change 100644 => 100755 transformer_engine/pytorch/fp8.py mode change 100644 => 100755 transformer_engine/pytorch/module/_common.py mode change 100644 => 100755 transformer_engine/pytorch/module/base.py mode change 100644 => 100755 transformer_engine/pytorch/module/layernorm.py mode change 100644 => 100755 transformer_engine/pytorch/module/layernorm_linear.py mode change 100644 => 100755 transformer_engine/pytorch/module/layernorm_mlp.py mode change 100644 => 100755 transformer_engine/pytorch/module/linear.py mode change 100644 => 100755 transformer_engine/pytorch/module/rmsnorm.py mode change 100644 => 100755 transformer_engine/pytorch/utils.py diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index a1f97a4fe0..31c32a9f93 100755 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -16,6 +16,7 @@ import torch import torch.nn.functional as F +from torch.utils.cpp_extension import load import transformer_engine_extensions as tex from transformer_engine.pytorch.cpp_extensions import ( @@ -102,6 +103,13 @@ __all__ = ["DotProductAttention", "InferenceParams", "MultiheadAttention"] +cuda = load( + name='attention_copy', + sources=['attention_copy.cu'], + verbose=True +) + + class InferenceParams: # pylint: disable=too-few-public-methods """ Inference parameters that are passed to the main model in order @@ -121,6 +129,7 @@ def __init__(self, max_batch_size, max_sequence_length): self.sequence_len_offset = 0 self.batch_size_offset = 0 self.key_value_memory_dict = {} + self.thd = False def swap_key_value_dict(self, batch_indices): """ @@ -3229,6 +3238,7 @@ def __init__( self.cp_group = cp_group self.cp_global_ranks = cp_global_ranks self.cp_stream = cp_stream + self.channels = channels self.hidden_size_per_attention_head = channels // num_attention_heads @@ -3486,6 +3496,7 @@ def forward( produced) """ + assert ( query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda ), 'DotProductAttention only supports CUDA tensors.' @@ -3529,21 +3540,44 @@ def forward( (inference_key_memory, inference_value_memory, ) = inference_params.key_value_memory_dict[self.layer_number] - batch_start = inference_params.batch_size_offset - batch_end = batch_start + key_layer.size(1) - assert batch_end <= inference_key_memory.size(1) - sequence_start = inference_params.sequence_len_offset - sequence_end = sequence_start + key_layer.size(0) - assert sequence_end <= inference_key_memory.size(0) + if not inference_params.thd: + batch_start = inference_params.batch_size_offset + batch_end = batch_start + key_layer.size(1) + assert batch_end <= inference_key_memory.size(1) + + sequence_start = inference_params.sequence_len_offset + sequence_end = sequence_start + key_layer.size(0) + assert sequence_end <= inference_key_memory.size(0) + + # Copy keys and values into KV-cache + inference_key_memory[ + sequence_start:sequence_end, batch_start:batch_end, ...] = key_layer + inference_value_memory[ + sequence_start:sequence_end, batch_start:batch_end, ...] = value_layer + key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...] + value_layer = inference_value_memory[:sequence_end, batch_start:batch_end, ...] + else: + cuda.attention_copy(inference_key_memory, inference_params.seq_len + 1, key_layer, inference_params.max_batch_size, self.channels) + cuda.attention_copy(inference_value_memory, inference_params.seq_len + 1, value_layer, inference_params.max_batch_size, self.channels) + + q = query_layer.view(-1, query_layer.shape[2], query_layer.shape[3]) + k = inference_key_memory.view(-1, inference_key_memory.shape[2], inference_key_memory.shape[3]) + v = inference_value_memory.view(-1, inference_value_memory.shape[2], inference_value_memory.shape[3]) + + q, k, v = q.to(torch.bfloat16), k.to(torch.bfloat16), v.to(torch.bfloat16), + + out, _, _ = fused_attn_fwd( + False, 1, key_layer.shape[1], inference_params.seq_len, inference_params.seq_len, + q, k, v, + TE_DType[q.dtype], FusedAttnBackend["F16_max512_seqlen"], + qkv_layout="t3hd", attn_bias_type=core_attention_bias_type, + attn_bias=core_attention_bias, fast_zero_fill=fast_zero_fill + ) + print("xd") + exit() + return out - # Copy keys and values into KV-cache - inference_key_memory[ - sequence_start:sequence_end, batch_start:batch_end, ...] = key_layer - inference_value_memory[ - sequence_start:sequence_end, batch_start:batch_end, ...] = value_layer - key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...] - value_layer = inference_value_memory[:sequence_end, batch_start:batch_end, ...] if qkv_format == "bshd": key_layer = key_layer.transpose(0, 1) @@ -4359,6 +4393,8 @@ def forward( """ # hidden_states: [sq, b, h] + + if attn_mask_type is not None: window_size = check_set_window_size(attn_mask_type, window_size) if attn_mask_type is None: @@ -4420,7 +4456,6 @@ def forward( is_first_microbatch=is_first_microbatch, is_first_module_in_mha=True, # specific to FP8 MHA ) - num_queries_per_key_value = (self.num_attention_heads_per_partition // self.num_gqa_groups_per_partition) if self.qkv_weight_interleaved: diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py old mode 100644 new mode 100755 index 574627ac5d..74030ba809 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -835,20 +835,7 @@ def fused_attn_fwd( if fused_attention_backend == FusedAttnBackend["FP8"]: rng_elts_per_thread = (max_seqlen_q * max_seqlen_q + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1)//BACKEND_F16m512_FP8_THREADS_PER_CTA - - assert (d_scale_qkv is not None - ), "d_scale_qkv is required as an input for FP8 fused attention." - assert (d_scale_s is not None - ), "q_scale_s is required as an input for FP8 fused attention." - assert (q_scale_s is not None - ), "q_scale_s is required as an input for FP8 fused attention." - assert (q_scale_o is not None - ), "q_scale_o is required as an input for FP8 fused attention." - assert (amax_s is not None - ), "amax_s is required as an input for FP8 fused attention." - assert (amax_o is not None - ), "amax_o is required as an input for FP8 fused attention." - + # execute kernel output_tensors = tex.fused_attn_fwd( max_seqlen_q, max_seqlen_kv, is_training, attn_scale, dropout, fast_zero_fill, @@ -994,6 +981,9 @@ def fused_attn_bwd( ), "aux_ctx_tensors must contain rng_state as its last element." if fused_attention_backend == FusedAttnBackend["FP8"]: + print("rr") + print(d_scale_qkv) + exit() assert (d_scale_qkv is not None), "d_scale_qkv is required for FP8 fused attention." assert (d_scale_s is not None), "d_scale_s is required for FP8 fused attention." assert (d_scale_o is not None), "d_scale_o is required for FP8 fused attention." diff --git a/transformer_engine/pytorch/cpp_extensions/normalization.py b/transformer_engine/pytorch/cpp_extensions/normalization.py old mode 100644 new mode 100755 diff --git a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h old mode 100644 new mode 100755 diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h old mode 100644 new mode 100755 diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu old mode 100644 new mode 100755 diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cu b/transformer_engine/pytorch/csrc/extensions/normalization.cu old mode 100644 new mode 100755 diff --git a/transformer_engine/pytorch/csrc/userbuffers/CMakeLists.txt b/transformer_engine/pytorch/csrc/userbuffers/CMakeLists.txt old mode 100644 new mode 100755 diff --git a/transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp b/transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp old mode 100644 new mode 100755 diff --git a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu b/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu old mode 100644 new mode 100755 diff --git a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.h b/transformer_engine/pytorch/csrc/userbuffers/userbuffers.h old mode 100644 new mode 100755 diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py old mode 100644 new mode 100755 diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py old mode 100644 new mode 100755 diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py old mode 100644 new mode 100755 diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py old mode 100644 new mode 100755 diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py old mode 100644 new mode 100755 diff --git a/transformer_engine/pytorch/module/layernorm.py b/transformer_engine/pytorch/module/layernorm.py old mode 100644 new mode 100755 diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py old mode 100644 new mode 100755 diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py old mode 100644 new mode 100755 diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py old mode 100644 new mode 100755 diff --git a/transformer_engine/pytorch/module/rmsnorm.py b/transformer_engine/pytorch/module/rmsnorm.py old mode 100644 new mode 100755 diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py old mode 100644 new mode 100755 From 1d3105cdbdbfcb33033d207c17a2012f8c73c5f6 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 21 Mar 2024 22:42:42 +0000 Subject: [PATCH 062/244] Fixed Llama tutorial. Changed batch size and added fused=True. Signed-off-by: Pawel Gadzinski Signed-off-by: root Signed-off-by: Pawel Gadzinski --- docs/examples/te_llama/te_llama.py | 23 +-- ...tutorial_accelerate_hf_llama_with_te.ipynb | 155 +++++++++++++++--- docs/examples/te_llama/utils.py | 4 +- 3 files changed, 143 insertions(+), 39 deletions(-) diff --git a/docs/examples/te_llama/te_llama.py b/docs/examples/te_llama/te_llama.py index aa23b638f0..d6dbac4ebd 100755 --- a/docs/examples/te_llama/te_llama.py +++ b/docs/examples/te_llama/te_llama.py @@ -56,7 +56,8 @@ def __init__(self, config, *args, **kwargs): normalization="RMSNorm", activation="swiglu", attn_input_format="bshd", - num_gqa_groups=config.num_key_value_heads + num_gqa_groups=config.num_key_value_heads, + kv_channels=16 ) te_rope = RotaryPositionEmbedding(config.hidden_size//config.num_attention_heads) self.te_rope_emb = te_rope(max_seq_len=config.max_position_embeddings).cuda() @@ -123,10 +124,8 @@ def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, **k for shard_file in resolved_archive_file: state_dict = load_state_dict(shard_file) - # replace_params copies parameters relevant only to TransformerEngine - replace_params(state_dict, vanilla_model.state_dict(), config) - # _load_state_dict_into_model copies parameters other than those in TransformerEngine - _load_state_dict_into_model(vanilla_model, state_dict, start_prefix="") + replaces_params = replace_params(state_dict, vanilla_model.state_dict()) + #_load_state_dict_into_model(vanilla_model, state_dict, start_prefix="") # Force mem release. Taken from huggingface code del state_dict @@ -143,8 +142,6 @@ def replace_params(hf_state_dict, te_state_dict, config): if m is not None: all_layer_prefixes.add(m.group()) - - for layer_prefix in all_layer_prefixes: # When loading weights into models with less number of layers, skip the # copy if the corresponding layer doesn't exist in HF model @@ -165,16 +162,8 @@ def replace_params(hf_state_dict, te_state_dict, config): if layer_prefix + 'post_attention_layernorm.weight' in hf_state_dict: te_state_dict[layer_prefix + 'layernorm_mlp.layer_norm_weight'].data[:] = hf_state_dict[layer_prefix + 'post_attention_layernorm.weight'].data[:] - - # It may happen that gate_proj.weight and up_proj.weight will be in the different files, so we need to - # load them separately. - if layer_prefix + 'mlp.gate_proj.weight' in hf_state_dict: - te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[:config.intermediate_size] = \ - hf_state_dict[layer_prefix + 'mlp.gate_proj.weight'].data - - if layer_prefix + 'mlp.up_proj.weight' in hf_state_dict: - te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[config.intermediate_size:] = \ - hf_state_dict[layer_prefix + 'mlp.up_proj.weight'].data + if layer_prefix + 'mlp.gate_proj.weight' in hf_state_dict and 'mlp.up_proj.weight' in hf_state_dict: + te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[:] = torch.cat((hf_state_dict[layer_prefix + 'mlp.gate_proj.weight'].data[:], hf_state_dict[layer_prefix + 'mlp.up_proj.weight'].data[:]), dim=0) if layer_prefix + 'mlp.down_proj.weight' in hf_state_dict: te_state_dict[layer_prefix + 'layernorm_mlp.fc2_weight'].data[:] = hf_state_dict[layer_prefix + 'mlp.down_proj.weight'].data[:] diff --git a/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb b/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb index cc77b484f9..59bd322729 100755 --- a/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb +++ b/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb @@ -206,23 +206,31 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "id": "e36ff380", "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:02<00:00, 1.25it/s]\n", + "Repo card metadata block was not found. Setting CardData to empty.\n" + ] + }, { "name": "stdout", "output_type": "stream", "text": [ "10 finetuning steps complete!\n", - "Average time taken per step: 315 milliseconds\n" + "Average time taken per step: 764 milliseconds\n" ] } ], "source": [ "# Restart the notebook (to flush the GPU memory)\n", "from utils import restart_jupyter_notebook\n", - "restart_jupyter_notebook()\n", + "#restart_jupyter_notebook()\n", "\n", "\n", "# Import necessary packages and methods\n", @@ -231,14 +239,13 @@ "\n", "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", "## !!! `model_name` attr must point to the location of the model weights !!!\n", - "## Weights can be downloaded from: https://llama.meta.com/llama-downloads/ and then coverted to the HuggingFace format.\n", - "## Instructions for conversion are available on the website https://ai.meta.com/blog/5-steps-to-getting-started-with-llama-2/ - steps 1 and 2.\n", - "hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", + "## Weights can be downloaded from: https://llama.meta.com/llama-downloads/\n", + "hyperparams.model_name = \"../../../../llama-hf-weights\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", "hyperparams.mixed_precision = \"bf16\"\n", "\n", "\n", "# Init the model and accelerator wrapper\n", - "model = init_baseline_model(hyperparams)\n", + "model = init_baseline_model(hyperparams).cuda()\n", "accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)\n", "\n", "\n", @@ -536,19 +543,65 @@ "id": "4974b738", "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.10/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in LlamaForCausalLM is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained(\"openai/whisper-tiny\", attn_implementation=\"flash_attention_2\", torch_dtype=torch.float16)`\n", + "Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in LlamaModel is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained(\"openai/whisper-tiny\", attn_implementation=\"flash_attention_2\", torch_dtype=torch.float16)`\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "/usr/local/lib/python3.10/dist-packages/accelerate/accelerator.py:387: UserWarning: `log_with=wandb` was passed but no supported trackers are currently installed.\n", + " warnings.warn(f\"`log_with={log_with}` was passed but no supported trackers are currently installed.\")\n", + "Repo card metadata block was not found. Setting CardData to empty.\n", + "You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers\n" + ] + }, { "name": "stdout", "output_type": "stream", "text": [ "10 finetuning steps complete!\n", - "Average time taken per step: 252 milliseconds\n" + "Average time taken per step: 678 milliseconds\n" ] } ], "source": [ "# Restart the notebook (to flush the GPU memory)\n", "from utils import restart_jupyter_notebook\n", - "restart_jupyter_notebook()\n", + "#restart_jupyter_notebook()\n", "\n", "\n", "# Import necessary packages and methods\n", @@ -557,9 +610,8 @@ "\n", "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", "## !!! `model_name` attr must point to the location of the model weights !!!\n", - "## Weights can be downloaded from: https://llama.meta.com/llama-downloads/ and then coverted to the HuggingFace format.\n", - "## Instructions for conversion are available on the website https://ai.meta.com/blog/5-steps-to-getting-started-with-llama-2/ - steps 1 and 2.\n", - "hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", + "## Weights can be downloaded from: https://llama.meta.com/llama-downloads/\n", + "hyperparams.model_name = \"../../../../llama-hf-weights\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", "hyperparams.mixed_precision = \"bf16\"\n", "\n", "\n", @@ -617,18 +669,82 @@ "metadata": {}, "outputs": [ { - "name": "stdout", + "name": "stderr", "output_type": "stream", "text": [ - "10 finetuning steps complete!\n", - "Average time taken per step: 226 milliseconds\n" + "/usr/local/lib/python3.10/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in LlamaForCausalLM is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained(\"openai/whisper-tiny\", attn_implementation=\"flash_attention_2\", torch_dtype=torch.float16)`\n", + "Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in LlamaModel is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained(\"openai/whisper-tiny\", attn_implementation=\"flash_attention_2\", torch_dtype=torch.float16)`\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "/usr/local/lib/python3.10/dist-packages/accelerate/accelerator.py:387: UserWarning: `log_with=wandb` was passed but no supported trackers are currently installed.\n", + " warnings.warn(f\"`log_with={log_with}` was passed but no supported trackers are currently installed.\")\n", + "Repo card metadata block was not found. Setting CardData to empty.\n", + "You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers\n" + ] + }, + { + "ename": "RuntimeError", + "evalue": "shape '[16, 256, 3, 32, 16]' is invalid for input of size 50331648", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[1], line 23\u001b[0m\n\u001b[1;32m 19\u001b[0m accelerator, model, optimizer, train_dataloader, lr_scheduler \u001b[38;5;241m=\u001b[39m wrap_with_accelerator(model, hyperparams)\n\u001b[1;32m 22\u001b[0m \u001b[38;5;66;03m# Finetune the model\u001b[39;00m\n\u001b[0;32m---> 23\u001b[0m \u001b[43mfinetune_model\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhyperparams\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maccelerator\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_dataloader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlr_scheduler\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/perfhome/tutorial/TransformerEngine/docs/examples/te_llama/utils.py:142\u001b[0m, in \u001b[0;36mfinetune_model\u001b[0;34m(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler)\u001b[0m\n\u001b[1;32m 140\u001b[0m step, batch \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mnext\u001b[39m(train_dataloader)\n\u001b[1;32m 141\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m accelerator\u001b[38;5;241m.\u001b[39maccumulate(model):\n\u001b[0;32m--> 142\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mbatch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 143\u001b[0m loss \u001b[38;5;241m=\u001b[39m outputs\u001b[38;5;241m.\u001b[39mloss\n\u001b[1;32m 144\u001b[0m total_loss \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m loss\u001b[38;5;241m.\u001b[39mdetach()\u001b[38;5;241m.\u001b[39mfloat()\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1518\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1519\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1523\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "File \u001b[0;32m/usr/lib/python3.10/contextlib.py:79\u001b[0m, in \u001b[0;36mContextDecorator.__call__..inner\u001b[0;34m(*args, **kwds)\u001b[0m\n\u001b[1;32m 76\u001b[0m \u001b[38;5;129m@wraps\u001b[39m(func)\n\u001b[1;32m 77\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21minner\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwds):\n\u001b[1;32m 78\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_recreate_cm():\n\u001b[0;32m---> 79\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwds\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1196\u001b[0m, in \u001b[0;36mLlamaForCausalLM.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)\u001b[0m\n\u001b[1;32m 1193\u001b[0m return_dict \u001b[38;5;241m=\u001b[39m return_dict \u001b[38;5;28;01mif\u001b[39;00m return_dict \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39muse_return_dict\n\u001b[1;32m 1195\u001b[0m \u001b[38;5;66;03m# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\u001b[39;00m\n\u001b[0;32m-> 1196\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1197\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1198\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1199\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1200\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1201\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1202\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1203\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1204\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1205\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1206\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_position\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_position\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1207\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1209\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 1210\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mpretraining_tp \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1518\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1519\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1523\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1016\u001b[0m, in \u001b[0;36mLlamaModel.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)\u001b[0m\n\u001b[1;32m 1005\u001b[0m layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_gradient_checkpointing_func(\n\u001b[1;32m 1006\u001b[0m decoder_layer\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__call__\u001b[39m,\n\u001b[1;32m 1007\u001b[0m hidden_states,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1013\u001b[0m cache_position,\n\u001b[1;32m 1014\u001b[0m )\n\u001b[1;32m 1015\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1016\u001b[0m layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[43mdecoder_layer\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1017\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1018\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcausal_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1019\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1020\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1021\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1022\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1023\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_position\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_position\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1024\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1026\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m layer_outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 1028\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m use_cache:\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1518\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1519\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1523\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "File \u001b[0;32m/perfhome/tutorial/TransformerEngine/docs/examples/te_llama/te_llama.py:75\u001b[0m, in \u001b[0;36mTELlamaDecoderLayer.forward\u001b[0;34m(self, hidden_states, attention_mask, *args, **kwargs)\u001b[0m\n\u001b[1;32m 65\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 66\u001b[0m hidden_states,\n\u001b[1;32m 67\u001b[0m \u001b[38;5;241m*\u001b[39margs,\n\u001b[1;32m 68\u001b[0m attention_mask,\n\u001b[1;32m 69\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 70\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 71\u001b[0m \u001b[38;5;124;03m Custom forward to make sure we only pass relevant arguments to the\u001b[39;00m\n\u001b[1;32m 72\u001b[0m \u001b[38;5;124;03m forward pass of the `TransformerLayer`. Also, make sure the output\u001b[39;00m\n\u001b[1;32m 73\u001b[0m \u001b[38;5;124;03m format matches the output of the HF's `LlamaDecoderLayer`.\u001b[39;00m\n\u001b[1;32m 74\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m---> 75\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m (\u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrotary_pos_emb\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mte_rope_emb\u001b[49m\u001b[43m)\u001b[49m,)\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/transformer.py:625\u001b[0m, in \u001b[0;36mTransformerLayer.forward\u001b[0;34m(self, hidden_states, attention_mask, self_attn_mask_type, window_size, encoder_output, enc_dec_attn_mask, is_first_microbatch, checkpoint_core_attention, inference_params, rotary_pos_emb, core_attention_bias_type, core_attention_bias, fast_zero_fill)\u001b[0m\n\u001b[1;32m 620\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m cast_if_needed(\n\u001b[1;32m 621\u001b[0m hidden_states, torch\u001b[38;5;241m.\u001b[39mget_autocast_gpu_dtype()\n\u001b[1;32m 622\u001b[0m )\n\u001b[1;32m 624\u001b[0m \u001b[38;5;66;03m# Self attention.\u001b[39;00m\n\u001b[0;32m--> 625\u001b[0m self_attention_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mself_attention\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 626\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 627\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 628\u001b[0m \u001b[43m \u001b[49m\u001b[43mattn_mask_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mself_attn_mask_type\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 629\u001b[0m \u001b[43m \u001b[49m\u001b[43mwindow_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mwindow_size\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 630\u001b[0m \u001b[43m \u001b[49m\u001b[43minference_params\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minference_params\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 631\u001b[0m \u001b[43m \u001b[49m\u001b[43mis_first_microbatch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mis_first_microbatch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 632\u001b[0m \u001b[43m \u001b[49m\u001b[43mcheckpoint_core_attention\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcheckpoint_core_attention\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 633\u001b[0m \u001b[43m \u001b[49m\u001b[43mrotary_pos_emb\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrotary_pos_emb\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 634\u001b[0m \u001b[43m \u001b[49m\u001b[43mcore_attention_bias_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcore_attention_bias_type\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 635\u001b[0m \u001b[43m \u001b[49m\u001b[43mcore_attention_bias\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcore_attention_bias\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 636\u001b[0m \u001b[43m \u001b[49m\u001b[43mfast_zero_fill\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfast_zero_fill\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 637\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 639\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mapply_residual_connection_post_layernorm \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput_layernorm:\n\u001b[1;32m 640\u001b[0m attention_output, attention_bias, residual \u001b[38;5;241m=\u001b[39m self_attention_outputs\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1518\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1519\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1523\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/attention.py:3333\u001b[0m, in \u001b[0;36mMultiheadAttention.forward\u001b[0;34m(self, hidden_states, attention_mask, encoder_output, attn_mask_type, window_size, is_first_microbatch, checkpoint_core_attention, inference_params, rotary_pos_emb, core_attention_bias_type, core_attention_bias, fast_zero_fill)\u001b[0m\n\u001b[1;32m 3330\u001b[0m \u001b[38;5;66;03m# split along third last dimension\u001b[39;00m\n\u001b[1;32m 3331\u001b[0m split_dim \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m3\u001b[39m\n\u001b[0;32m-> 3333\u001b[0m mixed_x_layer \u001b[38;5;241m=\u001b[39m \u001b[43mmixed_x_layer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mview\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mnew_tensor_shape\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3335\u001b[0m \u001b[38;5;66;03m# qkv_weight_interleaved:\u001b[39;00m\n\u001b[1;32m 3336\u001b[0m \u001b[38;5;66;03m# [sq, b, ng, (np/ng + 2), hn]\u001b[39;00m\n\u001b[1;32m 3337\u001b[0m \u001b[38;5;66;03m# --> [sq, b, ng, np/ng, hn], [sq, b, ng, 1, hn], [sq, b, ng, 1, hn]\u001b[39;00m\n\u001b[1;32m 3338\u001b[0m \u001b[38;5;66;03m# not qkv_weight_interleaved:\u001b[39;00m\n\u001b[1;32m 3339\u001b[0m \u001b[38;5;66;03m# [sq, b, (np/ng + 2), ng, hn]\u001b[39;00m\n\u001b[1;32m 3340\u001b[0m \u001b[38;5;66;03m# --> [sq, b, np/ng, np, hn], [sq, b, 1, ng, hn], [sq, b, 1, ng, hn]\u001b[39;00m\n\u001b[1;32m 3341\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_in_onnx_export_mode():\n", + "\u001b[0;31mRuntimeError\u001b[0m: shape '[16, 256, 3, 32, 16]' is invalid for input of size 50331648" ] } ], "source": [ "# Restart the notebook (to flush the GPU memory)\n", "from utils import restart_jupyter_notebook\n", - "restart_jupyter_notebook()\n", + "#restart_jupyter_notebook()\n", "\n", "\n", "# Import necessary packages and methods\n", @@ -637,14 +753,13 @@ "\n", "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", "## !!! `model_name` attr must point to the location of the model weights !!!\n", - "## Weights can be downloaded from: https://llama.meta.com/llama-downloads/ and then coverted to the HuggingFace format.\n", - "## Instructions for conversion are available on the website https://ai.meta.com/blog/5-steps-to-getting-started-with-llama-2/ - steps 1 and 2.\n", - "hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", + "## Weights can be downloaded from: https://llama.meta.com/llama-downloads/\n", + "hyperparams.model_name = \"../../../../llama-hf-weights\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", "hyperparams.mixed_precision = \"fp8\"\n", "\n", "\n", "# Init the model and accelerator wrapper\n", - "model = init_te_llama_model(hyperparams)\n", + "model = init_te_llama_model(hyperparams).cuda()\n", "accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)\n", "\n", "\n", diff --git a/docs/examples/te_llama/utils.py b/docs/examples/te_llama/utils.py index 9c36e5bd17..4782813c62 100755 --- a/docs/examples/te_llama/utils.py +++ b/docs/examples/te_llama/utils.py @@ -23,7 +23,7 @@ def __init__(self): self.dataset_name = "timdettmers/openassistant-guanaco" self.dataset_text_field = "text" self.learning_rate = 1.41e-5 - self.batch_size = 8 + self.batch_size = 16 self.max_seq_length = 256 self.gradient_accumulation_steps = 1 self.num_warmup_steps=5 @@ -117,7 +117,7 @@ def wrap_with_accelerator(model, hyperparams): train_dataloader = get_dataloaders(accelerator, hyperparams) # Wrap model, optimizer/scheduler, dataloaders in accelerate - optimizer = AdamW(params = model.parameters(), lr=hyperparams.learning_rate) + optimizer = AdamW(params = model.parameters(), lr=hyperparams.learning_rate, fused=True) lr_scheduler = get_linear_schedule_with_warmup( optimizer=optimizer, num_warmup_steps=100, From 70aa1f3b95eb13ae29e6b88854cae2057410429f Mon Sep 17 00:00:00 2001 From: root Date: Fri, 22 Mar 2024 16:38:02 +0000 Subject: [PATCH 063/244] Tutorial updated but not complete yet. Signed-off-by: Pawel Gadzinski Signed-off-by: root Signed-off-by: Pawel Gadzinski --- docs/examples/te_llama/te_llama.py | 3 +- ...tutorial_accelerate_hf_llama_with_te.ipynb | 162 +++++++++++++++--- docs/examples/te_llama/utils.py | 4 +- 3 files changed, 139 insertions(+), 30 deletions(-) diff --git a/docs/examples/te_llama/te_llama.py b/docs/examples/te_llama/te_llama.py index d6dbac4ebd..1a1d0ca8a6 100755 --- a/docs/examples/te_llama/te_llama.py +++ b/docs/examples/te_llama/te_llama.py @@ -56,8 +56,7 @@ def __init__(self, config, *args, **kwargs): normalization="RMSNorm", activation="swiglu", attn_input_format="bshd", - num_gqa_groups=config.num_key_value_heads, - kv_channels=16 + num_gqa_groups=config.num_key_value_heads ) te_rope = RotaryPositionEmbedding(config.hidden_size//config.num_attention_heads) self.te_rope_emb = te_rope(max_seq_len=config.max_position_embeddings).cuda() diff --git a/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb b/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb index 59bd322729..dc06cab43a 100755 --- a/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb +++ b/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb @@ -240,7 +240,7 @@ "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", "## !!! `model_name` attr must point to the location of the model weights !!!\n", "## Weights can be downloaded from: https://llama.meta.com/llama-downloads/\n", - "hyperparams.model_name = \"../../../../llama-hf-weights\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", + "hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", "hyperparams.mixed_precision = \"bf16\"\n", "\n", "\n", @@ -611,7 +611,7 @@ "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", "## !!! `model_name` attr must point to the location of the model weights !!!\n", "## Weights can be downloaded from: https://llama.meta.com/llama-downloads/\n", - "hyperparams.model_name = \"../../../../llama-hf-weights\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", + "hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", "hyperparams.mixed_precision = \"bf16\"\n", "\n", "\n", @@ -715,29 +715,11 @@ ] }, { - "ename": "RuntimeError", - "evalue": "shape '[16, 256, 3, 32, 16]' is invalid for input of size 50331648", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[1], line 23\u001b[0m\n\u001b[1;32m 19\u001b[0m accelerator, model, optimizer, train_dataloader, lr_scheduler \u001b[38;5;241m=\u001b[39m wrap_with_accelerator(model, hyperparams)\n\u001b[1;32m 22\u001b[0m \u001b[38;5;66;03m# Finetune the model\u001b[39;00m\n\u001b[0;32m---> 23\u001b[0m \u001b[43mfinetune_model\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhyperparams\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maccelerator\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_dataloader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlr_scheduler\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m/perfhome/tutorial/TransformerEngine/docs/examples/te_llama/utils.py:142\u001b[0m, in \u001b[0;36mfinetune_model\u001b[0;34m(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler)\u001b[0m\n\u001b[1;32m 140\u001b[0m step, batch \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mnext\u001b[39m(train_dataloader)\n\u001b[1;32m 141\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m accelerator\u001b[38;5;241m.\u001b[39maccumulate(model):\n\u001b[0;32m--> 142\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mbatch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 143\u001b[0m loss \u001b[38;5;241m=\u001b[39m outputs\u001b[38;5;241m.\u001b[39mloss\n\u001b[1;32m 144\u001b[0m total_loss \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m loss\u001b[38;5;241m.\u001b[39mdetach()\u001b[38;5;241m.\u001b[39mfloat()\n", - "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1518\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1519\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1523\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", - "File \u001b[0;32m/usr/lib/python3.10/contextlib.py:79\u001b[0m, in \u001b[0;36mContextDecorator.__call__..inner\u001b[0;34m(*args, **kwds)\u001b[0m\n\u001b[1;32m 76\u001b[0m \u001b[38;5;129m@wraps\u001b[39m(func)\n\u001b[1;32m 77\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21minner\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwds):\n\u001b[1;32m 78\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_recreate_cm():\n\u001b[0;32m---> 79\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwds\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1196\u001b[0m, in \u001b[0;36mLlamaForCausalLM.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)\u001b[0m\n\u001b[1;32m 1193\u001b[0m return_dict \u001b[38;5;241m=\u001b[39m return_dict \u001b[38;5;28;01mif\u001b[39;00m return_dict \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39muse_return_dict\n\u001b[1;32m 1195\u001b[0m \u001b[38;5;66;03m# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\u001b[39;00m\n\u001b[0;32m-> 1196\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1197\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1198\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1199\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1200\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1201\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1202\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1203\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1204\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1205\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1206\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_position\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_position\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1207\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1209\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 1210\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mpretraining_tp \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n", - "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1518\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1519\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1523\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", - "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py:1016\u001b[0m, in \u001b[0;36mLlamaModel.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)\u001b[0m\n\u001b[1;32m 1005\u001b[0m layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_gradient_checkpointing_func(\n\u001b[1;32m 1006\u001b[0m decoder_layer\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__call__\u001b[39m,\n\u001b[1;32m 1007\u001b[0m hidden_states,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1013\u001b[0m cache_position,\n\u001b[1;32m 1014\u001b[0m )\n\u001b[1;32m 1015\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1016\u001b[0m layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[43mdecoder_layer\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1017\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1018\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcausal_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1019\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1020\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1021\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1022\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1023\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_position\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_position\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1024\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1026\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m layer_outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 1028\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m use_cache:\n", - "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1518\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1519\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1523\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", - "File \u001b[0;32m/perfhome/tutorial/TransformerEngine/docs/examples/te_llama/te_llama.py:75\u001b[0m, in \u001b[0;36mTELlamaDecoderLayer.forward\u001b[0;34m(self, hidden_states, attention_mask, *args, **kwargs)\u001b[0m\n\u001b[1;32m 65\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 66\u001b[0m hidden_states,\n\u001b[1;32m 67\u001b[0m \u001b[38;5;241m*\u001b[39margs,\n\u001b[1;32m 68\u001b[0m attention_mask,\n\u001b[1;32m 69\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 70\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 71\u001b[0m \u001b[38;5;124;03m Custom forward to make sure we only pass relevant arguments to the\u001b[39;00m\n\u001b[1;32m 72\u001b[0m \u001b[38;5;124;03m forward pass of the `TransformerLayer`. Also, make sure the output\u001b[39;00m\n\u001b[1;32m 73\u001b[0m \u001b[38;5;124;03m format matches the output of the HF's `LlamaDecoderLayer`.\u001b[39;00m\n\u001b[1;32m 74\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m---> 75\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m (\u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrotary_pos_emb\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mte_rope_emb\u001b[49m\u001b[43m)\u001b[49m,)\n", - "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/transformer.py:625\u001b[0m, in \u001b[0;36mTransformerLayer.forward\u001b[0;34m(self, hidden_states, attention_mask, self_attn_mask_type, window_size, encoder_output, enc_dec_attn_mask, is_first_microbatch, checkpoint_core_attention, inference_params, rotary_pos_emb, core_attention_bias_type, core_attention_bias, fast_zero_fill)\u001b[0m\n\u001b[1;32m 620\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m cast_if_needed(\n\u001b[1;32m 621\u001b[0m hidden_states, torch\u001b[38;5;241m.\u001b[39mget_autocast_gpu_dtype()\n\u001b[1;32m 622\u001b[0m )\n\u001b[1;32m 624\u001b[0m \u001b[38;5;66;03m# Self attention.\u001b[39;00m\n\u001b[0;32m--> 625\u001b[0m self_attention_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mself_attention\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 626\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 627\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 628\u001b[0m \u001b[43m \u001b[49m\u001b[43mattn_mask_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mself_attn_mask_type\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 629\u001b[0m \u001b[43m \u001b[49m\u001b[43mwindow_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mwindow_size\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 630\u001b[0m \u001b[43m \u001b[49m\u001b[43minference_params\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minference_params\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 631\u001b[0m \u001b[43m \u001b[49m\u001b[43mis_first_microbatch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mis_first_microbatch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 632\u001b[0m \u001b[43m \u001b[49m\u001b[43mcheckpoint_core_attention\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcheckpoint_core_attention\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 633\u001b[0m \u001b[43m \u001b[49m\u001b[43mrotary_pos_emb\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrotary_pos_emb\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 634\u001b[0m \u001b[43m \u001b[49m\u001b[43mcore_attention_bias_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcore_attention_bias_type\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 635\u001b[0m \u001b[43m \u001b[49m\u001b[43mcore_attention_bias\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcore_attention_bias\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 636\u001b[0m \u001b[43m \u001b[49m\u001b[43mfast_zero_fill\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfast_zero_fill\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 637\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 639\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mapply_residual_connection_post_layernorm \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput_layernorm:\n\u001b[1;32m 640\u001b[0m attention_output, attention_bias, residual \u001b[38;5;241m=\u001b[39m self_attention_outputs\n", - "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1518\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1519\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1523\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", - "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/attention.py:3333\u001b[0m, in \u001b[0;36mMultiheadAttention.forward\u001b[0;34m(self, hidden_states, attention_mask, encoder_output, attn_mask_type, window_size, is_first_microbatch, checkpoint_core_attention, inference_params, rotary_pos_emb, core_attention_bias_type, core_attention_bias, fast_zero_fill)\u001b[0m\n\u001b[1;32m 3330\u001b[0m \u001b[38;5;66;03m# split along third last dimension\u001b[39;00m\n\u001b[1;32m 3331\u001b[0m split_dim \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m3\u001b[39m\n\u001b[0;32m-> 3333\u001b[0m mixed_x_layer \u001b[38;5;241m=\u001b[39m \u001b[43mmixed_x_layer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mview\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mnew_tensor_shape\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3335\u001b[0m \u001b[38;5;66;03m# qkv_weight_interleaved:\u001b[39;00m\n\u001b[1;32m 3336\u001b[0m \u001b[38;5;66;03m# [sq, b, ng, (np/ng + 2), hn]\u001b[39;00m\n\u001b[1;32m 3337\u001b[0m \u001b[38;5;66;03m# --> [sq, b, ng, np/ng, hn], [sq, b, ng, 1, hn], [sq, b, ng, 1, hn]\u001b[39;00m\n\u001b[1;32m 3338\u001b[0m \u001b[38;5;66;03m# not qkv_weight_interleaved:\u001b[39;00m\n\u001b[1;32m 3339\u001b[0m \u001b[38;5;66;03m# [sq, b, (np/ng + 2), ng, hn]\u001b[39;00m\n\u001b[1;32m 3340\u001b[0m \u001b[38;5;66;03m# --> [sq, b, np/ng, np, hn], [sq, b, 1, ng, hn], [sq, b, 1, ng, hn]\u001b[39;00m\n\u001b[1;32m 3341\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_in_onnx_export_mode():\n", - "\u001b[0;31mRuntimeError\u001b[0m: shape '[16, 256, 3, 32, 16]' is invalid for input of size 50331648" + "name": "stdout", + "output_type": "stream", + "text": [ + "10 finetuning steps complete!\n", + "Average time taken per step: 498 milliseconds\n" ] } ], @@ -754,7 +736,7 @@ "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", "## !!! `model_name` attr must point to the location of the model weights !!!\n", "## Weights can be downloaded from: https://llama.meta.com/llama-downloads/\n", - "hyperparams.model_name = \"../../../../llama-hf-weights\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", + "hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", "hyperparams.mixed_precision = \"fp8\"\n", "\n", "\n", @@ -782,6 +764,134 @@ "After turning on FP8 precision, we get even more speedup of almost **40%**!" ] }, + { + "cell_type": "markdown", + "id": "4933825e", + "metadata": {}, + "source": [ + "# [Improvement 3] Using AdamW with fused=True. " + ] + }, + { + "cell_type": "markdown", + "id": "cad9e4a7", + "metadata": {}, + "source": [ + "We can obtain even bigger speedup, when running the optimizer in the speedup mode. The change in code is simple - we change the line:\n", + "\n", + "```\n", + "optimizer = AdamW(params = model.parameters(), lr=hyperparams.learning_rate)\n", + "```\n", + "into \n", + "\n", + "```\n", + "optimizer = AdamW(params = model.parameters(), lr=hyperparams.learning_rate, fused=True)\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "7f5d3f79", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.10/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in LlamaForCausalLM is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained(\"openai/whisper-tiny\", attn_implementation=\"flash_attention_2\", torch_dtype=torch.float16)`\n", + "Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in LlamaModel is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained(\"openai/whisper-tiny\", attn_implementation=\"flash_attention_2\", torch_dtype=torch.float16)`\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "/usr/local/lib/python3.10/dist-packages/accelerate/accelerator.py:387: UserWarning: `log_with=wandb` was passed but no supported trackers are currently installed.\n", + " warnings.warn(f\"`log_with={log_with}` was passed but no supported trackers are currently installed.\")\n", + "Repo card metadata block was not found. Setting CardData to empty.\n", + "You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "10 finetuning steps complete!\n", + "Average time taken per step: 487 milliseconds\n" + ] + } + ], + "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "#restart_jupyter_notebook()\n", + "\n", + "\n", + "# Import necessary packages and methods\n", + "from utils import *\n", + "\n", + "\n", + "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", + "## !!! `model_name` attr must point to the location of the model weights !!!\n", + "## Weights can be downloaded from: https://llama.meta.com/llama-downloads/\n", + "hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", + "hyperparams.mixed_precision = \"fp8\"\n", + "\n", + "\n", + "# Init the model and accelerator wrapper\n", + "model = init_te_llama_model(hyperparams).cuda()\n", + "accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams, fused_optizer=True)\n", + "\n", + "\n", + "# Finetune the model\n", + "finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler)" + ] + }, + { + "cell_type": "markdown", + "id": "73ed7b79", + "metadata": {}, + "source": [ + "| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n", + "|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n", + "| HF (baseline) | BF16 | 315 | 1 |\n", + "| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | BF16 | 252 | 1.25 |\n", + "| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | FP8 | 226 | 1.39 |\n", + "| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`, use fused optimizer) | FP8 | ? | 1.49 |\n", + "\n", + "\n", + "Using option fused=True in the optimizer resulred in **1.49** speedup!" + ] + }, { "cell_type": "markdown", "id": "41b80b0f", diff --git a/docs/examples/te_llama/utils.py b/docs/examples/te_llama/utils.py index 4782813c62..a43e6fa079 100755 --- a/docs/examples/te_llama/utils.py +++ b/docs/examples/te_llama/utils.py @@ -102,7 +102,7 @@ def init_te_llama_model(hyperparams): return model -def wrap_with_accelerator(model, hyperparams): +def wrap_with_accelerator(model, hyperparams, fused_optizer=False): # Create FP8 kwarg handler if required fp8_kwarg_handler = [FP8RecipeKwargs(backend="te")] if hyperparams.mixed_precision == "fp8" else None @@ -117,7 +117,7 @@ def wrap_with_accelerator(model, hyperparams): train_dataloader = get_dataloaders(accelerator, hyperparams) # Wrap model, optimizer/scheduler, dataloaders in accelerate - optimizer = AdamW(params = model.parameters(), lr=hyperparams.learning_rate, fused=True) + optimizer = AdamW(params = model.parameters(), lr=hyperparams.learning_rate, fused=fused_optizer) lr_scheduler = get_linear_schedule_with_warmup( optimizer=optimizer, num_warmup_steps=100, From b52a73378b369ab60093ae69330ea3b800beb269 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 22 Mar 2024 21:38:45 +0000 Subject: [PATCH 064/244] Tutorial notebook reseted - removed fuse=true Signed-off-by: Pawel Gadzinski Signed-off-by: root Signed-off-by: Pawel Gadzinski --- ...tutorial_accelerate_hf_llama_with_te.ipynb | 248 +----------------- 1 file changed, 10 insertions(+), 238 deletions(-) mode change 100755 => 100644 docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb diff --git a/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb b/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb old mode 100755 new mode 100644 index dc06cab43a..178922c9d2 --- a/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb +++ b/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb @@ -206,31 +206,23 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "id": "e36ff380", "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:02<00:00, 1.25it/s]\n", - "Repo card metadata block was not found. Setting CardData to empty.\n" - ] - }, { "name": "stdout", "output_type": "stream", "text": [ "10 finetuning steps complete!\n", - "Average time taken per step: 764 milliseconds\n" + "Average time taken per step: 315 milliseconds\n" ] } ], "source": [ "# Restart the notebook (to flush the GPU memory)\n", "from utils import restart_jupyter_notebook\n", - "#restart_jupyter_notebook()\n", + "restart_jupyter_notebook()\n", "\n", "\n", "# Import necessary packages and methods\n", @@ -245,7 +237,7 @@ "\n", "\n", "# Init the model and accelerator wrapper\n", - "model = init_baseline_model(hyperparams).cuda()\n", + "model = init_baseline_model(hyperparams)\n", "accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)\n", "\n", "\n", @@ -543,65 +535,19 @@ "id": "4974b738", "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.10/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n", - "Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in LlamaForCausalLM is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained(\"openai/whisper-tiny\", attn_implementation=\"flash_attention_2\", torch_dtype=torch.float16)`\n", - "Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in LlamaModel is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained(\"openai/whisper-tiny\", attn_implementation=\"flash_attention_2\", torch_dtype=torch.float16)`\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "/usr/local/lib/python3.10/dist-packages/accelerate/accelerator.py:387: UserWarning: `log_with=wandb` was passed but no supported trackers are currently installed.\n", - " warnings.warn(f\"`log_with={log_with}` was passed but no supported trackers are currently installed.\")\n", - "Repo card metadata block was not found. Setting CardData to empty.\n", - "You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers\n" - ] - }, { "name": "stdout", "output_type": "stream", "text": [ "10 finetuning steps complete!\n", - "Average time taken per step: 678 milliseconds\n" + "Average time taken per step: 252 milliseconds\n" ] } ], "source": [ "# Restart the notebook (to flush the GPU memory)\n", "from utils import restart_jupyter_notebook\n", - "#restart_jupyter_notebook()\n", + "restart_jupyter_notebook()\n", "\n", "\n", "# Import necessary packages and methods\n", @@ -668,65 +614,19 @@ "id": "8f2b752e", "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.10/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n", - "Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in LlamaForCausalLM is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained(\"openai/whisper-tiny\", attn_implementation=\"flash_attention_2\", torch_dtype=torch.float16)`\n", - "Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in LlamaModel is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained(\"openai/whisper-tiny\", attn_implementation=\"flash_attention_2\", torch_dtype=torch.float16)`\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "/usr/local/lib/python3.10/dist-packages/accelerate/accelerator.py:387: UserWarning: `log_with=wandb` was passed but no supported trackers are currently installed.\n", - " warnings.warn(f\"`log_with={log_with}` was passed but no supported trackers are currently installed.\")\n", - "Repo card metadata block was not found. Setting CardData to empty.\n", - "You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers\n" - ] - }, { "name": "stdout", "output_type": "stream", "text": [ "10 finetuning steps complete!\n", - "Average time taken per step: 498 milliseconds\n" + "Average time taken per step: 226 milliseconds\n" ] } ], "source": [ "# Restart the notebook (to flush the GPU memory)\n", "from utils import restart_jupyter_notebook\n", - "#restart_jupyter_notebook()\n", + "restart_jupyter_notebook()\n", "\n", "\n", "# Import necessary packages and methods\n", @@ -736,12 +636,12 @@ "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", "## !!! `model_name` attr must point to the location of the model weights !!!\n", "## Weights can be downloaded from: https://llama.meta.com/llama-downloads/\n", - "hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", + "hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", "hyperparams.mixed_precision = \"fp8\"\n", "\n", "\n", "# Init the model and accelerator wrapper\n", - "model = init_te_llama_model(hyperparams).cuda()\n", + "model = init_te_llama_model(hyperparams)\n", "accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)\n", "\n", "\n", @@ -764,134 +664,6 @@ "After turning on FP8 precision, we get even more speedup of almost **40%**!" ] }, - { - "cell_type": "markdown", - "id": "4933825e", - "metadata": {}, - "source": [ - "# [Improvement 3] Using AdamW with fused=True. " - ] - }, - { - "cell_type": "markdown", - "id": "cad9e4a7", - "metadata": {}, - "source": [ - "We can obtain even bigger speedup, when running the optimizer in the speedup mode. The change in code is simple - we change the line:\n", - "\n", - "```\n", - "optimizer = AdamW(params = model.parameters(), lr=hyperparams.learning_rate)\n", - "```\n", - "into \n", - "\n", - "```\n", - "optimizer = AdamW(params = model.parameters(), lr=hyperparams.learning_rate, fused=True)\n", - "```" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "7f5d3f79", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.10/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n", - "Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in LlamaForCausalLM is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained(\"openai/whisper-tiny\", attn_implementation=\"flash_attention_2\", torch_dtype=torch.float16)`\n", - "Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in LlamaModel is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained(\"openai/whisper-tiny\", attn_implementation=\"flash_attention_2\", torch_dtype=torch.float16)`\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "/usr/local/lib/python3.10/dist-packages/accelerate/accelerator.py:387: UserWarning: `log_with=wandb` was passed but no supported trackers are currently installed.\n", - " warnings.warn(f\"`log_with={log_with}` was passed but no supported trackers are currently installed.\")\n", - "Repo card metadata block was not found. Setting CardData to empty.\n", - "You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "10 finetuning steps complete!\n", - "Average time taken per step: 487 milliseconds\n" - ] - } - ], - "source": [ - "# Restart the notebook (to flush the GPU memory)\n", - "from utils import restart_jupyter_notebook\n", - "#restart_jupyter_notebook()\n", - "\n", - "\n", - "# Import necessary packages and methods\n", - "from utils import *\n", - "\n", - "\n", - "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", - "## !!! `model_name` attr must point to the location of the model weights !!!\n", - "## Weights can be downloaded from: https://llama.meta.com/llama-downloads/\n", - "hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", - "hyperparams.mixed_precision = \"fp8\"\n", - "\n", - "\n", - "# Init the model and accelerator wrapper\n", - "model = init_te_llama_model(hyperparams).cuda()\n", - "accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams, fused_optizer=True)\n", - "\n", - "\n", - "# Finetune the model\n", - "finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler)" - ] - }, - { - "cell_type": "markdown", - "id": "73ed7b79", - "metadata": {}, - "source": [ - "| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n", - "|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n", - "| HF (baseline) | BF16 | 315 | 1 |\n", - "| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | BF16 | 252 | 1.25 |\n", - "| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | FP8 | 226 | 1.39 |\n", - "| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`, use fused optimizer) | FP8 | ? | 1.49 |\n", - "\n", - "\n", - "Using option fused=True in the optimizer resulred in **1.49** speedup!" - ] - }, { "cell_type": "markdown", "id": "41b80b0f", From bd6aa42e6417d9d13736982e0d42f5a3d0b194a9 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 22 Mar 2024 21:40:17 +0000 Subject: [PATCH 065/244] Removed fused=true Signed-off-by: Pawel Gadzinski Signed-off-by: root Signed-off-by: Pawel Gadzinski --- docs/examples/te_llama/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/examples/te_llama/utils.py b/docs/examples/te_llama/utils.py index a43e6fa079..2b80415113 100755 --- a/docs/examples/te_llama/utils.py +++ b/docs/examples/te_llama/utils.py @@ -102,7 +102,7 @@ def init_te_llama_model(hyperparams): return model -def wrap_with_accelerator(model, hyperparams, fused_optizer=False): +def wrap_with_accelerator(model, hyperparams): # Create FP8 kwarg handler if required fp8_kwarg_handler = [FP8RecipeKwargs(backend="te")] if hyperparams.mixed_precision == "fp8" else None @@ -117,7 +117,7 @@ def wrap_with_accelerator(model, hyperparams, fused_optizer=False): train_dataloader = get_dataloaders(accelerator, hyperparams) # Wrap model, optimizer/scheduler, dataloaders in accelerate - optimizer = AdamW(params = model.parameters(), lr=hyperparams.learning_rate, fused=fused_optizer) + optimizer = AdamW(params = model.parameters(), lr=hyperparams.learning_rate) lr_scheduler = get_linear_schedule_with_warmup( optimizer=optimizer, num_warmup_steps=100, From 91dd83ee1228894c6031c53eb5c2c1d8d62bad0f Mon Sep 17 00:00:00 2001 From: root Date: Fri, 22 Mar 2024 21:44:16 +0000 Subject: [PATCH 066/244] Batch size back to 8 Signed-off-by: Pawel Gadzinski Signed-off-by: root Signed-off-by: Pawel Gadzinski --- docs/examples/te_llama/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/examples/te_llama/utils.py b/docs/examples/te_llama/utils.py index 2b80415113..9c36e5bd17 100755 --- a/docs/examples/te_llama/utils.py +++ b/docs/examples/te_llama/utils.py @@ -23,7 +23,7 @@ def __init__(self): self.dataset_name = "timdettmers/openassistant-guanaco" self.dataset_text_field = "text" self.learning_rate = 1.41e-5 - self.batch_size = 16 + self.batch_size = 8 self.max_seq_length = 256 self.gradient_accumulation_steps = 1 self.num_warmup_steps=5 From 7edce8e2917755e6318da9b8dcd5b4e7b1c3d396 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 22 Mar 2024 23:32:05 +0000 Subject: [PATCH 067/244] Typo and commented out line Signed-off-by: Pawel Gadzinski Signed-off-by: root Signed-off-by: Pawel Gadzinski --- docs/examples/te_llama/te_llama.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/examples/te_llama/te_llama.py b/docs/examples/te_llama/te_llama.py index 1a1d0ca8a6..690fd9f707 100755 --- a/docs/examples/te_llama/te_llama.py +++ b/docs/examples/te_llama/te_llama.py @@ -123,8 +123,8 @@ def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, **k for shard_file in resolved_archive_file: state_dict = load_state_dict(shard_file) - replaces_params = replace_params(state_dict, vanilla_model.state_dict()) - #_load_state_dict_into_model(vanilla_model, state_dict, start_prefix="") + replace_params(state_dict, vanilla_model.state_dict()) + _load_state_dict_into_model(vanilla_model, state_dict, start_prefix="") # Force mem release. Taken from huggingface code del state_dict From ef9db44b66c2954533739b273b69af70e58efcee Mon Sep 17 00:00:00 2001 From: root Date: Wed, 27 Mar 2024 00:49:51 +0000 Subject: [PATCH 068/244] fixed whitespace Signed-off-by: root Signed-off-by: Pawel Gadzinski --- docs/examples/te_llama/te_llama.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/examples/te_llama/te_llama.py b/docs/examples/te_llama/te_llama.py index 690fd9f707..d6ad6dffbd 100755 --- a/docs/examples/te_llama/te_llama.py +++ b/docs/examples/te_llama/te_llama.py @@ -166,4 +166,5 @@ def replace_params(hf_state_dict, te_state_dict, config): if layer_prefix + 'mlp.down_proj.weight' in hf_state_dict: te_state_dict[layer_prefix + 'layernorm_mlp.fc2_weight'].data[:] = hf_state_dict[layer_prefix + 'mlp.down_proj.weight'].data[:] - return all_layer_prefixes \ No newline at end of file + + return all_layer_prefixes From ccb7f2619031e62bf9c3a4de34262718b5c09467 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 27 Mar 2024 00:52:23 +0000 Subject: [PATCH 069/244] fixed whitespace Signed-off-by: root Signed-off-by: Pawel Gadzinski --- docs/examples/te_llama/te_llama.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/docs/examples/te_llama/te_llama.py b/docs/examples/te_llama/te_llama.py index d6ad6dffbd..24f9610ec0 100755 --- a/docs/examples/te_llama/te_llama.py +++ b/docs/examples/te_llama/te_llama.py @@ -140,7 +140,7 @@ def replace_params(hf_state_dict, te_state_dict, config): m = re.match(layer_prefix_pat, param_key) if m is not None: all_layer_prefixes.add(m.group()) - + for layer_prefix in all_layer_prefixes: # When loading weights into models with less number of layers, skip the # copy if the corresponding layer doesn't exist in HF model @@ -166,5 +166,4 @@ def replace_params(hf_state_dict, te_state_dict, config): if layer_prefix + 'mlp.down_proj.weight' in hf_state_dict: te_state_dict[layer_prefix + 'layernorm_mlp.fc2_weight'].data[:] = hf_state_dict[layer_prefix + 'mlp.down_proj.weight'].data[:] - - return all_layer_prefixes + return all_layer_prefixes \ No newline at end of file From 187d7fc89ebe6afd41dd48e06498e50eb9a7e4c1 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 27 Mar 2024 18:17:36 +0000 Subject: [PATCH 070/244] Added comment to attention line. Fixed potential bug with loading weights - now loading works correctly, confirmed by the generation code. Signed-off-by: root Signed-off-by: Pawel Gadzinski --- docs/examples/te_llama/te_llama.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/docs/examples/te_llama/te_llama.py b/docs/examples/te_llama/te_llama.py index 24f9610ec0..e405f6a937 100755 --- a/docs/examples/te_llama/te_llama.py +++ b/docs/examples/te_llama/te_llama.py @@ -95,7 +95,7 @@ def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, **k Custom method adapted from `from_pretrained` method in HuggingFace Transformers repo: https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579 """ - vanilla_model = cls(config).to(kwargs['torch_dtype']) + vanilla_model = cls(config) is_local = os.path.isdir(pretrained_model_name_or_path) subfolder = "" variant = None @@ -140,6 +140,8 @@ def replace_params(hf_state_dict, te_state_dict, config): m = re.match(layer_prefix_pat, param_key) if m is not None: all_layer_prefixes.add(m.group()) + + GATE_PROJ_SIZE = 11008 for layer_prefix in all_layer_prefixes: # When loading weights into models with less number of layers, skip the @@ -161,8 +163,14 @@ def replace_params(hf_state_dict, te_state_dict, config): if layer_prefix + 'post_attention_layernorm.weight' in hf_state_dict: te_state_dict[layer_prefix + 'layernorm_mlp.layer_norm_weight'].data[:] = hf_state_dict[layer_prefix + 'post_attention_layernorm.weight'].data[:] - if layer_prefix + 'mlp.gate_proj.weight' in hf_state_dict and 'mlp.up_proj.weight' in hf_state_dict: - te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[:] = torch.cat((hf_state_dict[layer_prefix + 'mlp.gate_proj.weight'].data[:], hf_state_dict[layer_prefix + 'mlp.up_proj.weight'].data[:]), dim=0) + + if layer_prefix + 'mlp.gate_proj.weight' in hf_state_dict: + te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[:GATE_PROJ_SIZE] = \ + hf_state_dict[layer_prefix + 'mlp.gate_proj.weight'].data + + if layer_prefix + 'mlp.up_proj.weight' in hf_state_dict: + te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[GATE_PROJ_SIZE:] = \ + hf_state_dict[layer_prefix + 'mlp.up_proj.weight'].data if layer_prefix + 'mlp.down_proj.weight' in hf_state_dict: te_state_dict[layer_prefix + 'layernorm_mlp.fc2_weight'].data[:] = hf_state_dict[layer_prefix + 'mlp.down_proj.weight'].data[:] From 59eaf7cf576bb86bfdc2a1d7e23c30809f75e878 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 27 Mar 2024 19:01:19 +0000 Subject: [PATCH 071/244] Comments Signed-off-by: root Signed-off-by: Pawel Gadzinski --- docs/examples/te_llama/te_llama.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/examples/te_llama/te_llama.py b/docs/examples/te_llama/te_llama.py index e405f6a937..2e6dfe4855 100755 --- a/docs/examples/te_llama/te_llama.py +++ b/docs/examples/te_llama/te_llama.py @@ -123,7 +123,9 @@ def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, **k for shard_file in resolved_archive_file: state_dict = load_state_dict(shard_file) + # replace_params copies parameters relevant only to TransformerEngine replace_params(state_dict, vanilla_model.state_dict()) + # _load_state_dict_into_model copies parameters other than those in TransformerEngine _load_state_dict_into_model(vanilla_model, state_dict, start_prefix="") # Force mem release. Taken from huggingface code @@ -164,6 +166,8 @@ def replace_params(hf_state_dict, te_state_dict, config): if layer_prefix + 'post_attention_layernorm.weight' in hf_state_dict: te_state_dict[layer_prefix + 'layernorm_mlp.layer_norm_weight'].data[:] = hf_state_dict[layer_prefix + 'post_attention_layernorm.weight'].data[:] + # It may happen that gate_proj.weight and up_proj.weight will be in the different files, so we need to + # load them separately. if layer_prefix + 'mlp.gate_proj.weight' in hf_state_dict: te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[:GATE_PROJ_SIZE] = \ hf_state_dict[layer_prefix + 'mlp.gate_proj.weight'].data From 72e5017cca31d93c50afff4be418e571b768ee58 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 27 Mar 2024 20:05:24 +0000 Subject: [PATCH 072/244] Models cast added again Signed-off-by: root Signed-off-by: Pawel Gadzinski --- docs/examples/te_llama/te_llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/examples/te_llama/te_llama.py b/docs/examples/te_llama/te_llama.py index 2e6dfe4855..cf03fc4c9e 100755 --- a/docs/examples/te_llama/te_llama.py +++ b/docs/examples/te_llama/te_llama.py @@ -95,7 +95,7 @@ def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, **k Custom method adapted from `from_pretrained` method in HuggingFace Transformers repo: https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579 """ - vanilla_model = cls(config) + vanilla_model = cls(config).to(kwargs['torch_dtype']) is_local = os.path.isdir(pretrained_model_name_or_path) subfolder = "" variant = None From 12edbcff25b747088829ac92cd83b66c13d5f4e2 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 27 Mar 2024 20:12:20 +0000 Subject: [PATCH 073/244] Weight download info Signed-off-by: Pawel Gadzinski --- .../te_llama/tutorial_accelerate_hf_llama_with_te.ipynb | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) mode change 100644 => 100755 docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb diff --git a/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb b/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb old mode 100644 new mode 100755 index 178922c9d2..cc77b484f9 --- a/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb +++ b/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb @@ -231,7 +231,8 @@ "\n", "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", "## !!! `model_name` attr must point to the location of the model weights !!!\n", - "## Weights can be downloaded from: https://llama.meta.com/llama-downloads/\n", + "## Weights can be downloaded from: https://llama.meta.com/llama-downloads/ and then coverted to the HuggingFace format.\n", + "## Instructions for conversion are available on the website https://ai.meta.com/blog/5-steps-to-getting-started-with-llama-2/ - steps 1 and 2.\n", "hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", "hyperparams.mixed_precision = \"bf16\"\n", "\n", @@ -556,7 +557,8 @@ "\n", "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", "## !!! `model_name` attr must point to the location of the model weights !!!\n", - "## Weights can be downloaded from: https://llama.meta.com/llama-downloads/\n", + "## Weights can be downloaded from: https://llama.meta.com/llama-downloads/ and then coverted to the HuggingFace format.\n", + "## Instructions for conversion are available on the website https://ai.meta.com/blog/5-steps-to-getting-started-with-llama-2/ - steps 1 and 2.\n", "hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", "hyperparams.mixed_precision = \"bf16\"\n", "\n", @@ -635,7 +637,8 @@ "\n", "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", "## !!! `model_name` attr must point to the location of the model weights !!!\n", - "## Weights can be downloaded from: https://llama.meta.com/llama-downloads/\n", + "## Weights can be downloaded from: https://llama.meta.com/llama-downloads/ and then coverted to the HuggingFace format.\n", + "## Instructions for conversion are available on the website https://ai.meta.com/blog/5-steps-to-getting-started-with-llama-2/ - steps 1 and 2.\n", "hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", "hyperparams.mixed_precision = \"fp8\"\n", "\n", From 3e77434cb30dcdf4a35720f47240bcd735853fe5 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 27 Mar 2024 22:15:15 +0000 Subject: [PATCH 074/244] Moved parameter gate_proj_size to config Signed-off-by: Pawel Gadzinski --- docs/examples/te_llama/te_llama.py | 8 ++++---- docs/examples/te_llama/utils.py | 2 ++ 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/docs/examples/te_llama/te_llama.py b/docs/examples/te_llama/te_llama.py index cf03fc4c9e..c6d29a39c7 100755 --- a/docs/examples/te_llama/te_llama.py +++ b/docs/examples/te_llama/te_llama.py @@ -124,7 +124,7 @@ def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, **k for shard_file in resolved_archive_file: state_dict = load_state_dict(shard_file) # replace_params copies parameters relevant only to TransformerEngine - replace_params(state_dict, vanilla_model.state_dict()) + replace_params(state_dict, vanilla_model.state_dict(), config) # _load_state_dict_into_model copies parameters other than those in TransformerEngine _load_state_dict_into_model(vanilla_model, state_dict, start_prefix="") @@ -143,7 +143,7 @@ def replace_params(hf_state_dict, te_state_dict, config): if m is not None: all_layer_prefixes.add(m.group()) - GATE_PROJ_SIZE = 11008 + for layer_prefix in all_layer_prefixes: # When loading weights into models with less number of layers, skip the @@ -169,11 +169,11 @@ def replace_params(hf_state_dict, te_state_dict, config): # It may happen that gate_proj.weight and up_proj.weight will be in the different files, so we need to # load them separately. if layer_prefix + 'mlp.gate_proj.weight' in hf_state_dict: - te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[:GATE_PROJ_SIZE] = \ + te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[:config.gate_proj_size] = \ hf_state_dict[layer_prefix + 'mlp.gate_proj.weight'].data if layer_prefix + 'mlp.up_proj.weight' in hf_state_dict: - te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[GATE_PROJ_SIZE:] = \ + te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[config.gate_proj_size:] = \ hf_state_dict[layer_prefix + 'mlp.up_proj.weight'].data if layer_prefix + 'mlp.down_proj.weight' in hf_state_dict: diff --git a/docs/examples/te_llama/utils.py b/docs/examples/te_llama/utils.py index 9c36e5bd17..28664e09c3 100755 --- a/docs/examples/te_llama/utils.py +++ b/docs/examples/te_llama/utils.py @@ -77,6 +77,7 @@ def init_baseline_model(hyperparams): config = AutoConfig.from_pretrained(hyperparams.model_name) # make sure to use flash_attention to do iso comparison with TELlamaModel config._attn_implementation = "flash_attention_2" + config.gate_proj_size = 11008 model = AutoModelForCausalLM.from_pretrained( hyperparams.model_name, config=config, @@ -92,6 +93,7 @@ def init_te_llama_model(hyperparams): from te_llama import TELlamaForCausalLM config = AutoConfig.from_pretrained(hyperparams.model_name) config._attn_implementation = "flash_attention_2" + config.gate_proj_size = 11008 model = TELlamaForCausalLM.from_pretrained_local( hyperparams.model_name, config=config, From 42235da3bc37f65e7e3014823ac0c782e0866e25 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 28 Mar 2024 00:19:55 +0000 Subject: [PATCH 075/244] gate_proj_size removed and put immediate_size instead Signed-off-by: Pawel Gadzinski --- docs/examples/te_llama/te_llama.py | 4 ++-- docs/examples/te_llama/utils.py | 2 -- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/docs/examples/te_llama/te_llama.py b/docs/examples/te_llama/te_llama.py index c6d29a39c7..aa23b638f0 100755 --- a/docs/examples/te_llama/te_llama.py +++ b/docs/examples/te_llama/te_llama.py @@ -169,11 +169,11 @@ def replace_params(hf_state_dict, te_state_dict, config): # It may happen that gate_proj.weight and up_proj.weight will be in the different files, so we need to # load them separately. if layer_prefix + 'mlp.gate_proj.weight' in hf_state_dict: - te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[:config.gate_proj_size] = \ + te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[:config.intermediate_size] = \ hf_state_dict[layer_prefix + 'mlp.gate_proj.weight'].data if layer_prefix + 'mlp.up_proj.weight' in hf_state_dict: - te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[config.gate_proj_size:] = \ + te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[config.intermediate_size:] = \ hf_state_dict[layer_prefix + 'mlp.up_proj.weight'].data if layer_prefix + 'mlp.down_proj.weight' in hf_state_dict: diff --git a/docs/examples/te_llama/utils.py b/docs/examples/te_llama/utils.py index 28664e09c3..9c36e5bd17 100755 --- a/docs/examples/te_llama/utils.py +++ b/docs/examples/te_llama/utils.py @@ -77,7 +77,6 @@ def init_baseline_model(hyperparams): config = AutoConfig.from_pretrained(hyperparams.model_name) # make sure to use flash_attention to do iso comparison with TELlamaModel config._attn_implementation = "flash_attention_2" - config.gate_proj_size = 11008 model = AutoModelForCausalLM.from_pretrained( hyperparams.model_name, config=config, @@ -93,7 +92,6 @@ def init_te_llama_model(hyperparams): from te_llama import TELlamaForCausalLM config = AutoConfig.from_pretrained(hyperparams.model_name) config._attn_implementation = "flash_attention_2" - config.gate_proj_size = 11008 model = TELlamaForCausalLM.from_pretrained_local( hyperparams.model_name, config=config, From 18ff64583ba049f790dccc020b4839af5a4384b9 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Mon, 15 Apr 2024 23:41:55 +0000 Subject: [PATCH 076/244] add THD support for arbitrary_seqlen backend Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Signed-off-by: Pawel Gadzinski --- tests/pytorch/fused_attn/test_fused_attn.py | 93 ++- .../common/fused_attn/fused_attn.cpp | 46 +- .../fused_attn_f16_arbitrary_seqlen.cu | 287 +++++-- .../fused_attn_f16_arbitrary_seqlen.h | 24 +- .../include/transformer_engine/fused_attn.h | 36 + transformer_engine/pytorch/attention.py | 745 ++++-------------- .../pytorch/cpp_extensions/fused_attn.py | 82 +- transformer_engine/pytorch/csrc/extensions.h | 18 + .../pytorch/csrc/extensions/attention.cu | 168 ++++ 9 files changed, 824 insertions(+), 675 deletions(-) diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index caba385d46..a93fe75b16 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -194,13 +194,17 @@ def _is_flash_attention_supported(config: ModelConfig) -> bool: return False return True - -def _is_unfused_attention_supported(config: ModelConfig) -> bool: +def _is_unfused_attention_supported( + config: ModelConfig, + qkv_format: str, + ) -> bool: """Check if UnfusedDotProductAttention supports a model configuration""" if ("padding" in config.attn_mask_type): return False if ("causal" in config.attn_mask_type and config.attn_type == 'cross'): return False + if qkv_format == 'thd': + return False return True @@ -258,7 +262,8 @@ def test_dot_product_attention(dtype, model_configs, model, ckpt_attn, workspace ) # Skip if only unfused backend is supported - unfused_attn_supported = _is_unfused_attention_supported(config) + qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()]) + unfused_attn_supported = _is_unfused_attention_supported(config, qkv_format) if config.max_seqlen_q <= 512 and config.max_seqlen_kv <= 512: os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0" fused_attn_supported, fused_attn_backend = _is_fused_attention_supported( @@ -269,6 +274,8 @@ def test_dot_product_attention(dtype, model_configs, model, ckpt_attn, workspace flash_attn_supported = _is_flash_attention_supported(config) if (len(fused_attn_backend) + flash_attn_supported + unfused_attn_supported) < 2: pytest.skip("Less than two backends to compare.") + if (qkv_format == 'thd' and 'padding' not in config.attn_mask_type): + pytest.skip("THD layout requires padding/padding_causal mask type.") # UnfusedDotProductAttention backend if unfused_attn_supported: @@ -318,8 +325,16 @@ def test_dot_product_attention(dtype, model_configs, model, ckpt_attn, workspace if fused_attn_supported and flash_attn_supported: if _NVTE_DEBUG: print("[test_dot_product_attention]: fused attn vs flash attn") + print("fused_attn_fwd min {:.8f} max {:.8f}".format( + fused_attn_fwd.min().item(), fused_attn_fwd.max().item())) + print("flash_attn_fwd min {:.8f} max {:.8f}".format( + flash_attn_fwd.min().item(), flash_attn_fwd.max().item())) torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols) for i,_ in enumerate(flash_attn_bwd): + print("fused_attn_bwd[{}] min {:.8f} max {:.8f}".format(i, + fused_attn_bwd[i].min().item(), fused_attn_bwd[i].max().item())) + print("flash_attn_bwd[{}] min {:.8f} max {:.8f}".format(i, + flash_attn_bwd[i].min().item(), flash_attn_bwd[i].max().item())) torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols) if fused_attn_supported and len(fused_attn_backend) == 2: if _NVTE_DEBUG: @@ -493,6 +508,41 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout): """Test DotProductAttention module with different QKV layouts""" test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False) +qkv_layouts_thd = [ + 't3hd', 'th3d', 'thd_t2hd', 'thd_th2d', 'thd_thd_thd', + ] + +model_configs_layout_thd = { + # test: b, h, hg, d, sq, skv, p, mask, bias + #"layout_0_1": ModelConfig(1, 16, 16, 64, 128, 128, 0.0, "padding", "no_bias"), #all 5 pass + #"layout_0_2": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding", "no_bias"), #th3d/thd_t2hd + #"layout_0_3": ModelConfig(1, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"), #all 5 pass + "layout_0_4": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"), #th3d/thd_t2hd + #"layout_1_1": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), #all 5 pass + #"layout_1_2": ModelConfig(8, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), #th3d/t3hd/thd_t2hd + #"layout_1_3": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), #all 5 pass + #"layout_1_4": ModelConfig(8, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), #th3d/t3hd/thd_t2hd + #"layout_2_1": ModelConfig(1, 16, 16, 128, 128, 128, 0.0, "padding", "no_bias"), #all 5 fail + #"layout_2_2": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"), #all 5 pass + #"layout_2_3": ModelConfig(1, 16, 16, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), #all 5 fail + #"layout_2_4": ModelConfig(8, 16, 16, 64, 2048, 4096, 0.0, "padding_causal", "no_bias"), #all 5 skipped + +# Note: all failed tests were due to mismatches (30-50%) except for layout_2_1 tests which were exec errors: +#E RuntimeError: /code/fmha/github3/pr-thd/TransformerEngine/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu:633 in function operator(): cuDNN Error: CUDNN_BACKEND_OPERATIONGRAPH_DESCRIPTOR: Check and Set CUDNN_ATTR_OPERATIONGRAPH_HANDLE cudnn_status: CUDNN_STATUS_BAD_PARAM. For more information, enable cuDNN error logging by setting CUDNN_LOGERR_DBG=1 and CUDNN_LOGDEST_DBG=stderr in the environment. +#E! CuDNN (v8907) function cudnnCreate() called: +#e! Error: CUDNN_STATUS_INTERNAL_ERROR; Reason: cudaStreamCreateWithFlags(&ctx->streamPool[0][i], 0x01) != cudaSuccess +#e! Time: 2024-03-21T03:36:55.887897 (0d+0h+0m+0s since start) +#e! Process=8573; Thread=8678; GPU=NULL; Handle=NULL; StreamId=NULL. +} + +@pytest.mark.skipif(_cudnn_version() < (8,9,5), reason="cuDNN 8.9.5+ is required.") +@pytest.mark.parametrize("dtype", param_types_lean) +@pytest.mark.parametrize("model_configs", [model_configs_layout_thd]) +@pytest.mark.parametrize("model", model_configs_layout_thd.keys()) +@pytest.mark.parametrize("qkv_layout", qkv_layouts_thd) +def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout): + """Test DotProductAttention module with different QKV layouts""" + test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False) def _run_dot_product_attention( dtype: torch.dtype, @@ -536,6 +586,10 @@ def _run_dot_product_attention( cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda") cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0) cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0) + #print('seqlens_q',seqlens_q) + #print('seqlens_kv',seqlens_kv) + #print('cu_seqlens_q',cu_seqlens_q) + #print('cu_seqlens_kv',cu_seqlens_kv) # Create attention mask if padding attention_mask = None @@ -616,6 +670,34 @@ def _run_dot_product_attention( for i in range(3): inp[i].requires_grad = True + # Create ragged offsets for q/k/v + seq_offsets_q, seq_offsets_k, seq_offsets_v = None, None, None + qkv_group = ''.join([x for x in qkv_layout if x not in 'bst']) + if qkv_format == 'thd': + if qkv_group == 'hd_hd_hd': + seq_offsets_q = config.num_heads * config.head_dim * cu_seqlens_q + seq_offsets_k = config.num_gqa_groups * config.head_dim * cu_seqlens_kv + seq_offsets_v = config.num_gqa_groups * config.head_dim * cu_seqlens_kv + if qkv_group == '3hd': + seq_offsets_q = config.num_heads * config.head_dim * cu_seqlens_q + seq_offsets_k = config.num_heads * config.head_dim * 2 * cu_seqlens_q + seq_offsets_v = config.num_heads * config.head_dim * 3 * cu_seqlens_q + if qkv_group == 'h3d': + seq_offsets_q = config.num_heads * config.head_dim * cu_seqlens_q + seq_offsets_k = config.num_heads * config.head_dim * 2 * cu_seqlens_q + seq_offsets_v = config.num_heads * config.head_dim * 3 * cu_seqlens_q + if qkv_group == 'hd_2hd': + seq_offsets_q = config.num_heads * config.head_dim * cu_seqlens_q + seq_offsets_k = config.num_gqa_groups * config.head_dim * cu_seqlens_kv + seq_offsets_v = config.num_gqa_groups * config.head_dim * 2 * cu_seqlens_kv + if qkv_group == 'hd_h2d': + seq_offsets_q = config.num_heads * config.head_dim * cu_seqlens_q + seq_offsets_k = config.num_gqa_groups * config.head_dim * cu_seqlens_kv + seq_offsets_v = config.num_gqa_groups * config.head_dim * 2 * cu_seqlens_kv + #print('seq_offsets_q',seq_offsets_q) + #print('seq_offsets_k',seq_offsets_k) + #print('seq_offsets_v',seq_offsets_v) + # Create output gradient qkv_format_kv = '_'.join(qkv_format) qkv_format_kv = qkv_format_kv.replace('s', 'sq') @@ -666,6 +748,9 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: qkv_format=qkv_format, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, + seq_offsets_q=seq_offsets_q, + seq_offsets_k=seq_offsets_k, + seq_offsets_v=seq_offsets_v, attn_mask_type=config.attn_mask_type, checkpoint_core_attention=ckpt_attn, core_attention_bias_type=config.attn_bias_type, @@ -715,7 +800,7 @@ def test_transformer_layer(dtype, model_configs, model, ckpt_attn, qkv_format, f qkv_layout="sbh3d" if fused_qkv_params else "sb3hd", ) flash_attn_supported = _is_flash_attention_supported(config) - unfused_attn_supported = _is_unfused_attention_supported(config) + unfused_attn_supported = _is_unfused_attention_supported(config, qkv_format) if (len(fused_attn_backend) + flash_attn_supported + unfused_attn_supported) < 2: pytest.skip("Less than two backends to compare.") diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 2d9759898f..82bc8375e4 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -160,6 +160,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) && bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) && ((qkv_format == NVTE_QKV_Format::NVTE_SBHD) + || (qkv_format == NVTE_QKV_Format::NVTE_THD) || (qkv_format == NVTE_QKV_Format::NVTE_BSHD))) { flag_arb = true; } @@ -208,6 +209,9 @@ void nvte_fused_attn_fwd_qkvpacked( NVTETensor O, NVTETensorPack* Aux_CTX_Tensors, const NVTETensor cu_seqlens, + const NVTETensor seq_offsets_q, + const NVTETensor seq_offsets_k, + const NVTETensor seq_offsets_v, const NVTETensor rng_state, size_t max_seqlen, bool is_training, float attn_scale, float dropout, @@ -219,6 +223,9 @@ void nvte_fused_attn_fwd_qkvpacked( using namespace transformer_engine; const Tensor *input_cu_seqlens = reinterpret_cast(cu_seqlens); + const Tensor *input_seq_offsets_q = reinterpret_cast(seq_offsets_q); + const Tensor *input_seq_offsets_k = reinterpret_cast(seq_offsets_k); + const Tensor *input_seq_offsets_v = reinterpret_cast(seq_offsets_v); const Tensor *input_rng_state = reinterpret_cast(rng_state); const Tensor *input_QKV = reinterpret_cast(QKV); const Tensor *input_Bias = reinterpret_cast(Bias); @@ -269,6 +276,7 @@ void nvte_fused_attn_fwd_qkvpacked( input_QKV, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens, + input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, input_rng_state, wkspace, stream, handle); #else @@ -303,6 +311,9 @@ void nvte_fused_attn_bwd_qkvpacked( NVTETensor dQKV, NVTETensor dBias, const NVTETensor cu_seqlens, + const NVTETensor seq_offsets_q, + const NVTETensor seq_offsets_k, + const NVTETensor seq_offsets_v, size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, @@ -313,6 +324,9 @@ void nvte_fused_attn_bwd_qkvpacked( using namespace transformer_engine; const Tensor *input_cu_seqlens = reinterpret_cast(cu_seqlens); + const Tensor *input_seq_offsets_q = reinterpret_cast(seq_offsets_q); + const Tensor *input_seq_offsets_k = reinterpret_cast(seq_offsets_k); + const Tensor *input_seq_offsets_v = reinterpret_cast(seq_offsets_v); const Tensor *input_QKV = reinterpret_cast(QKV); const Tensor *input_O = reinterpret_cast(O); const Tensor *input_dO = reinterpret_cast(dO); @@ -374,7 +388,9 @@ void nvte_fused_attn_bwd_qkvpacked( input_QKV, input_O, input_dO, input_Bias, output_S, output_dQKV, output_dBias, - input_cu_seqlens, input_rng_state, + input_cu_seqlens, + input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, + input_rng_state, wkspace, stream, handle); #else const char *err_msg = @@ -414,6 +430,9 @@ void nvte_fused_attn_fwd_kvpacked( NVTETensorPack* Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor seq_offsets_q, + const NVTETensor seq_offsets_k, + const NVTETensor seq_offsets_v, const NVTETensor rng_state, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, @@ -425,6 +444,9 @@ void nvte_fused_attn_fwd_kvpacked( using namespace transformer_engine; const Tensor *input_cu_seqlens_q = reinterpret_cast(cu_seqlens_q); const Tensor *input_cu_seqlens_kv = reinterpret_cast(cu_seqlens_kv); + const Tensor *input_seq_offsets_q = reinterpret_cast(seq_offsets_q); + const Tensor *input_seq_offsets_k = reinterpret_cast(seq_offsets_k); + const Tensor *input_seq_offsets_v = reinterpret_cast(seq_offsets_v); const Tensor *input_rng_state = reinterpret_cast(rng_state); const Tensor *input_Q = reinterpret_cast(Q); const Tensor *input_KV = reinterpret_cast(KV); @@ -479,6 +501,7 @@ void nvte_fused_attn_fwd_kvpacked( input_Q, input_KV, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, + input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, input_rng_state, wkspace, stream, handle); #else @@ -516,6 +539,9 @@ void nvte_fused_attn_bwd_kvpacked( NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor seq_offsets_q, + const NVTETensor seq_offsets_k, + const NVTETensor seq_offsets_v, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, @@ -526,6 +552,9 @@ void nvte_fused_attn_bwd_kvpacked( using namespace transformer_engine; const Tensor *input_cu_seqlens_q = reinterpret_cast(cu_seqlens_q); const Tensor *input_cu_seqlens_kv = reinterpret_cast(cu_seqlens_kv); + const Tensor *input_seq_offsets_q = reinterpret_cast(seq_offsets_q); + const Tensor *input_seq_offsets_k = reinterpret_cast(seq_offsets_k); + const Tensor *input_seq_offsets_v = reinterpret_cast(seq_offsets_v); const Tensor *input_Q = reinterpret_cast(Q); const Tensor *input_KV = reinterpret_cast(KV); const Tensor *input_O = reinterpret_cast(O); @@ -593,6 +622,7 @@ void nvte_fused_attn_bwd_kvpacked( output_S, output_dQ, output_dKV, output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv, + input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, input_rng_state, wkspace, stream, handle); #else const char *err_msg = @@ -633,6 +663,9 @@ void nvte_fused_attn_fwd( NVTETensorPack* Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor seq_offsets_q, + const NVTETensor seq_offsets_k, + const NVTETensor seq_offsets_v, const NVTETensor rng_state, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, @@ -644,6 +677,9 @@ void nvte_fused_attn_fwd( using namespace transformer_engine; const Tensor *input_cu_seqlens_q = reinterpret_cast(cu_seqlens_q); const Tensor *input_cu_seqlens_kv = reinterpret_cast(cu_seqlens_kv); + const Tensor *input_seq_offsets_q = reinterpret_cast(seq_offsets_q); + const Tensor *input_seq_offsets_k = reinterpret_cast(seq_offsets_k); + const Tensor *input_seq_offsets_v = reinterpret_cast(seq_offsets_v); const Tensor *input_rng_state = reinterpret_cast(rng_state); const Tensor *input_Q = reinterpret_cast(Q); const Tensor *input_K = reinterpret_cast(K); @@ -690,6 +726,7 @@ void nvte_fused_attn_fwd( input_Q, input_K, input_V, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, + input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, input_rng_state, wkspace, stream, handle); #else @@ -729,6 +766,9 @@ void nvte_fused_attn_bwd( NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor seq_offsets_q, + const NVTETensor seq_offsets_k, + const NVTETensor seq_offsets_v, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, @@ -739,6 +779,9 @@ void nvte_fused_attn_bwd( using namespace transformer_engine; const Tensor *input_cu_seqlens_q = reinterpret_cast(cu_seqlens_q); const Tensor *input_cu_seqlens_kv = reinterpret_cast(cu_seqlens_kv); + const Tensor *input_seq_offsets_q = reinterpret_cast(seq_offsets_q); + const Tensor *input_seq_offsets_k = reinterpret_cast(seq_offsets_k); + const Tensor *input_seq_offsets_v = reinterpret_cast(seq_offsets_v); const Tensor *input_Q = reinterpret_cast(Q); const Tensor *input_K = reinterpret_cast(K); const Tensor *input_V = reinterpret_cast(V); @@ -799,6 +842,7 @@ void nvte_fused_attn_bwd( output_S, output_dQ, output_dK, output_dV, output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv, + input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, input_rng_state, wkspace, stream, handle); #else const char *err_msg = diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 180759f327..c40dd327ad 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -57,9 +57,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl( void *devPtrSoftmaxStats, void *devPtrO, void* devPtrDropoutSeed, void* devPtrDropoutOffset, void* devPtrCuSeqlensQ, void* devPtrCuSeqlensKV, + void* devPtrSeqOffsetsQ, void* devPtrSeqOffsetsK, void* devPtrSeqOffsetsV, cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { + using namespace transformer_engine; bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI); bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) @@ -67,6 +69,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl( bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); bool is_dropout = (is_training && dropout_probability != 0.0f); + bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD); + if (is_ragged) { + NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!"); + } try { FADescriptor_v1 descriptor{b, h, @@ -89,6 +95,9 @@ void fused_attn_arbitrary_seqlen_fwd_impl( std::shared_ptr, // bias std::shared_ptr, // seq_q std::shared_ptr, // seq_kv + std::shared_ptr, // offset_q + std::shared_ptr, // offset_k + std::shared_ptr, // offset_v std::shared_ptr, // dropout_seed std::shared_ptr >; // dropout_offset @@ -113,8 +122,25 @@ void fused_attn_arbitrary_seqlen_fwd_impl( std::shared_ptr Q, K, V, attn_scale; std::shared_ptr bias, seq_q, seq_kv; + std::shared_ptr offset_q, offset_k, offset_v; std::shared_ptr dropout_seed, dropout_offset; + offset_q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_q") + .set_dim({b+1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + offset_k = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_k") + .set_dim({b+1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + offset_v = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_v") + .set_dim({b+1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + std::vector q_stride(4); std::vector k_stride(4); std::vector v_stride(4); @@ -124,18 +150,37 @@ void fused_attn_arbitrary_seqlen_fwd_impl( layout, NVTE_QKV_Matrix::NVTE_K_Matrix); generateMatrixStrides(b, hg, s_q, s_kv, d, v_stride.data(), layout, NVTE_QKV_Matrix::NVTE_V_Matrix); - Q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Q") - .set_dim({b, h, s_q, d}) - .set_stride(q_stride)); - K = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("K") - .set_dim({b, hg, s_kv, d}) - .set_stride(k_stride)); - V = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("V") - .set_dim({b, hg, s_kv, d}) - .set_stride(v_stride)); + + if (is_ragged) { + Q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Q") + .set_dim({b, h, s_q, d}) + .set_stride(q_stride) + .set_ragged_offset(offset_q)); + K = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("K") + .set_dim({b, hg, s_kv, d}) + .set_stride(k_stride) + .set_ragged_offset(offset_k)); + V = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("V") + .set_dim({b, hg, s_kv, d}) + .set_stride(v_stride) + .set_ragged_offset(offset_v)); + } else { + Q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Q") + .set_dim({b, h, s_q, d}) + .set_stride(q_stride)); + K = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("K") + .set_dim({b, hg, s_kv, d}) + .set_stride(k_stride)); + V = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("V") + .set_dim({b, hg, s_kv, d}) + .set_stride(v_stride)); + } attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("attn_scale") @@ -197,7 +242,16 @@ void fused_attn_arbitrary_seqlen_fwd_impl( std::vector o_stride(4); generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), layout, NVTE_QKV_Matrix::NVTE_O_Matrix); - O->set_output(true).set_dim({b, h, s_q, d}).set_stride(o_stride); + if (is_ragged) { + O->set_output(true) + .set_dim({b, h, s_q, d}) + .set_stride(o_stride) + .set_ragged_offset(offset_q); + } else { + O->set_output(true) + .set_dim({b, h, s_q, d}) + .set_stride(o_stride); + } Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT) .set_dim({b, h, s_q, 1}) @@ -213,11 +267,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl( auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr); auto padding_tuple = is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); + auto offset_tuple = is_ragged ? + std::make_tuple(offset_q, offset_k, offset_v) : + std::make_tuple(nullptr, nullptr, nullptr); auto dropout_tuple = is_dropout ? std::make_tuple(dropout_seed, dropout_offset) : std::make_tuple(nullptr, nullptr); - auto return_empty_tuple = std::tuple_cat( - std::make_tuple(nullptr), key_tensors_tuple, - Stats_tuple, bias_tuple, padding_tuple, dropout_tuple); NVTE_CHECK_CUDNN_FE(mha_graph->validate()); NVTE_CHECK_CUDNN_FE(mha_graph->build_operation_graph(handle)); @@ -227,18 +281,18 @@ void fused_attn_arbitrary_seqlen_fwd_impl( auto return_tuple = std::tuple_cat( std::make_tuple(mha_graph), key_tensors_tuple, - Stats_tuple, bias_tuple, padding_tuple, dropout_tuple); + Stats_tuple, bias_tuple, padding_tuple, offset_tuple, dropout_tuple); cache.insert({descriptor, return_tuple}); return return_tuple; }; auto [mha_graph, Q, K, V, attn_scale, O, Stats, - bias, seq_q, seq_kv, dropout_seed, dropout_offset] = get_graph( + bias, seq_q, seq_kv, offset_q, offset_k, offset_v, + dropout_seed, dropout_offset] = get_graph( sdpa_f16_fprop_cache, descriptor); auto plan_workspace_size = mha_graph->get_workspace_size(); - // Exit to request upper level API to allocate memory if needed size_t actual_seqlen_workspace_size = 2 * b * sizeof(int32_t); if (workspace == nullptr) { @@ -277,6 +331,12 @@ void fused_attn_arbitrary_seqlen_fwd_impl( variant_pack[seq_kv] = devActualSeqlenKV; } + if (is_ragged) { + variant_pack[offset_q] = devPtrSeqOffsetsQ; + variant_pack[offset_k] = devPtrSeqOffsetsK; + variant_pack[offset_v] = devPtrSeqOffsetsV; + } + if (is_dropout) { variant_pack[dropout_seed] = devPtrDropoutSeed; variant_pack[dropout_offset] = devPtrDropoutOffset; @@ -298,8 +358,10 @@ void fused_attn_arbitrary_seqlen_bwd_impl( void* devPtrdQ, void* devPtrdK, void* devPtrdV, void* devPtrdO, void* devPtrdBias, void* devPtrDropoutSeed, void* devPtrDropoutOffset, void* devPtrCuSeqlensQ, void* devPtrCuSeqlensKV, + void* devPtrSeqOffsetsQ, void* devPtrSeqOffsetsK, void* devPtrSeqOffsetsV, cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { + using namespace transformer_engine; bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI); bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) @@ -307,6 +369,10 @@ void fused_attn_arbitrary_seqlen_bwd_impl( bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); bool is_dropout = (dropout_probability != 0.0f); + bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD); + if (is_ragged) { + NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!"); + } try { FADescriptor_v1 descriptor{b, h, @@ -334,6 +400,9 @@ void fused_attn_arbitrary_seqlen_bwd_impl( std::shared_ptr, // dBias std::shared_ptr, // seq_q std::shared_ptr, // seq_kv + std::shared_ptr, // offset_q + std::shared_ptr, // offset_k + std::shared_ptr, // offset_v std::shared_ptr, // dropout_seed std::shared_ptr >; // dropout_offset @@ -358,8 +427,24 @@ void fused_attn_arbitrary_seqlen_bwd_impl( std::shared_ptr q, k, v, o, dO, stats, attn_scale; std::shared_ptr bias, dBias, seq_q, seq_kv; + std::shared_ptr offset_q, offset_k, offset_v; std::shared_ptr dropout_seed, dropout_offset; + offset_q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_q") + .set_dim({b+1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + offset_k = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_k") + .set_dim({b+1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + offset_v = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_v") + .set_dim({b+1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); std::vector q_stride(4); std::vector k_stride(4); std::vector v_stride(4); @@ -372,26 +457,55 @@ void fused_attn_arbitrary_seqlen_bwd_impl( layout, NVTE_QKV_Matrix::NVTE_V_Matrix); generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), layout, NVTE_QKV_Matrix::NVTE_O_Matrix); - q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Q") - .set_dim({b, h, s_q, d}) - .set_stride(q_stride)); - k = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("K") - .set_dim({b, hg, s_kv, d}) - .set_stride(k_stride)); - v = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("V") - .set_dim({b, hg, s_kv, d}) - .set_stride(v_stride)); - o = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("O") - .set_dim({b, h, s_q, d}) - .set_stride(o_stride)); - dO = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("dO") - .set_dim({b, h, s_q, d}) - .set_stride(o_stride)); + + if (is_ragged) { + q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Q") + .set_dim({b, h, s_q, d}) + .set_stride(q_stride) + .set_ragged_offset(offset_q)); + k = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("K") + .set_dim({b, hg, s_kv, d}) + .set_stride(k_stride) + .set_ragged_offset(offset_k)); + v = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("V") + .set_dim({b, hg, s_kv, d}) + .set_stride(v_stride) + .set_ragged_offset(offset_v)); + o = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("O") + .set_dim({b, h, s_q, d}) + .set_stride(o_stride) + .set_ragged_offset(offset_q)); + dO = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("dO") + .set_dim({b, h, s_q, d}) + .set_stride(o_stride) + .set_ragged_offset(offset_q)); + } else { + q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Q") + .set_dim({b, h, s_q, d}) + .set_stride(q_stride)); + k = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("K") + .set_dim({b, hg, s_kv, d}) + .set_stride(k_stride)); + v = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("V") + .set_dim({b, hg, s_kv, d}) + .set_stride(v_stride)); + o = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("O") + .set_dim({b, h, s_q, d}) + .set_stride(o_stride)); + dO = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("dO") + .set_dim({b, h, s_q, d}) + .set_stride(o_stride)); + } stats = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("stats") .set_dim({b, h, s_q, 1}) @@ -465,15 +579,30 @@ void fused_attn_arbitrary_seqlen_bwd_impl( auto [dQ, dK, dV] = mha_graph->sdpa_backward( q, k, v, o, dO, stats, sdpa_backward_options); - dQ->set_output(true) - .set_dim({b, h, s_q, d}) - .set_stride(q_stride); - dK->set_output(true) - .set_dim({b, hg, s_kv, d}) - .set_stride(k_stride); - dV->set_output(true) - .set_dim({b, hg, s_kv, d}) - .set_stride(v_stride); + if (is_ragged) { + dQ->set_output(true) + .set_dim({b, h, s_q, d}) + .set_stride(q_stride) + .set_ragged_offset(offset_q); + dK->set_output(true) + .set_dim({b, hg, s_kv, d}) + .set_stride(k_stride) + .set_ragged_offset(offset_k); + dV->set_output(true) + .set_dim({b, hg, s_kv, d}) + .set_stride(v_stride) + .set_ragged_offset(offset_v); + } else { + dQ->set_output(true) + .set_dim({b, h, s_q, d}) + .set_stride(q_stride); + dK->set_output(true) + .set_dim({b, hg, s_kv, d}) + .set_stride(k_stride); + dV->set_output(true) + .set_dim({b, hg, s_kv, d}) + .set_stride(v_stride); + } std::tuple, // q std::shared_ptr, // k @@ -490,11 +619,11 @@ void fused_attn_arbitrary_seqlen_bwd_impl( std::make_tuple(bias, dBias) : std::make_tuple(nullptr, nullptr); auto padding_tuple = is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); + auto offset_tuple = is_ragged ? + std::make_tuple(offset_q, offset_k, offset_v) : + std::make_tuple(nullptr, nullptr, nullptr); auto dropout_tuple = is_dropout ? std::make_tuple(dropout_seed, dropout_offset) : std::make_tuple(nullptr, nullptr); - auto return_empty_tuple = std::tuple_cat( - std::make_tuple(nullptr), key_tensors_tuple, - bias_tuple, padding_tuple, dropout_tuple); NVTE_CHECK_CUDNN_FE(mha_graph->validate()); NVTE_CHECK_CUDNN_FE(mha_graph->build_operation_graph(handle)); @@ -504,14 +633,15 @@ void fused_attn_arbitrary_seqlen_bwd_impl( auto return_tuple = std::tuple_cat( std::make_tuple(mha_graph), key_tensors_tuple, - bias_tuple, padding_tuple, dropout_tuple); + bias_tuple, padding_tuple, offset_tuple, dropout_tuple); cache.insert({descriptor, return_tuple}); return return_tuple; }; auto [mha_graph, q, k, v, o, dO, stats, attn_scale, dQ, dK, dV, - bias, dBias, seq_q, seq_kv, dropout_seed, dropout_offset] = get_graph( + bias, dBias, seq_q, seq_kv, offset_q, offset_k, offset_v, + dropout_seed, dropout_offset] = get_graph( sdpa_f16_bprop_cache, descriptor); auto plan_workspace_size = mha_graph->get_workspace_size(); @@ -564,6 +694,12 @@ void fused_attn_arbitrary_seqlen_bwd_impl( variant_pack[seq_kv] = devActualSeqlenKV; } + if (is_ragged) { + variant_pack[offset_q] = devPtrSeqOffsetsQ; + variant_pack[offset_k] = devPtrSeqOffsetsK; + variant_pack[offset_v] = devPtrSeqOffsetsV; + } + if (is_dropout) { variant_pack[dropout_seed] = devPtrDropoutSeed; variant_pack[dropout_offset] = devPtrDropoutOffset; @@ -581,8 +717,9 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { + NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, + const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, const Tensor *seq_offsets_v, + const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const auto QKV_type = input_QKV->data.dtype; @@ -609,6 +746,9 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( void *devPtrO = output_O->data.dptr; void *devPtrS = nullptr; void *devPtrCuSeqlens = cu_seqlens->data.dptr; + void *devPtrSeqOffsetsQ = seq_offsets_q->data.dptr; + void *devPtrSeqOffsetsK = seq_offsets_k->data.dptr; + void *devPtrSeqOffsetsV = seq_offsets_v->data.dptr; if (Aux_CTX_Tensors->size == 0) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { @@ -665,6 +805,7 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, + devPtrSeqOffsetsQ, devPtrSeqOffsetsK, devPtrSeqOffsetsV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); @@ -690,9 +831,10 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t num_attn_hea NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, - Tensor *output_S, - Tensor *output_dQKV, Tensor *output_dBias, - const Tensor *cu_seqlens, const Tensor *rng_state, + Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, + const Tensor *cu_seqlens, const Tensor *seq_offsets_q, + const Tensor *seq_offsets_k, const Tensor *seq_offsets_v, + const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; @@ -732,6 +874,9 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t num_attn_hea devPtrSoftmaxStats = output_S->data.dptr; void *devPtrCuSeqlens = cu_seqlens->data.dptr; + void *devPtrSeqOffsetsQ = seq_offsets_q->data.dptr; + void *devPtrSeqOffsetsK = seq_offsets_k->data.dptr; + void *devPtrSeqOffsetsV = seq_offsets_v->data.dptr; void* devPtrDropoutSeed = rng_state->data.dptr; void* devPtrDropoutOffset = reinterpret_cast( @@ -747,6 +892,7 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t num_attn_hea devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, + devPtrSeqOffsetsQ, devPtrSeqOffsetsK, devPtrSeqOffsetsV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); @@ -771,6 +917,7 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( NVTE_Mask_Type mask_type, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, + const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, const Tensor *seq_offsets_v, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; @@ -800,6 +947,9 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; + void *devPtrSeqOffsetsQ = seq_offsets_q->data.dptr; + void *devPtrSeqOffsetsK = seq_offsets_k->data.dptr; + void *devPtrSeqOffsetsV = seq_offsets_v->data.dptr; if (Aux_CTX_Tensors->size == 0) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { @@ -856,6 +1006,7 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, + devPtrSeqOffsetsQ, devPtrSeqOffsetsK, devPtrSeqOffsetsV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); @@ -885,7 +1036,8 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias, const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, + const Tensor *cu_seqlens_kv, const Tensor *seq_offsets_q, + const Tensor *seq_offsets_k, const Tensor *seq_offsets_v, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; @@ -926,6 +1078,9 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; + void *devPtrSeqOffsetsQ = seq_offsets_q->data.dptr; + void *devPtrSeqOffsetsK = seq_offsets_k->data.dptr; + void *devPtrSeqOffsetsV = seq_offsets_v->data.dptr; void* devPtrDropoutSeed = rng_state->data.dptr; void* devPtrDropoutOffset = reinterpret_cast( @@ -941,6 +1096,7 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, + devPtrSeqOffsetsQ, devPtrSeqOffsetsK, devPtrSeqOffsetsV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); @@ -966,6 +1122,7 @@ void fused_attn_arbitrary_seqlen_fwd( NVTE_Mask_Type mask_type, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, + const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, const Tensor *seq_offsets_v, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; @@ -987,6 +1144,9 @@ void fused_attn_arbitrary_seqlen_fwd( void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; + void *devPtrSeqOffsetsQ = seq_offsets_q->data.dptr; + void *devPtrSeqOffsetsK = seq_offsets_k->data.dptr; + void *devPtrSeqOffsetsV = seq_offsets_v->data.dptr; if (Aux_CTX_Tensors->size == 0) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { @@ -1043,6 +1203,7 @@ void fused_attn_arbitrary_seqlen_fwd( devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, + devPtrSeqOffsetsQ, devPtrSeqOffsetsK, devPtrSeqOffsetsV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); @@ -1072,11 +1233,11 @@ void fused_attn_arbitrary_seqlen_bwd(size_t batch, size_t num_attn_heads, size_t Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, + const Tensor *cu_seqlens_kv, const Tensor *seq_offsets_q, + const Tensor *seq_offsets_k, const Tensor *seq_offsets_v, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; - const auto QKV_type = input_Q->data.dtype; void *devPtrQ = input_Q->data.dptr; void *devPtrK = input_K->data.dptr; @@ -1102,6 +1263,9 @@ void fused_attn_arbitrary_seqlen_bwd(size_t batch, size_t num_attn_heads, size_t void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; + void *devPtrSeqOffsetsQ = seq_offsets_q->data.dptr; + void *devPtrSeqOffsetsK = seq_offsets_k->data.dptr; + void *devPtrSeqOffsetsV = seq_offsets_v->data.dptr; void* devPtrDropoutSeed = rng_state->data.dptr; void* devPtrDropoutOffset = reinterpret_cast( @@ -1116,6 +1280,7 @@ void fused_attn_arbitrary_seqlen_bwd(size_t batch, size_t num_attn_heads, size_t devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, + devPtrSeqOffsetsQ, devPtrSeqOffsetsK, devPtrSeqOffsetsV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index a8866908ce..baedf8ca74 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -24,8 +24,10 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, - const Tensor *cu_seqlens, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + const Tensor *cu_seqlens, const Tensor *seq_offsets_q, + const Tensor *seq_offsets_k, const Tensor *seq_offsets_v, + const Tensor *rng_state, Tensor *workspace, + cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_bwd_qkvpacked( size_t batch, size_t num_attn_heads, size_t max_seqlen, @@ -35,8 +37,10 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked( const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, - const Tensor *cu_seqlens, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + const Tensor *cu_seqlens, const Tensor *seq_offsets_q, + const Tensor *seq_offsets_k, const Tensor *seq_offsets_v, + const Tensor *rng_state, Tensor *workspace, + cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_fwd_kvpacked( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, @@ -47,7 +51,8 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *rng_state, + const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, + const Tensor *seq_offsets_v, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_bwd_kvpacked( @@ -59,7 +64,8 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *rng_state, + const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, + const Tensor *seq_offsets_v, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_fwd( @@ -72,7 +78,8 @@ void fused_attn_arbitrary_seqlen_fwd( const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *rng_state, + const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, + const Tensor *seq_offsets_v, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_bwd( @@ -86,7 +93,8 @@ void fused_attn_arbitrary_seqlen_bwd( Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *rng_state, + const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, + const Tensor *seq_offsets_v, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); #endif // CUDNN_VERSION >= 8900 diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index c13a841067..48cebed28a 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -177,6 +177,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( * \param[out] Aux_CTX_Tensors Auxiliary output tensors when training, * e.g. M, ZInv, rng_state. * \param[in] cu_seqlens Accumulative sequence lengths, [batch_size + 1]. + * \param[in] seq_offsets_q Cumulative sequence offsets for Q, [batch_size + 1]. + * \param[in] seq_offsets_k Cumulative sequence offsets for K, [batch_size + 1]. + * \param[in] seq_offsets_v Cumulative sequence offsets for V, [batch_size + 1]. * \param[in] rng_state Seed and offset of CUDA random number generator. * \param[in] max_seqlen Max sequence length used for computing, * it may be >= max(seqlen_i) for i=0,...batch_size-1. @@ -196,6 +199,9 @@ void nvte_fused_attn_fwd_qkvpacked( NVTETensor O, NVTETensorPack* Aux_CTX_Tensors, const NVTETensor cu_seqlens, + const NVTETensor seq_offsets_q, + const NVTETensor seq_offsets_k, + const NVTETensor seq_offsets_v, const NVTETensor rng_state, size_t max_seqlen, bool is_training, float attn_scale, float dropout, @@ -224,6 +230,9 @@ void nvte_fused_attn_fwd_qkvpacked( * \param[out] dQKV The gradient of the QKV tensor. * \param[out] dBias The gradient of the Bias tensor. * \param[in] cu_seqlens Accumulative sequence lengths, [batch_size + 1]. + * \param[in] seq_offsets_q Cumulative sequence offsets for Q, [batch_size + 1]. + * \param[in] seq_offsets_k Cumulative sequence offsets for K, [batch_size + 1]. + * \param[in] seq_offsets_v Cumulative sequence offsets for V, [batch_size + 1]. * \param[in] max_seqlen Max sequence length used for computing, * it may be >= max(seqlen_i) for i=0,...batch_size-1. * \param[in] attn_scale Scaling factor for Q * K.T. @@ -244,6 +253,9 @@ void nvte_fused_attn_bwd_qkvpacked( NVTETensor dQKV, NVTETensor dBias, const NVTETensor cu_seqlens, + const NVTETensor seq_offsets_q, + const NVTETensor seq_offsets_k, + const NVTETensor seq_offsets_v, size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, @@ -275,6 +287,9 @@ void nvte_fused_attn_bwd_qkvpacked( * e.g. M, ZInv, rng_state. * \param[in] cu_seqlens_q Accumulative sequence lengths for Q, [batch_size + 1]. * \param[in] cu_seqlens_kv Accumulative sequence lengths for KV, [batch_size + 1]. + * \param[in] seq_offsets_q Cumulative sequence offsets for Q, [batch_size + 1]. + * \param[in] seq_offsets_k Cumulative sequence offsets for K, [batch_size + 1]. + * \param[in] seq_offsets_v Cumulative sequence offsets for V, [batch_size + 1]. * \param[in] rng_state Seed and offset of CUDA random number generator. * \param[in] max_seqlen_q Max sequence length used for computing for Q. * it may be >= max(seqlen_q_i) for i=0,...batch_size-1. @@ -298,6 +313,9 @@ void nvte_fused_attn_fwd_kvpacked( NVTETensorPack* Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor seq_offsets_q, + const NVTETensor seq_offsets_k, + const NVTETensor seq_offsets_v, const NVTETensor rng_state, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, @@ -328,6 +346,9 @@ void nvte_fused_attn_fwd_kvpacked( * \param[out] dBias The gradient of the Bias tensor. * \param[in] cu_seqlens_q Accumulative sequence lengths for Q, [batch_size + 1]. * \param[in] cu_seqlens_kv Accumulative sequence lengths for KV, [batch_size + 1]. + * \param[in] seq_offsets_q Cumulative sequence offsets for Q, [batch_size + 1]. + * \param[in] seq_offsets_k Cumulative sequence offsets for K, [batch_size + 1]. + * \param[in] seq_offsets_v Cumulative sequence offsets for V, [batch_size + 1]. * \param[in] max_seqlen_q Max sequence length used for computing for Q. * it may be >= max(seqlen_q_i) for i=0,...batch_size-1. * \param[in] max_seqlen_kv Max sequence length used for computing for KV. @@ -353,6 +374,9 @@ void nvte_fused_attn_bwd_kvpacked( NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor seq_offsets_q, + const NVTETensor seq_offsets_k, + const NVTETensor seq_offsets_v, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, @@ -388,6 +412,9 @@ void nvte_fused_attn_bwd_kvpacked( * e.g. M, ZInv, rng_state. * \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1]. * \param[in] cu_seqlens_kv Cumulative sequence lengths for K and V, [batch_size + 1]. + * \param[in] seq_offsets_q Cumulative sequence offsets for Q, [batch_size + 1]. + * \param[in] seq_offsets_k Cumulative sequence offsets for K, [batch_size + 1]. + * \param[in] seq_offsets_v Cumulative sequence offsets for V, [batch_size + 1]. * \param[in] rng_state Seed and offset of CUDA random number generator. * \param[in] max_seqlen_q Max sequence length used for computing for Q. * it may be >= max(seqlen_q_i) for i=0,...batch_size-1. @@ -412,6 +439,9 @@ void nvte_fused_attn_fwd( NVTETensorPack* Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor seq_offsets_q, + const NVTETensor seq_offsets_k, + const NVTETensor seq_offsets_v, const NVTETensor rng_state, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, @@ -447,6 +477,9 @@ void nvte_fused_attn_fwd( * \param[out] dBias The gradient of the Bias tensor. * \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1]. * \param[in] cu_seqlens_kv Cumulative sequence lengths for K and V, [batch_size + 1]. + * \param[in] seq_offsets_q Cumulative sequence offsets for Q, [batch_size + 1]. + * \param[in] seq_offsets_k Cumulative sequence offsets for K, [batch_size + 1]. + * \param[in] seq_offsets_v Cumulative sequence offsets for V, [batch_size + 1]. * \param[in] max_seqlen_q Max sequence length used for computing for Q. * it may be >= max(seqlen_q_i) for i=0,...batch_size-1. * \param[in] max_seqlen_kv Max sequence length used for computing for K and V. @@ -474,6 +507,9 @@ void nvte_fused_attn_bwd( NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor seq_offsets_q, + const NVTETensor seq_offsets_k, + const NVTETensor seq_offsets_v, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 31c32a9f93..6d20cc8b29 100755 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1617,8 +1617,6 @@ def forward( assert (qkv_layout in QKVLayouts ), f"UnfusedDotProductAttention does not support qkv_layout = {qkv_layout}!" qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()]) - assert (qkv_format != 'thd' - ), """UnfusedDotProductAttention does not support variable sequence lengths!""" if qkv_format == 'bshd': # convert to sbhd and use sbhd implementation for now query_layer, key_layer, value_layer = [x.transpose(0, 1) @@ -2001,7 +1999,7 @@ def forward( else: query_layer, key_layer, value_layer = [x.transpose(0,1).contiguous() for x in (query_layer, key_layer, value_layer)] - elif qkv_format == 'bshd': + elif qkv_format in ['bshd', 'thd']: query_layer, key_layer, value_layer = [x.contiguous() for x in (query_layer, key_layer, value_layer)] @@ -2064,14 +2062,11 @@ def forward( ) elif qkv_format == 'thd': assert not context_parallel, "thd format not supported with context parallelism!" - assert (cu_seqlens_q is not None and cu_seqlens_kv is not None - ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!" - if max_seqlen_q is None: - seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] - max_seqlen_q = seqlens_q.max().item() - if max_seqlen_kv is None: - seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] - max_seqlen_kv = seqlens_kv.max().item() + assert (max_seqlen_q is not None + and max_seqlen_kv is not None + and cu_seqlens_q is not None + and cu_seqlens_kv is not None + ), "max_seqlen_q/kv and cu_seqlens_q/kv can not be None when qkv_format is thd!" if context_parallel: assert ( @@ -2116,7 +2111,7 @@ def forward( **fa_optional_forward_kwargs, ) - if 'padding' in attn_mask_type: + if qkv_format in ['sbhd', 'bshd'] and 'padding' in attn_mask_type: output = UnpackTensor.apply(indices_q, batch_size * max_seqlen_q, output) if qkv_format == 'sbhd': @@ -2165,83 +2160,20 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): """Function for FusedAttention with packed QKV input""" @staticmethod - def forward(ctx, is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype, attn_bias, attn_scale, + def forward(ctx, is_training, max_seqlen, cu_seqlens, + seq_offsets_q, seq_offsets_k, seq_offsets_v, + qkv, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, - rng_gen, fused_attention_backend, use_FAv2_bwd, - fp8, fp8_meta): - if fp8: - if _NVTE_DEBUG: - print('[DotProductAttention]: using FP8 forward') - if fp8_meta["recipe"].fp8_mha: - assert (isinstance(qkv, Float8Tensor)), "qkv must be Float8Tensors for FP8 MHA." - fp8_meta["scaling_fwd"].scale_inv[META_QKV] = qkv._scale_inv - fused_attention_backend = FusedAttnBackend["FP8"] - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - # 1: qkv packed, 2: kv packed, 3: qkv separate - qkv_group = len(qkv_layout.split('_')) - assert (qkv_group == 1 - ), f"qkv layout should conform to 3hd or h3d, e.g. sb3hd, \ - but found {qkv_layout}." - if fp8_meta["recipe"].fp8_mha: - qkv_fp8 = qkv._data - else: - qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) - qkv_fp8 = cast_to_fp8(qkv_c, - fp8_meta["scaling_fwd"], - META_QKV, fp8_dtype_forward).view(qkv.shape) - out_fp8, aux_ctx_tensors = fused_attn_fwd_qkvpacked( - is_training, max_seqlen, cu_seqlens, - qkv_fp8, fp8_dtype_forward, fused_attention_backend, attn_bias, - fp8_meta["scaling_fwd"].scale_inv[META_QKV], - fp8_meta["scaling_fwd"].scale_inv[META_S], - fp8_meta["scaling_fwd"].scale[META_S], - fp8_meta["scaling_fwd"].scale[META_O], - fp8_meta["scaling_fwd"].amax_history[0][META_S], - fp8_meta["scaling_fwd"].amax_history[0][META_O], - attn_scale, dropout_p, fast_zero_fill, qkv_layout, - attn_bias_type, attn_mask_type, rng_gen) - if fp8_meta["recipe"].fp8_mha: - out_ret = Float8Tensor(data=out_fp8, - fp8_meta=fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=META_O, - fp8_dtype=fp8_dtype_forward, - dtype=qkv.dtype, - ) - else: - out_ret = cast_from_fp8( - out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), - fp8_meta["scaling_fwd"], META_O, - fp8_dtype_forward, qkv_dtype).view(out_fp8.shape) - out_save = out_ret - if fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) - qkv = cast_from_fp8(qkv_c._data, - fp8_meta["scaling_fwd"], - META_QKV, fp8_dtype_forward, TE_DType[qkv.dtype]).view(qkv.shape) - out_save = cast_from_fp8( - out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), - fp8_meta["scaling_fwd"], META_O, - fp8_dtype_forward, qkv_dtype).view(out_fp8.shape) - fp8_tensors = (qkv_fp8, out_fp8, - fp8_meta["scaling_fwd"].scale.clone(), - fp8_meta["scaling_fwd"].scale_inv.clone()) - else: - if _NVTE_DEBUG: - print('[DotProductAttention]: using non-FP8 forward') - out_ret, aux_ctx_tensors = fused_attn_fwd_qkvpacked( - is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype, - fused_attention_backend, attn_bias, - None, None, None, None, None, None, - attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, - rng_gen) - fp8_tensors = (None, None, None, None) - out_save = out_ret - - ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) - qkvo_tensors = (qkv, out_save) if not ctx.fp8 else (None, None) - ctx.save_for_backward(*qkvo_tensors, cu_seqlens, *fp8_tensors) - ctx.fp8_meta = fp8_meta + rng_gen, fused_attention_backend, use_FAv2_bwd): + out, aux_ctx_tensors = fused_attn_fwd_qkvpacked( + is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype, + fused_attention_backend, + seq_offsets_q, seq_offsets_k, seq_offsets_v, + attn_bias, None, None, None, None, None, + attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, + rng_gen) + + ctx.save_for_backward(qkv, out, cu_seqlens, seq_offsets_q, seq_offsets_k, seq_offsets_v) ctx.aux_ctx_tensors = aux_ctx_tensors ctx.max_seqlen = max_seqlen ctx.qkv_dtype = qkv_dtype @@ -2266,8 +2198,7 @@ def backward(ctx, d_out): d_out = d_out._data d_out = d_out.contiguous() - (qkv, out, cu_seqlens, - qkv_fp8, out_fp8, fwd_scales, fwd_scale_invs) = ctx.saved_tensors + qkv, out, cu_seqlens, seq_offsets_q, seq_offsets_k, seq_offsets_v = ctx.saved_tensors if not ctx.aux_ctx_tensors[0].is_contiguous(): ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous() if ctx.use_FAv2_bwd: @@ -2284,73 +2215,22 @@ def backward(ctx, d_out): ) dqkv = dqkv[..., :d_out.shape[-1]] else: - with torch.cuda.nvtx.range("_FusedAttn_qkvpacked"): - if ctx.fp8: - if _NVTE_DEBUG: - print('[DotProductAttention]: using FP8 backward') - fp8_dtype_forward = get_fp8_te_dtype( - ctx.fp8_meta["recipe"], fprop_tensor=True) - fp8_dtype_backward = get_fp8_te_dtype( - ctx.fp8_meta["recipe"], fprop_tensor=False) - if ctx.fp8_meta["recipe"].fp8_mha: - d_out_fp8 = d_out - ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO] = d_out_f8tensor._scale_inv - else: - d_out_fp8 = cast_to_fp8( - d_out.view(-1, d_out.shape[-2] * d_out.shape[-1]), - ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward - ).view(d_out.shape) - dqkv_fp8, *rest = fused_attn_bwd_qkvpacked( - ctx.max_seqlen, cu_seqlens, - qkv_fp8, out_fp8, d_out_fp8, - fp8_dtype_forward, fp8_dtype_backward, ctx.aux_ctx_tensors, - ctx.fused_attention_backend, - fwd_scale_invs[META_QKV], # d_scale_qkv, - fwd_scale_invs[META_S], # d_scale_s, - fwd_scale_invs[META_O], # d_scale_o, - ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO], # d_scale_do - ctx.fp8_meta['scaling_bwd'].scale_inv[META_DP], # d_scale_dp - fwd_scales[META_S], # q_scale_s - ctx.fp8_meta['scaling_bwd'].scale[META_DP], # q_scale_dp - ctx.fp8_meta['scaling_bwd'].scale[META_DQKV], # q_scale_dqkv - ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DP], # amax_dp - ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DQKV], # amax_dqkv - ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, - ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type) - if ctx.fp8_meta["recipe"].fp8_mha: - dqkv = Float8Tensor(data=dqkv_fp8, - fp8_meta=ctx.fp8_meta, - fp8_meta_forward=False, - fp8_meta_index=META_DQKV, - fp8_dtype=fp8_dtype_backward, - dtype=d_out_f8tensor.dtype, - ) - else: - dqkv_c_fp8 = dqkv_fp8.view(-1, - dqkv_fp8.shape[-3] * dqkv_fp8.shape[-2] * dqkv_fp8.shape[-1]) - dqkv = cast_from_fp8(dqkv_c_fp8, - ctx.fp8_meta["scaling_bwd"], META_DQKV, - fp8_dtype_backward, ctx.qkv_dtype).view(dqkv_fp8.shape) - else: - if _NVTE_DEBUG: - print('[DotProductAttention]: using non-FP8 backward') - if d_out.dtype == torch.uint8: - d_out = d_out_f8tensor.from_float8(qkv.dtype) - dqkv, *rest = fused_attn_bwd_qkvpacked( - ctx.max_seqlen, cu_seqlens, qkv, out, d_out, - ctx.qkv_dtype, ctx.qkv_dtype, ctx.aux_ctx_tensors, - ctx.fused_attention_backend, - None, None, None, None, None, None, None, None, None, None, - ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, - ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type) + dqkv, *rest = fused_attn_bwd_qkvpacked( + ctx.max_seqlen, cu_seqlens, qkv, out, d_out, + ctx.qkv_dtype, ctx.aux_ctx_tensors, + ctx.fused_attention_backend, + seq_offsets_q, seq_offsets_k, seq_offsets_v, + None, None, None, None, None, None, None, None, None, + ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, + ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type) # if no_bias or alibi, return dqkv if ctx.attn_bias_type in ["no_bias", "alibi"]: - return (None, None, None, dqkv, None, None, None, + return (None, None, None, None, None, None, dqkv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None) # else, return (dqkv, dbias) - return (None, None, None, dqkv, None, rest[0], None, + return (None, None, None, None, None, None, dqkv, None, rest[0], None, None, None, None, None, None, None, None, None, None, None, None, None) @@ -2360,89 +2240,20 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): @staticmethod def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, + seq_offsets_q, seq_offsets_k, seq_offsets_v, q, kv, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill, - qkv_layout, attn_bias_type, attn_mask_type, rng_gen, fused_attention_backend, - use_FAv2_bwd, fp8, fp8_meta): - if fp8: - if _NVTE_DEBUG: - print('[DotProductAttention]: using FP8 forward') - if fp8_meta["recipe"].fp8_mha: - assert (isinstance(q, Float8Tensor) - and isinstance(kv, Float8Tensor)), "q/kv must be Float8Tensors for FP8 MHA." - fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv - fused_attention_backend = FusedAttnBackend["FP8"] - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - if fp8_meta["recipe"].fp8_mha: - q_fp8, kv_fp8 = q._data, kv._data - else: - # 1: qkv packed, 2: kv packed, 3: qkv separate - qkv_group = len(qkv_layout.split('_')) - assert (qkv_group == 2 - ), f"qkv layout should conform to hd_2hd or hd_h2d, e.g. sbhd_sb2hd, \ - but found {qkv_layout}." - q_fp8 = cast_to_fp8(q, - fp8_meta["scaling_fwd"], - META_QKV, fp8_dtype_forward).view(q.shape) - kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) - kv_fp8 = cast_to_fp8(kv_c, - fp8_meta["scaling_fwd"], - META_QKV, fp8_dtype_forward).view(kv.shape) - out_fp8, aux_ctx_tensors = fused_attn_fwd_kvpacked( - is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, - q_fp8, kv_fp8, fp8_dtype_forward, fused_attention_backend, attn_bias, - fp8_meta["scaling_fwd"].scale_inv[META_QKV], - fp8_meta["scaling_fwd"].scale_inv[META_S], - fp8_meta["scaling_fwd"].scale[META_S], - fp8_meta["scaling_fwd"].scale[META_O], - fp8_meta["scaling_fwd"].amax_history[0][META_S], - fp8_meta["scaling_fwd"].amax_history[0][META_O], - attn_scale, dropout_p, fast_zero_fill, qkv_layout, - attn_bias_type, attn_mask_type, rng_gen) - if fp8_meta["recipe"].fp8_mha: - out_ret = Float8Tensor(data=out_fp8, - fp8_meta=fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=META_O, - fp8_dtype=fp8_dtype_forward, - dtype=q.dtype, - ) - else: - out_ret = cast_from_fp8( - out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), - fp8_meta["scaling_fwd"], META_O, - fp8_dtype_forward, qkv_dtype).view(out_fp8.shape) - out_save = out_ret - if fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - q = cast_from_fp8(q._data, - fp8_meta["scaling_fwd"], - META_QKV, fp8_dtype_forward, TE_DType[q.dtype]).view(q.shape) - kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) - kv = cast_from_fp8(kv_c._data, - fp8_meta["scaling_fwd"], - META_QKV, fp8_dtype_forward, TE_DType[kv.dtype]).view(kv.shape) - out_save = cast_from_fp8( - out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), - fp8_meta["scaling_fwd"], META_O, - fp8_dtype_forward, qkv_dtype).view(out_fp8.shape) - fp8_tensors = (q_fp8, kv_fp8, out_fp8, - fp8_meta["scaling_fwd"].scale.clone(), - fp8_meta["scaling_fwd"].scale_inv.clone()) - else: - if _NVTE_DEBUG: - print('[DotProductAttention]: using non-FP8 forward') - out_ret, aux_ctx_tensors = fused_attn_fwd_kvpacked( - is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, - q, kv, qkv_dtype, fused_attention_backend, attn_bias, - None, None, None, None, None, None, - attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, - rng_gen) - out_save = out_ret - fp8_tensors = (None, None, None, None, None) - - ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) - qkvo_tensors = (q, kv, out_save) if not ctx.fp8 else (None, None, None) - ctx.save_for_backward(*qkvo_tensors, cu_seqlens_q, cu_seqlens_kv, *fp8_tensors) - ctx.fp8_meta = fp8_meta + qkv_layout, attn_bias_type, attn_mask_type, + rng_gen, fused_attention_backend, use_FAv2_bwd): + out, aux_ctx_tensors = fused_attn_fwd_kvpacked( + is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, + q, kv, qkv_dtype, fused_attention_backend, + seq_offsets_q, seq_offsets_k, seq_offsets_v, + attn_bias, None, None, None, None, None, + attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, + rng_gen) + + ctx.save_for_backward(q, kv, out, cu_seqlens_q, cu_seqlens_kv, + seq_offsets_q, seq_offsets_k, seq_offsets_v) ctx.aux_ctx_tensors = aux_ctx_tensors ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_kv = max_seqlen_kv @@ -2469,7 +2280,7 @@ def backward(ctx, d_out): d_out = d_out.contiguous() (q, kv, out, cu_seqlens_q, cu_seqlens_kv, - q_fp8, kv_fp8, out_fp8, fwd_scales, fwd_scale_invs) = ctx.saved_tensors + seq_offsets_q, seq_offsets_k, seq_offsets_v) = ctx.saved_tensors if not ctx.aux_ctx_tensors[0].is_contiguous(): ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous() if ctx.use_FAv2_bwd: @@ -2488,85 +2299,23 @@ def backward(ctx, d_out): dq = dq[..., :d_out.shape[-1]] dkv = dkv[..., :d_out.shape[-1]] else: - with torch.cuda.nvtx.range("_FusedAttn_kvpacked"): - if ctx.fp8: - if _NVTE_DEBUG: - print('[DotProductAttention]: using FP8 backward') - fp8_dtype_forward = get_fp8_te_dtype( - ctx.fp8_meta["recipe"], fprop_tensor=True) - fp8_dtype_backward = get_fp8_te_dtype( - ctx.fp8_meta["recipe"], fprop_tensor=False) - if ctx.fp8_meta["recipe"].fp8_mha: - d_out_fp8 = d_out - ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO] = d_out_f8tensor._scale_inv - else: - d_out_fp8 = cast_to_fp8( - d_out.view(-1, d_out.shape[-2] * d_out.shape[-1]), - ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward - ).view(d_out.shape) - dq_fp8, dkv_fp8, *rest = fused_attn_bwd_kvpacked( - ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, - q_fp8, kv_fp8, out_fp8, d_out_fp8, - fp8_dtype_forward, fp8_dtype_backward, ctx.aux_ctx_tensors, - ctx.fused_attention_backend, - fwd_scale_invs[META_QKV], # d_scale_qkv, - fwd_scale_invs[META_S], # d_scale_s, - fwd_scale_invs[META_O], # d_scale_o, - ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO], # d_scale_do - ctx.fp8_meta['scaling_bwd'].scale_inv[META_DP], # d_scale_dp - fwd_scales[META_S], # q_scale_s - ctx.fp8_meta['scaling_bwd'].scale[META_DP], # q_scale_dp - ctx.fp8_meta['scaling_bwd'].scale[META_DQKV], # q_scale_dqkv - ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DP], # amax_dp - ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DQKV], # amax_dqkv - ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, - ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type) - if ctx.fp8_meta["recipe"].fp8_mha: - dq = Float8Tensor(data=dq_fp8, - fp8_meta=ctx.fp8_meta, - fp8_meta_forward=False, - fp8_meta_index=META_DQKV, - fp8_dtype=fp8_dtype_backward, - dtype=d_out_f8tensor.dtype, - ) - dkv = Float8Tensor(data=dkv_fp8, - fp8_meta=ctx.fp8_meta, - fp8_meta_forward=False, - fp8_meta_index=META_DQKV, - fp8_dtype=fp8_dtype_backward, - dtype=d_out_f8tensor.dtype, - ) - else: - dq = cast_from_fp8( - dq_fp8.view(-1, dq_fp8.shape[-2] * dq_fp8.shape[-1]), - ctx.fp8_meta["scaling_bwd"], META_DQKV, - fp8_dtype_backward, ctx.qkv_dtype).view(dq_fp8.shape) - dkv_c_fp8 = dkv_fp8.view(-1, - dkv_fp8.shape[-3] * dkv_fp8.shape[-2] * dkv_fp8.shape[-1]) - dkv = cast_from_fp8(dkv_c_fp8, - ctx.fp8_meta["scaling_bwd"], META_DQKV, - fp8_dtype_backward, ctx.qkv_dtype).view(dkv_fp8.shape) - else: - if _NVTE_DEBUG: - print('[DotProductAttention]: using non-FP8 backward') - if d_out.dtype == torch.uint8: - d_out = d_out_f8tensor.from_float8(q.dtype) - dq, dkv, *rest = fused_attn_bwd_kvpacked( - ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, - q, kv, out, d_out, - ctx.qkv_dtype, ctx.qkv_dtype, ctx.aux_ctx_tensors, - ctx.fused_attention_backend, - None, None, None, None, None, None, None, None, None, None, - ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, - ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type) + dq, dkv, *rest = fused_attn_bwd_kvpacked( + ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, + q, kv, out, d_out, + ctx.qkv_dtype, ctx.aux_ctx_tensors, + ctx.fused_attention_backend, + seq_offsets_q, seq_offsets_k, seq_offsets_v, + None, None, None, None, None, None, None, None, None, + ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, + ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type) # if no_bias or alibi, return dqkv if ctx.attn_bias_type in ["no_bias", "alibi"]: - return (None, None, None, None, None, dq, dkv, None, None, None, + return (None, None, None, None, None, None, None, None, dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None) # else, return (dqkv, dbias) - return (None, None, None, None, None, dq, dkv, None, rest[0], None, + return (None, None, None, None, None, None, None, None, dq, dkv, None, rest[0], None, None, None, None, None, None, None, None, None, None, None, None, None) @@ -2575,133 +2324,17 @@ class FusedAttnFunc(torch.autograd.Function): @staticmethod def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, + seq_offsets_q, seq_offsets_k, seq_offsets_v, q, k, v, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill, - qkv_layout, attn_bias_type, attn_mask_type, rng_gen, fused_attention_backend, - use_FAv2_bwd, fp8, fp8_meta): - if fp8: - if _NVTE_DEBUG: - print('[DotProductAttention]: using FP8 forward') - fused_attention_backend = FusedAttnBackend["FP8"] - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - if fp8_meta["recipe"].fp8_mha: - assert (isinstance(q, Float8Tensor) - and isinstance(k, Float8Tensor) - and isinstance(v, Float8Tensor)), "q/k/v must be Float8Tensors for FP8 MHA." - fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv - q_fp8, k_fp8, v_fp8 = q._data, k._data, v._data - else: - # 1: qkv packed, 2: kv packed, 3: qkv separate - qkv_group = len(qkv_layout.split('_')) - if qkv_group == 1: - dim = qkv_layout.find('3') - qkv = _combine_tensors([q,k,v], dim) - qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) - qkv_fp8 = cast_to_fp8(qkv_c, - fp8_meta["scaling_fwd"], - META_QKV, fp8_dtype_forward).view(qkv.shape) - q_fp8, k_fp8, v_fp8 = _SplitAlongDim.apply(qkv_fp8, dim, [1,1,1]) - q_fp8, k_fp8, v_fp8 = [x.squeeze(dim) for x in [q_fp8, k_fp8, v_fp8]] - if qkv_group == 2: - q_fp8 = cast_to_fp8(q, - fp8_meta["scaling_fwd"], - META_QKV, fp8_dtype_forward).view(q.shape) - dim = qkv_layout.split('_')[1].find('2') - kv = _combine_tensors([k,v], dim) - kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) - kv_fp8 = cast_to_fp8(kv_c, - fp8_meta["scaling_fwd"], - META_QKV, fp8_dtype_forward).view(kv.shape) - k_fp8, v_fp8 = _SplitAlongDim.apply(kv_fp8, dim, [1,1]) - k_fp8, v_fp8 = [x.squeeze(dim) for x in [k_fp8, v_fp8]] - if qkv_group == 3: - q_fp8 = cast_to_fp8(q, - fp8_meta["scaling_fwd"], - META_QKV, fp8_dtype_forward).view(q.shape) - k_fp8 = cast_to_fp8(k, - fp8_meta["scaling_fwd"], - META_QKV, fp8_dtype_forward).view(k.shape) - v_fp8 = cast_to_fp8(v, - fp8_meta["scaling_fwd"], - META_QKV, fp8_dtype_forward).view(v.shape) - out_fp8, aux_ctx_tensors = fused_attn_fwd( - is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, - q_fp8, k_fp8, v_fp8, fp8_dtype_forward, fused_attention_backend, attn_bias, - fp8_meta["scaling_fwd"].scale_inv[META_QKV], - fp8_meta["scaling_fwd"].scale_inv[META_S], - fp8_meta["scaling_fwd"].scale[META_S], - fp8_meta["scaling_fwd"].scale[META_O], - fp8_meta["scaling_fwd"].amax_history[0][META_S], - fp8_meta["scaling_fwd"].amax_history[0][META_O], - attn_scale, dropout_p, fast_zero_fill, qkv_layout, - attn_bias_type, attn_mask_type, rng_gen) - if fp8_meta["recipe"].fp8_mha: - out_ret = Float8Tensor(data=out_fp8, - fp8_meta=fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=META_O, - fp8_dtype=fp8_dtype_forward, - dtype=q.dtype, - ) - else: - out_ret = cast_from_fp8( - out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), - fp8_meta["scaling_fwd"], META_O, - fp8_dtype_forward, qkv_dtype).view(out_fp8.shape) - out_save = out_ret - - if fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - # 1: qkv packed, 2: kv packed, 3: qkv separate - qkv_group = len(qkv_layout.split('_')) - if qkv_group == 1: - dim = qkv_layout.find('3') - qkv = _combine_tensors([q,k,v], dim) - qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) - qkv_no_fp8 = cast_from_fp8(qkv_c._data, - fp8_meta["scaling_fwd"], - META_QKV, fp8_dtype_forward, TE_DType[qkv.dtype]).view(qkv.shape) - q, k, v = _SplitAlongDim.apply(qkv_no_fp8, dim, [1,1,1]) - q, k, v = [x.squeeze(dim) for x in [q, k, v]] - if qkv_group == 2: - q = cast_from_fp8(q._data, - fp8_meta["scaling_fwd"], - META_QKV, fp8_dtype_forward, TE_DType[q.dtype]).view(q.shape) - dim = qkv_layout.split('_')[1].find('2') - kv = _combine_tensors([k,v], dim) - kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) - kv_no_fp8 = cast_from_fp8(kv_c._data, - fp8_meta["scaling_fwd"], - META_QKV, fp8_dtype_forward, TE_DType[kv.dtype]).view(kv.shape) - k, v = _SplitAlongDim.apply(kv_no_fp8, dim, [1,1]) - k, v = [x.squeeze(dim) for x in [k, v]] - if qkv_group == 3: - q = cast_from_fp8(q._data, - fp8_meta["scaling_fwd"], - META_QKV, fp8_dtype_forward, TE_DType[q.dtype]).view(q.shape) - k = cast_from_fp8(k._data, - fp8_meta["scaling_fwd"], - META_QKV, fp8_dtype_forward, TE_DType[k.dtype]).view(k.shape) - v = cast_from_fp8(v._data, - fp8_meta["scaling_fwd"], - META_QKV, fp8_dtype_forward, TE_DType[v.dtype]).view(v.shape) - out_save = cast_from_fp8( - out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), - fp8_meta["scaling_fwd"], META_O, - fp8_dtype_forward, qkv_dtype).view(out_fp8.shape) - - fp8_tensors = (q_fp8, k_fp8, v_fp8, out_fp8, - fp8_meta["scaling_fwd"].scale.clone(), - fp8_meta["scaling_fwd"].scale_inv.clone()) - else: - if _NVTE_DEBUG: - print('[DotProductAttention]: using non-FP8 forward') - out_ret, aux_ctx_tensors = fused_attn_fwd( - is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, - q, k, v, qkv_dtype, fused_attention_backend, attn_bias, - None, None, None, None, None, None, - attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, - rng_gen) - out_save = out_ret - fp8_tensors = (None, None, None, None, None, None) + qkv_layout, attn_bias_type, attn_mask_type, + rng_gen, fused_attention_backend, use_FAv2_bwd): + out, aux_ctx_tensors = fused_attn_fwd( + is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, + q, k, v, qkv_dtype, fused_attention_backend, + seq_offsets_q, seq_offsets_k, seq_offsets_v, + attn_bias, None, None, None, None, None, + attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, + rng_gen) from .cpu_offload import CPUOffloadEnabled if CPUOffloadEnabled: @@ -2711,10 +2344,9 @@ def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seql if tensor is not None: tensor.activation_offloading = True - ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) - qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None) - ctx.save_for_backward(*qkvo_tensors, cu_seqlens_q, cu_seqlens_kv, *fp8_tensors) - ctx.fp8_meta = fp8_meta + + ctx.save_for_backward(q, k, v, out, + cu_seqlens_q, cu_seqlens_kv, seq_offsets_q, seq_offsets_k, seq_offsets_v) ctx.aux_ctx_tensors = aux_ctx_tensors ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_kv = max_seqlen_kv @@ -2741,7 +2373,7 @@ def backward(ctx, d_out): d_out = d_out.contiguous() (q, k, v, out, cu_seqlens_q, cu_seqlens_kv, - q_fp8, k_fp8, v_fp8, out_fp8, fwd_scales, fwd_scale_invs) = ctx.saved_tensors + seq_offsets_q, seq_offsets_k, seq_offsets_v) = ctx.saved_tensors if not ctx.aux_ctx_tensors[0].is_contiguous(): ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous() if ctx.use_FAv2_bwd: @@ -2762,120 +2394,23 @@ def backward(ctx, d_out): dk = dk[..., :d_out.shape[-1]] dv = dv[..., :d_out.shape[-1]] else: - with torch.cuda.nvtx.range("_FusedAttn"): - if ctx.fp8: - if _NVTE_DEBUG: - print('[DotProductAttention]: using FP8 backward') - fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) - fp8_dtype_backward = get_fp8_te_dtype( - ctx.fp8_meta["recipe"], fprop_tensor=False) - if ctx.fp8_meta["recipe"].fp8_mha: - d_out_fp8 = d_out - ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO] = d_out_f8tensor._scale_inv - else: - d_out_fp8 = cast_to_fp8( - d_out.view(-1, d_out.shape[-2] * d_out.shape[-1]), - ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward - ).view(d_out.shape) - dq_fp8, dk_fp8, dv_fp8, *rest = fused_attn_bwd( - ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, - q_fp8, k_fp8, v_fp8, out_fp8, d_out_fp8, - fp8_dtype_forward, fp8_dtype_backward, ctx.aux_ctx_tensors, - ctx.fused_attention_backend, - fwd_scale_invs[META_QKV], # d_scale_qkv, - fwd_scale_invs[META_S], # d_scale_s, - fwd_scale_invs[META_O], # d_scale_o, - ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO], # d_scale_do - ctx.fp8_meta['scaling_bwd'].scale_inv[META_DP], # d_scale_dp - fwd_scales[META_S], # q_scale_s - ctx.fp8_meta['scaling_bwd'].scale[META_DP], # q_scale_dp - ctx.fp8_meta['scaling_bwd'].scale[META_DQKV], # q_scale_dqkv - ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DP], # amax_dp - ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DQKV], # amax_dqkv - ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, - ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type) - if ctx.fp8_meta["recipe"].fp8_mha: - dq = Float8Tensor(data=dq_fp8, - fp8_meta=ctx.fp8_meta, - fp8_meta_forward=False, - fp8_meta_index=META_DQKV, - fp8_dtype=fp8_dtype_backward, - dtype=d_out_f8tensor.dtype, - ) - dk = Float8Tensor(data=dk_fp8, - fp8_meta=ctx.fp8_meta, - fp8_meta_forward=False, - fp8_meta_index=META_DQKV, - fp8_dtype=fp8_dtype_backward, - dtype=d_out_f8tensor.dtype, - ) - dv = Float8Tensor(data=dv_fp8, - fp8_meta=ctx.fp8_meta, - fp8_meta_forward=False, - fp8_meta_index=META_DQKV, - fp8_dtype=fp8_dtype_backward, - dtype=d_out_f8tensor.dtype, - ) - else: - qkv_group = len(ctx.qkv_layout.split('_')) - if qkv_group == 1: - dim = ctx.qkv_layout.find('3') - dqkv_fp8 = _combine_tensors([dq_fp8,dk_fp8,dv_fp8], dim) - dqkv_c_fp8 = dqkv_fp8.view(-1, - dqkv_fp8.shape[-3] * dqkv_fp8.shape[-2] * dqkv_fp8.shape[-1]) - dqkv = cast_from_fp8(dqkv_c_fp8, - ctx.fp8_meta["scaling_bwd"], META_DQKV, - fp8_dtype_backward, ctx.qkv_dtype).view(dqkv_fp8.shape) - dq, dk, dv = _SplitAlongDim.apply(dqkv, dim, [1,1,1]) - dq, dk, dv = [x.squeeze(dim) for x in [dq, dk, dv]] - if qkv_group == 2: - dq = cast_from_fp8( - dq_fp8.view(-1, dq_fp8.shape[-2] * dq_fp8.shape[-1]), - ctx.fp8_meta["scaling_bwd"], META_DQKV, - fp8_dtype_backward, ctx.qkv_dtype).view(dq_fp8.shape) - dim = ctx.qkv_layout.split('_')[1].find('2') - dkv_fp8 = _combine_tensors([dk_fp8,dv_fp8], dim) - dkv_c_fp8 = dkv_fp8.view(-1, - dkv_fp8.shape[-3] * dkv_fp8.shape[-2] * dkv_fp8.shape[-1]) - dkv = cast_from_fp8(dkv_c_fp8, - ctx.fp8_meta["scaling_bwd"], META_DQKV, - fp8_dtype_backward, ctx.qkv_dtype).view(dkv_fp8.shape) - dk, dv = _SplitAlongDim.apply(dkv, dim, [1,1]) - dk, dv = [x.squeeze(dim) for x in [dk, dv]] - if qkv_group == 3: - dq = cast_from_fp8( - dq_fp8.view(-1, dq_fp8.shape[-2] * dq_fp8.shape[-1]), - ctx.fp8_meta["scaling_bwd"], META_DQKV, - fp8_dtype_backward, ctx.qkv_dtype).view(dq_fp8.shape) - dk = cast_from_fp8( - dk_fp8.view(-1, dk_fp8.shape[-2] * dk_fp8.shape[-1]), - ctx.fp8_meta["scaling_bwd"], META_DQKV, - fp8_dtype_backward, ctx.qkv_dtype).view(dk_fp8.shape) - dv = cast_from_fp8( - dv_fp8.view(-1, dv_fp8.shape[-2] * dv_fp8.shape[-1]), - ctx.fp8_meta["scaling_bwd"], META_DQKV, - fp8_dtype_backward, ctx.qkv_dtype).view(dv_fp8.shape) - else: - if _NVTE_DEBUG: - print('[DotProductAttention]: using non-FP8 backward') - if d_out.dtype == torch.uint8: - d_out = d_out_f8tensor.from_float8(q.dtype) - dq, dk, dv, *rest = fused_attn_bwd( - ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, - q, k, v, out, d_out, - ctx.qkv_dtype, ctx.qkv_dtype, ctx.aux_ctx_tensors, - ctx.fused_attention_backend, - None, None, None, None, None, None, None, None, None, None, - ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, - ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type) + dq, dk, dv, *rest = fused_attn_bwd( + ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, + q, k, v, out, d_out, + ctx.qkv_dtype, ctx.aux_ctx_tensors, + ctx.fused_attention_backend, + seq_offsets_q, seq_offsets_k, seq_offsets_v, + None, None, None, None, None, None, None, None, None, + ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, + ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type) # if no_bias or alibi, return dqkv if ctx.attn_bias_type in ["no_bias", "alibi"]: - return (None, None, None, None, None, dq, dk, dv, None, None, None, + return (None, None, None, None, None, None, None, None, dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None) # else, return (dqkv, dbias) - return (None, None, None, None, None, dq, dk, dv, None, rest[0], None, + return (None, None, None, None, None, None, None, None, dq, dk, dv, None, rest[0], None, None, None, None, None, None, None, None, None, None, None, None, None) @@ -2968,6 +2503,9 @@ def forward( qkv_layout: str = "sbh3d", cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_kv: Optional[torch.Tensor] = None, + seq_offsets_q: Optional[torch.Tensor] = None, + seq_offsets_k: Optional[torch.Tensor] = None, + seq_offsets_v: Optional[torch.Tensor] = None, max_seqlen_q: Optional[int] = None, max_seqlen_kv: Optional[int] = None, attn_mask_type: str = "causal", @@ -2983,7 +2521,6 @@ def forward( is_first_microbatch: Optional[bool] = None, ) -> torch.Tensor: """fused attention fprop""" - assert (fused_attention_backend != tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend ), 'No fused attention backend supports this input combination!' @@ -3002,9 +2539,6 @@ def forward( context_parallel = (cp_group is not None) and (get_distributed_world_size(cp_group) != 1) qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()]) - assert ( - qkv_format != 'thd' - ), 'FusedAttention does not support qkv_format = thd!' if qkv_format in ['sbhd', 'bshd']: if qkv_format == 'sbhd': @@ -3040,6 +2574,30 @@ def forward( max_seqlen_kv, key_layer.device, ) + if qkv_format == 'thd': + assert not context_parallel, "thd format not supported with context parallelism!" + assert (max_seqlen_q is not None + and max_seqlen_kv is not None + and cu_seqlens_q is not None + and cu_seqlens_kv is not None + ), "max_seqlen_q/kv and cu_seqlens_q/kv can not be None when qkv_format is thd!" + if (seq_offsets_q is None or seq_offsets_k is None or seq_offsets_v is None): + qkv_group = ''.join([x for x in qkv_layout if x not in 'bst']) + num_heads = query_layer.shape[-2] + num_gqa_groups = key_layer.shape[-2] + head_dim = query_layer.shape[-1] + if qkv_group == 'hd_hd_hd': + seq_offsets_q = num_heads * head_dim * cu_seqlens_q + seq_offsets_k = num_gqa_groups * head_dim * cu_seqlens_kv + seq_offsets_v = num_gqa_groups * head_dim * cu_seqlens_kv + if qkv_group in ['3hd', 'h3d']: + seq_offsets_q = num_heads * head_dim * cu_seqlens_q + seq_offsets_k = num_heads * head_dim * 2 * cu_seqlens_q + seq_offsets_v = num_heads * head_dim * 3 * cu_seqlens_q + if qkv_group in ['hd_2hd', 'hd_h2d']: + seq_offsets_q = num_heads * head_dim * cu_seqlens_q + seq_offsets_k = num_gqa_groups * head_dim * cu_seqlens_kv + seq_offsets_v = num_gqa_groups * head_dim * 2 * cu_seqlens_kv qkv_dtype = TE_DType[query_layer.dtype] @@ -3073,41 +2631,25 @@ def forward( use_fused_attention=True, ) else: - with self.prepare_forward(query_layer, - is_first_microbatch, - num_gemms=3, - allow_non_contiguous=True) as query_layer: - with self.attention_dropout_ctx(): - forced_fp8_dpa = "" - if self.fp8_meta["recipe"].fp8_mha: - if not self.fp8_meta["recipe"].fp8_dpa: - self.fp8_meta["recipe"].fp8_dpa = True - forced_fp8_dpa = " (forced)" - if _NVTE_DEBUG: - print("[DotProductAttention]: " - f"""using fp8_recipe.fp8_mha={self.fp8_meta["recipe"].fp8_mha}, """ - f"""fp8_recipe.fp8_dpa={self.fp8_meta["recipe"].fp8_dpa}""" - f"""{forced_fp8_dpa} and """ - f"""NVTE_FP8_DPA_BWD={int(os.getenv("NVTE_FP8_DPA_BWD", "1"))}""") - output = FusedAttnFunc.apply( - self.training, - max_seqlen_q, max_seqlen_kv, - cu_seqlens_q, cu_seqlens_kv, - query_layer, key_layer, value_layer, - qkv_dtype, - core_attention_bias, - 1.0/self.norm_factor, - self.attention_dropout if self.training else 0.0, - fast_zero_fill, - qkv_layout, - core_attention_bias_type, - attn_mask_type, - None, # rng_gen - fused_attention_backend, - use_FAv2_bwd, - self.fp8 and self.fp8_meta["recipe"].fp8_dpa, - self.fp8_meta, - ) + with self.attention_dropout_ctx(): + output = FusedAttnFunc.apply( + self.training, + max_seqlen_q, max_seqlen_kv, + cu_seqlens_q, cu_seqlens_kv, + seq_offsets_q, seq_offsets_k, seq_offsets_v, + query_layer, key_layer, value_layer, + qkv_dtype, + core_attention_bias, + 1.0/self.norm_factor, + self.attention_dropout if self.training else 0.0, + fast_zero_fill, + qkv_layout, + core_attention_bias_type, + attn_mask_type, + None, # rng_gen + fused_attention_backend, + use_FAv2_bwd, + ) # ...hd -> ...(hd) return output.view(*output.shape[:-2], -1) @@ -3367,6 +2909,9 @@ def forward( qkv_format: Optional[str] = None, cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_kv: Optional[torch.Tensor] = None, + seq_offsets_q: Optional[torch.Tensor] = None, + seq_offsets_k: Optional[torch.Tensor] = None, + seq_offsets_v: Optional[torch.Tensor] = None, max_seqlen_q: Optional[int] = None, max_seqlen_kv: Optional[int] = None, attn_mask_type: Optional[str] = None, @@ -3444,6 +2989,15 @@ def forward( cu_seqlens_kv: Optional[torch.Tensor], default = `None` Cumulative sum of sequence lengths in a batch for `key_layer` and `value_layer`, with shape [batch_size + 1] and dtype torch.int32. + seqlen_offsets_q: Optional[torch.Tensor], default = `None` + Cumulative offset of different sequences in a batch for `query_layer`, + with shape [batch_size + 1] and dtype torch.int32. Required for `thd` layouts. + seqlen_offsets_k: Optional[torch.Tensor], default = `None` + Cumulative offset of different sequences in a batch for `key_layer`, + with shape [batch_size + 1] and dtype torch.int32. Required for `thd` layouts. + seqlen_offsets_v: Optional[torch.Tensor], default = `None` + Cumulative offset of different sequences in a batch for `value_layer`, + with shape [batch_size + 1] and dtype torch.int32. Required for `thd` layouts. max_seqlen_q: Optional[int], default = `None` Maximum sequence length in `query_layer`. Calculated from `cu_seqlens_q` if not provided. @@ -3515,6 +3069,9 @@ def forward( assert (attn_mask_type in AttnMaskTypes ), f"Attention mask type {attn_mask_type} is not supported!" + if qkv_format == 'thd': + assert ('padding' in attn_mask_type + ), "Attention mask type must be padding or padding_causal for qkv_format=thd!" if self.rng_states_tracker is not None and is_graph_capturing(): assert ( @@ -3606,10 +3163,10 @@ def forward( ), "cu_seqlens_q and cu_seqlens_q must both be in dtype torch.int32!" if max_seqlen_q is None: seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] - max_seqlen_q = seqlens_q.max().item() + max_seqlen_q = pow(2, math.ceil(math.log2(seqlens_q.max().item()))) if max_seqlen_kv is None: seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] - max_seqlen_kv = seqlens_kv.max().item() + max_seqlen_kv = pow(2, math.ceil(math.log2(seqlens_kv.max().item()))) if qkv_format in ['sbhd', 'bshd']: assert (all(len(x.shape) == 4 for x in (query_layer, key_layer, value_layer)) @@ -3647,6 +3204,10 @@ def forward( # The following section filters out some backends based on # certain asserts before executing the forward pass. + # Filter: QKV layout. + if qkv_format == 'thd': + use_unfused_attention = False + # Filter: ONNX export. if is_in_onnx_export_mode(): use_flash_attention = False @@ -3848,8 +3409,9 @@ def forward( qkv_layout=qkv_layout, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, - max_seqlen_q=max_seqlen_q, - max_seqlen_kv=max_seqlen_kv, + seq_offsets_q=seq_offsets_q, + seq_offsets_k=seq_offsets_k, + seq_offsets_v=seq_offsets_v, attn_mask_type=attn_mask_type, attention_mask=attention_mask, fused_attention_backend=fused_attention_backend, @@ -3867,8 +3429,9 @@ def forward( qkv_layout=qkv_layout, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, - max_seqlen_q=max_seqlen_q, - max_seqlen_kv=max_seqlen_kv, + seq_offsets_q=seq_offsets_q, + seq_offsets_k=seq_offsets_k, + seq_offsets_v=seq_offsets_v, attn_mask_type=attn_mask_type, attention_mask=attention_mask, fused_attention_backend=fused_attention_backend, diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 74030ba809..1e0bc53fe1 100755 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -82,6 +82,9 @@ def fused_attn_fwd_qkvpacked( qkv: torch.Tensor, qkv_dtype: tex.DType, fused_attention_backend: tex.NVTE_Fused_Attn_Backend, + seq_offsets_q: torch.Tensor = None, + seq_offsets_k: torch.Tensor = None, + seq_offsets_v: torch.Tensor = None, attn_bias: torch.Tensor = None, d_scale_qkv: torch.Tensor = None, d_scale_s: torch.Tensor = None, @@ -115,6 +118,12 @@ def fused_attn_fwd_qkvpacked( data type of QKV; in tex.DType, not torch.dtype fused_attention_backend: tex.NVTE_Fused_Attn_Backend please see FusedAttention module for details on supported backends. + seq_offsets_q: torch.Tensor, default = None + cumulative sequence offsets for Q; shape [batch_size + 1] + seq_offsets_k: torch.Tensor, default = None + cumulative sequence offsets for K; shape [batch_size + 1] + seq_offsets_v: torch.Tensor, default = None + cumulative sequence offsets for V; shape [batch_size + 1] attn_bias: torch.Tensor, default = None input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias"; shape [1, num_heads, max_seqlen, max_seqlen], same data type as qkv @@ -225,7 +234,8 @@ def fused_attn_fwd_qkvpacked( max_seqlen, is_training, attn_scale, dropout, fast_zero_fill, QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], cu_seqlens, qkv, qkv_dtype, - d_scale_qkv, d_scale_s, q_scale_s, q_scale_o, amax_s, amax_o, attn_bias, + seq_offsets_q, seq_offsets_k, seq_offsets_v, + d_scale_qkv, q_scale_s, q_scale_o, amax_s, amax_o, attn_bias, rng_gen, rng_elts_per_thread, ) @@ -243,6 +253,9 @@ def fused_attn_bwd_qkvpacked( dqkv_dtype: tex.DType, aux_ctx_tensors: List[torch.Tensor], fused_attention_backend: tex.NVTE_Fused_Attn_Backend, + seq_offsets_q: torch.Tensor = None, + seq_offsets_k: torch.Tensor = None, + seq_offsets_v: torch.Tensor = None, d_scale_qkv: torch.Tensor = None, d_scale_s: torch.Tensor = None, d_scale_o: torch.Tensor = None, @@ -286,6 +299,12 @@ def fused_attn_bwd_qkvpacked( e.g. aux_ctx_tensors = [M, ZInv, rng_state] fused_attention_backend: tex.NVTE_Fused_Attn_Backend please see FusedAttention module for details on supported backends. + seq_offsets_q: torch.Tensor, default = None + cumulative sequence offsets for Q; shape [batch_size + 1] + seq_offsets_k: torch.Tensor, default = None + cumulative sequence offsets for K; shape [batch_size + 1] + seq_offsets_v: torch.Tensor, default = None + cumulative sequence offsets for V; shape [batch_size + 1] d_scale_qkv: torch.Tensor, default = None input tensor for the dequantization of QKV in FP8 computations d_scale_s: torch.Tensor, default = None @@ -360,8 +379,9 @@ def fused_attn_bwd_qkvpacked( output_tensors = tex.fused_attn_bwd_qkvpacked( max_seqlen, attn_scale, dropout, fast_zero_fill, QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], - cu_seqlens, qkv, o, d_o, qkv_dtype, dqkv_dtype, aux_ctx_tensors, - d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, d_scale_dp, + cu_seqlens, qkv, o, d_o, qkv_dtype, aux_ctx_tensors, + seq_offsets_q, seq_offsets_k, seq_offsets_v, + d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv, ) @@ -378,6 +398,9 @@ def fused_attn_fwd_kvpacked( kv: torch.Tensor, qkv_dtype: tex.DType, fused_attention_backend: tex.NVTE_Fused_Attn_Backend, + seq_offsets_q: torch.Tensor = None, + seq_offsets_k: torch.Tensor = None, + seq_offsets_v: torch.Tensor = None, attn_bias: torch.Tensor = None, d_scale_qkv: torch.Tensor = None, d_scale_s: torch.Tensor = None, @@ -418,6 +441,12 @@ def fused_attn_fwd_kvpacked( data type of Q and KV; in tex.DType, not torch.dtype fused_attention_backend: tex.NVTE_Fused_Attn_Backend please see FusedAttention module for details on supported backends. + seq_offsets_q: torch.Tensor, default = None + cumulative sequence offsets for Q; shape [batch_size + 1] + seq_offsets_k: torch.Tensor, default = None + cumulative sequence offsets for K; shape [batch_size + 1] + seq_offsets_v: torch.Tensor, default = None + cumulative sequence offsets for V; shape [batch_size + 1] attn_bias: torch.Tensor, default = None input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias"; shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q and kv @@ -529,7 +558,8 @@ def fused_attn_fwd_kvpacked( max_seqlen_q, max_seqlen_kv, is_training, attn_scale, dropout, fast_zero_fill, QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], cu_seqlens_q, cu_seqlens_kv, q, kv, qkv_dtype, - d_scale_qkv, d_scale_s, q_scale_s, q_scale_o, amax_s, amax_o, + seq_offsets_q, seq_offsets_k, seq_offsets_v, + d_scale_qkv, q_scale_s, q_scale_o, amax_s, amax_o, attn_bias, rng_gen, rng_elts_per_thread, ) @@ -550,6 +580,9 @@ def fused_attn_bwd_kvpacked( dqkv_dtype: tex.DType, aux_ctx_tensors: List[torch.Tensor], fused_attention_backend: tex.NVTE_Fused_Attn_Backend, + seq_offsets_q: torch.Tensor = None, + seq_offsets_k: torch.Tensor = None, + seq_offsets_v: torch.Tensor = None, d_scale_qkv: torch.Tensor = None, d_scale_s: torch.Tensor = None, d_scale_o: torch.Tensor = None, @@ -600,6 +633,12 @@ def fused_attn_bwd_kvpacked( e.g. aux_ctx_tensors = [M, ZInv, rng_state] fused_attention_backend: tex.NVTE_Fused_Attn_Backend please see FusedAttention module for details on supported backends. + seq_offsets_q: torch.Tensor, default = None + cumulative sequence offsets for Q; shape [batch_size + 1] + seq_offsets_k: torch.Tensor, default = None + cumulative sequence offsets for K; shape [batch_size + 1] + seq_offsets_v: torch.Tensor, default = None + cumulative sequence offsets for V; shape [batch_size + 1] d_scale_qkv: torch.Tensor, default = None input tensor for the dequantization of QKV in FP8 computations d_scale_s: torch.Tensor, default = None @@ -678,8 +717,9 @@ def fused_attn_bwd_kvpacked( output_tensors = tex.fused_attn_bwd_kvpacked( max_seqlen_q, max_seqlen_kv, attn_scale, dropout, fast_zero_fill, QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], - cu_seqlens_q, cu_seqlens_kv, q, kv, o, d_o, qkv_dtype, dqkv_dtype, aux_ctx_tensors, - d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, d_scale_dp, + cu_seqlens_q, cu_seqlens_kv, q, kv, o, d_o, qkv_dtype, aux_ctx_tensors, + seq_offsets_q, seq_offsets_k, seq_offsets_v, + d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv, ) @@ -697,6 +737,9 @@ def fused_attn_fwd( v: torch.Tensor, qkv_dtype: tex.DType, fused_attention_backend: tex.NVTE_Fused_Attn_Backend, + seq_offsets_q: torch.Tensor = None, + seq_offsets_k: torch.Tensor = None, + seq_offsets_v: torch.Tensor = None, attn_bias: torch.Tensor = None, d_scale_qkv: torch.Tensor = None, d_scale_s: torch.Tensor = None, @@ -741,6 +784,12 @@ def fused_attn_fwd( data type of Q, K and V; in tex.DType, not torch.dtype fused_attention_backend: tex.NVTE_Fused_Attn_Backend please see FusedAttention module for details on supported backends. + seq_offsets_q: torch.Tensor, default = None + cumulative sequence offsets for Q; shape [batch_size + 1] + seq_offsets_k: torch.Tensor, default = None + cumulative sequence offsets for K; shape [batch_size + 1] + seq_offsets_v: torch.Tensor, default = None + cumulative sequence offsets for V; shape [batch_size + 1] attn_bias: torch.Tensor, default = None input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias"; shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q, k and v @@ -840,8 +889,10 @@ def fused_attn_fwd( output_tensors = tex.fused_attn_fwd( max_seqlen_q, max_seqlen_kv, is_training, attn_scale, dropout, fast_zero_fill, QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], - cu_seqlens_q, cu_seqlens_kv, q, k, v, qkv_dtype, - d_scale_qkv, d_scale_s, q_scale_s, q_scale_o, amax_s, amax_o, + cu_seqlens_q, cu_seqlens_kv, + q, k, v, qkv_dtype, + seq_offsets_q, seq_offsets_k, seq_offsets_v, + d_scale_qkv, q_scale_s, q_scale_o, amax_s, amax_o, attn_bias, rng_gen, rng_elts_per_thread, ) @@ -863,6 +914,9 @@ def fused_attn_bwd( dqkv_dtype: tex.DType, aux_ctx_tensors: List[torch.Tensor], fused_attention_backend: tex.NVTE_Fused_Attn_Backend, + seq_offsets_q: torch.Tensor = None, + seq_offsets_k: torch.Tensor = None, + seq_offsets_v: torch.Tensor = None, d_scale_qkv: torch.Tensor = None, d_scale_s: torch.Tensor = None, d_scale_o: torch.Tensor = None, @@ -916,6 +970,12 @@ def fused_attn_bwd( e.g. aux_ctx_tensors = [M, ZInv, rng_state] fused_attention_backend: tex.NVTE_Fused_Attn_Backend please see FusedAttention module for details on supported backends. + seq_offsets_q: torch.Tensor, default = None + cumulative sequence offsets for Q; shape [batch_size + 1] + seq_offsets_k: torch.Tensor, default = None + cumulative sequence offsets for K; shape [batch_size + 1] + seq_offsets_v: torch.Tensor, default = None + cumulative sequence offsets for V; shape [batch_size + 1] d_scale_qkv: torch.Tensor, default = None input tensor for the dequantization of Q, K and V in FP8 computations d_scale_s: torch.Tensor, default = None @@ -1001,8 +1061,10 @@ def fused_attn_bwd( output_tensors = tex.fused_attn_bwd( max_seqlen_q, max_seqlen_kv, attn_scale, dropout, fast_zero_fill, QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], - cu_seqlens_q, cu_seqlens_kv, q, k, v, o, d_o, qkv_dtype, dqkv_dtype, aux_ctx_tensors, - d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, d_scale_dp, + cu_seqlens_q, cu_seqlens_kv, + q, k, v, o, d_o, qkv_dtype, aux_ctx_tensors, + seq_offsets_q, seq_offsets_k, seq_offsets_v, + d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv, ) diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index abbecb1609..2f552fe28f 100755 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -31,6 +31,9 @@ std::vector fused_attn_fwd_qkvpacked( const at::Tensor cu_seqlens, const at::Tensor QKV, const transformer_engine::DType qkv_type, + const c10::optional seq_offsets_q, + const c10::optional seq_offsets_k, + const c10::optional seq_offsets_v, const c10::optional descale_QKV, const c10::optional descale_S, const c10::optional scale_S, @@ -54,6 +57,9 @@ std::vector fused_attn_bwd_qkvpacked( const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type, const std::vector Aux_CTX_Tensors, + const c10::optional seq_offsets_q, + const c10::optional seq_offsets_k, + const c10::optional seq_offsets_v, const c10::optional descale_QKV, const c10::optional descale_S, const c10::optional descale_O, @@ -76,6 +82,9 @@ std::vector fused_attn_fwd_kvpacked( const at::Tensor Q, const at::Tensor KV, const transformer_engine::DType qkv_type, + const c10::optional seq_offsets_q, + const c10::optional seq_offsets_k, + const c10::optional seq_offsets_v, const c10::optional descale_QKV, const c10::optional descale_S, const c10::optional scale_S, @@ -101,6 +110,9 @@ std::vector fused_attn_bwd_kvpacked( const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type, const std::vector Aux_CTX_Tensors, + const c10::optional seq_offsets_q, + const c10::optional seq_offsets_k, + const c10::optional seq_offsets_v, const c10::optional descale_QKV, const c10::optional descale_S, const c10::optional descale_O, @@ -124,6 +136,9 @@ std::vector fused_attn_fwd( const at::Tensor K, const at::Tensor V, const transformer_engine::DType qkv_type, + const c10::optional seq_offsets_q, + const c10::optional seq_offsets_k, + const c10::optional seq_offsets_v, const c10::optional descale_QKV, const c10::optional descale_S, const c10::optional scale_S, @@ -150,6 +165,9 @@ std::vector fused_attn_bwd( const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type, const std::vector Aux_CTX_Tensors, + const c10::optional seq_offsets_q, + const c10::optional seq_offsets_k, + const c10::optional seq_offsets_v, const c10::optional descale_QKV, const c10::optional descale_S, const c10::optional descale_O, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index cc747655c4..037ae72b2b 100755 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -96,6 +96,9 @@ std::vector fused_attn_fwd_qkvpacked( const at::Tensor cu_seqlens, const at::Tensor QKV, const transformer_engine::DType qkv_type, + const c10::optional seq_offsets_q, + const c10::optional seq_offsets_k, + const c10::optional seq_offsets_v, const c10::optional descale_QKV, const c10::optional descale_S, const c10::optional scale_S, @@ -123,6 +126,7 @@ std::vector fused_attn_fwd_qkvpacked( // construct NVTE tensors TensorWrapper te_QKV, te_S, te_O, te_Bias, te_cu_seqlens; + TensorWrapper te_seq_offsets_q, te_seq_offsets_k, te_seq_offsets_v; if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { // FP8 auto h = q_shape[q_shape.size() - 2]; @@ -169,6 +173,24 @@ std::vector fused_attn_fwd_qkvpacked( te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens.data_ptr(), cu_seqlens_shape, DType::kInt32, nullptr, nullptr, nullptr); + if ((seq_offsets_q.has_value()) && (seq_offsets_k.has_value()) && (seq_offsets_v.has_value())) { + auto seq_offsets_q_sizes = seq_offsets_q.value().sizes().vec(); + std::vector seq_offsets_q_shape{ + seq_offsets_q_sizes.begin(), seq_offsets_q_sizes.end()}; + auto seq_offsets_k_sizes = seq_offsets_k.value().sizes().vec(); + std::vector seq_offsets_k_shape{ + seq_offsets_k_sizes.begin(), seq_offsets_k_sizes.end()}; + auto seq_offsets_v_sizes = seq_offsets_v.value().sizes().vec(); + std::vector seq_offsets_v_shape{ + seq_offsets_v_sizes.begin(), seq_offsets_v_sizes.end()}; + te_seq_offsets_q = makeTransformerEngineTensor(seq_offsets_q.value().data_ptr(), + seq_offsets_q_shape, DType::kInt32, nullptr, nullptr, nullptr); + te_seq_offsets_k = makeTransformerEngineTensor(seq_offsets_k.value().data_ptr(), + seq_offsets_k_shape, DType::kInt32, nullptr, nullptr, nullptr); + te_seq_offsets_v = makeTransformerEngineTensor(seq_offsets_v.value().data_ptr(), + seq_offsets_v_shape, DType::kInt32, nullptr, nullptr, nullptr); + } + // extract random number generator seed and offset auto gen = at::get_generator_or_default( rng_gen, at::cuda::detail::getDefaultCUDAGenerator()); @@ -193,6 +215,9 @@ std::vector fused_attn_fwd_qkvpacked( te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens.data(), + te_seq_offsets_q.data(), + te_seq_offsets_k.data(), + te_seq_offsets_v.data(), te_rng_state.data(), max_seqlen, is_training, attn_scale, p_dropout, @@ -241,6 +266,9 @@ std::vector fused_attn_fwd_qkvpacked( te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens.data(), + te_seq_offsets_q.data(), + te_seq_offsets_k.data(), + te_seq_offsets_v.data(), te_rng_state.data(), max_seqlen, is_training, attn_scale, p_dropout, @@ -266,6 +294,9 @@ std::vector fused_attn_bwd_qkvpacked( const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type, const std::vector Aux_CTX_Tensors, + const c10::optional seq_offsets_q, + const c10::optional seq_offsets_k, + const c10::optional seq_offsets_v, const c10::optional descale_QKV, const c10::optional descale_S, const c10::optional descale_O, @@ -380,6 +411,25 @@ std::vector fused_attn_bwd_qkvpacked( TensorWrapper te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens.data_ptr(), cu_seqlens_shape, DType::kInt32, nullptr, nullptr, nullptr); + TensorWrapper te_seq_offsets_q, te_seq_offsets_k, te_seq_offsets_v; + if ((seq_offsets_q.has_value()) && (seq_offsets_k.has_value()) && (seq_offsets_v.has_value())) { + auto seq_offsets_q_sizes = seq_offsets_q.value().sizes().vec(); + std::vector seq_offsets_q_shape{ + seq_offsets_q_sizes.begin(), seq_offsets_q_sizes.end()}; + auto seq_offsets_k_sizes = seq_offsets_k.value().sizes().vec(); + std::vector seq_offsets_k_shape{ + seq_offsets_k_sizes.begin(), seq_offsets_k_sizes.end()}; + auto seq_offsets_v_sizes = seq_offsets_v.value().sizes().vec(); + std::vector seq_offsets_v_shape{ + seq_offsets_v_sizes.begin(), seq_offsets_v_sizes.end()}; + te_seq_offsets_q = makeTransformerEngineTensor(seq_offsets_q.value().data_ptr(), + seq_offsets_q_shape, DType::kInt32, nullptr, nullptr, nullptr); + te_seq_offsets_k = makeTransformerEngineTensor(seq_offsets_k.value().data_ptr(), + seq_offsets_k_shape, DType::kInt32, nullptr, nullptr, nullptr); + te_seq_offsets_v = makeTransformerEngineTensor(seq_offsets_v.value().data_ptr(), + seq_offsets_v_shape, DType::kInt32, nullptr, nullptr, nullptr); + } + // create workspace TensorWrapper workspace; @@ -394,6 +444,9 @@ std::vector fused_attn_bwd_qkvpacked( te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), + te_seq_offsets_q.data(), + te_seq_offsets_k.data(), + te_seq_offsets_v.data(), max_seqlen, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, @@ -417,6 +470,9 @@ std::vector fused_attn_bwd_qkvpacked( te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), + te_seq_offsets_q.data(), + te_seq_offsets_k.data(), + te_seq_offsets_v.data(), max_seqlen, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, @@ -439,6 +495,9 @@ std::vector fused_attn_fwd_kvpacked( const at::Tensor Q, const at::Tensor KV, const transformer_engine::DType qkv_type, + const c10::optional seq_offsets_q, + const c10::optional seq_offsets_k, + const c10::optional seq_offsets_v, const c10::optional descale_QKV, const c10::optional descale_S, const c10::optional scale_S, @@ -462,6 +521,7 @@ std::vector fused_attn_fwd_kvpacked( // construct NVTE tensors TensorWrapper te_Q, te_KV, te_S, te_O, te_Bias, te_cu_seqlens_q, te_cu_seqlens_kv; + TensorWrapper te_seq_offsets_q, te_seq_offsets_k, te_seq_offsets_v; if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { // FP8 auto h = q_shape[q_shape.size() - 2]; @@ -516,6 +576,24 @@ std::vector fused_attn_fwd_kvpacked( te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape, DType::kInt32, nullptr, nullptr, nullptr); + if ((seq_offsets_q.has_value()) && (seq_offsets_k.has_value()) && (seq_offsets_v.has_value())) { + auto seq_offsets_q_sizes = seq_offsets_q.value().sizes().vec(); + std::vector seq_offsets_q_shape{ + seq_offsets_q_sizes.begin(), seq_offsets_q_sizes.end()}; + auto seq_offsets_k_sizes = seq_offsets_k.value().sizes().vec(); + std::vector seq_offsets_k_shape{ + seq_offsets_k_sizes.begin(), seq_offsets_k_sizes.end()}; + auto seq_offsets_v_sizes = seq_offsets_v.value().sizes().vec(); + std::vector seq_offsets_v_shape{ + seq_offsets_v_sizes.begin(), seq_offsets_v_sizes.end()}; + te_seq_offsets_q = makeTransformerEngineTensor(seq_offsets_q.value().data_ptr(), + seq_offsets_q_shape, DType::kInt32, nullptr, nullptr, nullptr); + te_seq_offsets_k = makeTransformerEngineTensor(seq_offsets_k.value().data_ptr(), + seq_offsets_k_shape, DType::kInt32, nullptr, nullptr, nullptr); + te_seq_offsets_v = makeTransformerEngineTensor(seq_offsets_v.value().data_ptr(), + seq_offsets_v_shape, DType::kInt32, nullptr, nullptr, nullptr); + } + // extract rng seed and offset auto gen = at::get_generator_or_default( rng_gen, at::cuda::detail::getDefaultCUDAGenerator()); @@ -542,6 +620,9 @@ std::vector fused_attn_fwd_kvpacked( &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), + te_seq_offsets_q.data(), + te_seq_offsets_k.data(), + te_seq_offsets_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout, @@ -592,6 +673,9 @@ std::vector fused_attn_fwd_kvpacked( &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), + te_seq_offsets_q.data(), + te_seq_offsets_k.data(), + te_seq_offsets_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout, @@ -620,6 +704,9 @@ std::vector fused_attn_bwd_kvpacked( const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type, const std::vector Aux_CTX_Tensors, + const c10::optional seq_offsets_q, + const c10::optional seq_offsets_k, + const c10::optional seq_offsets_v, const c10::optional descale_QKV, const c10::optional descale_S, const c10::optional descale_O, @@ -725,6 +812,25 @@ std::vector fused_attn_bwd_kvpacked( te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape, DType::kInt32, nullptr, nullptr, nullptr); + TensorWrapper te_seq_offsets_q, te_seq_offsets_k, te_seq_offsets_v; + if ((seq_offsets_q.has_value()) && (seq_offsets_k.has_value()) && (seq_offsets_v.has_value())) { + auto seq_offsets_q_sizes = seq_offsets_q.value().sizes().vec(); + std::vector seq_offsets_q_shape{ + seq_offsets_q_sizes.begin(), seq_offsets_q_sizes.end()}; + auto seq_offsets_k_sizes = seq_offsets_k.value().sizes().vec(); + std::vector seq_offsets_k_shape{ + seq_offsets_k_sizes.begin(), seq_offsets_k_sizes.end()}; + auto seq_offsets_v_sizes = seq_offsets_v.value().sizes().vec(); + std::vector seq_offsets_v_shape{ + seq_offsets_v_sizes.begin(), seq_offsets_v_sizes.end()}; + te_seq_offsets_q = makeTransformerEngineTensor(seq_offsets_q.value().data_ptr(), + seq_offsets_q_shape, DType::kInt32, nullptr, nullptr, nullptr); + te_seq_offsets_k = makeTransformerEngineTensor(seq_offsets_k.value().data_ptr(), + seq_offsets_k_shape, DType::kInt32, nullptr, nullptr, nullptr); + te_seq_offsets_v = makeTransformerEngineTensor(seq_offsets_v.value().data_ptr(), + seq_offsets_v_shape, DType::kInt32, nullptr, nullptr, nullptr); + } + // convert auxiliary tensors from forward to NVTETensors NVTETensorPack nvte_aux_tensor_pack; nvte_tensor_pack_create(&nvte_aux_tensor_pack); @@ -771,6 +877,9 @@ std::vector fused_attn_bwd_kvpacked( te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), + te_seq_offsets_q.data(), + te_seq_offsets_k.data(), + te_seq_offsets_v.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, @@ -797,6 +906,9 @@ std::vector fused_attn_bwd_kvpacked( te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), + te_seq_offsets_q.data(), + te_seq_offsets_k.data(), + te_seq_offsets_v.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, @@ -820,6 +932,9 @@ std::vector fused_attn_fwd( const at::Tensor K, const at::Tensor V, const transformer_engine::DType qkv_type, + const c10::optional seq_offsets_q, + const c10::optional seq_offsets_k, + const c10::optional seq_offsets_v, const c10::optional descale_QKV, const c10::optional descale_S, const c10::optional scale_S, @@ -844,6 +959,7 @@ std::vector fused_attn_fwd( // construct NVTE tensors TensorWrapper te_Q, te_K, te_V, te_S, te_O, te_Bias; TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv; + TensorWrapper te_seq_offsets_q, te_seq_offsets_k, te_seq_offsets_v; if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { // FP8 auto h = q_shape[q_shape.size() - 2]; @@ -902,6 +1018,24 @@ std::vector fused_attn_fwd( te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape, DType::kInt32, nullptr, nullptr, nullptr); + if ((seq_offsets_q.has_value()) && (seq_offsets_k.has_value()) && (seq_offsets_v.has_value())) { + auto seq_offsets_q_sizes = seq_offsets_q.value().sizes().vec(); + std::vector seq_offsets_q_shape{ + seq_offsets_q_sizes.begin(), seq_offsets_q_sizes.end()}; + auto seq_offsets_k_sizes = seq_offsets_k.value().sizes().vec(); + std::vector seq_offsets_k_shape{ + seq_offsets_k_sizes.begin(), seq_offsets_k_sizes.end()}; + auto seq_offsets_v_sizes = seq_offsets_v.value().sizes().vec(); + std::vector seq_offsets_v_shape{ + seq_offsets_v_sizes.begin(), seq_offsets_v_sizes.end()}; + te_seq_offsets_q = makeTransformerEngineTensor(seq_offsets_q.value().data_ptr(), + seq_offsets_q_shape, DType::kInt32, nullptr, nullptr, nullptr); + te_seq_offsets_k = makeTransformerEngineTensor(seq_offsets_k.value().data_ptr(), + seq_offsets_k_shape, DType::kInt32, nullptr, nullptr, nullptr); + te_seq_offsets_v = makeTransformerEngineTensor(seq_offsets_v.value().data_ptr(), + seq_offsets_v_shape, DType::kInt32, nullptr, nullptr, nullptr); + } + // extract rng seed and offset auto gen = at::get_generator_or_default( rng_gen, at::cuda::detail::getDefaultCUDAGenerator()); @@ -930,6 +1064,9 @@ std::vector fused_attn_fwd( &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), + te_seq_offsets_q.data(), + te_seq_offsets_k.data(), + te_seq_offsets_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout, @@ -981,6 +1118,9 @@ std::vector fused_attn_fwd( &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), + te_seq_offsets_q.data(), + te_seq_offsets_k.data(), + te_seq_offsets_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout, @@ -1010,6 +1150,9 @@ std::vector fused_attn_bwd( const transformer_engine::DType qkv_type, const transformer_engine::DType dqkv_type, const std::vector Aux_CTX_Tensors, + const c10::optional seq_offsets_q, + const c10::optional seq_offsets_k, + const c10::optional seq_offsets_v, const c10::optional descale_QKV, const c10::optional descale_S, const c10::optional descale_O, @@ -1183,6 +1326,25 @@ std::vector fused_attn_bwd( te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape, DType::kInt32, nullptr, nullptr, nullptr); + TensorWrapper te_seq_offsets_q, te_seq_offsets_k, te_seq_offsets_v; + if ((seq_offsets_q.has_value()) && (seq_offsets_k.has_value()) && (seq_offsets_v.has_value())) { + auto seq_offsets_q_sizes = seq_offsets_q.value().sizes().vec(); + std::vector seq_offsets_q_shape{ + seq_offsets_q_sizes.begin(), seq_offsets_q_sizes.end()}; + auto seq_offsets_k_sizes = seq_offsets_k.value().sizes().vec(); + std::vector seq_offsets_k_shape{ + seq_offsets_k_sizes.begin(), seq_offsets_k_sizes.end()}; + auto seq_offsets_v_sizes = seq_offsets_v.value().sizes().vec(); + std::vector seq_offsets_v_shape{ + seq_offsets_v_sizes.begin(), seq_offsets_v_sizes.end()}; + te_seq_offsets_q = makeTransformerEngineTensor(seq_offsets_q.value().data_ptr(), + seq_offsets_q_shape, DType::kInt32, nullptr, nullptr, nullptr); + te_seq_offsets_k = makeTransformerEngineTensor(seq_offsets_k.value().data_ptr(), + seq_offsets_k_shape, DType::kInt32, nullptr, nullptr, nullptr); + te_seq_offsets_v = makeTransformerEngineTensor(seq_offsets_v.value().data_ptr(), + seq_offsets_v_shape, DType::kInt32, nullptr, nullptr, nullptr); + } + // convert auxiliary tensors from forward to NVTETensors NVTETensorPack nvte_aux_tensor_pack; nvte_tensor_pack_create(&nvte_aux_tensor_pack); @@ -1231,6 +1393,9 @@ std::vector fused_attn_bwd( te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), + te_seq_offsets_q.data(), + te_seq_offsets_k.data(), + te_seq_offsets_v.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, @@ -1259,6 +1424,9 @@ std::vector fused_attn_bwd( te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), + te_seq_offsets_q.data(), + te_seq_offsets_k.data(), + te_seq_offsets_v.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, From 906f74e86fbf94d5bb82d3acfbe61b5406561c4f Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Mon, 15 Apr 2024 23:47:58 +0000 Subject: [PATCH 077/244] update test results Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Signed-off-by: Pawel Gadzinski --- tests/pytorch/fused_attn/test_fused_attn.py | 29 ++++++++------------- 1 file changed, 11 insertions(+), 18 deletions(-) diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index a93fe75b16..ef8db8be51 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -514,25 +514,18 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout): model_configs_layout_thd = { # test: b, h, hg, d, sq, skv, p, mask, bias - #"layout_0_1": ModelConfig(1, 16, 16, 64, 128, 128, 0.0, "padding", "no_bias"), #all 5 pass - #"layout_0_2": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding", "no_bias"), #th3d/thd_t2hd - #"layout_0_3": ModelConfig(1, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"), #all 5 pass + "layout_0_1": ModelConfig(1, 16, 16, 64, 128, 128, 0.0, "padding", "no_bias"), #all 5 pass + "layout_0_2": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding", "no_bias"), #th3d/thd_t2hd + "layout_0_3": ModelConfig(1, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"), #all 5 pass "layout_0_4": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"), #th3d/thd_t2hd - #"layout_1_1": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), #all 5 pass - #"layout_1_2": ModelConfig(8, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), #th3d/t3hd/thd_t2hd - #"layout_1_3": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), #all 5 pass - #"layout_1_4": ModelConfig(8, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), #th3d/t3hd/thd_t2hd - #"layout_2_1": ModelConfig(1, 16, 16, 128, 128, 128, 0.0, "padding", "no_bias"), #all 5 fail - #"layout_2_2": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"), #all 5 pass - #"layout_2_3": ModelConfig(1, 16, 16, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), #all 5 fail - #"layout_2_4": ModelConfig(8, 16, 16, 64, 2048, 4096, 0.0, "padding_causal", "no_bias"), #all 5 skipped - -# Note: all failed tests were due to mismatches (30-50%) except for layout_2_1 tests which were exec errors: -#E RuntimeError: /code/fmha/github3/pr-thd/TransformerEngine/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu:633 in function operator(): cuDNN Error: CUDNN_BACKEND_OPERATIONGRAPH_DESCRIPTOR: Check and Set CUDNN_ATTR_OPERATIONGRAPH_HANDLE cudnn_status: CUDNN_STATUS_BAD_PARAM. For more information, enable cuDNN error logging by setting CUDNN_LOGERR_DBG=1 and CUDNN_LOGDEST_DBG=stderr in the environment. -#E! CuDNN (v8907) function cudnnCreate() called: -#e! Error: CUDNN_STATUS_INTERNAL_ERROR; Reason: cudaStreamCreateWithFlags(&ctx->streamPool[0][i], 0x01) != cudaSuccess -#e! Time: 2024-03-21T03:36:55.887897 (0d+0h+0m+0s since start) -#e! Process=8573; Thread=8678; GPU=NULL; Handle=NULL; StreamId=NULL. + "layout_1_1": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), #all 5 pass + "layout_1_2": ModelConfig(8, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), #th3d/t3hd/thd_t2hd/thd_th2d + "layout_1_3": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), #all 5 pass + "layout_1_4": ModelConfig(8, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), #th3d/t3hd/thd_t2hd/thd_th2d + "layout_2_1": ModelConfig(1, 16, 16, 128, 128, 128, 0.0, "padding", "no_bias"), #all 5 pass + "layout_2_2": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"), #all 5 pass + "layout_2_3": ModelConfig(1, 16, 16, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), #all 5 pass + "layout_2_4": ModelConfig(8, 16, 16, 64, 2048, 4096, 0.0, "padding_causal", "no_bias"), #all 5 skipped } @pytest.mark.skipif(_cudnn_version() < (8,9,5), reason="cuDNN 8.9.5+ is required.") From bd8a7dcdc3d321e07e666ffe30952726852c8d71 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 24 Apr 2024 22:28:40 +0000 Subject: [PATCH 078/244] THD generation Signed-off-by: root Signed-off-by: Pawel Gadzinski --- docs/examples/te_gemma/generate.py | 52 ++++ docs/examples/te_gemma/generate_baseline.py | 52 ++++ docs/examples/te_gemma/generate_convert.py | 4 +- .../examples/te_gemma/generate_cuda_graphs.py | 63 ++++ docs/examples/te_gemma/generate_fp8.py | 14 +- docs/examples/te_gemma/utils.py | 5 +- transformer_engine/pytorch/attention.py | 291 +++++++++++++++--- transformer_engine/pytorch/module/base.py | 19 ++ transformer_engine/pytorch/transformer.py | 2 + 9 files changed, 442 insertions(+), 60 deletions(-) create mode 100755 docs/examples/te_gemma/generate.py create mode 100755 docs/examples/te_gemma/generate_baseline.py mode change 100644 => 100755 docs/examples/te_gemma/generate_convert.py create mode 100644 docs/examples/te_gemma/generate_cuda_graphs.py diff --git a/docs/examples/te_gemma/generate.py b/docs/examples/te_gemma/generate.py new file mode 100755 index 0000000000..422b005bd8 --- /dev/null +++ b/docs/examples/te_gemma/generate.py @@ -0,0 +1,52 @@ +# Restart the notebook (to flush the GPU memory) +from utils import restart_jupyter_notebook +#restart_jupyter_notebook() + + +# Import necessary packages and methods +from utils import * +import accelerate + +# Default hyperparams, also defined in `utils.py` in class `Hyperparameters` +## !!! `model_name` attr must point to the location of the model weights !!! +## Weights can be downloaded from: https://llama.meta.com/llama-downloads/ +hyperparams.model_name = "../../../../gemma-weights" # <== Add model weight location here e.g. "/path/to/downloaded/llama/weights" +hyperparams.mixed_precision = "bf16" +hyperparams.fuse_qkv_params = False + +# Init the model and accelerator wrapper +model = init_te_gemma_model(hyperparams).cuda() +#accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams) + +model = model.to(torch.bfloat16).cuda() + +tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name) +inputs = tokenizer(["I love when "] * 64, return_tensors="pt", padding=True) + +inputs['input_ids'] = inputs['input_ids'].cuda() +inputs['attention_mask'] = inputs['attention_mask'].cuda() + +import time + +# Początek pomiaru czasu +start_time = time.time() + +outputs = model.generate( + **inputs, + max_new_tokens=40 +) + +# Koniec pomiaru czasu +end_time = time.time() + +# Obliczamy czas trwania operacji +duration = end_time - start_time +print(f"Generation time: {duration} seconds") + + +# Decode the output tensor to text +generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True) + +# Display the generated text +for text in generated_texts: + print(text) \ No newline at end of file diff --git a/docs/examples/te_gemma/generate_baseline.py b/docs/examples/te_gemma/generate_baseline.py new file mode 100755 index 0000000000..3db56c958f --- /dev/null +++ b/docs/examples/te_gemma/generate_baseline.py @@ -0,0 +1,52 @@ +# Restart the notebook (to flush the GPU memory) +from utils import restart_jupyter_notebook +#restart_jupyter_notebook() + + +# Import necessary packages and methods +from utils import * +import torch + + +# Default hyperparams, also defined in `utils.py` in class `Hyperparameters` +## !!! `model_name` attr must point to the location of the model weights !!! +## Weights can be downloaded from: https://llama.meta.com/llama-downloads/ +hyperparams.model_name = "../../../../gemma-weights" # <== Add model weight location here e.g. "/path/to/downloaded/llama/weights" +hyperparams.mixed_precision = "bf16" + + +# Init the model and accelerator wrapper +model = init_baseline_model(hyperparams).cuda() +model = model.to(torch.bfloat16) + +tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name) +inputs = tokenizer(["I love when ", "I "] * 32, return_tensors="pt", padding=True) + +inputs['input_ids'] = inputs['input_ids'].cuda() +inputs['attention_mask'] = inputs['attention_mask'].cuda() + + +# Początek pomiaru czasu +start_time = time.time() + +outputs = model.generate( + **inputs, + max_new_tokens=10 +) + +# Koniec pomiaru czasu +end_time = time.time() + +# Obliczamy czas trwania operacji +duration = end_time - start_time + + + +print(outputs) + +# Decode the output tensor to text +generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True) + +# Display the generated text +for text in generated_texts: + print(text) \ No newline at end of file diff --git a/docs/examples/te_gemma/generate_convert.py b/docs/examples/te_gemma/generate_convert.py old mode 100644 new mode 100755 index 66338c64a0..3bd9250b7d --- a/docs/examples/te_gemma/generate_convert.py +++ b/docs/examples/te_gemma/generate_convert.py @@ -33,8 +33,10 @@ batch["input_ids"] = batch["input_ids"].cuda() outputs = model.generate( **batch, - max_new_tokens=1 + max_new_tokens=10 ) + generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True) + print(generated_texts[0][:50]) print("calibration_finished") print("scale_fwd computation started") diff --git a/docs/examples/te_gemma/generate_cuda_graphs.py b/docs/examples/te_gemma/generate_cuda_graphs.py new file mode 100644 index 0000000000..69e6677ee7 --- /dev/null +++ b/docs/examples/te_gemma/generate_cuda_graphs.py @@ -0,0 +1,63 @@ +import os + +os.environ['CUDNN_LOGLEVEL_DBG'] = '3' +os.environ['CUDNN_LOGDEST_DBG'] = 'backlog.txt' +#Restart the notebook (to flush the GPU memory) +from utils import restart_jupyter_notebook +#restart_jupyter_notebook() +import transformer_engine.pytorch as te + +from torch.cuda.amp import autocast + + +# Import necessary packages and methods +from utils import * + +from transformer_engine.pytorch import fp8_model_init +from transformer_engine.common.recipe import Format, DelayedScaling + + +hyperparams.model_name = "../../../../gemma-weights" +hyperparams.fuse_qkv_params = True +model = init_te_gemma_model(hyperparams, fp8_model_init=True).cuda() + +print("Loading model") +model_state_dict = torch.load('model_fp8_state_dict.pth') +model.load_state_dict(model_state_dict) +print("Model loaded") + +tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name) +inputs = tokenizer(["I love when", "I "] * 32, return_tensors="pt", padding=True) + +inputs['input_ids'] = inputs['input_ids'].cuda() +inputs['attention_mask'] = inputs['attention_mask'].cuda() + +import time + + + +start_time = time.time() + +fp8_format = Format.HYBRID +fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") +torch.manual_seed(1234) +with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): + with autocast(dtype=torch.bfloat16, cache_enabled=False): + with torch.no_grad(): + model.eval() + outputs = model.generate( + **inputs, + max_new_tokens=40, + use_cuda_graphs=False + ) + + +end_time = time.time() +duration = end_time - start_time + +generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True) +for text in generated_texts[:2]: + print("-" * 50) + print(text) + +print(f"Duration = {duration}") diff --git a/docs/examples/te_gemma/generate_fp8.py b/docs/examples/te_gemma/generate_fp8.py index 4a6bc1853e..85fcbff714 100755 --- a/docs/examples/te_gemma/generate_fp8.py +++ b/docs/examples/te_gemma/generate_fp8.py @@ -15,20 +15,20 @@ model = init_te_gemma_model(hyperparams, fp8_model_init=True).cuda() print("Loading model") -model_state_dict = torch.load('model_fp8_state_dict.pth') -model.load_state_dict(model_state_dict) +#model_state_dict = torch.load('model_fp8_state_dict.pth') +#model.load_state_dict(model_state_dict) +#model = model.to(torch.bfloat16).cuda() print("Model loaded") -tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name) -inputs = tokenizer(["I love when", "I love when"] * 32, return_tensors="pt", padding=True) +tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name, + torch_dtype=torch.bfloat16) +inputs = tokenizer(["I love when", "I love when"] * 16, return_tensors="pt", padding=True) inputs['input_ids'] = inputs['input_ids'].cuda() inputs['attention_mask'] = inputs['attention_mask'].cuda() import time - - start_time = time.time() fp8_format = Format.HYBRID @@ -39,7 +39,7 @@ model.eval() outputs = model.generate( **inputs, - max_new_tokens=40 + max_new_tokens=160 ) diff --git a/docs/examples/te_gemma/utils.py b/docs/examples/te_gemma/utils.py index 35bd0421d9..1746c3165d 100755 --- a/docs/examples/te_gemma/utils.py +++ b/docs/examples/te_gemma/utils.py @@ -28,6 +28,7 @@ def __init__(self): self.gradient_accumulation_steps = 1 self.num_warmup_steps=5 self.num_training_steps=10 + self.fuse_qkv_params=False hyperparams = HyperParameters() @@ -86,15 +87,17 @@ def init_baseline_model(hyperparams): return model -def init_te_gemma_model(hyperparams): +def init_te_gemma_model(hyperparams, fp8_model_init=False): # Init the model from te_gemma import TEGemmaForCausalLM config = AutoConfig.from_pretrained(hyperparams.model_name) config._attn_implementation = "flash_attention_2" + config.fuse_qkv_params = hyperparams.fuse_qkv_params model = TEGemmaForCausalLM.from_pretrained_local( hyperparams.model_name, config=config, torch_dtype=torch.bfloat16, + fp8_init=fp8_model_init, ) # Needed for the cases when using TEGemmaForCausalLM diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 6d20cc8b29..afe89483b5 100755 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -130,6 +130,7 @@ def __init__(self, max_batch_size, max_sequence_length): self.batch_size_offset = 0 self.key_value_memory_dict = {} self.thd = False + self.seq_len=torch.tensor((1000)) def swap_key_value_dict(self, batch_indices): """ @@ -2326,15 +2327,132 @@ class FusedAttnFunc(torch.autograd.Function): def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, seq_offsets_q, seq_offsets_k, seq_offsets_v, q, k, v, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill, - qkv_layout, attn_bias_type, attn_mask_type, - rng_gen, fused_attention_backend, use_FAv2_bwd): - out, aux_ctx_tensors = fused_attn_fwd( - is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, - q, k, v, qkv_dtype, fused_attention_backend, - seq_offsets_q, seq_offsets_k, seq_offsets_v, - attn_bias, None, None, None, None, None, - attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, - rng_gen) + qkv_layout, attn_bias_type, attn_mask_type, rng_gen, fused_attention_backend, + use_FAv2_bwd, fp8, fp8_meta, tp_size, tp_group): + if fp8: + if _NVTE_DEBUG: + print('[DotProductAttention]: using FP8 forward') + fused_attention_backend = FusedAttnBackend["FP8"] + fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + if fp8_meta["recipe"].fp8_mha: + assert (isinstance(q, Float8Tensor) + and isinstance(k, Float8Tensor) + and isinstance(v, Float8Tensor)), "q/k/v must be Float8Tensors for FP8 MHA." + fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv + q_fp8, k_fp8, v_fp8 = q._data, k._data, v._data + else: + # 1: qkv packed, 2: kv packed, 3: qkv separate + qkv_group = len(qkv_layout.split('_')) + if qkv_group == 1: + dim = qkv_layout.find('3') + qkv = _combine_tensors([q,k,v], dim) + qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) + qkv_fp8 = cast_to_fp8(qkv_c, + fp8_meta["scaling_fwd"], + META_QKV, fp8_dtype_forward).view(qkv.shape) + q_fp8, k_fp8, v_fp8 = _SplitAlongDim.apply(qkv_fp8, dim, [1,1,1]) + q_fp8, k_fp8, v_fp8 = [x.squeeze(dim) for x in [q_fp8, k_fp8, v_fp8]] + if qkv_group == 2: + q_fp8 = cast_to_fp8(q, + fp8_meta["scaling_fwd"], + META_QKV, fp8_dtype_forward).view(q.shape) + dim = qkv_layout.split('_')[1].find('2') + kv = _combine_tensors([k,v], dim) + kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) + kv_fp8 = cast_to_fp8(kv_c, + fp8_meta["scaling_fwd"], + META_QKV, fp8_dtype_forward).view(kv.shape) + k_fp8, v_fp8 = _SplitAlongDim.apply(kv_fp8, dim, [1,1]) + k_fp8, v_fp8 = [x.squeeze(dim) for x in [k_fp8, v_fp8]] + if qkv_group == 3: + q_fp8 = cast_to_fp8(q, + fp8_meta["scaling_fwd"], + META_QKV, fp8_dtype_forward).view(q.shape) + k_fp8 = cast_to_fp8(k, + fp8_meta["scaling_fwd"], + META_QKV, fp8_dtype_forward).view(k.shape) + v_fp8 = cast_to_fp8(v, + fp8_meta["scaling_fwd"], + META_QKV, fp8_dtype_forward).view(v.shape) + out_fp8, aux_ctx_tensors = fused_attn_fwd( + is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, + q_fp8, k_fp8, v_fp8, fp8_dtype_forward, fused_attention_backend, attn_bias, + fp8_meta["scaling_fwd"].scale_inv[META_QKV], + fp8_meta["scaling_fwd"].scale_inv[META_S], + fp8_meta["scaling_fwd"].scale[META_S], + fp8_meta["scaling_fwd"].scale[META_O], + fp8_meta["scaling_fwd"].amax_history[0][META_S], + fp8_meta["scaling_fwd"].amax_history[0][META_O], + attn_scale, dropout_p, fast_zero_fill, qkv_layout, + attn_bias_type, attn_mask_type, rng_gen) + if fp8_meta["recipe"].fp8_mha: + out_ret = Float8Tensor(data=out_fp8, + fp8_meta=fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=META_O, + fp8_dtype=fp8_dtype_forward, + dtype=q.dtype, + ) + else: + out_ret = cast_from_fp8( + out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), + fp8_meta["scaling_fwd"], META_O, + fp8_dtype_forward, qkv_dtype).view(out_fp8.shape) + out_save = out_ret + + if fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + # 1: qkv packed, 2: kv packed, 3: qkv separate + qkv_group = len(qkv_layout.split('_')) + if qkv_group == 1: + dim = qkv_layout.find('3') + qkv = _combine_tensors([q,k,v], dim) + qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) + qkv_no_fp8 = cast_from_fp8(qkv_c._data, + fp8_meta["scaling_fwd"], + META_QKV, fp8_dtype_forward, TE_DType[qkv.dtype]).view(qkv.shape) + q, k, v = _SplitAlongDim.apply(qkv_no_fp8, dim, [1,1,1]) + q, k, v = [x.squeeze(dim) for x in [q, k, v]] + if qkv_group == 2: + q = cast_from_fp8(q._data, + fp8_meta["scaling_fwd"], + META_QKV, fp8_dtype_forward, TE_DType[q.dtype]).view(q.shape) + dim = qkv_layout.split('_')[1].find('2') + kv = _combine_tensors([k,v], dim) + kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) + kv_no_fp8 = cast_from_fp8(kv_c._data, + fp8_meta["scaling_fwd"], + META_QKV, fp8_dtype_forward, TE_DType[kv.dtype]).view(kv.shape) + k, v = _SplitAlongDim.apply(kv_no_fp8, dim, [1,1]) + k, v = [x.squeeze(dim) for x in [k, v]] + if qkv_group == 3: + q = cast_from_fp8(q._data, + fp8_meta["scaling_fwd"], + META_QKV, fp8_dtype_forward, TE_DType[q.dtype]).view(q.shape) + k = cast_from_fp8(k._data, + fp8_meta["scaling_fwd"], + META_QKV, fp8_dtype_forward, TE_DType[k.dtype]).view(k.shape) + v = cast_from_fp8(v._data, + fp8_meta["scaling_fwd"], + META_QKV, fp8_dtype_forward, TE_DType[v.dtype]).view(v.shape) + out_save = cast_from_fp8( + out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), + fp8_meta["scaling_fwd"], META_O, + fp8_dtype_forward, qkv_dtype).view(out_fp8.shape) + + fp8_tensors = (q_fp8, k_fp8, v_fp8, out_fp8, + fp8_meta["scaling_fwd"].scale.clone(), + fp8_meta["scaling_fwd"].scale_inv.clone()) + else: + if _NVTE_DEBUG: + print('[DotProductAttention]: using non-FP8 forward') + out_ret, aux_ctx_tensors = fused_attn_fwd( + is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, + q, k, v, qkv_dtype, fused_attention_backend, + seq_offsets_q, seq_offsets_k, seq_offsets_v, attn_bias, None, None, None, None, None, None, + attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, + rng_gen) + out_save = out_ret + fp8_tensors = (None, None, None, None, None, None) from .cpu_offload import CPUOffloadEnabled if CPUOffloadEnabled: @@ -2631,25 +2749,50 @@ def forward( use_fused_attention=True, ) else: - with self.attention_dropout_ctx(): - output = FusedAttnFunc.apply( - self.training, - max_seqlen_q, max_seqlen_kv, - cu_seqlens_q, cu_seqlens_kv, - seq_offsets_q, seq_offsets_k, seq_offsets_v, - query_layer, key_layer, value_layer, - qkv_dtype, - core_attention_bias, - 1.0/self.norm_factor, - self.attention_dropout if self.training else 0.0, - fast_zero_fill, - qkv_layout, - core_attention_bias_type, - attn_mask_type, - None, # rng_gen - fused_attention_backend, - use_FAv2_bwd, - ) + with self.prepare_forward(query_layer, + is_first_microbatch, + num_gemms=3, + allow_non_contiguous=True) as query_layer: + with self.attention_dropout_ctx(): + forced_fp8_dpa = "" + if self.fp8_meta["recipe"].fp8_mha: + if not self.fp8_meta["recipe"].fp8_dpa: + self.fp8_meta["recipe"].fp8_dpa = True + forced_fp8_dpa = " (forced)" + if _NVTE_DEBUG: + print("[DotProductAttention]: " + f"""using fp8_recipe.fp8_mha={self.fp8_meta["recipe"].fp8_mha}, """ + f"""fp8_recipe.fp8_dpa={self.fp8_meta["recipe"].fp8_dpa}""" + f"""{forced_fp8_dpa} and """ + f"""NVTE_FP8_DPA_BWD={int(os.getenv("NVTE_FP8_DPA_BWD", "1"))}""") + + output = FusedAttnFunc.apply( + self.training, + max_seqlen_q, max_seqlen_kv, + cu_seqlens_q, cu_seqlens_kv, + seq_offsets_q, seq_offsets_k, seq_offsets_v, + query_layer, key_layer, value_layer, + qkv_dtype, + core_attention_bias, + 1.0/self.norm_factor, + self.attention_dropout if self.training else 0.0, + fast_zero_fill, + qkv_layout, + core_attention_bias_type, + attn_mask_type, + None, # rng_gen + fused_attention_backend, + use_FAv2_bwd, + self.fp8 and self.fp8_meta["recipe"].fp8_dpa, + self.fp8_meta, + self.tp_size, + self.tp_group, + ) + + + if self.layer_number == 1: + print(output.shape) + # ...hd -> ...(hd) return output.view(*output.shape[:-2], -1) @@ -3049,8 +3192,8 @@ def forward( first microbatch (since it is the first gradient being produced) """ - - + value_layer = value_layer.contiguous() + assert ( query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda ), 'DotProductAttention only supports CUDA tensors.' @@ -3087,6 +3230,9 @@ def forward( if qkv_format is None: qkv_format = self.qkv_format + + + if inference_params is not None: assert self.layer_number is not None, "Layer number must be set!" @@ -3097,7 +3243,6 @@ def forward( (inference_key_memory, inference_value_memory, ) = inference_params.key_value_memory_dict[self.layer_number] - if not inference_params.thd: batch_start = inference_params.batch_size_offset batch_end = batch_start + key_layer.size(1) @@ -3107,6 +3252,8 @@ def forward( sequence_end = sequence_start + key_layer.size(0) assert sequence_end <= inference_key_memory.size(0) + + # Copy keys and values into KV-cache inference_key_memory[ sequence_start:sequence_end, batch_start:batch_end, ...] = key_layer @@ -3115,26 +3262,44 @@ def forward( key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...] value_layer = inference_value_memory[:sequence_end, batch_start:batch_end, ...] else: - cuda.attention_copy(inference_key_memory, inference_params.seq_len + 1, key_layer, inference_params.max_batch_size, self.channels) - cuda.attention_copy(inference_value_memory, inference_params.seq_len + 1, value_layer, inference_params.max_batch_size, self.channels) + bs = query_layer.shape[0] + cuda.attention_copy( + inference_key_memory, + inference_params.seq_len, + key_layer, + inference_params.max_sequence_length, + bs, + self.channels + ) + cuda.attention_copy( + inference_value_memory, + inference_params.seq_len, + value_layer, + inference_params.max_sequence_length, + bs, + self.channels) - q = query_layer.view(-1, query_layer.shape[2], query_layer.shape[3]) - k = inference_key_memory.view(-1, inference_key_memory.shape[2], inference_key_memory.shape[3]) - v = inference_value_memory.view(-1, inference_value_memory.shape[2], inference_value_memory.shape[3]) + seqlens_q = torch.ones([bs], dtype=torch.int32, device="cuda") + cu_seqlens_q = torch.zeros(bs + 1, dtype=torch.int32, device="cuda") + cu_seqlens_kv = torch.zeros(bs + 1, dtype=torch.int32, device="cuda") + cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0) + cu_seqlens_kv[1:] = torch.cumsum(inference_params.seq_len + 1, dim=0) - q, k, v = q.to(torch.bfloat16), k.to(torch.bfloat16), v.to(torch.bfloat16), + max_seqlen_q = 1 + max_seqlen_kv = inference_params.max_sequence_length + - out, _, _ = fused_attn_fwd( - False, 1, key_layer.shape[1], inference_params.seq_len, inference_params.seq_len, - q, k, v, - TE_DType[q.dtype], FusedAttnBackend["F16_max512_seqlen"], - qkv_layout="t3hd", attn_bias_type=core_attention_bias_type, - attn_bias=core_attention_bias, fast_zero_fill=fast_zero_fill - ) - print("xd") - exit() - return out + seq_offsets_q = torch.arange(0, bs + 1, dtype=torch.int32, device="cuda") * self.channels * max_seqlen_q + seq_offsets_k = torch.arange(0, bs + 1, dtype=torch.int32, device="cuda") * self.channels * max_seqlen_kv + seq_offsets_k[1:] = seq_offsets_k[1:] + inference_params.begin_offsets * self.channels + seq_offsets_v = seq_offsets_k.clone() + + + query_layer = query_layer.view(-1, query_layer.shape[2], query_layer.shape[3]).to(torch.bfloat16) + key_layer = inference_key_memory.view(-1, inference_key_memory.shape[2], inference_key_memory.shape[3]).to(torch.bfloat16) + value_layer = inference_value_memory.view(-1, inference_value_memory.shape[2], inference_value_memory.shape[3]).to(torch.bfloat16) + qkv_format="thd" if qkv_format == "bshd": key_layer = key_layer.transpose(0, 1) @@ -3204,6 +3369,7 @@ def forward( # The following section filters out some backends based on # certain asserts before executing the forward pass. + # Filter: QKV layout. if qkv_format == 'thd': use_unfused_attention = False @@ -3258,6 +3424,8 @@ def forward( use_fused_attention = False if (not _flash_attn_2_3_plus) or context_parallel: use_flash_attention = False + + # Filter: Attention mask type. # attn_mask_type(s) | supported backends @@ -3278,6 +3446,7 @@ def forward( ): use_unfused_attention = False + # Filter: bias. global _alibi_cache if alibi_slopes is not None: @@ -3334,6 +3503,9 @@ def forward( max_seqlen_kv, query_layer.shape[-1], # head_dim ) + if inference_params is not None: + if inference_params.thd: + fused_attention_backend = FusedAttnBackend["F16_arbitrary_seqlen"] # DPA does not support FP8; for FP8, use cpp_extensions modules directly is_backend_avail = (fused_attention_backend in [FusedAttnBackend["F16_max512_seqlen"], @@ -3374,6 +3546,13 @@ def forward( if self.device_compute_capability == (9, 0): use_flash_attention = False + if inference_params is not None: + if inference_params.thd: + use_flash_attention = False + + if len(query_layer.shape) == 4: + use_flash_attention=True + if use_flash_attention: if _NVTE_DEBUG: print("[DotProductAttention]: using flash-attn",_flash_attn_version) @@ -3422,7 +3601,8 @@ def forward( cp_global_ranks=self.cp_global_ranks, cp_stream=self.cp_stream, is_first_microbatch=is_first_microbatch) - return self.fused_attention( + + out = self.fused_attention( query_layer, key_layer, value_layer, @@ -3442,6 +3622,13 @@ def forward( cp_global_ranks=self.cp_global_ranks, cp_stream=self.cp_stream, is_first_microbatch=is_first_microbatch) + if inference_params.thd: + out = out.unsqueeze(1) + + + + + return out assert (not context_parallel), \ "Context parallelism is only implemented with Flash Attention and Fused Attention!" @@ -3831,7 +4018,7 @@ def __init__( def _allocate_memory( self, inference_max_sequence_len: int, batch_size: int, dtype: torch.dtype ) -> torch.Tensor: - return torch.empty( + return torch.zeros( inference_max_sequence_len, batch_size, self.num_gqa_groups_per_partition, @@ -3956,8 +4143,6 @@ def forward( """ # hidden_states: [sq, b, h] - - if attn_mask_type is not None: window_size = check_set_window_size(attn_mask_type, window_size) if attn_mask_type is None: @@ -4021,6 +4206,7 @@ def forward( ) num_queries_per_key_value = (self.num_attention_heads_per_partition // self.num_gqa_groups_per_partition) + if self.qkv_weight_interleaved: # [sq, b, ng * (np/ng + 2) * hn] --> [sq, b, ng, (np/ng + 2), hn] new_tensor_shape = mixed_x_layer.size()[:-1] + ( @@ -4042,6 +4228,7 @@ def forward( mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + # qkv_weight_interleaved: # [sq, b, ng, (np/ng + 2), hn] # --> [sq, b, ng, np/ng, hn], [sq, b, ng, 1, hn], [sq, b, ng, 1, hn] @@ -4127,6 +4314,7 @@ def forward( ) query_layer = query_layer.view(*new_tensor_shape) + # ====================================================== # Apply relative positional encoding (rotary embedding) # ====================================================== @@ -4156,6 +4344,7 @@ def forward( query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb, self.qkv_format, fused=True) key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb, self.qkv_format, fused=True) + # =========================== # Core attention computation diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 0803b474f6..eaa9c82745 100755 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -858,3 +858,22 @@ def get_fp8_weights_scratchpad( is_first_microbatch: Union[bool, None], ) -> List[torch.Tensor]: """Needs override.""" + + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + """ + The function loads an extra state containing fp8_meta weights. + This metadata is crucial when copying fp8 parameters. + For instance, when casting fp16 parameters to fp8, the _copy function + utilizes the scale_inv parameter from fp8_meta + to set the appropriate scaling factor for the new tensor. + Therefore, this extra state must be loaded before the tensor copying process, + not after, as is the default behavior in _load_from_state_dict. + """ + extra_state_key = prefix + torch.nn.modules.module._EXTRA_STATE_KEY_SUFFIX + if extra_state_key in state_dict: + self.set_extra_state(state_dict[extra_state_key]) + super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs) + \ No newline at end of file diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index b59c1ce346..2219154903 100755 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -635,6 +635,7 @@ def forward( fast_zero_fill=fast_zero_fill, ) + if self.apply_residual_connection_post_layernorm and not self.output_layernorm: attention_output, attention_bias, residual = self_attention_outputs hidden_states = self._bias_dropout_add( @@ -673,6 +674,7 @@ def forward( hidden_states, is_first_microbatch=is_first_microbatch, ) + if self.apply_residual_connection_post_layernorm: mlp_output, mlp_bias, residual = mlp_outputs output = self._bias_dropout_add(mlp_output, mlp_bias, residual, self.drop_path) From eb76011277369ba39402b6a4a7f76f14002a34af Mon Sep 17 00:00:00 2001 From: root Date: Mon, 29 Apr 2024 18:27:44 +0000 Subject: [PATCH 079/244] Cuda graphs generation (which seems to be working) Signed-off-by: root Signed-off-by: Pawel Gadzinski --- docs/examples/te_gemma/attention_copy.cu | 96 ++++++ docs/examples/te_gemma/generate.py | 5 +- docs/examples/te_gemma/generate_baseline.py | 7 +- .../examples/te_gemma/generate_cuda_graphs.py | 6 +- docs/examples/te_gemma/generate_fp8.py | 37 ++- docs/examples/te_gemma/te_gemma.py | 314 ++++++++++++++++-- transformer_engine/pytorch/attention.py | 176 +++++++--- 7 files changed, 535 insertions(+), 106 deletions(-) create mode 100644 docs/examples/te_gemma/attention_copy.cu diff --git a/docs/examples/te_gemma/attention_copy.cu b/docs/examples/te_gemma/attention_copy.cu new file mode 100644 index 0000000000..810c66c377 --- /dev/null +++ b/docs/examples/te_gemma/attention_copy.cu @@ -0,0 +1,96 @@ +#include +#include +#include + +extern "C" +__global__ void attn_copy(__nv_bfloat16* A, int* seq_len, __nv_bfloat16* B, int max_seq_len, int b, int s) { + for(int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { + int per_block = s / blockDim.x; + int remainder = s % blockDim.x; + int copy_block_offset_begin = per_block * threadIdx.x + min(threadIdx.x, remainder); + + int offset = seq_len[batch_idx]; + + __nv_bfloat16* begin_A_copy = A + max_seq_len * s * batch_idx + s * offset; + __nv_bfloat16* begin_B_copy = B + s * batch_idx; + + int limit = copy_block_offset_begin + per_block + (threadIdx.x < remainder ? 1 : 0); + + for(int i = copy_block_offset_begin; i < limit; i++) { + *(begin_A_copy + i) = *(begin_B_copy + i); + } + } +} + +extern "C" +__global__ void gv(float* src, int* seq_len, float* dst, int d, int b) { + // src [s, 1, 1, d] + // dst [b] + for(int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { + int per_block = d / blockDim.x; + int remainder = d % blockDim.x; + int copy_block_offset_begin = per_block * threadIdx.x + min(threadIdx.x, remainder); + + int offset = seq_len[batch_idx]; + + float* begin_src_copy = src + d * offset; + float* begin_dst_copy = dst + d * batch_idx; + + int limit = copy_block_offset_begin + per_block + (threadIdx.x < remainder ? 1 : 0); + + for(int i = copy_block_offset_begin; i < limit; i++) { + *(begin_dst_copy + i) = *(begin_src_copy + i); + } + } +} + + + + + + +void attention_copy(torch::Tensor A, torch::Tensor seq_len, torch::Tensor B, int max_seq_len, int b, int s, void* stream_ptr) { + cudaStream_t stream = static_cast(stream_ptr); + attn_copy<<<16, 32, 0, stream>>>(reinterpret_cast<__nv_bfloat16*>(A.data_ptr()), + seq_len.data_ptr(), + reinterpret_cast<__nv_bfloat16*>(B.data_ptr()), max_seq_len, b, s); +} + + +void attention_copy2(torch::Tensor A, torch::Tensor seq_len, torch::Tensor B, int max_seq_len, int b, int s) { + attn_copy<<<16, 32, 0>>>(reinterpret_cast<__nv_bfloat16*>(A.data_ptr()), + seq_len.data_ptr(), + reinterpret_cast<__nv_bfloat16*>(B.data_ptr()), max_seq_len, b, s); +} + + +void get_values(torch::Tensor A, torch::Tensor seq_len, torch::Tensor B, int d, int b, void* stream_ptr) { + cudaStream_t stream = static_cast(stream_ptr); + gv<<<16, 32, 0, stream>>>(A.data_ptr(), + seq_len.data_ptr(), + B.data_ptr(), d, b); +} + + +void get_values2(torch::Tensor A, torch::Tensor seq_len, torch::Tensor B, int d, int b) { + gv<<<16, 32, 0>>>((A.data_ptr()), + seq_len.data_ptr(), + (B.data_ptr()), d, b); +} + + + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("attention_copy", &attention_copy, "Copy function for attention mechanism", + py::arg("A"), py::arg("seq_len"), py::arg("B"), py::arg("b"), py::arg("max_seq_len"), py::arg("s"), py::arg("stream_ptr")); + + m.def("attention_copy2", &attention_copy2, "Copy function for attention mechanism", + py::arg("A"), py::arg("seq_len"), py::arg("B"), py::arg("b"), py::arg("max_seq_len"), py::arg("s")); + + m.def("get_values", &get_values, "1Get values function", + py::arg("A"), py::arg("seq_len"), py::arg("B"), py::arg("d"), py::arg("b"), py::arg("stream_ptr")); + + m.def("get_values2", &get_values2, "2Get values function", + py::arg("A"), py::arg("seq_len"), py::arg("B"), py::arg("d"), py::arg("b")); +} \ No newline at end of file diff --git a/docs/examples/te_gemma/generate.py b/docs/examples/te_gemma/generate.py index 422b005bd8..ae63777438 100755 --- a/docs/examples/te_gemma/generate.py +++ b/docs/examples/te_gemma/generate.py @@ -21,7 +21,7 @@ model = model.to(torch.bfloat16).cuda() tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name) -inputs = tokenizer(["I love when "] * 64, return_tensors="pt", padding=True) +inputs = tokenizer(["I love when ", "I "] * 32, return_tensors="pt", padding=True) inputs['input_ids'] = inputs['input_ids'].cuda() inputs['attention_mask'] = inputs['attention_mask'].cuda() @@ -49,4 +49,5 @@ # Display the generated text for text in generated_texts: - print(text) \ No newline at end of file + print(text) + print("=" * 100) \ No newline at end of file diff --git a/docs/examples/te_gemma/generate_baseline.py b/docs/examples/te_gemma/generate_baseline.py index 3db56c958f..872ce92ac8 100755 --- a/docs/examples/te_gemma/generate_baseline.py +++ b/docs/examples/te_gemma/generate_baseline.py @@ -20,7 +20,7 @@ model = model.to(torch.bfloat16) tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name) -inputs = tokenizer(["I love when ", "I "] * 32, return_tensors="pt", padding=True) +inputs = tokenizer(["I love when"] * 32, return_tensors="pt", padding=True) inputs['input_ids'] = inputs['input_ids'].cuda() inputs['attention_mask'] = inputs['attention_mask'].cuda() @@ -31,7 +31,7 @@ outputs = model.generate( **inputs, - max_new_tokens=10 + max_new_tokens=40 ) # Koniec pomiaru czasu @@ -49,4 +49,5 @@ # Display the generated text for text in generated_texts: - print(text) \ No newline at end of file + print(text) + print("=" * 100) \ No newline at end of file diff --git a/docs/examples/te_gemma/generate_cuda_graphs.py b/docs/examples/te_gemma/generate_cuda_graphs.py index 69e6677ee7..ae5e413afc 100644 --- a/docs/examples/te_gemma/generate_cuda_graphs.py +++ b/docs/examples/te_gemma/generate_cuda_graphs.py @@ -27,7 +27,7 @@ print("Model loaded") tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name) -inputs = tokenizer(["I love when", "I "] * 32, return_tensors="pt", padding=True) +inputs = tokenizer(["I love when"] * 32, return_tensors="pt", padding=True) inputs['input_ids'] = inputs['input_ids'].cuda() inputs['attention_mask'] = inputs['attention_mask'].cuda() @@ -48,7 +48,7 @@ outputs = model.generate( **inputs, max_new_tokens=40, - use_cuda_graphs=False + use_cuda_graphs=True ) @@ -56,7 +56,7 @@ duration = end_time - start_time generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True) -for text in generated_texts[:2]: +for text in generated_texts[:12]: print("-" * 50) print(text) diff --git a/docs/examples/te_gemma/generate_fp8.py b/docs/examples/te_gemma/generate_fp8.py index 85fcbff714..bde5be1def 100755 --- a/docs/examples/te_gemma/generate_fp8.py +++ b/docs/examples/te_gemma/generate_fp8.py @@ -1,8 +1,14 @@ -# Restart the notebook (to flush the GPU memory) +import os + +os.environ['CUDNN_LOGLEVEL_DBG'] = '3' +os.environ['CUDNN_LOGDEST_DBG'] = 'backlog.txt' +#Restart the notebook (to flush the GPU memory) from utils import restart_jupyter_notebook #restart_jupyter_notebook() import transformer_engine.pytorch as te +from torch.cuda.amp import autocast + # Import necessary packages and methods from utils import * @@ -10,44 +16,47 @@ from transformer_engine.pytorch import fp8_model_init from transformer_engine.common.recipe import Format, DelayedScaling + hyperparams.model_name = "../../../../gemma-weights" hyperparams.fuse_qkv_params = True model = init_te_gemma_model(hyperparams, fp8_model_init=True).cuda() print("Loading model") -#model_state_dict = torch.load('model_fp8_state_dict.pth') -#model.load_state_dict(model_state_dict) -#model = model.to(torch.bfloat16).cuda() +model_state_dict = torch.load('model_fp8_state_dict.pth') +model.load_state_dict(model_state_dict) print("Model loaded") -tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name, - torch_dtype=torch.bfloat16) -inputs = tokenizer(["I love when", "I love when"] * 16, return_tensors="pt", padding=True) +tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name) +inputs = tokenizer(["I love when", "I "] * 32, return_tensors="pt", padding=True) inputs['input_ids'] = inputs['input_ids'].cuda() inputs['attention_mask'] = inputs['attention_mask'].cuda() import time + + start_time = time.time() fp8_format = Format.HYBRID fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") torch.manual_seed(1234) with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): - with torch.no_grad(): - model.eval() - outputs = model.generate( - **inputs, - max_new_tokens=160 - ) + with autocast(dtype=torch.bfloat16, cache_enabled=False): + with torch.no_grad(): + model.eval() + outputs = model.generate( + **inputs, + max_new_tokens=40, + use_cuda_graphs=False + ) end_time = time.time() duration = end_time - start_time generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True) -for text in generated_texts[:2]: +for text in generated_texts[:12]: print("-" * 50) print(text) diff --git a/docs/examples/te_gemma/te_gemma.py b/docs/examples/te_gemma/te_gemma.py index 27c079338d..376eb4bbd5 100755 --- a/docs/examples/te_gemma/te_gemma.py +++ b/docs/examples/te_gemma/te_gemma.py @@ -14,10 +14,14 @@ import torch from torch import nn +from torch.utils.cpp_extension import load + + import transformer_engine as te -from transformer_engine.pytorch.attention import RotaryPositionEmbedding +from transformer_engine.pytorch.attention import InferenceParams, RotaryPositionEmbedding from transformer_engine.pytorch.fp8 import fp8_model_init +from transformer_engine.common.recipe import Format, DelayedScaling import transformers from transformers.models.gemma.modeling_gemma import GemmaModel, GemmaForCausalLM, GemmaRMSNorm, GemmaConfig @@ -25,6 +29,12 @@ from transformers.utils import WEIGHTS_INDEX_NAME from transformers.utils.hub import get_checkpoint_shard_files +cuda = load( + name='attention_copy', + sources=['attention_copy.cu'], + verbose=True +) + @contextmanager def replace_decoder(te_decoder_cls): @@ -83,6 +93,85 @@ def forward(self, """ return (super().forward(hidden_states, attention_mask=attention_mask, rotary_pos_emb=self.te_rope_emb, inference_params=inference_params, self_attn_mask_type=self_attn_mask_type),) +class TeGraphed(torch.nn.Module): + def __init__(self, model, lm_head, inference_params, normalizer, generation_config, thd=True): + super().__init__() + self.model = model + self.inference_params = inference_params + self.inference_params.thd = thd + self.thd=thd + self.normalizer = normalizer + self.generation_config = generation_config + self.lm_head = lm_head + + + self.attn_mask = torch.ones([inference_params.max_batch_size, inference_params.max_sequence_length]).to(dtype=torch.bool) + + def forward(self, + hidden_states, + unfinished_sequences, + seq_len, + vl_space, + kl_space, + ql_space, + seqlens_q, + cu_seqlens_q, + cu_seqlens_kv, + seq_offsets_q, + seq_offsets_k, + seq_offsets_v, + position_embedding_matrix, + k_pos_emb, + q_pos_emb, + *args + ): + hidden_states.data[:] = hidden_states.data[:] * self.normalizer + inference_params = InferenceParams(self.inference_params.max_batch_size, self.inference_params.max_sequence_length) + inference_params.thd = self.thd + inference_params.seq_len = seq_len + inference_params.value_layer = vl_space + inference_params.key_layer = kl_space + inference_params.query_layer = ql_space + inference_params.seqlens_q = seqlens_q + inference_params.cu_seqlens_q = cu_seqlens_q + inference_params.cu_seqlens_kv = cu_seqlens_kv + inference_params.seq_offsets_q = seq_offsets_q + inference_params.seq_offsets_k = seq_offsets_k + inference_params.seq_offsets_v = seq_offsets_v + inference_params.position_embedding_matrix = position_embedding_matrix + inference_params.k_pos_emb = k_pos_emb + inference_params.q_pos_emb = q_pos_emb + + assert len(args) == 28 * 2 + + + for i in range(0, len(args), 2): + inference_params.key_value_memory_dict[i // 2 + 1] = (args[i], args[i + 1]) + + for decoder_layer in self.model.layers: + hidden_states.copy_(decoder_layer( + hidden_states, + inference_params=inference_params, + self_attn_mask_type='padding', + attention_mask=None + )[0]) + + + seq_len.copy_(seq_len + 1) + + hidden_states.copy_(self.model.norm(hidden_states)) + logits = self.lm_head(hidden_states) + logits = logits.float() + logits = logits[:, -1, :] + next_tokens = torch.argmax(logits, dim=-1) + + # Sequences, which are finished should contain padding - taken from huggingface transformers. + next_tokens = next_tokens * unfinished_sequences + self.generation_config.pad_token_id * (1 - unfinished_sequences) + + unfinished_sequences.copy_(unfinished_sequences & ~(next_tokens == self.generation_config.eos_token_id)) + + hidden_states.copy_(self.model.embed_tokens(next_tokens).unsqueeze(1)) + return next_tokens, logits class TEGemmaForCausalLM: """ @@ -153,56 +242,220 @@ def generate( input_ids: Optional[torch.Tensor] = None, generation_config: Optional[GenerationConfig] = None, max_new_tokens = 0, + use_cuda_graphs = False, **kwargs, ): + batch_size, seq_len = input_ids.shape - max_seq_len = seq_len + max_new_tokens generation_config, _ = self._prepare_generation_config(generation_config, **kwargs) unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) # inference_params object is a cache, where keys and values of previous tokens are stored inference_params = te.pytorch.InferenceParams( max_batch_size=batch_size, - max_sequence_length=seq_len+max_new_tokens+1) + max_sequence_length=max(128, input_ids.shape[1] + max_new_tokens) + ) # mask has shape [batch_size, num_heads, 1, max_seq_len] and contains False # when coressponding token is padding and True otherwise. pad_attention_mask = input_ids.ne(generation_config.pad_token_id).unsqueeze(1).unsqueeze(2) - mask = torch.ones((batch_size, 1, 1, max_seq_len), dtype=torch.bool).cuda() - mask[..., :seq_len] = mask[..., :seq_len] & pad_attention_mask.expand(-1, 1, -1, -1) + ############################################################################################# + # Encode part # + ############################################################################################# + + hidden_states = self.model.embed_tokens(input_ids) + normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) + output_tokens = [] - for i in range(max_new_tokens): - normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) - hidden_states = hidden_states * normalizer - for decoder_layer in self.model.layers: - hidden_states = decoder_layer( - hidden_states, - # In the case of arbiutrary mask, the meaning of True and False is switched, so negation is needed. - attention_mask=pad_attention_mask if i == 0 else ~mask[..., :seq_len], - self_attn_mask_type="causal" if i == 0 else "arbitrary", - inference_params=inference_params - )[0] + hidden_states = hidden_states * normalizer + for decoder_layer in self.model.layers: + hidden_states = decoder_layer( + hidden_states, + # In the case of arbiutrary mask, the meaning of True and False is switched, so negation is needed. + attention_mask=pad_attention_mask, + self_attn_mask_type="padding_causal", + inference_params=inference_params + )[0] + + hidden_states = self.model.norm(hidden_states) + logits = self.lm_head(hidden_states) + logits = logits.float() + logits = logits[:, -1, :] + next_tokens = torch.argmax(logits, dim=-1) + # Sequences, which are finished should contain padding - taken from huggingface transformers. + next_tokens = next_tokens * unfinished_sequences + generation_config.pad_token_id * (1 - unfinished_sequences) + output_tokens.append(next_tokens) + + unfinished_sequences = unfinished_sequences & ~(next_tokens == generation_config.eos_token_id) + + hidden_states = self.model.embed_tokens(next_tokens).unsqueeze(1) + lengths = torch.sum(pad_attention_mask, dim=-1).squeeze() + + + def process(x): + """ + Args: + x: Tensor with shape [s, b, h, d], where s is sequence length, b is batch size, h is number of heads, and d is hidden dimension. + l: List of integers representing the actual lengths of each sequence in the batch before padding. + + Returns: + torch.Tensor: Tensor with switched contents such that padded zeros are moved to the end of the sequence. + """ + s1, b, h, d = x.shape + s = torch.max(lengths) + new_x = torch.zeros_like(x) + + for i in range(b): + seq_length = lengths[i] - # inference_params.sequence_len_offset should contain position of the current token in the sequence. - inference_params.sequence_len_offset += hidden_states.shape[1] + # Check if the sequence length is not the full length of the sequence dimension + if seq_length < s: + # Place the original data to the end part of the new tensor + new_x[:seq_length, i, :, :] = x[s - seq_length:s, i, :, :] + # Place the padding at the beginning of the new tensor + new_x[seq_length:, i, :, :] = 0 + else: + # If seq_length is the full length, just copy the entire sequence as is + new_x[:, i, :, :] = x[:, i, :, :] - hidden_states = self.model.norm(hidden_states) - logits = self.lm_head(hidden_states) - logits = logits.float() - logits = logits[:, -1, :] - next_tokens = torch.argmax(logits, dim=-1) + return new_x.permute((1, 0, 2, 3)).contiguous().cuda() - # Sequences, which are finished should contain padding - taken from huggingface transformers. - next_tokens = next_tokens * unfinished_sequences + generation_config.pad_token_id * (1 - unfinished_sequences) - output_tokens.append(next_tokens) + inference_params.seq_len = lengths.to(torch.int32) + seq_len_offset = torch.max(lengths).item() - unfinished_sequences = unfinished_sequences & ~(next_tokens == generation_config.eos_token_id) - hidden_states = self.model.embed_tokens(next_tokens).unsqueeze(1) - seq_len += 1 + seqlens_q = torch.zeros((batch_size), dtype=torch.int32).cuda() + cu_seqlens_q = torch.zeros((batch_size + 1), dtype=torch.int32).cuda() + cu_seqlens_kv = torch.zeros((batch_size + 1), dtype=torch.int32).cuda() + seq_offsets_q = torch.zeros((batch_size + 1), dtype=torch.int32).cuda() + seq_offsets_k = torch.zeros((batch_size + 1), dtype=torch.int32).cuda() + seq_offsets_v = torch.zeros((batch_size + 1), dtype=torch.int32).cuda() + + + + + + for k, v in inference_params.key_value_memory_dict.items(): + key_layer = process(v[0]) + value_layer = process(v[1]) + inference_params.key_value_memory_dict[k] = (key_layer, value_layer) + + ############################################################################################# + # Generate part # + ############################################################################################# + print("generate part") + + + graphed_generator = TeGraphed( + lm_head=self.lm_head, + model=self.model, + inference_params=inference_params, + normalizer=normalizer, + generation_config=generation_config, + thd=True + ) + + tensor_pointers = [(kc, vc) for kc, vc in inference_params.key_value_memory_dict.values()] + tensor_pointers = [element for tuple_ in tensor_pointers for element in tuple_] + + copy_hidden = hidden_states.clone() + copy_unfinished_sequences = unfinished_sequences.clone() + copy_tensor_pointers = [t.clone() for t in tensor_pointers] + copy_seq_len = inference_params.seq_len.clone() + + vl_space = torch.zeros((batch_size, 1, 16, 256)).to(torch.bfloat16).cuda() + kl_space = torch.zeros((batch_size, 1, 16, 256)).to(torch.bfloat16).cuda() + ql_space = torch.zeros((batch_size, 1, 16, 256)).to(torch.bfloat16).cuda() + q_pos_emb = torch.zeros((batch_size, 1, 1, 256)).to(torch.float32).cuda() + k_pos_emb = torch.zeros((batch_size, 1, 1, 256)).to(torch.float32).cuda() + + + te_rope = RotaryPositionEmbedding(256) + position_embedding_matrix = te_rope(8192).to(torch.float32).cuda() + + + graphed_layers = None + if use_cuda_graphs: + fp8_format = Format.HYBRID + fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") + + print("recording...") + graphed_layers = te.pytorch.make_graphed_callables( + graphed_generator, + ( + hidden_states, + unfinished_sequences, + inference_params.seq_len, + vl_space, + kl_space, + ql_space, + seqlens_q, + cu_seqlens_q, + cu_seqlens_kv, + seq_offsets_q, + seq_offsets_k, + seq_offsets_v, + position_embedding_matrix, + k_pos_emb, + q_pos_emb, + *tensor_pointers + ), + fp8_enabled=True, + fp8_recipe=fp8_recipe, + allow_unused_input=True + ) + print("recorded...") + hidden_states.data[:] = copy_hidden + unfinished_sequences.data[:] = copy_unfinished_sequences + inference_params.seq_len.data[:] = copy_seq_len + + + i = 0 + for t in tensor_pointers: + t.data[:] = copy_tensor_pointers[i] + i = i + 1 + + for i in range(max_new_tokens): + next_tokens, logits = graphed_layers( + hidden_states, + unfinished_sequences, + inference_params.seq_len, + vl_space, + kl_space, + ql_space, + seqlens_q, + cu_seqlens_q, + cu_seqlens_kv, + seq_offsets_q, + seq_offsets_k, + seq_offsets_v, + position_embedding_matrix, + k_pos_emb, + q_pos_emb, + *tensor_pointers + ) if use_cuda_graphs else graphed_generator( + hidden_states, + unfinished_sequences, + inference_params.seq_len, + vl_space, + kl_space, + ql_space, + seqlens_q, + cu_seqlens_q, + cu_seqlens_kv, + seq_offsets_q, + seq_offsets_k, + seq_offsets_v, + position_embedding_matrix, + k_pos_emb, + q_pos_emb, + *tensor_pointers + ) + output_tokens.append(next_tokens.clone()) + seq_len_offset += 1 result = torch.cat((input_ids, torch.stack(output_tokens).permute([1, 0])), dim=1) return result @@ -236,7 +489,7 @@ def replace_params(hf_state_dict, te_state_dict, config, fp8_init=False): # copy query dst[dst_offset:(dst_offset + config.head_dim), :] = \ q[(head_nr * config.head_dim):(head_nr * config.head_dim + config.head_dim), :] - + if layer_prefix + 'self_attn.k_proj.weight' in hf_state_dict: k = hf_state_dict[layer_prefix + 'self_attn.k_proj.weight'] for head_nr in range(config.num_attention_heads): @@ -259,6 +512,7 @@ def replace_params(hf_state_dict, te_state_dict, config, fp8_init=False): if layer_prefix + 'self_attn.k_proj.weight' in hf_state_dict: te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.key_weight'].copy_(hf_state_dict[layer_prefix + 'self_attn.k_proj.weight']) + if layer_prefix + 'self_attn.v_proj.weight' in hf_state_dict: te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.value_weight'].copy_(hf_state_dict[layer_prefix + 'self_attn.v_proj.weight']) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index afe89483b5..d8d5ec7560 100755 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -82,6 +82,9 @@ from flash_attn.flash_attn_interface import _flash_attn_varlen_forward as _flash_attn_forward # pylint: disable=no-name-in-module,ungrouped-imports from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward # pylint: disable=no-name-in-module +Z = torch.zeros((200, 200)).to(torch.bfloat16).cuda() +T = torch.zeros((200, 200)).to(torch.int32).cuda() + META_QKV = tex.FP8FwdTensors.GEMM1_OUTPUT META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1 META_O = tex.FP8FwdTensors.GEMM2_INPUT @@ -2329,6 +2332,7 @@ def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seql q, k, v, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, rng_gen, fused_attention_backend, use_FAv2_bwd, fp8, fp8_meta, tp_size, tp_group): + if fp8: if _NVTE_DEBUG: print('[DotProductAttention]: using FP8 forward') @@ -2790,9 +2794,6 @@ def forward( ) - if self.layer_number == 1: - print(output.shape) - # ...hd -> ...(hd) return output.view(*output.shape[:-2], -1) @@ -2926,6 +2927,8 @@ def __init__( self.channels = channels + + self.hidden_size_per_attention_head = channels // num_attention_heads self.num_gqa_groups = ( num_attention_heads if num_gqa_groups is None else num_gqa_groups @@ -3192,7 +3195,16 @@ def forward( first microbatch (since it is the first gradient being produced) """ - value_layer = value_layer.contiguous() + if inference_params.thd: + inference_params.value_layer.copy_(value_layer.contiguous()) + value_layer = inference_params.value_layer + inference_params.key_layer.copy_(key_layer.contiguous()) + key_layer = inference_params.key_layer + else: + key_layer = key_layer.contiguous() + value_layer = value_layer.contiguous() + + assert ( query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda @@ -3262,37 +3274,71 @@ def forward( key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...] value_layer = inference_value_memory[:sequence_end, batch_start:batch_end, ...] else: - bs = query_layer.shape[0] - cuda.attention_copy( - inference_key_memory, - inference_params.seq_len, - key_layer, - inference_params.max_sequence_length, - bs, - self.channels - ) - cuda.attention_copy( - inference_value_memory, - inference_params.seq_len, - value_layer, - inference_params.max_sequence_length, - bs, - self.channels) - - seqlens_q = torch.ones([bs], dtype=torch.int32, device="cuda") - cu_seqlens_q = torch.zeros(bs + 1, dtype=torch.int32, device="cuda") - cu_seqlens_kv = torch.zeros(bs + 1, dtype=torch.int32, device="cuda") - cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0) - cu_seqlens_kv[1:] = torch.cumsum(inference_params.seq_len + 1, dim=0) + bs = query_layer.shape[0] + import ctypes + current_stream = torch.cuda.current_stream() + + stream_ptr_capsule = ctypes.pythonapi.PyCapsule_New(current_stream.cuda_stream, None, None) if current_stream.cuda_stream != 0 else None + + + if stream_ptr_capsule is not None: + cuda.attention_copy( + inference_key_memory, + inference_params.seq_len, + key_layer, + inference_params.max_sequence_length, + bs, + self.channels, + stream_ptr_capsule + ) + cuda.attention_copy( + inference_value_memory, + inference_params.seq_len, + value_layer, + inference_params.max_sequence_length, + bs, + self.channels, + stream_ptr_capsule) + else: + cuda.attention_copy2( + inference_key_memory, + inference_params.seq_len, + key_layer, + inference_params.max_sequence_length, + bs, + self.channels + ) + cuda.attention_copy2( + inference_value_memory, + inference_params.seq_len, + value_layer, + inference_params.max_sequence_length, + bs, + self.channels) + + + inference_params.seqlens_q.copy_(torch.ones([bs], dtype=torch.int32, device="cuda")) + inference_params.cu_seqlens_q.copy_(torch.zeros(bs + 1, dtype=torch.int32, device="cuda")) + inference_params.cu_seqlens_kv.copy_(torch.zeros(bs + 1, dtype=torch.int32, device="cuda")) + inference_params.cu_seqlens_q[1:].copy_(torch.cumsum(inference_params.seqlens_q, dim=0)) + inference_params.cu_seqlens_kv[1:].copy_(torch.cumsum(inference_params.seq_len + 1, dim=0)) + + seqlens_q = inference_params.seqlens_q + cu_seqlens_q = inference_params.cu_seqlens_q + cu_seqlens_kv = inference_params.cu_seqlens_kv max_seqlen_q = 1 max_seqlen_kv = inference_params.max_sequence_length - seq_offsets_q = torch.arange(0, bs + 1, dtype=torch.int32, device="cuda") * self.channels * max_seqlen_q - seq_offsets_k = torch.arange(0, bs + 1, dtype=torch.int32, device="cuda") * self.channels * max_seqlen_kv - seq_offsets_k[1:] = seq_offsets_k[1:] + inference_params.begin_offsets * self.channels - seq_offsets_v = seq_offsets_k.clone() + inference_params.seq_offsets_q.copy_(torch.arange(0, bs + 1, dtype=torch.int32, device="cuda") * self.channels * max_seqlen_q) + inference_params.seq_offsets_k.copy_(torch.arange(0, bs + 1, dtype=torch.int32, device="cuda") * self.channels * max_seqlen_kv) + inference_params.seq_offsets_k[1:].copy_(inference_params.seq_offsets_k[1:] ) + inference_params.seq_offsets_v.copy_(inference_params.seq_offsets_k) + + seq_offsets_q = inference_params.seq_offsets_q + seq_offsets_k = inference_params.seq_offsets_k + seq_offsets_v = inference_params.seq_offsets_v query_layer = query_layer.view(-1, query_layer.shape[2], query_layer.shape[3]).to(torch.bfloat16) @@ -3552,6 +3598,7 @@ def forward( if len(query_layer.shape) == 4: use_flash_attention=True + if use_flash_attention: if _NVTE_DEBUG: @@ -4318,32 +4365,53 @@ def forward( # ====================================================== # Apply relative positional encoding (rotary embedding) # ====================================================== - + if rotary_pos_emb is not None: - assert (not isinstance(query_layer, Float8Tensor) - and not isinstance(key_layer, Float8Tensor) - ), "RoPE is not supported for Float8Tensors!" - # duplicate the pos_emb for self attention - if not isinstance(rotary_pos_emb, tuple): - rotary_pos_emb = ((rotary_pos_emb,) * 2) - - q_pos_emb, k_pos_emb = rotary_pos_emb - - # adjust key and value for inference - if inference_params is not None: - if self.qkv_format == "sbhd": - sequence_length = key_layer.size(0) - elif self.qkv_format == "bshd": - sequence_length = key_layer.size(1) + if inference_params.thd: + import ctypes + current_stream = torch.cuda.current_stream() + stream_ptr_capsule = ctypes.pythonapi.PyCapsule_New(current_stream.cuda_stream, None, None) if current_stream.cuda_stream != 0 else None - sequence_start = inference_params.sequence_len_offset - sequence_end = sequence_start + sequence_length + d = query_layer.shape[-1] + b = query_layer.shape[0] - q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...] - k_pos_emb = k_pos_emb[sequence_start:sequence_end, ...] + if stream_ptr_capsule is not None: + cuda.get_values(inference_params.position_embedding_matrix, inference_params.seq_len + 1, inference_params.q_pos_emb, d, b, stream_ptr_capsule) + cuda.get_values(inference_params.position_embedding_matrix, inference_params.seq_len + 1, inference_params.k_pos_emb, d, b, stream_ptr_capsule) + else: + cuda.get_values2(inference_params.position_embedding_matrix, inference_params.seq_len + 1, inference_params.q_pos_emb, d, b) + cuda.get_values2(inference_params.position_embedding_matrix, inference_params.seq_len + 1, inference_params.k_pos_emb, d, b) + inference_params.query_layer.copy_(apply_rotary_pos_emb(query_layer, inference_params.q_pos_emb, self.qkv_format, fused=True)) + inference_params.key_layer.copy_(apply_rotary_pos_emb(key_layer, inference_params.k_pos_emb, self.qkv_format, fused=True)) + else: + assert (not isinstance(query_layer, Float8Tensor) + and not isinstance(key_layer, Float8Tensor) + ), "RoPE is not supported for Float8Tensors!" + # duplicate the pos_emb for self attention + if not isinstance(rotary_pos_emb, tuple): + rotary_pos_emb = ((rotary_pos_emb,) * 2) + + q_pos_emb, k_pos_emb = rotary_pos_emb + + # adjust key and value for inference + if inference_params is not None: + if self.qkv_format == "sbhd": + sequence_length = key_layer.size(0) + elif self.qkv_format == "bshd": + sequence_length = key_layer.size(1) + + sequence_start = inference_params.sequence_len_offset + sequence_end = sequence_start + sequence_length + + q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...] + k_pos_emb = k_pos_emb[sequence_start:sequence_end, ...] + + query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb, self.qkv_format, fused=True) + key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb, self.qkv_format, fused=True) + query_layer = query_layer.contiguous() + key_layer = key_layer.contiguous() + - query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb, self.qkv_format, fused=True) - key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb, self.qkv_format, fused=True) # =========================== @@ -4351,8 +4419,8 @@ def forward( # =========================== context_layer = self.core_attention( - query_layer, - key_layer, + inference_params.query_layer if inference_params.thd else query_layer, + inference_params.key_layer if inference_params.thd else key_layer, value_layer, qkv_format=self.qkv_format, cu_seqlens_q=None, From 41045ab33ef2b093ef934a0b8602780ce15ad94b Mon Sep 17 00:00:00 2001 From: root Date: Wed, 1 May 2024 17:54:21 +0000 Subject: [PATCH 080/244] fp8 cuda_graphs generation Signed-off-by: root Signed-off-by: Pawel Gadzinski --- docs/examples/te_gemma/generate_baseline.py | 8 +- .../examples/te_gemma/generate_cuda_graphs.py | 6 +- docs/examples/te_gemma/generate_fp8.py | 6 +- docs/examples/te_gemma/te_gemma.py | 338 +++++------------- docs/examples/te_gemma/utils.py | 3 +- transformer_engine/pytorch/csrc/extensions.h | 3 + .../pytorch/csrc/extensions/attention.cu | 57 +++ 7 files changed, 172 insertions(+), 249 deletions(-) diff --git a/docs/examples/te_gemma/generate_baseline.py b/docs/examples/te_gemma/generate_baseline.py index 872ce92ac8..cb6fa86bf0 100755 --- a/docs/examples/te_gemma/generate_baseline.py +++ b/docs/examples/te_gemma/generate_baseline.py @@ -20,7 +20,7 @@ model = model.to(torch.bfloat16) tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name) -inputs = tokenizer(["I love when"] * 32, return_tensors="pt", padding=True) +inputs = tokenizer(["Some random initial str ", "Another string ... "] * 32, return_tensors="pt", padding=True) inputs['input_ids'] = inputs['input_ids'].cuda() inputs['attention_mask'] = inputs['attention_mask'].cuda() @@ -29,9 +29,11 @@ # Początek pomiaru czasu start_time = time.time() +import pdb +pdb.set_trace() outputs = model.generate( **inputs, - max_new_tokens=40 + max_new_tokens=1000 ) # Koniec pomiaru czasu @@ -42,7 +44,7 @@ -print(outputs) +print(duration) # Decode the output tensor to text generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True) diff --git a/docs/examples/te_gemma/generate_cuda_graphs.py b/docs/examples/te_gemma/generate_cuda_graphs.py index ae5e413afc..694dabfd91 100644 --- a/docs/examples/te_gemma/generate_cuda_graphs.py +++ b/docs/examples/te_gemma/generate_cuda_graphs.py @@ -19,7 +19,7 @@ hyperparams.model_name = "../../../../gemma-weights" hyperparams.fuse_qkv_params = True -model = init_te_gemma_model(hyperparams, fp8_model_init=True).cuda() +model = init_te_gemma_model(hyperparams, fp8_model_init=True, qkv_format="thd").cuda() print("Loading model") model_state_dict = torch.load('model_fp8_state_dict.pth') @@ -27,7 +27,7 @@ print("Model loaded") tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name) -inputs = tokenizer(["I love when"] * 32, return_tensors="pt", padding=True) +inputs = tokenizer(["Some random initial str ", "Another string ... "] * 32, return_tensors="pt", padding=True) inputs['input_ids'] = inputs['input_ids'].cuda() inputs['attention_mask'] = inputs['attention_mask'].cuda() @@ -47,7 +47,7 @@ model.eval() outputs = model.generate( **inputs, - max_new_tokens=40, + max_new_tokens=1000, use_cuda_graphs=True ) diff --git a/docs/examples/te_gemma/generate_fp8.py b/docs/examples/te_gemma/generate_fp8.py index bde5be1def..3ff07adf18 100755 --- a/docs/examples/te_gemma/generate_fp8.py +++ b/docs/examples/te_gemma/generate_fp8.py @@ -19,7 +19,7 @@ hyperparams.model_name = "../../../../gemma-weights" hyperparams.fuse_qkv_params = True -model = init_te_gemma_model(hyperparams, fp8_model_init=True).cuda() +model = init_te_gemma_model(hyperparams, fp8_model_init=True, qkv_format="thd").cuda() print("Loading model") model_state_dict = torch.load('model_fp8_state_dict.pth') @@ -27,7 +27,7 @@ print("Model loaded") tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name) -inputs = tokenizer(["I love when", "I "] * 32, return_tensors="pt", padding=True) +inputs = tokenizer(["Some random initial str ", "Another string ... "] * 32, return_tensors="pt", padding=True) inputs['input_ids'] = inputs['input_ids'].cuda() inputs['attention_mask'] = inputs['attention_mask'].cuda() @@ -47,7 +47,7 @@ model.eval() outputs = model.generate( **inputs, - max_new_tokens=40, + max_new_tokens=1000, use_cuda_graphs=False ) diff --git a/docs/examples/te_gemma/te_gemma.py b/docs/examples/te_gemma/te_gemma.py index 376eb4bbd5..3d96a97934 100755 --- a/docs/examples/te_gemma/te_gemma.py +++ b/docs/examples/te_gemma/te_gemma.py @@ -13,11 +13,6 @@ from transformers.generation.utils import * import torch -from torch import nn -from torch.utils.cpp_extension import load - - - import transformer_engine as te from transformer_engine.pytorch.attention import InferenceParams, RotaryPositionEmbedding from transformer_engine.pytorch.fp8 import fp8_model_init @@ -29,13 +24,6 @@ from transformers.utils import WEIGHTS_INDEX_NAME from transformers.utils.hub import get_checkpoint_shard_files -cuda = load( - name='attention_copy', - sources=['attention_copy.cu'], - verbose=True -) - - @contextmanager def replace_decoder(te_decoder_cls): """ @@ -71,7 +59,7 @@ def __init__(self, config, layer_idx, *args, **kwargs): fuse_qkv_params=config.fuse_qkv_params, normalization="RMSNorm", activation="geglu", - attn_input_format="bshd", + attn_input_format=config.qkv_format, num_gqa_groups=config.num_key_value_heads, attention_hidden_size=4096, layer_number=(layer_idx+1) @@ -91,73 +79,36 @@ def forward(self, forward pass of the `TransformerLayer`. Also, make sure the output format matches the output of the HF's `GemmaDecoderLayer`. """ - return (super().forward(hidden_states, attention_mask=attention_mask, rotary_pos_emb=self.te_rope_emb, inference_params=inference_params, self_attn_mask_type=self_attn_mask_type),) + return (super().forward( + hidden_states, + attention_mask=attention_mask, + rotary_pos_emb=self.te_rope_emb, + inference_params=inference_params, + self_attn_mask_type=self_attn_mask_type + ),) class TeGraphed(torch.nn.Module): - def __init__(self, model, lm_head, inference_params, normalizer, generation_config, thd=True): + def __init__(self, model, lm_head, inference_params, dtype, generation_config): super().__init__() self.model = model self.inference_params = inference_params - self.inference_params.thd = thd - self.thd=thd - self.normalizer = normalizer + self.normalizer = torch.tensor(self.model.config.hidden_size**0.5, dtype=dtype) self.generation_config = generation_config self.lm_head = lm_head - - self.attn_mask = torch.ones([inference_params.max_batch_size, inference_params.max_sequence_length]).to(dtype=torch.bool) - - def forward(self, - hidden_states, - unfinished_sequences, - seq_len, - vl_space, - kl_space, - ql_space, - seqlens_q, - cu_seqlens_q, - cu_seqlens_kv, - seq_offsets_q, - seq_offsets_k, - seq_offsets_v, - position_embedding_matrix, - k_pos_emb, - q_pos_emb, - *args - ): + def forward(self, hidden_states, unfinished_sequences): hidden_states.data[:] = hidden_states.data[:] * self.normalizer - inference_params = InferenceParams(self.inference_params.max_batch_size, self.inference_params.max_sequence_length) - inference_params.thd = self.thd - inference_params.seq_len = seq_len - inference_params.value_layer = vl_space - inference_params.key_layer = kl_space - inference_params.query_layer = ql_space - inference_params.seqlens_q = seqlens_q - inference_params.cu_seqlens_q = cu_seqlens_q - inference_params.cu_seqlens_kv = cu_seqlens_kv - inference_params.seq_offsets_q = seq_offsets_q - inference_params.seq_offsets_k = seq_offsets_k - inference_params.seq_offsets_v = seq_offsets_v - inference_params.position_embedding_matrix = position_embedding_matrix - inference_params.k_pos_emb = k_pos_emb - inference_params.q_pos_emb = q_pos_emb - - assert len(args) == 28 * 2 - - - for i in range(0, len(args), 2): - inference_params.key_value_memory_dict[i // 2 + 1] = (args[i], args[i + 1]) for decoder_layer in self.model.layers: hidden_states.copy_(decoder_layer( hidden_states, - inference_params=inference_params, + inference_params=self.inference_params, self_attn_mask_type='padding', attention_mask=None )[0]) - seq_len.copy_(seq_len + 1) + self.inference_params.seq_len.copy_(self.inference_params.seq_len + 1) hidden_states.copy_(self.model.norm(hidden_states)) logits = self.lm_head(hidden_states) @@ -167,11 +118,10 @@ def forward(self, # Sequences, which are finished should contain padding - taken from huggingface transformers. next_tokens = next_tokens * unfinished_sequences + self.generation_config.pad_token_id * (1 - unfinished_sequences) - unfinished_sequences.copy_(unfinished_sequences & ~(next_tokens == self.generation_config.eos_token_id)) - hidden_states.copy_(self.model.embed_tokens(next_tokens).unsqueeze(1)) - return next_tokens, logits + + return next_tokens class TEGemmaForCausalLM: """ @@ -193,12 +143,12 @@ def __new__(cls, config: GemmaConfig): return gemma_for_causal_lm @classmethod - def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, fp8_init=False, **kwargs): + def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, fp8_init=False, qkv_format="bshd", **kwargs): """ Custom method adapted from `from_pretrained` method in HuggingFace Transformers repo: https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579 """ - + config.qkv_format = qkv_format with fp8_model_init(fp8_init): vanilla_model = cls(config) is_local = os.path.isdir(pretrained_model_name_or_path) @@ -236,35 +186,35 @@ def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, fp8 return vanilla_model - @torch.no_grad() - def generate( - self, - input_ids: Optional[torch.Tensor] = None, - generation_config: Optional[GenerationConfig] = None, - max_new_tokens = 0, - use_cuda_graphs = False, - **kwargs, - ): - - batch_size, seq_len = input_ids.shape - generation_config, _ = self._prepare_generation_config(generation_config, **kwargs) - unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) - - # inference_params object is a cache, where keys and values of previous tokens are stored - inference_params = te.pytorch.InferenceParams( - max_batch_size=batch_size, - max_sequence_length=max(128, input_ids.shape[1] + max_new_tokens) - ) - - # mask has shape [batch_size, num_heads, 1, max_seq_len] and contains False - # when coressponding token is padding and True otherwise. - pad_attention_mask = input_ids.ne(generation_config.pad_token_id).unsqueeze(1).unsqueeze(2) - - ############################################################################################# - # Encode part # - ############################################################################################# + @staticmethod + def _padding_to_beginning(inputs, lengths): + """ + Gets the tensor with sequence padded from the beginning and + return tensor padded from its end. + Parameters + ---------- + inputs : Tensor, tensor with shape [b, s] containing token numbers. + It's padded from the beggining. + lengths: Tensor, tensor with shape [s] with lengths of the sequences. + """ + max_seq_len = torch.max(lengths) + batch_size, max_seq_len = inputs.shape + new_input_ids = inputs.clone() + for i in range(batch_size): + new_input_ids[i,:lengths[i]] = inputs[i, (max_seq_len-lengths[i]):max_seq_len] + new_input_ids[i,lengths[i]:] = inputs[i, 0:(max_seq_len-lengths[i])] + inputs.copy_(new_input_ids) + + def _generate_context_phase( + self, + input_ids, + inference_params, + pad_token_id, + eos_token_id, + unfinished_sequences + ): hidden_states = self.model.embed_tokens(input_ids) normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) @@ -274,188 +224,98 @@ def generate( for decoder_layer in self.model.layers: hidden_states = decoder_layer( hidden_states, - # In the case of arbiutrary mask, the meaning of True and False is switched, so negation is needed. - attention_mask=pad_attention_mask, + attention_mask=None, self_attn_mask_type="padding_causal", inference_params=inference_params )[0] + hidden_states = self.model.norm(hidden_states) logits = self.lm_head(hidden_states) logits = logits.float() - logits = logits[:, -1, :] - next_tokens = torch.argmax(logits, dim=-1) + logits = logits[torch.arange(logits.size(0)), inference_params.seq_len - 1, :] + next_tokens = torch.argmax(logits, dim=1) + # Sequences, which are finished should contain padding - taken from huggingface transformers. - next_tokens = next_tokens * unfinished_sequences + generation_config.pad_token_id * (1 - unfinished_sequences) + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) output_tokens.append(next_tokens) - unfinished_sequences = unfinished_sequences & ~(next_tokens == generation_config.eos_token_id) - + unfinished_sequences = unfinished_sequences & ~(next_tokens == eos_token_id) hidden_states = self.model.embed_tokens(next_tokens).unsqueeze(1) - lengths = torch.sum(pad_attention_mask, dim=-1).squeeze() + for k, v in inference_params.key_value_memory_dict.items(): + key_layer = v[0].permute((1, 0, 2, 3)).contiguous().cuda() + value_layer = v[1].permute((1, 0, 2, 3)).contiguous().cuda() + inference_params.key_value_memory_dict[k] = (key_layer, value_layer) - def process(x): - """ - Args: - x: Tensor with shape [s, b, h, d], where s is sequence length, b is batch size, h is number of heads, and d is hidden dimension. - l: List of integers representing the actual lengths of each sequence in the batch before padding. - - Returns: - torch.Tensor: Tensor with switched contents such that padded zeros are moved to the end of the sequence. - """ - s1, b, h, d = x.shape - s = torch.max(lengths) - new_x = torch.zeros_like(x) - - for i in range(b): - seq_length = lengths[i] - - # Check if the sequence length is not the full length of the sequence dimension - if seq_length < s: - # Place the original data to the end part of the new tensor - new_x[:seq_length, i, :, :] = x[s - seq_length:s, i, :, :] - # Place the padding at the beginning of the new tensor - new_x[seq_length:, i, :, :] = 0 - else: - # If seq_length is the full length, just copy the entire sequence as is - new_x[:, i, :, :] = x[:, i, :, :] - - return new_x.permute((1, 0, 2, 3)).contiguous().cuda() - - inference_params.seq_len = lengths.to(torch.int32) - seq_len_offset = torch.max(lengths).item() - - - seqlens_q = torch.zeros((batch_size), dtype=torch.int32).cuda() - cu_seqlens_q = torch.zeros((batch_size + 1), dtype=torch.int32).cuda() - cu_seqlens_kv = torch.zeros((batch_size + 1), dtype=torch.int32).cuda() - seq_offsets_q = torch.zeros((batch_size + 1), dtype=torch.int32).cuda() - seq_offsets_k = torch.zeros((batch_size + 1), dtype=torch.int32).cuda() - seq_offsets_v = torch.zeros((batch_size + 1), dtype=torch.int32).cuda() - + return hidden_states, output_tokens + + @torch.no_grad() + def generate( + self, + input_ids: Optional[torch.Tensor] = None, + generation_config: Optional[GenerationConfig] = None, + max_new_tokens = 0, + use_cuda_graphs = False, + **kwargs, + ): + batch_size, _ = input_ids.shape + generation_config, _ = self._prepare_generation_config(generation_config, **kwargs) + unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) + # inference_params object is a cache, where keys and values of previous tokens are stored + inference_params = te.pytorch.InferenceParams( + max_batch_size=batch_size, + max_sequence_length=input_ids.shape[1] + max_new_tokens + ) + # lengths is a tensor of shape [s] representing lengths of sequences. + lengths = torch.sum(input_ids.ne(generation_config.pad_token_id), dim=-1).squeeze() + inference_params.seq_len = lengths.to(torch.int32).clone().cuda() - for k, v in inference_params.key_value_memory_dict.items(): - key_layer = process(v[0]) - value_layer = process(v[1]) - inference_params.key_value_memory_dict[k] = (key_layer, value_layer) - - ############################################################################################# - # Generate part # - ############################################################################################# - print("generate part") + TEGemmaForCausalLM._padding_to_beginning(input_ids, lengths) + + hidden_states, output_tokens = TEGemmaForCausalLM._generate_context_phase( + self, + input_ids, + inference_params, + generation_config.pad_token_id, + generation_config.eos_token_id, + unfinished_sequences + ) graphed_generator = TeGraphed( lm_head=self.lm_head, model=self.model, inference_params=inference_params, - normalizer=normalizer, generation_config=generation_config, - thd=True + dtype=hidden_states.dtype, ) - tensor_pointers = [(kc, vc) for kc, vc in inference_params.key_value_memory_dict.values()] - tensor_pointers = [element for tuple_ in tensor_pointers for element in tuple_] - - copy_hidden = hidden_states.clone() - copy_unfinished_sequences = unfinished_sequences.clone() - copy_tensor_pointers = [t.clone() for t in tensor_pointers] - copy_seq_len = inference_params.seq_len.clone() - - vl_space = torch.zeros((batch_size, 1, 16, 256)).to(torch.bfloat16).cuda() - kl_space = torch.zeros((batch_size, 1, 16, 256)).to(torch.bfloat16).cuda() - ql_space = torch.zeros((batch_size, 1, 16, 256)).to(torch.bfloat16).cuda() - q_pos_emb = torch.zeros((batch_size, 1, 1, 256)).to(torch.float32).cuda() - k_pos_emb = torch.zeros((batch_size, 1, 1, 256)).to(torch.float32).cuda() + args = (hidden_states, unfinished_sequences) - - te_rope = RotaryPositionEmbedding(256) - position_embedding_matrix = te_rope(8192).to(torch.float32).cuda() - - - graphed_layers = None + saved_args = [arg.clone() for arg in args] # Warmup iterations of graph will change the arguments, we want to revert that. if use_cuda_graphs: fp8_format = Format.HYBRID fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") - - print("recording...") graphed_layers = te.pytorch.make_graphed_callables( graphed_generator, - ( - hidden_states, - unfinished_sequences, - inference_params.seq_len, - vl_space, - kl_space, - ql_space, - seqlens_q, - cu_seqlens_q, - cu_seqlens_kv, - seq_offsets_q, - seq_offsets_k, - seq_offsets_v, - position_embedding_matrix, - k_pos_emb, - q_pos_emb, - *tensor_pointers - ), + args, fp8_enabled=True, fp8_recipe=fp8_recipe, - allow_unused_input=True - ) - print("recorded...") - hidden_states.data[:] = copy_hidden - unfinished_sequences.data[:] = copy_unfinished_sequences - inference_params.seq_len.data[:] = copy_seq_len - - - i = 0 - for t in tensor_pointers: - t.data[:] = copy_tensor_pointers[i] - i = i + 1 + allow_unused_input=True, + num_warmup_iters=10 + ) + + for i in range(len(saved_args)): + args[i].copy_(saved_args[i]) + inference_params.seq_len.copy_(lengths.to(torch.int32)) for i in range(max_new_tokens): - next_tokens, logits = graphed_layers( - hidden_states, - unfinished_sequences, - inference_params.seq_len, - vl_space, - kl_space, - ql_space, - seqlens_q, - cu_seqlens_q, - cu_seqlens_kv, - seq_offsets_q, - seq_offsets_k, - seq_offsets_v, - position_embedding_matrix, - k_pos_emb, - q_pos_emb, - *tensor_pointers - ) if use_cuda_graphs else graphed_generator( - hidden_states, - unfinished_sequences, - inference_params.seq_len, - vl_space, - kl_space, - ql_space, - seqlens_q, - cu_seqlens_q, - cu_seqlens_kv, - seq_offsets_q, - seq_offsets_k, - seq_offsets_v, - position_embedding_matrix, - k_pos_emb, - q_pos_emb, - *tensor_pointers - ) + next_tokens = graphed_layers(*args) if use_cuda_graphs else graphed_generator(*args) output_tokens.append(next_tokens.clone()) - seq_len_offset += 1 result = torch.cat((input_ids, torch.stack(output_tokens).permute([1, 0])), dim=1) return result diff --git a/docs/examples/te_gemma/utils.py b/docs/examples/te_gemma/utils.py index 1746c3165d..6ccce22f9a 100755 --- a/docs/examples/te_gemma/utils.py +++ b/docs/examples/te_gemma/utils.py @@ -87,7 +87,7 @@ def init_baseline_model(hyperparams): return model -def init_te_gemma_model(hyperparams, fp8_model_init=False): +def init_te_gemma_model(hyperparams, fp8_model_init=False, qkv_format="thd"): # Init the model from te_gemma import TEGemmaForCausalLM config = AutoConfig.from_pretrained(hyperparams.model_name) @@ -98,6 +98,7 @@ def init_te_gemma_model(hyperparams, fp8_model_init=False): config=config, torch_dtype=torch.bfloat16, fp8_init=fp8_model_init, + qkv_format=qkv_format ) # Needed for the cases when using TEGemmaForCausalLM diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 2f552fe28f..f49a68cd50 100755 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -182,6 +182,9 @@ std::vector fused_attn_bwd( at::Tensor fa_prepare_fwd(at::Tensor qkvi); at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v); +void attention_copy(torch::Tensor A, torch::Tensor seq_len, torch::Tensor B, int max_seq_len, int b, int s); +void get_values(torch::Tensor A, torch::Tensor seq_len, torch::Tensor B, int d, int b); + /*************************************************************************************************** * GEMM **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index 037ae72b2b..5637166753 100755 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -1608,3 +1608,60 @@ at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v) { return qkv; } + + +extern "C" +__global__ void attn_copy(__nv_bfloat16* A, int* seq_len, __nv_bfloat16* B, int max_seq_len, int b, int s) { + for(int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { + int per_block = s / blockDim.x; + int remainder = s % blockDim.x; + int copy_block_offset_begin = per_block * threadIdx.x + min(threadIdx.x, remainder); + + int offset = seq_len[batch_idx]; + + __nv_bfloat16* begin_A_copy = A + max_seq_len * s * batch_idx + s * offset; + __nv_bfloat16* begin_B_copy = B + s * batch_idx; + + int limit = copy_block_offset_begin + per_block + (threadIdx.x < remainder ? 1 : 0); + + for(int i = copy_block_offset_begin; i < limit; i++) { + *(begin_A_copy + i) = *(begin_B_copy + i); + } + } +} + +extern "C" +__global__ void gv(float* src, int* seq_len, float* dst, int d, int b) { + // src [s, 1, 1, d] + // dst [b] + for(int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { + int per_block = d / blockDim.x; + int remainder = d % blockDim.x; + int copy_block_offset_begin = per_block * threadIdx.x + min(threadIdx.x, remainder); + + int offset = seq_len[batch_idx]; + + float* begin_src_copy = src + d * offset; + float* begin_dst_copy = dst + d * batch_idx; + + int limit = copy_block_offset_begin + per_block + (threadIdx.x < remainder ? 1 : 0); + + for(int i = copy_block_offset_begin; i < limit; i++) { + *(begin_dst_copy + i) = *(begin_src_copy + i); + } + } +} + + + +void attention_copy(torch::Tensor A, torch::Tensor seq_len, torch::Tensor B, int max_seq_len, int b, int s) { + attn_copy<<<16, 32, 0, at::cuda::getCurrentCUDAStream()>>>(reinterpret_cast<__nv_bfloat16*>(A.data_ptr()), + seq_len.data_ptr(), + reinterpret_cast<__nv_bfloat16*>(B.data_ptr()), max_seq_len, b, s); +} + +void get_values(torch::Tensor A, torch::Tensor seq_len, torch::Tensor B, int d, int b) { + gv<<<16, 32, 0, at::cuda::getCurrentCUDAStream()>>>(A.data_ptr(), + seq_len.data_ptr(), + B.data_ptr(), d, b); +} From c69664190b47673b5ab9ad415725b5657ab72c0f Mon Sep 17 00:00:00 2001 From: root Date: Wed, 1 May 2024 18:16:47 +0000 Subject: [PATCH 081/244] attention.py Signed-off-by: root Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/attention.py | 228 ++++++++++++------------ 1 file changed, 118 insertions(+), 110 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index d8d5ec7560..37efd8eb30 100755 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -16,7 +16,6 @@ import torch import torch.nn.functional as F -from torch.utils.cpp_extension import load import transformer_engine_extensions as tex from transformer_engine.pytorch.cpp_extensions import ( @@ -106,12 +105,6 @@ __all__ = ["DotProductAttention", "InferenceParams", "MultiheadAttention"] -cuda = load( - name='attention_copy', - sources=['attention_copy.cu'], - verbose=True -) - class InferenceParams: # pylint: disable=too-few-public-methods """ @@ -132,7 +125,6 @@ def __init__(self, max_batch_size, max_sequence_length): self.sequence_len_offset = 0 self.batch_size_offset = 0 self.key_value_memory_dict = {} - self.thd = False self.seq_len=torch.tensor((1000)) def swap_key_value_dict(self, batch_indices): @@ -2025,6 +2017,7 @@ def forward( assert ( max_seqlen_q == max_seqlen_kv ), "Maximum sequence length for Q and KV should be the same." + if cu_seqlens_q is None: assert (attention_mask is not None ), "Please provide attention_mask for padding!" @@ -3000,6 +2993,8 @@ def __init__( self.unfused_attention = UnfusedDotProductAttention( norm_factor, **attn_kwargs, layer_number=layer_number) + self.offset_module = OffsetsModule() + def _checkpointed_attention_forward( self, attention_func: Callable, @@ -3195,14 +3190,10 @@ def forward( first microbatch (since it is the first gradient being produced) """ - if inference_params.thd: - inference_params.value_layer.copy_(value_layer.contiguous()) - value_layer = inference_params.value_layer - inference_params.key_layer.copy_(key_layer.contiguous()) - key_layer = inference_params.key_layer - else: - key_layer = key_layer.contiguous() - value_layer = value_layer.contiguous() + batch_size = key_layer.shape[0] + q_size = query_layer.shape[1] + key_layer = key_layer.contiguous() + value_layer = value_layer.contiguous() @@ -3244,7 +3235,6 @@ def forward( - if inference_params is not None: assert self.layer_number is not None, "Layer number must be set!" @@ -3255,7 +3245,8 @@ def forward( (inference_key_memory, inference_value_memory, ) = inference_params.key_value_memory_dict[self.layer_number] - if not inference_params.thd: + + if not qkv_format == "thd": batch_start = inference_params.batch_size_offset batch_end = batch_start + key_layer.size(1) assert batch_end <= inference_key_memory.size(1) @@ -3264,8 +3255,6 @@ def forward( sequence_end = sequence_start + key_layer.size(0) assert sequence_end <= inference_key_memory.size(0) - - # Copy keys and values into KV-cache inference_key_memory[ sequence_start:sequence_end, batch_start:batch_end, ...] = key_layer @@ -3274,86 +3263,87 @@ def forward( key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...] value_layer = inference_value_memory[:sequence_end, batch_start:batch_end, ...] else: - bs = query_layer.shape[0] - import ctypes - current_stream = torch.cuda.current_stream() - - stream_ptr_capsule = ctypes.pythonapi.PyCapsule_New(current_stream.cuda_stream, None, None) if current_stream.cuda_stream != 0 else None + if query_layer.shape[1] == 1: + bs = query_layer.shape[0] - if stream_ptr_capsule is not None: - cuda.attention_copy( - inference_key_memory, - inference_params.seq_len, - key_layer, - inference_params.max_sequence_length, - bs, - self.channels, - stream_ptr_capsule - ) - cuda.attention_copy( - inference_value_memory, - inference_params.seq_len, - value_layer, - inference_params.max_sequence_length, - bs, - self.channels, - stream_ptr_capsule) - else: - cuda.attention_copy2( + tex.attention_copy( inference_key_memory, inference_params.seq_len, key_layer, inference_params.max_sequence_length, bs, - self.channels - ) - cuda.attention_copy2( + self.channels) + tex.attention_copy( inference_value_memory, inference_params.seq_len, value_layer, inference_params.max_sequence_length, bs, self.channels) - - - inference_params.seqlens_q.copy_(torch.ones([bs], dtype=torch.int32, device="cuda")) - inference_params.cu_seqlens_q.copy_(torch.zeros(bs + 1, dtype=torch.int32, device="cuda")) - inference_params.cu_seqlens_kv.copy_(torch.zeros(bs + 1, dtype=torch.int32, device="cuda")) - inference_params.cu_seqlens_q[1:].copy_(torch.cumsum(inference_params.seqlens_q, dim=0)) - inference_params.cu_seqlens_kv[1:].copy_(torch.cumsum(inference_params.seq_len + 1, dim=0)) - - seqlens_q = inference_params.seqlens_q - cu_seqlens_q = inference_params.cu_seqlens_q - cu_seqlens_kv = inference_params.cu_seqlens_kv - - max_seqlen_q = 1 - max_seqlen_kv = inference_params.max_sequence_length - + + max_seqlen_q = 1 + max_seqlen_kv = inference_params.max_sequence_length + cu_seqlens_q, cu_seqlens_kv, seq_offsets_q, seq_offsets_k, seq_offsets_v = self.offset_module(bs, inference_params, max_seqlen_q, max_seqlen_kv, self.channels) - inference_params.seq_offsets_q.copy_(torch.arange(0, bs + 1, dtype=torch.int32, device="cuda") * self.channels * max_seqlen_q) - inference_params.seq_offsets_k.copy_(torch.arange(0, bs + 1, dtype=torch.int32, device="cuda") * self.channels * max_seqlen_kv) - inference_params.seq_offsets_k[1:].copy_(inference_params.seq_offsets_k[1:] ) - inference_params.seq_offsets_v.copy_(inference_params.seq_offsets_k) - seq_offsets_q = inference_params.seq_offsets_q - seq_offsets_k = inference_params.seq_offsets_k - seq_offsets_v = inference_params.seq_offsets_v + query_layer = query_layer.view(-1, query_layer.shape[2], query_layer.shape[3]).to(torch.bfloat16) + key_layer = inference_key_memory.view(-1, inference_key_memory.shape[2], inference_key_memory.shape[3]).to(torch.bfloat16) + value_layer = inference_value_memory.view(-1, inference_value_memory.shape[2], inference_value_memory.shape[3]).to(torch.bfloat16) + else: + bs = query_layer.shape[0] + key_layer = key_layer.transpose(0, 1) + value_layer = value_layer.transpose(0, 1) - query_layer = query_layer.view(-1, query_layer.shape[2], query_layer.shape[3]).to(torch.bfloat16) - key_layer = inference_key_memory.view(-1, inference_key_memory.shape[2], inference_key_memory.shape[3]).to(torch.bfloat16) - value_layer = inference_value_memory.view(-1, inference_value_memory.shape[2], inference_value_memory.shape[3]).to(torch.bfloat16) + batch_start = inference_params.batch_size_offset + batch_end = batch_start + key_layer.size(1) + assert batch_end <= inference_key_memory.size(1) + + sequence_start = inference_params.sequence_len_offset + sequence_end = sequence_start + key_layer.size(0) + assert sequence_end <= inference_key_memory.size(0) + + # Copy keys and values into KV-cache + inference_key_memory[ + sequence_start:sequence_end, batch_start:batch_end, ...] = key_layer + inference_value_memory[ + sequence_start:sequence_end, batch_start:batch_end, ...] = value_layer + key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...] + value_layer = inference_value_memory[:sequence_end, batch_start:batch_end, ...] + + seqlens = (torch.zeros(bs + 1, dtype=torch.int32, device="cuda")) + seqlens[1:] = inference_params.seq_len + cu_seqlens_q = (torch.zeros(bs + 1, dtype=torch.int32, device="cuda")) + cu_seqlens_kv = (torch.zeros(bs + 1, dtype=torch.int32, device="cuda")) + cu_seqlens_q[1:] = (torch.cumsum(inference_params.seq_len, dim=0)) + cu_seqlens_kv[1:] = (torch.cumsum(inference_params.seq_len, dim=0)) + + max_seqlen_q = query_layer.shape[1] + max_seqlen_kv = key_layer.shape[0] + + seq_offsets_q = torch.arange(0, bs + 1, dtype=torch.int32, device="cuda") * self.channels * max_seqlen_q + seq_offsets_k = torch.arange(0, bs + 1, dtype=torch.int32, device="cuda") * self.channels * max_seqlen_kv + seq_offsets_v = seq_offsets_k + + + key_layer = key_layer.transpose(0, 1) + value_layer = value_layer.transpose(0, 1) + key_layer = key_layer.contiguous() + value_layer = value_layer.contiguous() + + + query_layer = query_layer.view(-1, query_layer.shape[2], query_layer.shape[3]).to(torch.bfloat16).contiguous() + key_layer = key_layer.view(-1, inference_key_memory.shape[2], inference_key_memory.shape[3]).to(torch.bfloat16).contiguous() + value_layer = value_layer.view(-1, inference_value_memory.shape[2], inference_value_memory.shape[3]).to(torch.bfloat16).contiguous() - qkv_format="thd" if qkv_format == "bshd": key_layer = key_layer.transpose(0, 1) value_layer = value_layer.transpose(0, 1) - key_layer = key_layer.contiguous() value_layer = value_layer.contiguous() - + assert (key_layer.shape[-2] == self.num_gqa_groups_per_partition and value_layer.shape[-2] == self.num_gqa_groups_per_partition ), f"Keys and values must have num_gqa_group = {self.num_gqa_groups} heads!" @@ -3549,9 +3539,6 @@ def forward( max_seqlen_kv, query_layer.shape[-1], # head_dim ) - if inference_params is not None: - if inference_params.thd: - fused_attention_backend = FusedAttnBackend["F16_arbitrary_seqlen"] # DPA does not support FP8; for FP8, use cpp_extensions modules directly is_backend_avail = (fused_attention_backend in [FusedAttnBackend["F16_max512_seqlen"], @@ -3592,14 +3579,6 @@ def forward( if self.device_compute_capability == (9, 0): use_flash_attention = False - if inference_params is not None: - if inference_params.thd: - use_flash_attention = False - - if len(query_layer.shape) == 4: - use_flash_attention=True - - if use_flash_attention: if _NVTE_DEBUG: print("[DotProductAttention]: using flash-attn",_flash_attn_version) @@ -3669,10 +3648,10 @@ def forward( cp_global_ranks=self.cp_global_ranks, cp_stream=self.cp_stream, is_first_microbatch=is_first_microbatch) - if inference_params.thd: + if qkv_format == "thd": out = out.unsqueeze(1) - - + if q_size > 1: + out = out.view((batch_size, -1, out.shape[2])).contiguous() return out @@ -4061,11 +4040,14 @@ def __init__( **common_gemm_kwargs, ) + self._allocator = BufferAllocator() + + def _allocate_memory( self, inference_max_sequence_len: int, batch_size: int, dtype: torch.dtype ) -> torch.Tensor: - return torch.zeros( + return torch.empty( inference_max_sequence_len, batch_size, self.num_gqa_groups_per_partition, @@ -4074,6 +4056,9 @@ def _allocate_memory( device=torch.cuda.current_device(), ) + def alloc(self, size, dtype, device): + return self._allocator(size, dtype, device) + def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None: """ Set the tensor parallel group for the given @@ -4365,24 +4350,22 @@ def forward( # ====================================================== # Apply relative positional encoding (rotary embedding) # ====================================================== - if rotary_pos_emb is not None: - if inference_params.thd: - import ctypes - current_stream = torch.cuda.current_stream() - stream_ptr_capsule = ctypes.pythonapi.PyCapsule_New(current_stream.cuda_stream, None, None) if current_stream.cuda_stream != 0 else None + if self.qkv_format == "thd" and query_layer.shape[1] == 1: + if not isinstance(rotary_pos_emb, tuple): + rotary_pos_emb = ((rotary_pos_emb,) * 2) d = query_layer.shape[-1] b = query_layer.shape[0] - if stream_ptr_capsule is not None: - cuda.get_values(inference_params.position_embedding_matrix, inference_params.seq_len + 1, inference_params.q_pos_emb, d, b, stream_ptr_capsule) - cuda.get_values(inference_params.position_embedding_matrix, inference_params.seq_len + 1, inference_params.k_pos_emb, d, b, stream_ptr_capsule) - else: - cuda.get_values2(inference_params.position_embedding_matrix, inference_params.seq_len + 1, inference_params.q_pos_emb, d, b) - cuda.get_values2(inference_params.position_embedding_matrix, inference_params.seq_len + 1, inference_params.k_pos_emb, d, b) - inference_params.query_layer.copy_(apply_rotary_pos_emb(query_layer, inference_params.q_pos_emb, self.qkv_format, fused=True)) - inference_params.key_layer.copy_(apply_rotary_pos_emb(key_layer, inference_params.k_pos_emb, self.qkv_format, fused=True)) + q_pos_emb = self.alloc((b, 1, 1, d), torch.float32, "cuda") + k_pos_emb = self.alloc((b, 1, 1, d), torch.float32, "cuda") + q_freq, k_freq = rotary_pos_emb + + tex.get_values(q_freq, inference_params.seq_len + 1, q_pos_emb, d, b) + tex.get_values(k_freq, inference_params.seq_len + 1, k_pos_emb, d, b) + query_layer.copy_(apply_rotary_pos_emb(query_layer, q_pos_emb, "bshd", fused=True)) + key_layer.copy_(apply_rotary_pos_emb(key_layer, k_pos_emb, "bshd", fused=True)) else: assert (not isinstance(query_layer, Float8Tensor) and not isinstance(key_layer, Float8Tensor) @@ -4399,6 +4382,8 @@ def forward( sequence_length = key_layer.size(0) elif self.qkv_format == "bshd": sequence_length = key_layer.size(1) + elif self.qkv_format == "thd": + sequence_length = key_layer.size(1) sequence_start = inference_params.sequence_len_offset sequence_end = sequence_start + sequence_length @@ -4406,21 +4391,19 @@ def forward( q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...] k_pos_emb = k_pos_emb[sequence_start:sequence_end, ...] - query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb, self.qkv_format, fused=True) - key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb, self.qkv_format, fused=True) + query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb, self.qkv_format if self.qkv_format != "thd" else "bshd", fused=True) + key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb, self.qkv_format if self.qkv_format != "thd" else "bshd", fused=True) query_layer = query_layer.contiguous() key_layer = key_layer.contiguous() - - # =========================== # Core attention computation # =========================== context_layer = self.core_attention( - inference_params.query_layer if inference_params.thd else query_layer, - inference_params.key_layer if inference_params.thd else key_layer, + query_layer, + key_layer, value_layer, qkv_format=self.qkv_format, cu_seqlens_q=None, @@ -4456,3 +4439,28 @@ def forward( if self.input_layernorm and self.return_layernorm_output: outputs += (layernorm_output,) return outputs if len(outputs) > 1 else outputs[0] + +class OffsetsModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, bs, inference_params, max_seqlen_q, max_seqlen_kv, channels): + + cu_seqlens_q = torch.arange(0, bs + 1, dtype=torch.int32, device="cuda") + cu_seqlens_kv = torch.zeros(bs + 1, dtype=torch.int32, device="cuda") + cu_seqlens_kv[1:].copy_(torch.cumsum(inference_params.seq_len + 1, dim=0)) + + + seq_offsets_q = torch.arange(0, bs + 1, dtype=torch.int32, device="cuda") * channels * max_seqlen_q + seq_offsets_k = torch.arange(0, bs + 1, dtype=torch.int32, device="cuda") * channels * max_seqlen_kv + seq_offsets_v = seq_offsets_k.clone() + + return cu_seqlens_q, cu_seqlens_kv, seq_offsets_q, seq_offsets_k, seq_offsets_v + +class BufferAllocator(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, size, dtype, device): + a = torch.zeros(size, dtype=dtype, device=device) + return a \ No newline at end of file From d572eb6e5644fb18938854970edd557cd742970c Mon Sep 17 00:00:00 2001 From: root Date: Wed, 1 May 2024 18:18:00 +0000 Subject: [PATCH 082/244] attention.py Signed-off-by: root Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/attention.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 37efd8eb30..4565f27d2f 100755 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -81,9 +81,6 @@ from flash_attn.flash_attn_interface import _flash_attn_varlen_forward as _flash_attn_forward # pylint: disable=no-name-in-module,ungrouped-imports from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward # pylint: disable=no-name-in-module -Z = torch.zeros((200, 200)).to(torch.bfloat16).cuda() -T = torch.zeros((200, 200)).to(torch.int32).cuda() - META_QKV = tex.FP8FwdTensors.GEMM1_OUTPUT META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1 META_O = tex.FP8FwdTensors.GEMM2_INPUT From d94c50501bb3a3d00a0ade9880c08a24de34e5e6 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 3 May 2024 22:03:32 +0000 Subject: [PATCH 083/244] Low level fixes Signed-off-by: Pawel Gadzinski --- tests/pytorch/test_torch_save_load.py | 1 + transformer_engine/pytorch/attention.py | 5 +++++ transformer_engine/pytorch/cpp_extensions/fused_attn.py | 2 +- 3 files changed, 7 insertions(+), 1 deletion(-) mode change 100644 => 100755 tests/pytorch/test_torch_save_load.py diff --git a/tests/pytorch/test_torch_save_load.py b/tests/pytorch/test_torch_save_load.py old mode 100644 new mode 100755 index 85ec7685b3..e29a986dd5 --- a/tests/pytorch/test_torch_save_load.py +++ b/tests/pytorch/test_torch_save_load.py @@ -65,6 +65,7 @@ def __init__(self, precision, use_bias): self.inp_type = tex.DType.kFloat8E4M3 self.weights_type = tex.DType.kFloat8E4M3 self.outp_type = precision + def forward(self, inp, weight): inp_fp8 = cast_to_fp8( diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 4565f27d2f..3c23b08c87 100755 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -3575,6 +3575,11 @@ def forward( and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]): if self.device_compute_capability == (9, 0): use_flash_attention = False + + if self.qkv_format == "thd": + use_flash_attention = False + use_fused_attention = True + fused_attention_backend = FusedAttnBackend["F16_arbitrary_seqlen"] if use_flash_attention: if _NVTE_DEBUG: diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 1e0bc53fe1..12ef702d9a 100755 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -892,7 +892,7 @@ def fused_attn_fwd( cu_seqlens_q, cu_seqlens_kv, q, k, v, qkv_dtype, seq_offsets_q, seq_offsets_k, seq_offsets_v, - d_scale_qkv, q_scale_s, q_scale_o, amax_s, amax_o, + d_scale_qkv, d_scale_s, q_scale_s, q_scale_o, amax_s, amax_o, attn_bias, rng_gen, rng_elts_per_thread, ) From 78125c4019f69b987b4c9fa0f5e2f2db84afc2fc Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 3 May 2024 22:39:54 +0000 Subject: [PATCH 084/244] pybind Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/csrc/extensions/pybind.cpp | 4 ++++ 1 file changed, 4 insertions(+) mode change 100644 => 100755 transformer_engine/pytorch/csrc/extensions/pybind.cpp diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp old mode 100644 new mode 100755 index 4a7d51cada..246724130f --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -102,6 +102,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("get_cudnn_version", &get_cudnn_version, "Get cuDNN version"); m.def("userbuf_comm_available", &userbuf_comm_available, "If userbuf backend is available"); + + m.def("attention_copy", &attention_copy, "attention_copy"); + m.def("get_values", &get_values, "get_values"); + // Data structures py::class_(m, "FP8TensorMeta") .def(py::init<>()) From 3ad4714ea2fcf22293610788bec8a3da9f87f4d5 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Sat, 4 May 2024 00:06:04 +0000 Subject: [PATCH 085/244] Prepare attention for generalized kernel Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/attention.py | 174 +++++++++--------------- 1 file changed, 67 insertions(+), 107 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 3c23b08c87..2af0a417e7 100755 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -2989,8 +2989,13 @@ def __init__( self.unfused_attention = UnfusedDotProductAttention( norm_factor, **attn_kwargs, layer_number=layer_number) + + self._allocator = StaticBufferAllocator() + + + def alloc(self, size, dtype, device): + return self._allocator(size, dtype, device) - self.offset_module = OffsetsModule() def _checkpointed_attention_forward( self, @@ -3231,7 +3236,6 @@ def forward( qkv_format = self.qkv_format - if inference_params is not None: assert self.layer_number is not None, "Layer number must be set!" @@ -3242,8 +3246,7 @@ def forward( (inference_key_memory, inference_value_memory, ) = inference_params.key_value_memory_dict[self.layer_number] - - if not qkv_format == "thd": + if qkv_format in ["bshd", "sbhd"]: batch_start = inference_params.batch_size_offset batch_end = batch_start + key_layer.size(1) assert batch_end <= inference_key_memory.size(1) @@ -3259,80 +3262,49 @@ def forward( sequence_start:sequence_end, batch_start:batch_end, ...] = value_layer key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...] value_layer = inference_value_memory[:sequence_end, batch_start:batch_end, ...] - else: - - if query_layer.shape[1] == 1: - bs = query_layer.shape[0] - - tex.attention_copy( - inference_key_memory, - inference_params.seq_len, - key_layer, - inference_params.max_sequence_length, - bs, - self.channels) - tex.attention_copy( - inference_value_memory, - inference_params.seq_len, - value_layer, - inference_params.max_sequence_length, - bs, - self.channels) - - max_seqlen_q = 1 - max_seqlen_kv = inference_params.max_sequence_length - cu_seqlens_q, cu_seqlens_kv, seq_offsets_q, seq_offsets_k, seq_offsets_v = self.offset_module(bs, inference_params, max_seqlen_q, max_seqlen_kv, self.channels) - - - query_layer = query_layer.view(-1, query_layer.shape[2], query_layer.shape[3]).to(torch.bfloat16) - key_layer = inference_key_memory.view(-1, inference_key_memory.shape[2], inference_key_memory.shape[3]).to(torch.bfloat16) - value_layer = inference_value_memory.view(-1, inference_value_memory.shape[2], inference_value_memory.shape[3]).to(torch.bfloat16) - else: - bs = query_layer.shape[0] - - key_layer = key_layer.transpose(0, 1) - value_layer = value_layer.transpose(0, 1) - - batch_start = inference_params.batch_size_offset - batch_end = batch_start + key_layer.size(1) - assert batch_end <= inference_key_memory.size(1) - - sequence_start = inference_params.sequence_len_offset - sequence_end = sequence_start + key_layer.size(0) - assert sequence_end <= inference_key_memory.size(0) - - # Copy keys and values into KV-cache - inference_key_memory[ - sequence_start:sequence_end, batch_start:batch_end, ...] = key_layer - inference_value_memory[ - sequence_start:sequence_end, batch_start:batch_end, ...] = value_layer - key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...] - value_layer = inference_value_memory[:sequence_end, batch_start:batch_end, ...] - - seqlens = (torch.zeros(bs + 1, dtype=torch.int32, device="cuda")) - seqlens[1:] = inference_params.seq_len - cu_seqlens_q = (torch.zeros(bs + 1, dtype=torch.int32, device="cuda")) - cu_seqlens_kv = (torch.zeros(bs + 1, dtype=torch.int32, device="cuda")) - cu_seqlens_q[1:] = (torch.cumsum(inference_params.seq_len, dim=0)) - cu_seqlens_kv[1:] = (torch.cumsum(inference_params.seq_len, dim=0)) - - max_seqlen_q = query_layer.shape[1] - max_seqlen_kv = key_layer.shape[0] + elif qkv_format == "thd": + """ + inference_params.seq_len - lengths of processed sequences + """ + bs = query_layer.shape[0] + + tex.attention_copy( + inference_key_memory, + inference_params.seq_len, + inference_params.incoming_seq_len, + key_layer, + inference_params.max_incoming_seqence_length, + inference_params.max_sequence_length, + bs, + self.channels) + tex.attention_copy( + inference_value_memory, + inference_params.seq_len, + inference_params.incoming_seq_len, + value_layer, + inference_params.max_incoming_seqence_length, + inference_params.max_sequence_length, + bs, + self.channels) - seq_offsets_q = torch.arange(0, bs + 1, dtype=torch.int32, device="cuda") * self.channels * max_seqlen_q - seq_offsets_k = torch.arange(0, bs + 1, dtype=torch.int32, device="cuda") * self.channels * max_seqlen_kv - seq_offsets_v = seq_offsets_k - - - key_layer = key_layer.transpose(0, 1) - value_layer = value_layer.transpose(0, 1) - key_layer = key_layer.contiguous() - value_layer = value_layer.contiguous() - - - query_layer = query_layer.view(-1, query_layer.shape[2], query_layer.shape[3]).to(torch.bfloat16).contiguous() - key_layer = key_layer.view(-1, inference_key_memory.shape[2], inference_key_memory.shape[3]).to(torch.bfloat16).contiguous() - value_layer = value_layer.view(-1, inference_value_memory.shape[2], inference_value_memory.shape[3]).to(torch.bfloat16).contiguous() + max_seqlen_q = inference_params.max_incoming_seqence_length + max_seqlen_kv = inference_params.max_sequence_length + cu_seqlens_q = self.alloc(bs + 1, dtype=torch.int32, device="cuda") + cu_seqlens_kv = self.alloc(bs + 1, dtype=torch.int32, device="cuda") + seq_offsets_q = self.alloc(bs + 1, dtype=torch.int32, device="cuda") + seq_offsets_k = self.alloc(bs + 1, dtype=torch.int32, device="cuda") + seq_offsets_v = self.alloc(bs + 1, dtype=torch.int32, device="cuda") + + cu_seqlens_q.copy_(torch.arange(0, bs + 1, dtype=torch.int32, device="cuda")) + cu_seqlens_kv[1:].copy_(torch.cumsum(inference_params.seq_len + 1, dim=0)) + + seq_offsets_q.copy_(torch.arange(0, bs + 1, dtype=torch.int32, device="cuda") * self.channels * max_seqlen_q) + seq_offsets_k.copy_(torch.arange(0, bs + 1, dtype=torch.int32, device="cuda") * self.channels * max_seqlen_kv) + seq_offsets_v.copy_(seq_offsets_k) + + query_layer = query_layer.view(-1, query_layer.shape[2], query_layer.shape[3]).to(torch.bfloat16) + key_layer = inference_key_memory.view(-1, inference_key_memory.shape[2], inference_key_memory.shape[3]).to(torch.bfloat16) + value_layer = inference_value_memory.view(-1, inference_value_memory.shape[2], inference_value_memory.shape[3]).to(torch.bfloat16) if qkv_format == "bshd": @@ -4042,7 +4014,7 @@ def __init__( **common_gemm_kwargs, ) - self._allocator = BufferAllocator() + self._allocator = StaticBufferAllocator() @@ -4353,12 +4325,15 @@ def forward( # Apply relative positional encoding (rotary embedding) # ====================================================== if rotary_pos_emb is not None: - if self.qkv_format == "thd" and query_layer.shape[1] == 1: - if not isinstance(rotary_pos_emb, tuple): - rotary_pos_emb = ((rotary_pos_emb,) * 2) - - d = query_layer.shape[-1] - b = query_layer.shape[0] + assert (not isinstance(query_layer, Float8Tensor) + and not isinstance(key_layer, Float8Tensor) + ), "RoPE is not supported for Float8Tensors!" + # duplicate the pos_emb for self attention + if not isinstance(rotary_pos_emb, tuple): + rotary_pos_emb = ((rotary_pos_emb,) * 2) + + if self.qkv_format == "thd" and inference_params is not None: + b, d = query_layer.shape[0], query_layer.shape[-1] q_pos_emb = self.alloc((b, 1, 1, d), torch.float32, "cuda") k_pos_emb = self.alloc((b, 1, 1, d), torch.float32, "cuda") @@ -4369,12 +4344,6 @@ def forward( query_layer.copy_(apply_rotary_pos_emb(query_layer, q_pos_emb, "bshd", fused=True)) key_layer.copy_(apply_rotary_pos_emb(key_layer, k_pos_emb, "bshd", fused=True)) else: - assert (not isinstance(query_layer, Float8Tensor) - and not isinstance(key_layer, Float8Tensor) - ), "RoPE is not supported for Float8Tensors!" - # duplicate the pos_emb for self attention - if not isinstance(rotary_pos_emb, tuple): - rotary_pos_emb = ((rotary_pos_emb,) * 2) q_pos_emb, k_pos_emb = rotary_pos_emb @@ -4442,24 +4411,15 @@ def forward( outputs += (layernorm_output,) return outputs if len(outputs) > 1 else outputs[0] -class OffsetsModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, bs, inference_params, max_seqlen_q, max_seqlen_kv, channels): - - cu_seqlens_q = torch.arange(0, bs + 1, dtype=torch.int32, device="cuda") - cu_seqlens_kv = torch.zeros(bs + 1, dtype=torch.int32, device="cuda") - cu_seqlens_kv[1:].copy_(torch.cumsum(inference_params.seq_len + 1, dim=0)) - - seq_offsets_q = torch.arange(0, bs + 1, dtype=torch.int32, device="cuda") * channels * max_seqlen_q - seq_offsets_k = torch.arange(0, bs + 1, dtype=torch.int32, device="cuda") * channels * max_seqlen_kv - seq_offsets_v = seq_offsets_k.clone() - - return cu_seqlens_q, cu_seqlens_kv, seq_offsets_q, seq_offsets_k, seq_offsets_v - -class BufferAllocator(torch.nn.Module): +class StaticBufferAllocator(torch.nn.Module): + """ + This class is used when we use te.make_graphed_callable(). + CUDA Graphs require all tensors to be static. Neverthless, + torch API make_graphed_callable() takes care of output of torch modules, + and makes them static. Thus by wrapping allocation of memory into + torch.nn.Module, we can greatly simplify our code. + """ def __init__(self): super().__init__() From 6dc12bc31e7cb6a0108b32313ecc23b81df581e5 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Sat, 4 May 2024 00:17:08 +0000 Subject: [PATCH 086/244] Prepare attention for generalized kernel Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/attention.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 2af0a417e7..bfba0d5e29 100755 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -4339,12 +4339,25 @@ def forward( k_pos_emb = self.alloc((b, 1, 1, d), torch.float32, "cuda") q_freq, k_freq = rotary_pos_emb - tex.get_values(q_freq, inference_params.seq_len + 1, q_pos_emb, d, b) - tex.get_values(k_freq, inference_params.seq_len + 1, k_pos_emb, d, b) + tex.get_values( + q_freq, + inference_params.seq_len + 1, + inference_params.incoming_seq_len, + q_pos_emb, + d, + b + ) + tex.get_values( + k_freq, + inference_params.seq_len + 1, + inference_params.incoming_seq_len, + k_pos_emb, + d, + b + ) query_layer.copy_(apply_rotary_pos_emb(query_layer, q_pos_emb, "bshd", fused=True)) key_layer.copy_(apply_rotary_pos_emb(key_layer, k_pos_emb, "bshd", fused=True)) else: - q_pos_emb, k_pos_emb = rotary_pos_emb # adjust key and value for inference @@ -4353,8 +4366,6 @@ def forward( sequence_length = key_layer.size(0) elif self.qkv_format == "bshd": sequence_length = key_layer.size(1) - elif self.qkv_format == "thd": - sequence_length = key_layer.size(1) sequence_start = inference_params.sequence_len_offset sequence_end = sequence_start + sequence_length @@ -4362,8 +4373,8 @@ def forward( q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...] k_pos_emb = k_pos_emb[sequence_start:sequence_end, ...] - query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb, self.qkv_format if self.qkv_format != "thd" else "bshd", fused=True) - key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb, self.qkv_format if self.qkv_format != "thd" else "bshd", fused=True) + query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb, self.qkv_format, fused=True) + key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb, self.qkv_format, fused=True) query_layer = query_layer.contiguous() key_layer = key_layer.contiguous() From 894cf584620b02140bc5aee9747a6b289da6eaae Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Mon, 6 May 2024 23:05:41 +0000 Subject: [PATCH 087/244] Drafts of tutorials Signed-off-by: Pawel Gadzinski --- ...tutorial_accelerate_hf_gemma_with_te.ipynb | 243 +++++++ .../tutorial_generation_gemma_with_te.ipynb | 622 ++++++++++++------ ...tutorial_accelerate_hf_llama_with_te.ipynb | 12 + 3 files changed, 678 insertions(+), 199 deletions(-) create mode 100755 docs/examples/te_gemma/tutorial_accelerate_hf_gemma_with_te.ipynb diff --git a/docs/examples/te_gemma/tutorial_accelerate_hf_gemma_with_te.ipynb b/docs/examples/te_gemma/tutorial_accelerate_hf_gemma_with_te.ipynb new file mode 100755 index 0000000000..c6a236a366 --- /dev/null +++ b/docs/examples/te_gemma/tutorial_accelerate_hf_gemma_with_te.ipynb @@ -0,0 +1,243 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Accelerating a Hugging Face Gemma model with Transformer Engine" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In the previous [tutorial](../te_llama/tutorial_accelerate_hf_llama_with_te.ipynb) we have demonstrated how to accelerate HF Llama models using Transformer Engine. Now, we will make similar thing with Gemma model. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Dependencies for this tutorial\n", + "\n", + "Following files and media are necessary to effectively run this tutorial:\n", + "\n", + "1. `te_gemma.py`\n", + " - This file contains the code to load a Hugging Face Gemma checkpoint in Transformer Engine's `TransformerLayer` instead of Hugging Face's `GemmaDecoderLayer`. This is used in the following two sections of the tutorial - \"Improvement 1\" and \"Improvement 2\".\n", + "2. `utils.py`\n", + " - This file contains the code related to dataloading, hyperparameters, setting up model/optimizers/accelerator, model training and other miscellaneous tasks like restarting the jupyter notebook from within the cell. \n", + "3. `media/`\n", + " - This directory contains the images used in the following tutorial." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Differences between Llama and Gemma" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Thr Llama and the Gemma are very similar models - both are based on Transformer Decoder architecture. The differences between them are the following:\n", + "1. The Gemma uses RMSNorm with zero centered gamma parameter, and Llama uses stardard RMSNorm.\n", + "2. The Gemma uses different head dimension than embedding dimension, but in Llama this numbers are equal.\n", + "3. The Gemma uses GeGlu activation function, the Llama uses SwiGlu." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# [Baseline] Running HF `GemmaModel` (Precision: `BF16`)\n", + "\n", + "Similarly to the Llama tutorial, we begin the experiments by running baseline training in BF16 precision.\n", + "\n", + "
\n", + "\n", + "Note\n", + " \n", + "This tutorial loads and trains a Llama 2 7B model which takes up most of the GPU memory and therefore, we need to restart the jupyter notebook each time before running the following sections. A small utility method `restart_jupyter_notebook` is defined in the accompanying `utils.py` file. This function restarts the jupyter notebook so that the GPU memory is flushed before the model is loaded again from the checkpoint in order to avoid running into OOM (Out Of Memory) errors.\n", + "\n", + "If the utility doesn't work, comment this line `restart_jupyter_notebook()` in the following cell and manually restart the jupyter notebook before running the cell. Repeat the same for other sections in this tutorial.\n", + "\n", + "
\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "restart_jupyter_notebook()\n", + "\n", + "\n", + "# Import necessary packages and methods\n", + "from utils import *\n", + "\n", + "\n", + "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", + "## !!! `model_name` attr must point to the location of the model weights !!!\n", + "## Weights can be downloaded from: https://llama.meta.com/llama-downloads/\n", + "hyperparams.model_name = \"../../../../gemma-weights\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", + "hyperparams.mixed_precision = \"bf16\"\n", + "\n", + "\n", + "# Init the model and accelerator wrapper\n", + "model = init_baseline_model(hyperparams).cuda()\n", + "accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)\n", + "\n", + "\n", + "# Finetune the model\n", + "finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's add this information in a table and keep comparing it with a few possible improvements in future sections:\n", + "\n", + "| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n", + "|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n", + "| HF (baseline) | BF16 | - | 1 |" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# [Improvement 1] Replace HF's `GemmaDecoderLayer` with TE's `TransformerLayer` (Precision: `BF16`)\n", + "\n", + "Now we substitute *GemmaDecoderLayer* with highly tuned *TransformerLayer*. Let's see how this will impact the speed of the mode." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "#restart_jupyter_notebook()\n", + "\n", + "\n", + "# Import necessary packages and methods\n", + "from utils import *\n", + "\n", + "\n", + "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", + "## !!! `model_name` attr must point to the location of the model weights !!!\n", + "## Weights can be downloaded from: https://llama.meta.com/llama-downloads/\n", + "hyperparams.model_name = \"../../../../gemma-weights\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", + "hyperparams.mixed_precision = \"bf16\"\n", + "\n", + "\n", + "# Init the model and accelerator wrapper\n", + "model = init_te_gemma_model(hyperparams).cuda()\n", + "accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)\n", + "\n", + "\n", + "# Finetune the model\n", + "finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Compared to the \"baseline\" implementation, we see that using Transformer Engine's `TransformerLayer` in place of Huggging Face's `GemmaDecoderLayer` gives a speedup of **??%** even when using only BF16 precision!\n", + "\n", + "| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n", + "|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n", + "| HF (baseline) | BF16 | 315 | 1 |\n", + "| TE (replace `GemmaDecoderLayer` with `TE.TransformerLayer`) | BF16 | 252 | 1.25 |" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# [Improvement 2] Replace HF's `GemmaDecoderLayer` with TE's `TransformerLayer` (Precision: `FP8`)\n", + "\n", + "The last improvement is about enabling FP8 precision. Let's see how it works." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "#restart_jupyter_notebook()\n", + "\n", + "\n", + "# Import necessary packages and methods\n", + "from utils import *\n", + "\n", + "\n", + "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", + "## !!! `model_name` attr must point to the location of the model weights !!!\n", + "## Weights can be downloaded from: https://llama.meta.com/llama-downloads/\n", + "hyperparams.model_name = \"../../../../gemma-weights\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", + "hyperparams.mixed_precision = \"fp8\"\n", + "\n", + "\n", + "# Init the model and accelerator wrapper\n", + "model = init_te_gemma_model(hyperparams).cuda()\n", + "accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)\n", + "\n", + "\n", + "# Finetune the model\n", + "finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n", + "|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n", + "| HF (baseline) | BF16 | - | 1 |\n", + "| TE (replace `GemmaDecoderLayer` with `TE.TransformerLayer`) | BF16 | - | - |\n", + "| TE (replace `GemmaDecoderLayer` with `TE.TransformerLayer`) | FP8 | - | - |\n", + "\n", + "\n", + "After turning on FP8 precision, we get even more speedup of almost **??%**!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Conclusion\n", + "\n", + "We can see, that similar to the Llama model, using `TransformerLayer` module from Transformer Engine as a substitute for Hugging Face's `GemmaDecoderLayer` provides a speedup over Hugging Face's native Gemma implementation." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## See more\n", + "\n", + "We also prepared [tutorial](./tutorial_generation_gemma_with_te.ipynb) covering CUDA graphs and THD attention which we use to speedup Gemma generation." + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb index 9fb353b8ea..cf851bbdf1 100755 --- a/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb +++ b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb @@ -2,31 +2,50 @@ "cells": [ { "cell_type": "markdown", - "id": "2cac9d39", + "id": "8581f0e4", "metadata": {}, "source": [ - "# Accelerating a Hugging Face Gemma model generation with Transformer Engine\n", + "# Speeding up the Hugging Face Gemma model generation with Cuda Graphs and THD attention with FP8 precision\n", + "\n", + "As it can be seen in the [tutorial for Llama](../te_llama/tutorial_accelerate_hf_llama_with_te.ipynb) or [tutorial for Gemma](./tutorial_accelerate_hf_gemma_with_te.ipynb), transformer models can be accelerated by using Transformer's Engine `TransformerLayer`. In this tutorial we want to present few more advanced features, namely\n", + "1. THD attention layout.\n", + "2. FP8 weight calibration - for doing inference in FP8 precisions for models, which were trained in higher precisions.\n", + "3. CUDA Graphs API.\n", + "\n", + "We will compare generation time at 3 benchmarks:\n", + "- long input sequences (max 256 tokens), short generation part (max 128 tokens),\n", + "- short input sequences (max 64 tokens), long generation (max 100 tokens),\n", + "\n", + "All benchmarks above run with batch size 64 and on the dataset \"timdettmers/openassistant-guanaco\".\n", "\n", "
\n", "\n", - "Goal\n", + "Note\n", + " \n", + "This tutorial aims to demonstrate features of TransformerEngine mentioned above on the example of generation. It's important to note though, that NVIDIA offers other library to use for inference - namely [TensorRT](https://developer.nvidia.com/tensorrt), which should be used in such cases.\n", "\n", - "This tutorial showcases how to accelerate generation done by a full Gemma model from [Hugging Face](https://huggingface.co/google/gemma-7b-it) by using `TransformerLayer` from the [Transformer Engine library](https://github.com/NVIDIA/TransformerEngine) in `BF16` precision.\n", + "
\n", "\n", - "\n" + "\n" ] }, { "cell_type": "markdown", - "id": "401f7fb1", + "id": "b18f91a9", + "metadata": {}, + "source": [ + "## Dependencies for this tutorial" + ] + }, + { + "cell_type": "markdown", + "id": "e5201d77", "metadata": {}, "source": [ - "## Dependencies for this tutorial\n", - "\n", "Following files and media are necessary to effectively run this tutorial:\n", "\n", "1. `te_gemma.py`\n", - " - This file contains the code to load a Hugging Face Gemma checkpoint in Transformer Engine's `TransformerLayer` instead of Hugging Face's `LlamaDecoderLayer`. Also it contains the logic of the generation using TransformerEngine. \n", + " - This file contains the code to load a Hugging Face Gemma checkpoint in Transformer Engine's `TransformerLayer` instead of Hugging Face's `GemmaDecoderLayer`. It does also contain code for generation with THD attention and weight calibration.\n", "2. `utils.py`\n", " - This file contains the code related to dataloading, hyperparameters, setting up model/optimizers/accelerator, model training and other miscellaneous tasks like restarting the jupyter notebook from within the cell. \n", "3. `media/`\n", @@ -35,68 +54,60 @@ }, { "cell_type": "markdown", - "id": "b564503c", + "id": "84bfbe6c", "metadata": {}, "source": [ - "## Baseline HuggingFace Gemma generation" + "## Table of contents" ] }, { "cell_type": "markdown", - "id": "24a8d0a5", + "id": "f09c29e7", "metadata": {}, "source": [ - "
\n", - "\n", - "Note\n", - " \n", - "This tutorial loads and trains a Llama 2 7B model which takes up most of the GPU memory and therefore, we need to restart the jupyter notebook each time before running the following sections. A small utility method `restart_jupyter_notebook` is defined in the accompanying `utils.py` file. This function restarts the jupyter notebook so that the GPU memory is flushed before the model is loaded again from the checkpoint in order to avoid running into OOM (Out Of Memory) errors.\n", - "\n", - "If the utility doesn't work, comment this line `restart_jupyter_notebook()` in the following cell and manually restart the jupyter notebook before running the cell. Repeat the same for other sections in this tutorial.\n", - "\n", - "
\n" + "1. [Baseline] Running Hugging Face generation with Gemma model\n", + "2. [Improvement 1] Speeding up generation by using Transformer Engine THD attention.\n", + "3. [Improvement 2] Running generation of the model trained in hign precision in FP8.\n", + "4. [Improvement 3] Speeding up generation with CudaGraphs.\n", + "5. Conclusions." + ] + }, + { + "cell_type": "markdown", + "id": "e8dfabbf", + "metadata": {}, + "source": [ + "## [Baseline] Running Hugging Face generation with Gemma model" + ] + }, + { + "cell_type": "markdown", + "id": "59560bff", + "metadata": {}, + "source": [ + "Hugging Face Transformers library offers generation API. We will treat this as our baseline." ] }, { "cell_type": "code", - "execution_count": 5, - "id": "e36ff380", + "execution_count": null, + "id": "7477e469", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 2.60it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Generation time: 26.482454538345337 seconds\n", - "I like the new look of the app. I like the new features. I like the new look of \n", - "==============================\n", - "I do not like the way the new version of the app is set up. I do not like the fa\n" - ] - } - ], + "outputs": [], "source": [ "# Restart the notebook (to flush the GPU memory)\n", "from utils import restart_jupyter_notebook\n", "#restart_jupyter_notebook()\n", "\n", - "\n", "# Import necessary packages and methods\n", "from utils import *\n", - "import torch\n", "\n", "\n", "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", "## !!! `model_name` attr must point to the location of the model weights !!!\n", "## Weights can be downloaded from: https://llama.meta.com/llama-downloads/\n", "hyperparams.model_name = \"../../../../gemma-weights\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", - "hyperparams.mixed_precision = \"no\"\n", + "hyperparams.mixed_precision = \"bf16\"\n", "\n", "\n", "# Init the model and accelerator wrapper\n", @@ -104,211 +115,107 @@ "model = model.to(torch.bfloat16)\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name)\n", - "inputs = tokenizer([\"I like\", \"I do not like\"] * 32, return_tensors=\"pt\", padding=True)\n", + "inputs = tokenizer([\"Some random initial str \", \"Another string ... \"] * 32, return_tensors=\"pt\", padding=True)\n", "\n", "inputs['input_ids'] = inputs['input_ids'].cuda()\n", "inputs['attention_mask'] = inputs['attention_mask'].cuda()\n", "\n", - "\n", "start_time = time.time()\n", "\n", "outputs = model.generate(\n", " **inputs,\n", - " max_new_tokens=400\n", + " max_new_tokens=1000\n", ")\n", "\n", "end_time = time.time()\n", "duration = end_time - start_time\n", - "print(f\"Generation time: {duration} seconds\")\n", "\n", + "print(duration)\n", "\n", "# Decode the output tensor to text\n", "generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n", "\n", - "# Display the first two samples of the generated text\n", - "print(generated_texts[0][:80])\n", - "print(30 * \"=\")\n", - "print(generated_texts[1][:80])" + "# Display the generated text\n", + "for text in generated_texts:\n", + " print(text)\n", + " print(\"=\" * 100)" ] }, { "cell_type": "markdown", - "id": "a64f0f33", + "id": "b3698dc6", "metadata": {}, "source": [ - "Let's add this information in a table and keep comparing it with a few possible improvements in future sections:\n", + "We will put these times into the table for later comparison.\n", "\n", - "| Models | Precision | Generation time | Speedup (over baseline) |\n", - "|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n", - "| HF (baseline) | BF16 | 26.48 | 1 |" + "| Models | max_input_len=64, max_new_tokens=1000 | max_input_len=128, max_new_tokens=128 | \n", + "|-------------------------------------------------------------|---------------------------------------|--------------------------------------|\n", + "| HF (baseline) | - | - | " ] }, { "cell_type": "markdown", - "id": "e2fb88e9", + "id": "2bbf3d47", "metadata": {}, "source": [ - "## [Improvement] Replace HF's `LlamaDecoderLayer` with TE's `TransformerLayer` and use generation within TE\n", + "## [Improvement 1] Speeding up generation by using Transformer Engine THD attention\n", + "\n", + "Similarly to the Gemma tutorial, we substitute `GemmaDecoderLayer` with `TransformerLayer` from Transformer Engine. Since initial sequences have different lengths, we have following choices:\n", + "1. Use padding from the beginning and then use standard attention with `\"bshd\"` or `\"sbhd\"` layout.\n", + "2. Do not pad from the beginning and use THD attention.\n", + "\n", + "In this tutorial we will show the second option. We illustrate THD attention idea on the two pictures below.\n", + "\n", + "
\n", + "\"Logo\n", + "\"Logo\n", + "
\n", "\n" ] }, - { - "cell_type": "markdown", - "id": "6f7fefac", - "metadata": {}, - "source": [ - "```\n", - "@torch.no_grad()\n", - " def generate(\n", - " self,\n", - " input_ids: Optional[torch.Tensor] = None,\n", - " generation_config: Optional[GenerationConfig] = None,\n", - " max_new_tokens = 0,\n", - " **kwargs,\n", - " ):\n", - " num_heads = self.model.config.num_attention_heads\n", - " batch_size, seq_len = input_ids.shape\n", - " max_seq_len = seq_len + max_new_tokens\n", - " generation_config, _ = self._prepare_generation_config(generation_config, **kwargs)\n", - " unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)\n", - "\n", - " # inference_params object is a cache, where keys and values of previous tokens are stored\n", - " inference_params = te.pytorch.InferenceParams(\n", - " max_batch_size=batch_size, \n", - " max_sequence_length=seq_len+max_new_tokens+1) \n", - "\n", - " # mask has shape [batch_size, num_heads, 1, max_seq_len] and contains False \n", - " # when coressponding token is padding and True otherwise.\n", - " pad_attention_mask = input_ids.ne(generation_config.pad_token_id)\n", - " mask = torch.ones((batch_size, num_heads, 1, max_seq_len), dtype=torch.bool).cuda()\n", - " mask[..., :seq_len] = mask[..., :seq_len] & pad_attention_mask.unsqueeze(1).unsqueeze(2).expand(-1, num_heads, -1, -1)\n", - "\n", - " hidden_states = self.model.embed_tokens(input_ids)\n", - " output_tokens = []\n", - " for i in range(max_new_tokens):\n", - " normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)\n", - " hidden_states = hidden_states * normalizer\n", - " for decoder_layer in self.model.layers:\n", - " hidden_states = decoder_layer(\n", - " hidden_states,\n", - " # In the case of arbiutrary mask, the meaning of True and False is switched, so negation is needed.\n", - " attention_mask=pad_attention_mask if i == 0 else ~mask[..., :seq_len],\n", - " self_attn_mask_type=\"padding_causal\" if i == 0 else \"arbitrary\",\n", - " inference_params=inference_params\n", - " )[0]\n", - "\n", - " # inference_params.sequence_len_offset should contain position of the current token in the sequence.\n", - " inference_params.sequence_len_offset += hidden_states.shape[1]\n", - "\n", - " hidden_states = self.model.norm(hidden_states)\n", - " logits = self.lm_head(hidden_states)\n", - " logits = logits.float()\n", - " logits = logits[:, -1, :]\n", - " next_tokens = torch.argmax(logits, dim=-1)\n", - "\n", - " # Sequences, which are finished should contain padding - taken from huggingface transformers.\n", - " next_tokens = next_tokens * unfinished_sequences + generation_config.pad_token_id * (1 - unfinished_sequences)\n", - " output_tokens.append(next_tokens)\n", - "\n", - " unfinished_sequences = unfinished_sequences & ~(next_tokens == generation_config.eos_token_id)\n", - "\n", - " hidden_states = self.model.embed_tokens(next_tokens).unsqueeze(1)\n", - " seq_len += 1\n", - "\n", - " result = torch.cat((input_ids, torch.stack(output_tokens).permute([1, 0])), dim=1)\n", - " return result\n", - "```" - ] - }, { "cell_type": "code", - "execution_count": 8, - "id": "8f2b752e", + "execution_count": null, + "id": "4fc5e1cd", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Generation time: 16.87099289894104 seconds\n", - "I like the idea of a \"re-do\" of the original \"The Man from U.N.C.L.E.\" movie. I \n", - "==============================\n", - "I do not like the way the \"new\" (2011) version of the 1099-MISC is set up. I ha\n" - ] - } - ], + "outputs": [], "source": [ - "# Restart the notebook (to flush the GPU memory)\n", - "from utils import restart_jupyter_notebook\n", - "#restart_jupyter_notebook()\n", - "\n", - "\n", "# Import necessary packages and methods\n", "from utils import *\n", - "import accelerate\n", "\n", "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", "## !!! `model_name` attr must point to the location of the model weights !!!\n", "## Weights can be downloaded from: https://llama.meta.com/llama-downloads/\n", "hyperparams.model_name = \"../../../../gemma-weights\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", + "hyperparams.mixed_precision = \"bf16\"\n", + "hyperparams.fuse_qkv_params = False\n", "\n", "# Init the model and accelerator wrapper\n", - "model = init_te_gemma_model(hyperparams)\n", + "model = init_te_gemma_model(hyperparams).cuda()\n", "#accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)\n", "\n", "model = model.to(torch.bfloat16).cuda()\n", "\n", - "\n", "tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name)\n", - "inputs = tokenizer([\"I like\", \"I do not like\"] * 32, return_tensors=\"pt\", padding=True)\n", + "inputs = tokenizer([\"I love when \", \"I \"] * 32, return_tensors=\"pt\", padding=True)\n", "\n", "inputs['input_ids'] = inputs['input_ids'].cuda()\n", "inputs['attention_mask'] = inputs['attention_mask'].cuda()\n", "\n", "import time\n", "\n", + "# PoczÄ…tek pomiaru czasu\n", "start_time = time.time()\n", "\n", "outputs = model.generate(\n", " **inputs,\n", - " max_new_tokens=400\n", + " max_new_tokens=40\n", ")\n", "\n", + "# Koniec pomiaru czasu\n", "end_time = time.time()\n", + "\n", + "# Obliczamy czas trwania operacji\n", "duration = end_time - start_time\n", "print(f\"Generation time: {duration} seconds\")\n", "\n", @@ -316,35 +223,352 @@ "# Decode the output tensor to text\n", "generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n", "\n", - "# Display the first two samples of the generated text\n", - "print(generated_texts[0][:80])\n", - "print(30 * \"=\")\n", - "print(generated_texts[1][:80])" + "# Display the generated text\n", + "for text in generated_texts:\n", + " print(text)\n", + " print(\"=\" * 100)" ] }, { "cell_type": "markdown", - "id": "67ec126c", + "id": "8e397a65", "metadata": {}, "source": [ - "| Models | Precision | Generation time | Speedup (over baseline) |\n", - "|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n", - "| HF (baseline) | BF16 | 26.48 | 1 |\n", - "| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | BF16 | 16.87 | 1.56 |\n", + "By using THD attention we obtained following speedups:\n", "\n", + "| Models | max_input_len=64, max_new_tokens=1000 | max_input_len=128, max_new_tokens=128 | \n", + "|-------------------------------------------------------------|---------------------------------------|--------------------------------------|\n", + "| HF (baseline) | - | - |\n", + "| THD attention with TE | - | - | " + ] + }, + { + "cell_type": "markdown", + "id": "e6b171a0", + "metadata": {}, + "source": [ + "## [Improvement 2] Running generation of the model trained in high precision in FP8" + ] + }, + { + "cell_type": "markdown", + "id": "1a80288b", + "metadata": {}, + "source": [ + "Now we want to run FP8 generation with Gemma model. But it's not that simple! Since model was trained in BF16 precision, the FP8 scaling factors are not computed. Running the model with such low precision without proper scaling will lead to serious numerical divergence, which will lead to wrong output.\n", + "\n", + "##### Weight calibration\n", "\n", + "The wieght calibration is solution of the problem mentioned above. We will run few forward iterations on BF16 precision within context `te.fp8_autocast(enabled=False, calibration=True)`. This means that the forward pass will be done in higher precision, but we will store `amax_history`, which will be used to compute FP8 scaling factors. \n", + "\n", + "In the code below, we initialize BF16 model, run few iterations of forward passes within mentioned context. Then, we save model - we will also use these weights in the next chapter. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aecee0e1", + "metadata": {}, + "outputs": [], + "source": [ + "# Import necessary packages and methods\n", + "import transformer_engine.pytorch as te\n", + "from utils import *\n", + "import accelerate\n", + "from transformer_engine.pytorch import fp8_model_init\n", + "from transformer_engine.common.recipe import Format, DelayedScaling\n", + "import torch\n", + "\n", + "\n", + "hyperparams.model_name = \"../../../../gemma-weights\"\n", + "hyperparams.fuse_qkv_params = True\n", + "model = init_te_gemma_model(hyperparams, fp8_model_init=False).cuda()\n", + "model = model.to(torch.bfloat16)\n", + "\n", + "\n", + "accelerator = Accelerator(\n", + " log_with=\"wandb\",\n", + " gradient_accumulation_steps=hyperparams.gradient_accumulation_steps,\n", + " mixed_precision=hyperparams.mixed_precision\n", + " )\n", + "train_dataloader = get_dataloaders(accelerator, hyperparams)\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name)\n", + "\n", + "print(\"Calibration started\")\n", + "with te.fp8_autocast(enabled=False, calibrating=True):\n", + " model.train()\n", + " train_dataloader = enumerate(train_dataloader)\n", + "\n", + " for i in range(100):\n", + " step, batch = next(train_dataloader)\n", + " batch[\"input_ids\"] = batch[\"input_ids\"].cuda()\n", + " outputs = model.generate(\n", + " **batch,\n", + " max_new_tokens=10\n", + " )\n", + " generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n", + " print(generated_texts[0][:50])\n", + "print(\"calibration_finished\")\n", + "\n", + "print(\"scale_fwd computation started\")\n", + "with te.fp8_autocast(enabled=True):\n", + " for i in range(10):\n", + " step, batch = next(train_dataloader)\n", + " batch[\"input_ids\"] = batch[\"input_ids\"].cuda()\n", + " outputs = model.generate(\n", + " **batch,\n", + " max_new_tokens=1\n", + " )\n", + "print(\"scale_fwd_computation ended\")\n", + "\n", + "print(\"Casting weights...\")\n", + "model_fp8 = init_te_gemma_model(hyperparams, fp8_model_init=True).cuda()\n", + "model_fp8.load_state_dict(model.state_dict())\n", + "print(\"Weights casted\")\n", + "\n", + "\n", + "print(\"Saving model...\")\n", + "torch.save(model_fp8.state_dict(), 'model_fp8_state_dict.pth')\n", + "print(\"Model saved!\")" + ] + }, + { + "cell_type": "markdown", + "id": "b6dcd135", + "metadata": {}, + "source": [ + "Now we are ready to run FP8 inference." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a913f54d", + "metadata": {}, + "outputs": [], + "source": [ + "#Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "#restart_jupyter_notebook()\n", + "import transformer_engine.pytorch as te\n", "\n", - "After converting to TE Transformer Layers, we obtained the speedup of **56%**!" + "import os\n", + "from torch.cuda.amp import autocast\n", + "\n", + "\n", + "# Import necessary packages and methods\n", + "from utils import *\n", + "\n", + "from transformer_engine.pytorch import fp8_model_init\n", + "from transformer_engine.common.recipe import Format, DelayedScaling\n", + "\n", + "\n", + "hyperparams.model_name = \"../../../../gemma-weights\"\n", + "hyperparams.fuse_qkv_params = True\n", + "model = init_te_gemma_model(hyperparams, fp8_model_init=True, qkv_format=\"thd\").cuda()\n", + "\n", + "print(\"Loading model\")\n", + "model_state_dict = torch.load('model_fp8_state_dict.pth')\n", + "model.load_state_dict(model_state_dict)\n", + "print(\"Model loaded\")\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name)\n", + "inputs = tokenizer([\"Some random initial str \", \"Another string ... \"] * 32, return_tensors=\"pt\", padding=True)\n", + "\n", + "inputs['input_ids'] = inputs['input_ids'].cuda()\n", + "inputs['attention_mask'] = inputs['attention_mask'].cuda()\n", + "\n", + "import time\n", + "\n", + "\n", + "\n", + "start_time = time.time()\n", + "\n", + "fp8_format = Format.HYBRID\n", + "fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo=\"max\")\n", + "torch.manual_seed(1234)\n", + "with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):\n", + " with autocast(dtype=torch.bfloat16, cache_enabled=False):\n", + " with torch.no_grad():\n", + " model.eval()\n", + " outputs = model.generate(\n", + " **inputs,\n", + " max_new_tokens=40,\n", + " use_cuda_graphs=False\n", + " )\n", + "\n", + "\n", + "end_time = time.time()\n", + "duration = end_time - start_time\n", + "\n", + "generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n", + "for text in generated_texts[:12]:\n", + " print(\"-\" * 50)\n", + " print(text)\n", + "\n", + "print(f\"Duration = {duration}\")\n" + ] + }, + { + "cell_type": "markdown", + "id": "8cdbb56c", + "metadata": {}, + "source": [ + "We add the speedups to the table:\n", + "\n", + "| Models | max_input_len=64, max_new_tokens=1000 | max_input_len=128, max_new_tokens=128 | \n", + "|-------------------------------------------------------------|---------------------------------------|--------------------------------------|\n", + "| HF (baseline) | - | - |\n", + "| THD attention with TE | - | - | \n", + "| THD attention + FP8 with TE | - | - | " + ] + }, + { + "cell_type": "markdown", + "id": "21a89d9c", + "metadata": {}, + "source": [ + "## [Improvement 3] Speeding up generation with CudaGraphs" + ] + }, + { + "cell_type": "markdown", + "id": "e2d53e7b", + "metadata": {}, + "source": [ + "The inference code is run by CPU which starts GPU kernels. When GPU kernels are fast enough, it can result in overhead caused by the CPU. That's where Cuda Graphs come in. When some series of kernels starts is repeatable, then it can be recorded and then repeated without usage of the CPU. You can read more about the Cuda Graphs [here](https://developer.nvidia.com/blog/cuda-graphs/).\n", + "\n", + "Pytorch supports Cuda Graphs with `torch.cuda` API. Neverthless, there are some requirements for sequence of tensor operations to be able of being captured and repeated correctly. Namely, all the operations need to be static - meaning that tensors should not \"move\" between iterations. Pytorch offers also simpler way of using cuda graphs - the method `torch.cuda.make_graphed_callables`. We can easily record every pytorch module.\n", + "\n", + "Transformer Engine from version 1.6 also `make_graphed_callables` API. In the following code I run generate method from `te_gemma.py`. This is the code responsible for making graphed part:\n", + "\n", + "```\n", + "graphed_generator = TeGraphed(...)\n", + "(...)\n", + " if use_cuda_graphs:\n", + " fp8_format = Format.HYBRID\n", + " fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo=\"max\")\n", + " graphed_layers = te.pytorch.make_graphed_callables(\n", + " graphed_generator, \n", + " args, \n", + " fp8_enabled=True, \n", + " fp8_recipe=fp8_recipe, \n", + " allow_unused_input=True,\n", + " num_warmup_iters=10\n", + " )\n", + " \n", + " for i in range(max_new_tokens):\n", + " next_tokens = graphed_layers(*args) if use_cuda_graphs else graphed_generator(*args)\n", + " output_tokens.append(next_tokens.clone())\n", + "```\n", + "\n", + "Now, let's see how big the speedup is going to be." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "31a3a8a3", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "os.environ['CUDNN_LOGLEVEL_DBG'] = '3'\n", + "os.environ['CUDNN_LOGDEST_DBG'] = 'backlog.txt'\n", + "#Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "#restart_jupyter_notebook()\n", + "import transformer_engine.pytorch as te\n", + "\n", + "from torch.cuda.amp import autocast\n", + "\n", + "\n", + "# Import necessary packages and methods\n", + "from utils import *\n", + "\n", + "from transformer_engine.pytorch import fp8_model_init\n", + "from transformer_engine.common.recipe import Format, DelayedScaling\n", + "\n", + "\n", + "hyperparams.model_name = \"../../../../gemma-weights\"\n", + "hyperparams.fuse_qkv_params = True\n", + "model = init_te_gemma_model(hyperparams, fp8_model_init=True, qkv_format=\"thd\").cuda()\n", + "\n", + "print(\"Loading model\")\n", + "model_state_dict = torch.load('model_fp8_state_dict.pth')\n", + "model.load_state_dict(model_state_dict)\n", + "print(\"Model loaded\")\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name)\n", + "inputs = tokenizer([\"Some random initial str \", \"Another string ... \"] * 32, return_tensors=\"pt\", padding=True)\n", + "\n", + "inputs['input_ids'] = inputs['input_ids'].cuda()\n", + "inputs['attention_mask'] = inputs['attention_mask'].cuda()\n", + "\n", + "import time\n", + "\n", + "start_time = time.time()\n", + "\n", + "fp8_format = Format.HYBRID\n", + "fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo=\"max\")\n", + "torch.manual_seed(1234)\n", + "with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):\n", + " with autocast(dtype=torch.bfloat16, cache_enabled=False):\n", + " with torch.no_grad():\n", + " model.eval()\n", + " outputs = model.generate(\n", + " **inputs,\n", + " max_new_tokens=10,\n", + " use_cuda_graphs=True\n", + " )\n", + "\n", + "end_time = time.time()\n", + "duration = end_time - start_time\n", + "\n", + "generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n", + "for text in generated_texts[:12]:\n", + " print(\"-\" * 50)\n", + " print(text)\n", + "\n", + "print(f\"Duration = {duration}\")\n" + ] + }, + { + "cell_type": "markdown", + "id": "53bb430f", + "metadata": {}, + "source": [ + "We finally obtained the **??%** speedup.\n", + "\n", + "| Models | max_input_len=64, max_new_tokens=1000 | max_input_len=128, max_new_tokens=128 | \n", + "|-------------------------------------------------------------|---------------------------------------|--------------------------------------|\n", + "| HF (baseline) | - | - |\n", + "| THD attention with TE | - | - | \n", + "| THD attention + FP8 with TE | - | - | \n", + "| THD attention + FP8 + Cuda Graphs with TE | - | - | " + ] + }, + { + "cell_type": "markdown", + "id": "c6e87275", + "metadata": {}, + "source": [ + "## Conclusions" ] }, { "cell_type": "markdown", - "id": "41b80b0f", + "id": "7bb2452d", "metadata": {}, "source": [ - "## Conclusion\n", + "In this tutorial we showed three features of Transformer Engine:\n", + "1. Support of THD attention layout,\n", + "2. FP8 weights calibration.\n", + "3. Support of Cuda Graphs.\n", "\n", - "Using `TransformerLayer` module from Transformer Engine as a substitute for Hugging Face's `LlamaDecoderLayer` provides a speedup over Hugging Face's native Gemma generation implementation. `TransformerLayer` provides a speedup over the baseline implementation" + "Each one of them can be used in different context, here we showed how to use them to obtain fast inference. We remind though, that this is not the fastest possible way of doing inference - for doing do we reccommend looking at the [TensorRT](https://developer.nvidia.com/tensorrt) library from NVIDIA." ] } ], diff --git a/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb b/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb index cc77b484f9..59a04c2599 100755 --- a/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb +++ b/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb @@ -676,6 +676,18 @@ "\n", "Using `TransformerLayer` module from Transformer Engine as a substitute for Hugging Face's `LlamaDecoderLayer` provides a speedup over Hugging Face's native Llama 2 implementation. This needs careful initialization of the model such that the model weights (which are meant for `LlamaDecoderLayer`) are correctly mapped to their counterparts in TE's `TransformerLayer`. Even with `BF16` precision, `TransformerLayer` provides a speedup over the baseline implementation. With `FP8` precision, the speed up is even more pronounced!" ] + }, + { + "cell_type": "markdown", + "id": "0edb6dab", + "metadata": {}, + "source": [ + "# See more\n", + "\n", + "We have prepared similar [tutorial](../te_gemma/tutorial_accelerate_hf_gemma_with_te.ipynb) for the Gemma model.\n", + "\n", + "We also prepared [tutorial](../te_gemma/tutorial_generation_gemma_with_te.ipynb) covering CUDA graphs and THD attention which we use to speedup Gemma generation." + ] } ], "metadata": { From b03543b0e7a3380f033734299fcd30a6e2d2c869 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Mon, 6 May 2024 23:07:12 +0000 Subject: [PATCH 088/244] Drafts of tutorials Signed-off-by: Pawel Gadzinski --- docs/examples/te_gemma/media/pic1.png | Bin 0 -> 2709 bytes docs/examples/te_gemma/media/pic2.png | Bin 0 -> 2709 bytes 2 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 docs/examples/te_gemma/media/pic1.png create mode 100644 docs/examples/te_gemma/media/pic2.png diff --git a/docs/examples/te_gemma/media/pic1.png b/docs/examples/te_gemma/media/pic1.png new file mode 100644 index 0000000000000000000000000000000000000000..b8baa59a07455a18134be2f24e95e51399c6b82c GIT binary patch literal 2709 zcmd^BYfO_@7%ox4p#>KgQItz`%0-|K&<^A>7=i@af*=S;L7-ejiUKXr(%~eE3j;PS zWvmL*N=wC-u|gT8*dZa!sI)MOAC#Md+*&C@Z7Y|y`%IS1-%kIyA8$_HbMl^?yw7=_ z=lw1W83Z-8G=)GQ(BP8+ryvj`92kXLjlrEnUmgcsjB-x}`9s*^eG6c+In58@2Z1zK zZreDw1vq5F$XKk%F5Qm6VPZj^%%6YTh~w-NaJNl z8uNkr{SMde3H4O-EZRs*+?5`fz#hYPiB%UkpnUatHMeiG$}~kA%R)?Y zGL8T0uMpfIeXf?s=D(gQn|b-H8%%}P#0UJ40K^0^slWREguxK@>hMK7xL~9?; z-ykke^$_F}8_?C=Imzr;R6S+<> z7=WN&>pnLklR44`Ca5$-V`C!*gQ@>sIL%wA`{-eO!{d@*23nis4tE^&9cDA|2)&m> ziH$ig`ga%0mi$bp-qXHKGJ%F?@m7`&&-BlS?I{lgAuKO1&siz86Gka|{pq)!y z@Rkb&ME(VR3Wpr0xO-Z#IGYEb9Pw>k;3$pPg};P;GpKOZEUFaP+NIfkJu}`HHZzh| z>+I&{h8v)_wzlRxdvx_J1Un2@#UT4MK4j8F-+(THD#pQwS6$>bTq|{#wgY&&syupz zT^0om8R7F|Dy*}#h-MxyHYzHrx4?@{)vqGolODYTKR&EL{_~(8^fVu4)W>@$0^eFo-sa}8V2)uDV*Prs_I9SCb z4*dD&Ipl8Kzv{Z#6LRBvl_ALn3w7t^q`;1!Y(7+Je%;=WY4(1jZBROfW0wn%aij#8L^p zcV}sxeN(3UU;K7~U1JD+Szt4D&j8G9kQ=67mpLS^ zy6x)W?%+nN;bjbpmB{w|>4_Y0hyZOvT<4pBto4d9F2PUkB+?lS5vV#0z1muBfiAd% zZe7%ACt)EWAs9TKi+$cOA-}|+>cOjt;*u>#L)ypyHwAGtT%%q48 s>%!+h5{3j5G_yYp75~YAlE^Cwvf$~VdrZ?;6vqn+I^`jx3GGc{dTM%1DTT7JJM4<90g+8p{3POQWq0mv<--^q!zj4OI?2nV1dv9)X^ZlOB zIVV3NEYNfVbOQteF%1s#kAy%B2>QrbZ>&E-eFlA^UkuJf2Kqv1JzFO9o3+S;@PiOY zeeuTSWF!6llT$%Q&Ojh$Eh}NrfvHY`KsJ0C?0+!&EMj_ySCF8x=}?aMba#t~p1O8b z8@0uy?f`-vsYVtecvQyLKF8gCX??7qJ40|QZ0xRW8GM&3yH8sAIcHcZc72<$(<0KS zCV!J#8}s_Z!mz7lTcPe=jXO#d&8LutBmr8T-ZQt0cUyB}Y@xl}(6KX`ZMeyQ*}yR0 z=YXSf-TKdUkc+Jjj@Q&{)@~+TI529yDQI<<5g750i}JCi@d`9<7?C9_6|!-dX>n5n z@P?yZ5TON^)|t~1qRPE8QPV7pl*>R<g6SKCj05oH8pRVR7=46 zOyCKLRJJETMG!3$2wIo$*z+>m;a(U)cTR#80m@!j)|(Uw0n`GW477j-Ej$F2T{SHL zM@QNy+W`yIb0WoBezB|u4wQ^DSkcWAELW%L)c_?EON0x7#r9OcJTnHI{5UD%*|TR~ zZ!r@&@t!>K?dI{qy>;BV_aMl-OZj15s@lu3&1ty}EiCNpQDQoP=f;|=W?tuP1T0=J z$#FkSYK2-^w#ca;rr< zDO?(zhmh}lm{)p77z1kufO|3nP^rq+Ze^5WdRv7aj@uaG94V`NWCeY7}! z6>uGe!6Bx5n0mn#?3=SPL?L)P0erDt;;znGG_l7NL}-T*cJy_6TQkl_^Q5fkaTml% z;y=K)wG_kz3b*m=U;;ujY>*tv%WxeDNl+t9i@{hnayEL=#EcP8owFDVCdx&OZ`8IE zmYMZM)2uk~=u)C6G*_{@YV&b9wes7^LfCYUM9kGJ;1$-27>UTL#(ttFF-_d=HT>vj zR~2e`p4FV|S^A^;nW6@E?lppVWwEYMb+?LR_24cz+EmzZw1_tc$3Ro@Yvj`@v(Gi(Kuof;An zf))%94>Qo)crT54Be^?cX#?}~8Gw}g2uSbl>U#KgU|CaFzuaQF+IeY}X@8pnu0a-^ WqyFmJl&&8oA;AG*{xsjDpZ)^GP0?)t literal 0 HcmV?d00001 From d0b62895725bd84e290e6dd6c6a4cce9c5ea33cf Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 7 May 2024 09:23:33 -0700 Subject: [PATCH 089/244] File permission updates Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/csrc/extensions/attention.cu | 0 transformer_engine/pytorch/csrc/extensions/normalization.cu | 0 2 files changed, 0 insertions(+), 0 deletions(-) mode change 100755 => 100644 transformer_engine/pytorch/csrc/extensions/attention.cu mode change 100755 => 100644 transformer_engine/pytorch/csrc/extensions/normalization.cu diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu old mode 100755 new mode 100644 diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cu b/transformer_engine/pytorch/csrc/extensions/normalization.cu old mode 100755 new mode 100644 From 370dd1ef7a801a624a75e76f79f78269a39d7691 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 7 May 2024 09:24:22 -0700 Subject: [PATCH 090/244] File permission updates Signed-off-by: Pawel Gadzinski --- docs/examples/te_gemma/generate.py | 0 docs/examples/te_gemma/generate_baseline.py | 0 docs/examples/te_gemma/generate_convert.py | 0 docs/examples/te_gemma/generate_fp8.py | 0 docs/examples/te_gemma/te_gemma.py | 0 docs/examples/te_gemma/tutorial_accelerate_hf_gemma_with_te.ipynb | 0 docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb | 0 docs/examples/te_gemma/utils.py | 0 docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb | 0 tests/pytorch/test_torch_save_load.py | 0 transformer_engine/pytorch/attention.py | 0 transformer_engine/pytorch/cpp_extensions/fused_attn.py | 0 transformer_engine/pytorch/cpp_extensions/normalization.py | 0 transformer_engine/pytorch/csrc/comm_gemm_overlap.h | 0 transformer_engine/pytorch/csrc/extensions.h | 0 transformer_engine/pytorch/csrc/extensions/pybind.cpp | 0 transformer_engine/pytorch/csrc/userbuffers/CMakeLists.txt | 0 transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp | 0 transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu | 0 transformer_engine/pytorch/csrc/userbuffers/userbuffers.h | 0 transformer_engine/pytorch/distributed.py | 0 transformer_engine/pytorch/float8_tensor.py | 0 transformer_engine/pytorch/fp8.py | 0 transformer_engine/pytorch/module/_common.py | 0 transformer_engine/pytorch/module/base.py | 0 transformer_engine/pytorch/module/layernorm.py | 0 transformer_engine/pytorch/module/layernorm_linear.py | 0 transformer_engine/pytorch/module/layernorm_mlp.py | 0 transformer_engine/pytorch/module/linear.py | 0 transformer_engine/pytorch/module/rmsnorm.py | 0 transformer_engine/pytorch/transformer.py | 0 transformer_engine/pytorch/utils.py | 0 32 files changed, 0 insertions(+), 0 deletions(-) mode change 100755 => 100644 docs/examples/te_gemma/generate.py mode change 100755 => 100644 docs/examples/te_gemma/generate_baseline.py mode change 100755 => 100644 docs/examples/te_gemma/generate_convert.py mode change 100755 => 100644 docs/examples/te_gemma/generate_fp8.py mode change 100755 => 100644 docs/examples/te_gemma/te_gemma.py mode change 100755 => 100644 docs/examples/te_gemma/tutorial_accelerate_hf_gemma_with_te.ipynb mode change 100755 => 100644 docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb mode change 100755 => 100644 docs/examples/te_gemma/utils.py mode change 100755 => 100644 docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb mode change 100755 => 100644 tests/pytorch/test_torch_save_load.py mode change 100755 => 100644 transformer_engine/pytorch/attention.py mode change 100755 => 100644 transformer_engine/pytorch/cpp_extensions/fused_attn.py mode change 100755 => 100644 transformer_engine/pytorch/cpp_extensions/normalization.py mode change 100755 => 100644 transformer_engine/pytorch/csrc/comm_gemm_overlap.h mode change 100755 => 100644 transformer_engine/pytorch/csrc/extensions.h mode change 100755 => 100644 transformer_engine/pytorch/csrc/extensions/pybind.cpp mode change 100755 => 100644 transformer_engine/pytorch/csrc/userbuffers/CMakeLists.txt mode change 100755 => 100644 transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp mode change 100755 => 100644 transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu mode change 100755 => 100644 transformer_engine/pytorch/csrc/userbuffers/userbuffers.h mode change 100755 => 100644 transformer_engine/pytorch/distributed.py mode change 100755 => 100644 transformer_engine/pytorch/float8_tensor.py mode change 100755 => 100644 transformer_engine/pytorch/fp8.py mode change 100755 => 100644 transformer_engine/pytorch/module/_common.py mode change 100755 => 100644 transformer_engine/pytorch/module/base.py mode change 100755 => 100644 transformer_engine/pytorch/module/layernorm.py mode change 100755 => 100644 transformer_engine/pytorch/module/layernorm_linear.py mode change 100755 => 100644 transformer_engine/pytorch/module/layernorm_mlp.py mode change 100755 => 100644 transformer_engine/pytorch/module/linear.py mode change 100755 => 100644 transformer_engine/pytorch/module/rmsnorm.py mode change 100755 => 100644 transformer_engine/pytorch/transformer.py mode change 100755 => 100644 transformer_engine/pytorch/utils.py diff --git a/docs/examples/te_gemma/generate.py b/docs/examples/te_gemma/generate.py old mode 100755 new mode 100644 diff --git a/docs/examples/te_gemma/generate_baseline.py b/docs/examples/te_gemma/generate_baseline.py old mode 100755 new mode 100644 diff --git a/docs/examples/te_gemma/generate_convert.py b/docs/examples/te_gemma/generate_convert.py old mode 100755 new mode 100644 diff --git a/docs/examples/te_gemma/generate_fp8.py b/docs/examples/te_gemma/generate_fp8.py old mode 100755 new mode 100644 diff --git a/docs/examples/te_gemma/te_gemma.py b/docs/examples/te_gemma/te_gemma.py old mode 100755 new mode 100644 diff --git a/docs/examples/te_gemma/tutorial_accelerate_hf_gemma_with_te.ipynb b/docs/examples/te_gemma/tutorial_accelerate_hf_gemma_with_te.ipynb old mode 100755 new mode 100644 diff --git a/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb old mode 100755 new mode 100644 diff --git a/docs/examples/te_gemma/utils.py b/docs/examples/te_gemma/utils.py old mode 100755 new mode 100644 diff --git a/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb b/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb old mode 100755 new mode 100644 diff --git a/tests/pytorch/test_torch_save_load.py b/tests/pytorch/test_torch_save_load.py old mode 100755 new mode 100644 diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py old mode 100755 new mode 100644 diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py old mode 100755 new mode 100644 diff --git a/transformer_engine/pytorch/cpp_extensions/normalization.py b/transformer_engine/pytorch/cpp_extensions/normalization.py old mode 100755 new mode 100644 diff --git a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h old mode 100755 new mode 100644 diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h old mode 100755 new mode 100644 diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp old mode 100755 new mode 100644 diff --git a/transformer_engine/pytorch/csrc/userbuffers/CMakeLists.txt b/transformer_engine/pytorch/csrc/userbuffers/CMakeLists.txt old mode 100755 new mode 100644 diff --git a/transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp b/transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp old mode 100755 new mode 100644 diff --git a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu b/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu old mode 100755 new mode 100644 diff --git a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.h b/transformer_engine/pytorch/csrc/userbuffers/userbuffers.h old mode 100755 new mode 100644 diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py old mode 100755 new mode 100644 diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py old mode 100755 new mode 100644 diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py old mode 100755 new mode 100644 diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py old mode 100755 new mode 100644 diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py old mode 100755 new mode 100644 diff --git a/transformer_engine/pytorch/module/layernorm.py b/transformer_engine/pytorch/module/layernorm.py old mode 100755 new mode 100644 diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py old mode 100755 new mode 100644 diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py old mode 100755 new mode 100644 diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py old mode 100755 new mode 100644 diff --git a/transformer_engine/pytorch/module/rmsnorm.py b/transformer_engine/pytorch/module/rmsnorm.py old mode 100755 new mode 100644 diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py old mode 100755 new mode 100644 diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py old mode 100755 new mode 100644 From 3363a673a6261361cf1a7547c9a12594e3c73879 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 7 May 2024 09:25:34 -0700 Subject: [PATCH 091/244] Removed draft attention_copy.cu Signed-off-by: Pawel Gadzinski --- docs/examples/te_gemma/attention_copy.cu | 96 ------------------------ 1 file changed, 96 deletions(-) delete mode 100644 docs/examples/te_gemma/attention_copy.cu diff --git a/docs/examples/te_gemma/attention_copy.cu b/docs/examples/te_gemma/attention_copy.cu deleted file mode 100644 index 810c66c377..0000000000 --- a/docs/examples/te_gemma/attention_copy.cu +++ /dev/null @@ -1,96 +0,0 @@ -#include -#include -#include - -extern "C" -__global__ void attn_copy(__nv_bfloat16* A, int* seq_len, __nv_bfloat16* B, int max_seq_len, int b, int s) { - for(int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { - int per_block = s / blockDim.x; - int remainder = s % blockDim.x; - int copy_block_offset_begin = per_block * threadIdx.x + min(threadIdx.x, remainder); - - int offset = seq_len[batch_idx]; - - __nv_bfloat16* begin_A_copy = A + max_seq_len * s * batch_idx + s * offset; - __nv_bfloat16* begin_B_copy = B + s * batch_idx; - - int limit = copy_block_offset_begin + per_block + (threadIdx.x < remainder ? 1 : 0); - - for(int i = copy_block_offset_begin; i < limit; i++) { - *(begin_A_copy + i) = *(begin_B_copy + i); - } - } -} - -extern "C" -__global__ void gv(float* src, int* seq_len, float* dst, int d, int b) { - // src [s, 1, 1, d] - // dst [b] - for(int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { - int per_block = d / blockDim.x; - int remainder = d % blockDim.x; - int copy_block_offset_begin = per_block * threadIdx.x + min(threadIdx.x, remainder); - - int offset = seq_len[batch_idx]; - - float* begin_src_copy = src + d * offset; - float* begin_dst_copy = dst + d * batch_idx; - - int limit = copy_block_offset_begin + per_block + (threadIdx.x < remainder ? 1 : 0); - - for(int i = copy_block_offset_begin; i < limit; i++) { - *(begin_dst_copy + i) = *(begin_src_copy + i); - } - } -} - - - - - - -void attention_copy(torch::Tensor A, torch::Tensor seq_len, torch::Tensor B, int max_seq_len, int b, int s, void* stream_ptr) { - cudaStream_t stream = static_cast(stream_ptr); - attn_copy<<<16, 32, 0, stream>>>(reinterpret_cast<__nv_bfloat16*>(A.data_ptr()), - seq_len.data_ptr(), - reinterpret_cast<__nv_bfloat16*>(B.data_ptr()), max_seq_len, b, s); -} - - -void attention_copy2(torch::Tensor A, torch::Tensor seq_len, torch::Tensor B, int max_seq_len, int b, int s) { - attn_copy<<<16, 32, 0>>>(reinterpret_cast<__nv_bfloat16*>(A.data_ptr()), - seq_len.data_ptr(), - reinterpret_cast<__nv_bfloat16*>(B.data_ptr()), max_seq_len, b, s); -} - - -void get_values(torch::Tensor A, torch::Tensor seq_len, torch::Tensor B, int d, int b, void* stream_ptr) { - cudaStream_t stream = static_cast(stream_ptr); - gv<<<16, 32, 0, stream>>>(A.data_ptr(), - seq_len.data_ptr(), - B.data_ptr(), d, b); -} - - -void get_values2(torch::Tensor A, torch::Tensor seq_len, torch::Tensor B, int d, int b) { - gv<<<16, 32, 0>>>((A.data_ptr()), - seq_len.data_ptr(), - (B.data_ptr()), d, b); -} - - - - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("attention_copy", &attention_copy, "Copy function for attention mechanism", - py::arg("A"), py::arg("seq_len"), py::arg("B"), py::arg("b"), py::arg("max_seq_len"), py::arg("s"), py::arg("stream_ptr")); - - m.def("attention_copy2", &attention_copy2, "Copy function for attention mechanism", - py::arg("A"), py::arg("seq_len"), py::arg("B"), py::arg("b"), py::arg("max_seq_len"), py::arg("s")); - - m.def("get_values", &get_values, "1Get values function", - py::arg("A"), py::arg("seq_len"), py::arg("B"), py::arg("d"), py::arg("b"), py::arg("stream_ptr")); - - m.def("get_values2", &get_values2, "2Get values function", - py::arg("A"), py::arg("seq_len"), py::arg("B"), py::arg("d"), py::arg("b")); -} \ No newline at end of file From 9ea62c38e305839c92d13461577119b577985383 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 7 May 2024 09:32:40 -0700 Subject: [PATCH 092/244] New vesrion of tutorial markdown Signed-off-by: Pawel Gadzinski --- docs/examples/te_gemma/media/pic1.png | Bin 2709 -> 19382 bytes docs/examples/te_gemma/media/pic2.png | Bin 2709 -> 25116 bytes .../tutorial_generation_gemma_with_te.ipynb | 230 +++++++----------- 3 files changed, 93 insertions(+), 137 deletions(-) diff --git a/docs/examples/te_gemma/media/pic1.png b/docs/examples/te_gemma/media/pic1.png index b8baa59a07455a18134be2f24e95e51399c6b82c..7c639fab31e8d71c619f8c5cf776d8964a5eb514 100644 GIT binary patch literal 19382 zcmeI4cT`i^+wY@{;*0{LC`iqWsEATT>D59RMMMOogA$Q$=nzN<%81ejmEIJjH$$(9 z2m~QO4ANT=2!t9)Ac2G=_u%|~?|s)@cdh&0Kki-kt~Gx^NW$4??{oHk_Orj=&puK2 z?i%v_BKivm1me5(hu(b<=qG*v+7ATcZ{Pnp&;>1e3<4P$-O{^m9_+9-h4gmVMo2RgY&jpru6U^4 zGVD29p0x7u_%rL6v-UWzI+SbbnMlvXu164oC8tEs*gu}c{$}g+##o|l{m{{>c!|`L ziCG`7G+ut7@kmzX_Xh*FqzD;yIHih{^2eUPs2Dt2Js^L#?5-^8kNj>UjgdL1%vknhFX(+qVDZUuV_! zKYRTXFYx=?>zVyGf1D8Cf4TL1|DfM){d3U&y$fMfWoOC@x=Wbv(7UrhOmS$?N_M!n z!~;^xQJVvSgxt$ipEk}?8}qx-}Z1zJ^c5L@Z^zw6yHWjV=o|RJiVxSrE>a)000pWh5 zddIl$6qrbj*CS`JhD^xJ*!?wEB?>u#u1t!eK3lOmrq!fC*Cq%?v`Z=ae~_22D36IW zsnJ^*>NDlioti|>JBq%_e%VaR5sCkvNGK{VjU$z&L)g-)`ndFTE1Ns@mjxB6!9o7h z_0n|q+1Xhm51O-R+kKEFy2;GPe*Tp!aS^s+(cn92C_b=Q!l@Pc`0-Q}pDt0- zYb28y@|HoXXXq0l5sK91KVtR5?3B#uG%Q>fue7-)=G6Ip_{&qMY{(PU4bnLjJtirn z>zdF+L45meEMLd8E+u-3xElS;HMXG1*~kBb!r}NqLz=P|fzw&%su7kLFC!s_;Mvt^ zd96y_CL?_t_B4IA@+1hzl`U$#cB3kNdrKex{{3l@s3IIO@98$yb*x1vofw2oc>V$= zr)!WiDm<02os2{p;xJGT945zWM9FZIWVJ{qii`^F9W$Rj<-%WcGn+ACo zFoIs4`4X(6qm$0J9pECd>k%Lph#$q{odxXFxmC%%lP1!Xgt1r`c2Mo(08)S}zKORh zh7-LB>2k0iX3pH(J*{I(9Zj{FwtQK|TTzKm*QwQ<(A7$hOL^k!uO+s;_Jv{LJdVY$ z<434*Z!z3ydRcKWx2c%exCz&d74N05x%Z_%kl?alHBei0rES%^a11a^jiuB;LRooI ztZMquH=crl@HO%^!{Hiv@M*9v|0VH0{z)6r*yI9eZyXADyUZz)UK#>YRiChi9BhEryy>&j48$-c#IIC`8=We%By zu(nlHyOcx)dCfF>!ry=ID*jSW!0xq*QZ%gmKS;EcT2a&w&*dYmtnra$<dRV4VNL(u=z|m6trDgx|KTb_3 z=boQv2#WvCF2y^upk8Wp4v8cu@p2`-DjwVyCN9Z#V4;33++Bg*^S{S`C3@z(?)n>6 z-N}~nVw??Frd21K8m;di~EuNl*EL^6vDyb^QG~=BlgC$BbXJ-50DE zM6gID<)VJvGW@IcR*t9A%JjYTr|}2IMSPSry=~ooom8UK$!Mkv{9+XGeemJl=tK3Y z$O+css6wO2#K(nUo^H1;R-G}l8g?`C_HTsX#b=&KSnGn-u!xYPc<-?|YTL#igsjU= z_4%oR*%0W(`3p%Byj2lFD$Z-f438DkM8#;oX|IrG8;t>fc$q1Ni#$p(sv(bpb zw*Y-Dw{u=_B99c4lwaIfE@$%~x%8H@Wu1X-)M!ig+MkI_1R^qvCo=SJNX1G>RkGgD@JKiu=wsIvAj8& zAPT({3GX`)?)PE_WVIE-s}Qs8$Di!0T3gZ* zATVQ7`|wBsW#QUa*)E3BV zyq{)+?B`m3|FkKZTl{@cCyoLOt1x+=V(=}Fm4~o*yI?Ea7qk{t*Lb)@SoYGz#+;s(hW<>wsa-^VCk7cqS!uyL_=+Qd{4M;HDs>($ooD7z#Pqf61D(B0GrFj!d^~#$jS?PYs!$_Zuio_|o-~dt z{HX6v-M(1j(laB_1Yu^ripdLn=2F<(BEC5D<=||7`YF`x$so@2g;0(2th^xXIk54F z63MXl#*jhZ!G4VcplP+|JP!%B8xLv5m(>-8+D+*C#QNr3+Il~EZ;#JscerA9=v%da zUmbb68Aa{hdjkv|NJ>{hhLi?4o1($FTDl;p2_bG@jf{Y_6GDj+J~weBAM2@TtGW*= z@|;oQU0=Qqt7XqWU3ejmrgGMsJ@>ZQzY!g1$q`*I4t>KqZ5Sf-wHX9vk!Dic)`0U- z9nD9NuuhM?vH2%4NF{?O_Bn5ixEnVrOj5tM2Yh?fld-mO&0l!vb)*?n!d?_NZDS4$ zxaTvVnLDB9z}Jx51r5tUCR>-QQZe478scapv<r-8fSW|c^dtM6>kTWW1Z+j7J zpFoP}-T6!(ZFpxkga6Rv3fp+NMI3s)qXd(U1WyRod)kTBZm&S_CN3gJFb~q*$H^NH zk8%8ecfPVr85hwsa<-g{Wq{+pOGmqRQ=Jv(KL4fK>BBg#cICv)>I!_f^L%+<(l!pq zWL6{a(N{e@ACnu{O|#|DUf5;OC+quK_6HuCriIzrKb8x9BcPl4UU%g}fg@RJ`U$%0 zhGd^*hy#K#Tri7D$nUzM^Vc>;4M$r)tWmJo{9cny#_qPXD0pP0nR~b7kS#}u%NqBJ{V1y`Vd>AHCD3g2oV2W#6g)q;~7|EaI2ci)Pyj^G(>#X zTfCuAS(F%1r{k6t`-~Nc<&@gyHo|Ayu6Yv2tOz|Y0p#WG!fpXDd7v_l#CjFE`~N|5*5bVn`MW!!2D5db}QUUISC zA7(sQ=@wO2xt6f#u%3`*DZKgV{ThMY5o;tgZGoKZJL1UgW4h-(U6^9!h{O?cnF~cx zO-vEOm!DOAwH)QjHN!CRXm~Wk1E__}el6TK@Dq?wbtNDg9G4Kg%ecA~b$Z9Z@>uUx zY&x78Kle1UVQ)E$P@PB3e1UqY(pkV&3my>U>u*Usy<-=;B`Kfo2MdMNP}ZWnI>W>8 z@4w&9|4{IIdWGzABWKZ(avAAnH#}-H==YKzI~L;o>Bdl{Tb9p*&U>-wZX7MrD!^I^ zf5e+EQ4!hj6R6*DzediNnP}Lm+ctq3Wd>WFj5-gQ-<&?3w&ELM88VmbLGQ94uBZE4 z?Z_tthX(mWi)s#|zY1lUWV6Zd;KOY*2YLS#RxwzKtJWrelpjXAhrH8GhTGiuJP%xL zl(wQ&^eYnhY5bDpLY$t7skT{ig)O*VTuhG=$H~bWM~lE*%_gZkPOJ{} z@98QkC%54kodMOk#J4n&n8ch@7ND8T-_$_C+UE@*^Y={g_VnppNJL^zU$EVbj7U_g zYCV>0KZWw@ixTc1>BP1mA-j2N&ql+HJysFuz0cM|jN0r&0!LEpts3EZ-#^c#Aa<3# zP=3aQU?=Dnr=>{YQ^%Y=ioI6br5m}iCPxS~Asvf-tu3%E8T69dc#Pw^(EdCAZKUgz zsomYzEATSP$^IiEvDu-Kd3Aw4+7|CWd}qeHX4ov2uxKmzxqO z+Q1shZ2=)B6jQI-MYk7=fgmo|QOFLT#?B^Xgaw|G1IU3{_BL~ zv~YRp?KrpyMh!109e>PlM7Z7%UXxF@deyDQbT{pb!)T100QyMfwvcL+_^Y*?wO44% ziuPwp@DskAfG6?sZ3|1Ko`c3n%6hyO4bwT=)ONSw_ux#>h+^yw7 z+}wS8%Ft2EYH4MyKhS}wrj2k|Rsoiotx3BQr?Z(+B9zUCp~G46$0$1wIo zm5`d9#MhPrPjZD7I z5>l0sty}KS>U8Y}IgjQh9UQNH*Mhx6ynS`AZqQaIr;Js%AXt!4_Qc+!&X2=oT*gSpiojexJZMk~bkm{8~Nr6nr z6BDG>HPR(y!91+$b>KQ-fkhn++84+$kIxaTrzv7WneI)IQc^J+Xf0825xN zQ0E+{ELP1MX_AfG@!piwZmu3K4!D#8&p@~b)m65@d8^*1#DG&h(rC`E$c=eo)@xn% zH{GlNF5dEPppsa~N9EE>$v$%<3CNu2y;OW*vExYMHbx^o0GeupO)g_~I-LMkvDF{8 za`uq@pv8t~4Uq2f(d?1sVI}=3+e@uoEh?fnbx+5m$gen_wxl;JFGmd5u-Wjk`(Wb~ z$)Hp4F+w8xbg)Y0Dmh`5K42odrEQbT>O`P68?N%4ciNc<&7F&BhETqvUReO^ji`bm zWY<6{+~+b~Tjzua=A&jEG8e4|d3Yrf<6~x5+gQ5uwn5%zeL23uR*W`s35R06E@;^K zv$H~$wx!XpL}bQu$`m4XCnfR{?qummZdL?k1&6fnadr~8YJrq#a|bMYw{8VjKQ(W{ z(3{xKQnVTomjhSCjOXu$n8ahhm@so01XK~GRw^*;5oAzd{ytkuYCr8FY;IS;?LBhC zrYs0>XT%^{OT0CKlJMR&+BPu2fYnr7n~Lv6RLiENT|GhFwdi;W zAj(;VX_dRjCMT$Lj>8|f{yrA{Q=w^&5>aSr zBg5~TT|-{=Q$hmTTlWYy4Ji&G#YI^U%M|GIE0ZussCesRh$nO2ISB{9P2qY&W!YOi zC>Ia2LA%RBp}G!Y89ga$z4Ye8pcVokF}sY;cl>coAig7*5PvEFMStC76*3!M{@C}- z`P(iH3igsX7xg*Lq_T+AO9M#fa?_lARw_K~GSe`=941@ZswVnis29CXOOK6Od^tB0 zyIWplE5>)p1bNdKkm_bOT+Q!(brh=I<}-D?3JIef;&mBMM4i^rI+Da7M^Dg zqY0_<=eq>x99=^RYSIBa@V*?t-)Nci;3@o)`VCVQ8Iork#l;- z6b;TKheT2-OZ{p|sS4Inp{3vm&Z zSdVK<5zS?1Xy!E*kmq}uOW6&3N3~IK3rJbHt*L+i#_}Iqy6WPBmXuhY9Ip!+2YwW_ z-&)^<|LSp$BXWO<5AL=q1r_zyAlqnVUF;Z*vVWD#@SAvK+;{154J!c?-Bz8 zgV`(!++{M1wRZ5QnHDLvqWf7{S(};|nkMPfObTYtnQ#>JX@x)_9H)dansMO%Bfo~z zI$z6L4yxm|xqGxta$W+q9=%F#!$L{B%Wp9F#K;^A#=Z<8)rCbvgyuZ(nI_ENZ2IH{vQ<+5*6>=dL{t>t)gKz+IRPhq0arSgPEZ}@y3kq z2lGxx?zUsf&FA+YZp?wX4Z)Lyq^B5GZ;nqm~tJ65m0j*gzjr6?#k z90;cU%rBHU@J-VDeje=L;{6ZL@C4Q;bOnygWmFj;+cl~BXFkg`@*YFyLY_3 zv#G$L;i~aniU$|3wp~9(UR#snJ(5q2aoo!#Yp2~j0MxWsSX@a}HH9bQJ`A|iK;vZj zpGRDz?$Pim?GaGKimlJaOj~YH`rrR7B;G>HqQn9>D5?EmLhE56pxC74)zw)0=D2Ni!CvgBZR{ujCA2fDs0jZEHJ&a>gY$a5=0P zArNTjF>RrbFf?RQ(;qLL1HUp{W^GmjDgb&&6r`&^I*ux|{Aydz%(0rX^Iuy4}L2_9n3h3&-`}!`rujTjIYP`a!m>AitJ?>&6r-rxUqze&caSXNI)&HIfWptS>y{?OdqT+~PM$6DYTJ+#Y`6u$P@<|5IJj<=U+#I-e(fwNOzNdUGX$Q~U)fKLN!P>N3}LwfKaW!HrPt-xwM(DfG7rzOzyGE$NqgS^s$3|Cs#$&(jWVXA+&Q zO>yiu_XC_GM;t<#3jA-d>cw+DC8l4NFPm9!^~gAnnPU#k1;v`y}Rhi(~l3 z@xA^Gn}S3h>95) z1b=@)efM4|_+67-=vwXO7pVk2OG``7>6S)vzKz!QpcyJ;dzp$>00hzqASNY*{zIVb zhu?otoPU_4|7|#{W}mp}f3%-(KwMofIC&AS2kgFV+0%ugxTPa8Nz~zbbnQSviMYD& z+1uRau>A!QhVF5N>wZ;TT?E4%s&pDEvDmc5Y}*2P(gzGiSH6jy8_ZlPjLE*ba>;S` zw&zq+t*gGc4!mIxY7#EY$UH^76KXc@|Oj_EIlT30@bEI|`GD6A3FEylw&eO+i&3PJNRTTN$=pnxs==AS1g2 zBy&?cZ*+Ix>?m+an3Jt8$@|d($oodCoQ)=tT=wxS*xbTxu*C;m1x!jxa?9$afO`S* zUvLvtDuX06^u*br;CVz0aMC1gSfk7=bEPtY7YzS6U^o=KTwLe`tUXXL+CtCEKU|*e zmwfab0LA@-#{E|r!I*D}QK_k4aEZBnD_5+?lk=}w`(L}hpN`2Mgb+e$Bp=vZs8jdr zb3ms%g>uq&SeLj^@aX$t2>#SAy`_deO1FD3o(ywi)2$Tyc0KNE_}<=6);^v2sPw>Y zT)aY1Y4H2f&KT_0r_A=#>2}zI;yhl6vEXklD-l z_`=nk?%U0mG-$J>OmtXg$^Zu{Gmb8ytc=j8wp}*plC_l}>V2lzhChLBVei;Xdrjf4 zRASIOmIsFBzrMdoVq<_%le^g4%$N?ngoH%*q65uODFz_{Vs zC(WV!Pw;WzK532gxCGslZOYmgn+K$CLxV2m0{qT%N$R~T93VzJI$cptK;@Bloe>LP13X= z4%(?}EluZ38=IswH)eacGgQV$5|u~pcs0k%>Yjc@`1B-TW0eRG+18fRpTmb0nz z#05^*`V9}D--lth`o?(hCuwQ=ZkXE2((b}b5{|r8gt@>8*T8q9SQ^+CdJLFCepf4^ zDzrUQXuBSO{F%RmW#oNt=qFykRP`d3r;c}Xll;Waw)%}{n>{ntN{GGyfyXy1h1fDu zQ)AflWKRbE3ApIN56CRdWze&?$uN2yw(TNehLe=DX2j?-WcCTQGL%+AI<=`VWy@8C znME(t0h5D{ojp7){S&VtZEXVXtOLpkh>+X|>&h5fJo6)!qgao?C%i}t|NLX?vUOeE zZEcE%UDZ6Aq6y_Ic5Q70O08vq^-#L@Mm21iY$uP_KXvp+mgunUI3zv-x?XeK;|x}^>G$-flVVbx zM$#Fv=*s%8X;gN@(m)J^jfjX^>0fc0%gQ3$FNp^nP1V3YMf_C{YPV)?W5fLWh2eE_ z&j`YelrMNkkFp35s0jgBZ=G#==z^LuvPaL+-%}!iZdj*3S0?7~rc^t5a<`y&8%yV9 z^^XVazR@K|(_dn9-TaJ{?C$v&ls%1!pSWn0RO1g*LgXr^M=#i&caT}@!;bEfc?co5 zs9$>GI;ItBfEMRlOoMZZN|o+z?M=J3J_&_klqb>Bfevp{<%!t1`apxOuH`l7d3f9;q_)2>MzKB0Nz0=*Gn~G)N;~|ldk*Y`gt!N5pU_qf)OTEV zq6H}%t|aGjcIcz5$ELa&=XrDejGQ+_2q&^~JZ}GaxG{=DW-rI8%@Gn|?QtB535c(u zp_09ofA14#HmZ#=Iv24kq^TJxu$>$cPwGCE>E>2uUte5=f3#HP|H(g>PqnjPUmuf?1?oabT*HBE5c-EgMCLma?fM4exo>udCI$x{Sw-!Ob_!N(pVJF^$3s@K82Atjot+1%f1- z?|dkDw0>4uhDWA#TklIxppq}tA_QgYqpvrVcq?K3l66y)99{d3zKX_4 z0ngN+A~sfiP(W5uV`_&B2((GRJYw|H5Jl9`K#)-KQ66Bae12JC(S`{ zd*NkrR_5ubHt?4IzhB^Gd$Im#TA4Zg=R2p$1O)tBDt zvskAcl(tXT!nO~rJ=CeQ>YQoll7-K=sZhvFY?wGkDLqj9Rh@uPpJ1$(KW+=b_W&A` zUFns{x)cAO_$^WwVUIMEAtdxeDFl^RA$x;tpLeGmcA}&}U)56VPv9Pu#Thf`dfdle zO7M>|Hk8%qsT^C|J1}~j4j2342RwT0tvX~5?aJ({6EVK{k;c|CzJT4T%3qvZzX=M<>cP1S(Eio~i z$$W9#Lw8Jatv=gn*Yn%w^%!rVAA#p$vH*+G}(~d6i49HB7A0C@)BD$B=g$YD5#~a5?)R0 z(M*gNVRwG)6+bq25t}^fpygY=(6m0vE*`@lNL8B@Zodkh55BI0Ftv~UyPK$oOg?P+ zefcgt^E)jxZ@b4fw|?(?7{p&|_*6M1lwCiesl5B!_P0y9v_k#rd(FNJw1Wmt{^-JE zvc3@AQNd7wb2q||rWp3IMDs}1K!UpwyhZ$GcdVZm!ct}R@0>d-B~WiT0bC3B@!XYc z;1!!2)I}m(rJ3dcM4)0Kv;(4v|DlfEv%qGjHR)*A=gY`z$_u8xnrf!+a zWGV76;9Un%L|PlC1|n+-HiJey^NGTE`YATFA`#OhGY$sJ+xRs+geg6p64tI4ZtFor zpu%%(8g6Ex+~%fdbq>W5HrKKUnOdyJN(MP@>Kts6G9SU;*%ZI6t7|hT#Jit&-qSa5 zDpDJ2@u%+8@fo99T$7Bq3@M0jIfv_2R9oW>S#g~+i@Y&y3nEl{f$P{6?K4xfyof@7 zZSI$wsvYz$Vaw6Wbu-dOW#h0WbrOtjj)>3W-_}Qc0drqeX}VE*e6gvrt;03;tz;VP zdcL)(@QSXo#I7tn>z6mXSKyxbG`^d6ft4VT+fN#F35VG*-)E1~^h_~TQZ^%5mqG9_ zwSpYmuP1~nWlL_oQB+c!I~IA!TGYr=c<4i(kN@nSCgr$fLs}s;3Y&W>h!Qt_$g$$` zX>Eph?5Wx={d0?GX}X??+i&P!jZ!S0zZ=?Bkq>Hdctfb4+hK|N6NfJ|T96JQTeHs6 z_6Mu?((6VW3tY7lDfKzR+lNwvTQ1dMT2-dYc%6Uwd)7o(kP+x2yxW-&4EicqXVG27 zfH7iYf;P|wXBsp5xJ6jC0dOm^B8`m=Qi$tl%MEE^y6(7|%C&YG=?4^aC!P+=m3M!V zvAJ9~>bGd_|Miq67w&!CMdFTDvWe}4kQ^;iXfeHQw6Pxg$mK!H6&uc3Oo6jkZ;M#R zk|?Y}S!GH)V5$d8o&>^S+uOa+?V8spHHo{jMzSqncu*FV|Kl6;Dd2s(_J! zy<$JS5?LoX6fJw)mAjXHTIat1xmc}1p>h)L*V`jSL7mvK{`0|>W@W^CgS1Ce=&r?( z8sg()lS3j#7=mE^9X_m~K_Ncu0-uxWVBeLmcA_4YlMnvn+YQ|-5_cANq4H`I*kghe z%w=acrO9ClRz+nmZ04vJVY5SP(M}7aI{s8L4#0H&Jl9A-EeS4$MhbhMV0Q$I2PK6q zCnmdm8J()v>vHT9mY8Sd6>1))L2_w8n0cJ(NhNRT^rWy{O|@^K`<|)?Iz~i}(p_g$ z)4q=?s;q?J8I$Q=1a?-PZ}8=0lZkzV>ZdaR4314Edn~~v7~^U@SA<^Rfn|iXTadns z1Kff*l0dsZ&Xv=ju-rzlLma|ZtZtxOoOh=7uycTD-gI(99-@nYl)3eY)iUK*C*NnR1FZYtHR;>k-&NuZSy zJh^Me);jm_PCs_e5N2QenaBn<+Z21E>nvQmeGMSUtbWeW-+rr&Ef0 zF$>GdWsL4O;%GR`D-YSH2nFrn7 z4Sc84kv+^F^P8L74a+LmN|ETTztt7lw6yC`uRv^fo6{Z3sx2#NL4F%v_ayHaKTzTH za>}F}hp`jz~Tm z9DgYQj%Q%FZ2c8+Izar)r*Eol^!B?peH>8Uw9F#ZSpV$dv&T)&ZQGC~DBZC|CRg7EbkH14Ar8s#Y?%Mnw9 zwLb|G7@gawB^)Bjq{t&Zmz_OD*o%5^{m+z@@+BH>FjbY42Ol=8|1FMPjqF5 z1g)Y&*5Aw)-VtAq(hho>FeM<+Wo3*|LNv&T>Y59`X>TQP^WLZKe3e77vgxiG>s2rR z=fKyOMPUg-j#nxC1_#&%F*W+8(v~Bck`1qA@8ez&JNQTa2tB)W0ic^Y%x{+MLg9jz!=?2u@W_srm+aU#wYC=VJ#sN zHO*Ro6<--dnkYZNJYs*U&vMvs4cAU!?Ql*<7c@9qF2+Ju`=kI@1+e$;dWX-vYXVlQ z88h!ne&)Z6pQl9-bgwwnB-N>p@*8Clxt!CYqR~6>eO;XhYt&Vc zgIQZFzgzZSAbG)d>jBw$f!hNp+ne0*_4i=E&Rphv;q99O!lI*fBl$dFUP(P{_Z~U_eAMqEe;v`AhB^a(G!tg(wrm zvY9^b&^(4S!WG5rkrd{}Ju2gnbUeqJU{t%qVPL>`ZH66uZw{OwfFR!lH2(QvTWFjv zGfy|V3yf?bt5*X%gWv!X)R8P_JIi5FDl;qKe)$O(2Xaz~IHRTeFA%`_`k{t<<3zU2 z-kdH60<_J#(bu_U^89^#R>x8|%{l3DUo8rDXXlBy6JXDHfq5Y3# z@&Q{}be_f^0eas*8`A&TWdG082L9__Z}Q{D+ySA9pgJvm!QW9KZCB^0axMJoV(4sW--x`-Hrv@D zOiH2ttww5qX3|@x5yGr8B~z{M?V#3^bg%ScM(P!!l8cih7wRAx|HC-6s`IqYuk4Yv zzEYc5f+up;h*+EgMjW}pJH$!|Ph)XAkWF`6lbO2s^XI>b2ggS9@BiI$9LWFWN6L4UO08?zLp18wj-{r<Kf%9QKr@ zo#|4Z`@fNgw+a@cEigYI3jf=CBq1C%4)-5!J8&5c?A`@J;Xlve_~(Hn|L@w*|H^>= zl@%nAd&N{OnMG&Vu-(>{nt5l(;)oKz>kj{3`H0b|-W@?0))YV*U%Ch^c04LuRU~2a ziR9(wXn@NWZZ(hj*ZRkPnvoBg@&6m!mJS6NVjR-3U@~|vZZkfpCC^oYBEa1cm}4Nh z3;-CE15%SeXWsSv%>S#nxoX>W3qf)7`m3kJ-^$uIDRI7oIqv~V8Q`c#F-Rl;GDaDl z0*K3BR>x~gi?XsMGy_j6zZ>6||K9xqXMW3xXS6J8>#q#Lm(6mb1>iG00K};BH^j*0 z5V@?}KzGg*h83jGS7h)nwy~Q3?so69Z@VHhdw+A0z9p(-E#D)>qDAkO@5_4k^i@1Z z|AJ4>M%^~Kzw`B_j(JVEI}jKFQuF>Vd;n;9L)KqoX76vtvHGuT4@o^q&w1Y6*>}z@ z7No$te>C>ypA>B&$w6got|{+3z1L2qQ#=p4h=I(0oZ!Xnwom(Mq)_>sJ&cs}H2R1* z@OTKIeDkJTPBOox8kc~J?*6FmeIU{L;f>$RzkQ%BD&^8!=7Vw@iyF|6a%ykCZ;Nsf zR+|8+CC&h6FG>9YSnU60C3B$x$mej|`J5And0zgf>qS#BF>MytX;0t}$Ewy#Oi8jBzQ6E?>20&+?!t!2-_!G=m* z{E<_VfD~Fp9#}W4l1SfQ?GyK#Z31EYyv7B#*i53KWE3!0?Eop~+4^)GD--* z9Frh3o@uG6uO^;0%K`HmI+yp~d@yi30{R2J?@`O4h)bM4%&zLc*x^)W&9mh#AUp0X ztwf6&CCf&Ua4A5qe6zW^NoAb6^k5%{ju%%ykKgQItz`%0-|K&<^A>7=i@af*=S;L7-ejiUKXr(%~eE3j;PS zWvmL*N=wC-u|gT8*dZa!sI)MOAC#Md+*&C@Z7Y|y`%IS1-%kIyA8$_HbMl^?yw7=_ z=lw1W83Z-8G=)GQ(BP8+ryvj`92kXLjlrEnUmgcsjB-x}`9s*^eG6c+In58@2Z1zK zZreDw1vq5F$XKk%F5Qm6VPZj^%%6YTh~w-NaJNl z8uNkr{SMde3H4O-EZRs*+?5`fz#hYPiB%UkpnUatHMeiG$}~kA%R)?Y zGL8T0uMpfIeXf?s=D(gQn|b-H8%%}P#0UJ40K^0^slWREguxK@>hMK7xL~9?; z-ykke^$_F}8_?C=Imzr;R6S+<> z7=WN&>pnLklR44`Ca5$-V`C!*gQ@>sIL%wA`{-eO!{d@*23nis4tE^&9cDA|2)&m> ziH$ig`ga%0mi$bp-qXHKGJ%F?@m7`&&-BlS?I{lgAuKO1&siz86Gka|{pq)!y z@Rkb&ME(VR3Wpr0xO-Z#IGYEb9Pw>k;3$pPg};P;GpKOZEUFaP+NIfkJu}`HHZzh| z>+I&{h8v)_wzlRxdvx_J1Un2@#UT4MK4j8F-+(THD#pQwS6$>bTq|{#wgY&&syupz zT^0om8R7F|Dy*}#h-MxyHYzHrx4?@{)vqGolODYTKR&EL{_~(8^fVu4)W>@$0^eFo-sa}8V2)uDV*Prs_I9SCb z4*dD&Ipl8Kzv{Z#6LRBvl_ALn3w7t^q`;1!Y(7+Je%;=WY4(1jZBROfW0wn%aij#8L^p zcV}sxeN(3UU;K7~U1JD+Szt4D&j8G9kQ=67mpLS^ zy6x)W?%+nN;bjbpmB{w|>4_Y0hyZOvT<4pBto4d9F2PUkB+?lS5vV#0z1muBfiAd% zZe7%ACt)EWAs9TKi+$cOA-}|+>cOjt;*u>#L)ypyHwAGtT%%q48 s>%!+h5{3j5G_yYp75~YAlE^Cwvf$~VzXIbFcv7UXQu5=E3vV8U|Gc1kLhi@$Pm5T=AICVciG9G~ z!o+`W|Co;d*=g~2IPpAn-?u+maJYJQD*2DeVv;6ej8s=>Cxa@=>C=aNZ?&tzb^~u7T#j|7MI_|B!j8t9jm=OM%w>N7gR*CHcO&Joj%&cXDn3*;^XSpi*gV&nch2SNkK zV^&UT6c!_5n@K_>2b-p7VH!$O;M@9%#{KtiTU8j^em52nxSw~jaX+vK_Zxwc0zZEQ zUNZb~yqbAf^vCh2!vDT~|F)VO91hR(HY`DPqt`d|le5aZSnB+HuFm(DTl_YpqPW82 zlNdBPlirYF^PAlKb2hWTMc9o?V=lEUy&HRce53kz0p)b}=o_vFTP3KeUHe;$m?bY!_g&Elw}3rsfFen_GMVI!>ALkh_a)rlfk=&7}C(?%}phlLsTwT3tm9`$ouC?f>rQ$Erom85X4 zOT7u6i~YkX22_^jYa{~Ew~nxsA-^>eam8X#2}1!01a=J)rTaKW{$3fgfP5!d6IO8| zeYA4Mw7xx?6DQ0WJW6g_XjvoJii(mB*=2<{Qa7#&2^;sQG!eEB zX!CdOtYKI02pPc=^-KNM0>W+_t0NR@)o}bF#3X9=Yyz3LW(*6<6f#nmQR*m3N*@@k zv?~kyK2By-6TF=0?4Y%GmPrTASr^oe;!X#RD6ORw! zFE^N;9IZL+gSK$SPiX6JgoG{NdvJ1bFR^h{X}b2NJ3d2HF774hV_f>eD+sq5 zb=j%h71FR>7w&rus}mmd_Sja^vC7p|hwE@j1>$FC z4TaB&ibyGFh1u zdja(r^j=OL(W!ZwJ>^Zjtt~s$?9LwgTox^l+U72NUWBuU;;4}*Y5irnG~JMF6||_% zF-J?U%bvmI0UJsx(SyCTXdI;>TE66aZ}rm$S6|J%S9ilx>uH^2F79QHNSY+e55QHJvj-<6-_ofUx&eCC#s3H9_I@?9kFreE;0~y9PBh=J#)rs ztVtVBU!8>_Ob(7n=c9ZfG>(R#z}bYuqDMnQ!eKe=zs0amjr2QRy+~)~|GF`R(09}g zh$K9}iNRdvwnY&9-+to3g*A|v$H_>c!jb_Uv|!mx-ot2aMW}-y66l{6l}0xZ9tbKI zZO`ovC|9mZ&n$F0g!Y^R@7_V=Xo=p7zwL(-j3T%Xc?EMFg)3!Sl{~^`h#0cxq2`ro zB}lO%x+9I)QD&5}dNV9I2d!I@dGY;f3e;2aN!Iv=E`~RC$!~c>(LcW)cWKwTE#1jJ zmgk#;_KA`y#-Eab5M>JwWuQHNzu_7|{KVPXf|tM?M<)Wqt>8Jk8QtI*Q}8_l@dE9U z$aFW-=A&@qhFZ&o#1aUhnE2#kk$8VRZ`mtuemj z7{u?h=J(%<#qJ~-3(@;hHZJ=Losd>?<*F0cKB3hl$a zfaHoC*T0UhOaM;hUN}^Bx|Uk2d1Y6NlDp#j4w~?v)ui@HWF3jVnV@fxQ~h_ zRvzt$YBy2sDvyh_hZUWB?3S@R&@#3!{&pmwC**@=5=;-tM-CBw3gXN;H2-QPI;YO zh|({~u5Y{=JpoUU|H`GwSj!mMd*Oxt0CjrnuFzUKEVNyogwCvNJvZl1QtqZ$Fx+$< zb(aR*JU<0?_}vcDoHrCVNZ0rco_AOiR+khUI@M;PP&8)5f+w$~A~7EGJIBpKj!F9) zJqaycu~>ZbIDB%Se_wT`JjQTTGMm+(AtVE+9>Ie#{(j2kMw(8kYC|0;ucN)dHV}ul z=nh~mli+RI>V;lg;dh_ze7l*meMmMMH)JtZUgr7kwf7O30Xr|1p=QA{?nhML6C&o3 zWmuw-mG5O85Vu47Y|(+(<)y+PUuV`K7rExy*N%%U7 zlc9Dm(`j=Km59yhFyhDouDCV~vg&z-@jG?QZVI)N6%SreDgR!gv&zSx#n!ogip0^? z1un!2#_5lPIP0|FUqNY47g9G{V4133U-cEre!DA{O|D(;_wl7*i6Sjm?{fNfs?Bci z7RlSR?*F=p?+7gqnCW9->k(C`#}<}~YTjpjv0m5YpR|9UNvt+_B^IB1qKs-Oo%k~O zMbSB54jvBkz+paya;d4Iho=pK3YNMaTpA+aTizK9Ew`F}rb*sE_A>f>8H8f^htrs~ zSGLB5Gf-Pks0D~-c$wtM8E`Ie{p#J7|E=>Qb^q&b9UIUfxD9lNA!`_QcK4Kd^gF%3 z6W*-&b$%k=<^EALe*(sGmoDXzRmq)v0>nxGvq!6Er?UCM9DE_v_1;-!Cn2uPS zE6~$XZT{v{qc=4nolsAQNrAAT=4%k5f#awSj}Nt3$IIg6O>vo1-nvBWAF%s*YXLB8 zKa%FX)q%90BFEI!QttQah28QfoObTL&E<^CiE>%H9XSpU!;8pSZSghsNR+Fb(Xoc6 zjHB&TJsM;E!ggblxCz>=?PX=-=NAS<9CT39FH-sqFG0l$Y3o{jr5Prdb_Rn&tkI`C z2-7_sOOoXq?56&Q17TIj(u=-HvQE?89X7_SDZk(-3yL^PloKzzy_%NfG||Dr&A6+^c=2Gw}McFbl1@=HQK3s*G|HWGrVw-N1&3`r@~Z z4hlRvEIRrxqPFs70q5<_J;A(eqGJhjTXQ zyrs}|R>fvP5)HPH#O;#3X`1V~TESSqF)N+6EV0p`!XOkH)D%1S*O0Y)tum0LZAf6C zZ&mOg&8}Im@YiU@egsbGb6(H4{k5h34vWuf^crHG6|q9cziacm?g&jbJ@e!?7pt24 z_s)3@^y_kJ)BAK}MD0N|7Mb8v+8~u=Re4sm^fY+Fy|B+ec``K?SnNt zc(`|I5$YQbrRB+ar|hkoxLuvp-gUwWRS^*(+I(=?^0UXAM!eaui^1i^^tNPV zba#wAyG`vhw zs6Zx+x*B$aji{IEi;PKLBfmi83v;>8NLBXDA|`uwt3e(kh3O>CnBz)iw$MsX-R+(< zr}g+!=qno45YOxBZGN0w`F)a6H91OA%x!5)! zVscV*XYMXz?}A;(=%{3tjcT=6SNKj8k#KjG?)_#c3PZz!5amR@>h{ zSr*#;AX=1y;RGd6j!=kMiT zR5mt&{cw@yE^>K!xUBSOuy#=IKoxgadTqan>qb`w)5}(L@mzs^`yfDx813J zNOd@j9Z#m3^y0l1#z(6rFRvO*N(daD*jLh&aG8Oe1}wm)76IdK?!1+jJ%f%y^+eur z(R5)RK_k+u(udv`kb~?(Tr?Q_C20D(fh=6N&cydugI@gB9NqAtM(7*P%WD}7# z|K^?l<%O+2&8+sK+c1b5VrR2KedA+2>P7=$Z7h-YzQyD`Hs-QLt=4__Z=iJ0a0kh+riXB0mP0u#Cdqc|1aIhLK`!;wzwfxHVh(WHOo7b`r()P0N-Yru=!2j3+ z4;(gtgw~jx0^9BEZ=;yjz$dy6c87cXQR82yf6l1#+&N0I&g1+|R~AL8?f@Wg4Np1@1qtVmIY? zNE^3-PMjO9 zykKpWknI#z`kUdbu(zG$=C>g_d;#??;*226~o zP9v!|-q}SB(>Wp~hy2G}FH1*!&ybCt>fNLl`JW3}lRYPcI<+^Yu0S#MsBK$y78$Wc zE}#l4!|14((Y)YtwewY#pSHJBF0lp&;cc3s`Y1_I9r+Blj*hI8wWAm&gdhi3=yksO zFL#ZyHb>QWUWq{-iU%OSH>veo72q-KL+`5L{O;OA=0Chhv9`Ccn>KQlnz*rtT{HT6Z& z;NjW>Tw|x%!B`tcc=vbjCj*Q?7oXvknI0v@m1a%X8dktaOuSC`*U^_wNa}b!nL(tS zvz4ihYdVQ8R?uIxKSGmMXl5mi4v8CB)I8+P-a6h3(ll-Jt7U!uTxO*k@h46Q93N@K zCTx;e@In*L$kCxigXNkR2_vr7cB)((h@};AFX?O>E%4zXopOQu(Ljb^R2c^ z>$*ha|%W?A)KDbvVb*Tl6wp(HhU&Y}v{IUf<6#T-%AsS3D!NFdz&QTXS~19ijV z2G{qwdPJhj5z6zARKD5^VJegkc+^ax{C5~m7FAE1;@kA!@TOWA3QlfqVBzc#&;BMI znf-YG1d77yzyOJSVlW~Cj+QHCM_|0(+`I)61}hD@S|;BJ*@2Z6R(-l#diw=x=O$@1 zvazN@#a-S~gnIQ9&mrGx{Q;X1iR^B?p&m=A#v}}h&sXcB@K@f%#3b#T7NyajXmo4i$gl~pMYFgkR zZ8Dg}^<9m&dB1nk&M*AE2$WOWRXM*2D90WaebsPhZA{}ys(gBhuVM5d|77YCpCXa% zS8maFaWrg`QnmJ=xK!G~+PGo5CKt_W9R8g2k%4kF!6%yReKVxSW%eFc!DlX(7Z-XB zm#fT&O)nZr#DL0Y90d85mBug!-O`E6=I&X2*)v$*eJw4k4U7yB$gI~`+HHuycwiK` z6c^3T7@tr?ZjTMmTC%A~IfE!u1ig<}pXV|*mo%b&>&&qdG?L>I7@&x@+;`5{E@HvEb{H* zLz*9`=v-D@dVwNmmPFQb8s}{kL zh))-@U7~qSY-CAfXiC*W7!i&*nE2R|hT!Tg4_12hUV(a)t_ZpLb%E~L>h@5fSyF@qj$-L4%TORT;XZ1FB;+ST z27=tDN&AV+g=O4KRCtt2`%}h5c*0zNr`oy{RO!v^;||07iAAg|6&@)&NH@Zt9a|$^ zDFbdv{ZQw@BuJlb%qGsfqMuJWmsqFi!{-(bTc#TEAs=zXG)8nX?|Bk*R8sof z*byK9LED<5&n@o~Z&Kc)T1>;mLnBd~9o!H>UTnyNRZ|sk-K1`Y-Spc>!h9dxZajgg z+O8ncKFBoCPnpqp3vq-(Q7Y|xueu6k6xIuc`a|CzBE!aGo+sUOBA;+PZEEI$fB(^G zr1KX6SEUnjf8INJIK9(AGx=QiUoqZYA2yi|GLh9q4ZQ38^*`4c;j7F4dhcnuyjZRw zO+h-Xxr~Bpi(8ah5u}^W&2xux!{!u`xx0aFrp%^1q{x+?zV<=X`>oy-FS0To72;Gr z+Fy77cD`!@V}_IyxtMX}udEU>@2`>iq$~pJyt9&>&>uYHjbU#dLzWS->fB}|w;x+| zYOqj0zwW?Yn@Ha}1fu*!vBC=#mU-~7+W2J7OpUi`2i>NPrB9tIUqAJ8Bac6=s)4BDVAM(OA_mwAJx-JR3&_XY{A! zR&rjko*iPTuT9DlE@5d`<~g_QKCeTX}xTB!7!3yUt=T%TX&KfU24v^XF+M2XlL z46TqJ%H4$vTe&P8>H*R#93}ph0Utvq$VDeCML$kV>dhq4CwFsyuLIJ{y~w)&zcKIz z2>*hEZ~LMHChCQCz=;&PR;C1ztS`PF=^vM(;OI7Hm!5liKH#HI|B;6dNT1b_NJcf` zw8+?iOjhp*fx}WVYg1oy#;!p-rO{xSvxg3pO>SKB48h!MUG;3ngJVZkiKau~tne+b zbz>C*n?x>FdQ_U<5^XYraLy)>GKRnnsHj>9qe792j++y7s;agj%r4N>*aji@j_?MhCE!a&+d~HKpgm{Qb-6U04Ln-IAXB~LGvx6(uU!l_ zvDhDnt3?8tnA_qry+vY31QCodJgSl$V%)ZEfk`1Cc+{M(i~7fmfHDJC&1tle_T=q~Zoi>KuvfNX2jhdo++BZ)u z*Q#JkQtFc9+xkH0y8a`~6tZOk4qR8EpbX0ca`Xo%HS%E#dUD1uDiKO_7tnehZ&H~( z^t<28^6q&jX~#$hC$5yam6Lu5Pzm_GLDW^+2~y)Tq5EM}W@P_}J98;hjempdy4P4- zP#p^Mg&Bw&W!!T!3;Z)^IWl65XsQ4N8fg>v5dtU0aY&*(`_J5iJxe=h>haB?|5907Fk~|G5g~Q}<$bw|SQB%Ad24 zqmp~m$vP&qsZrm}j9fhtlU|!h!aRJPFwwqq(fu({>b6`RwMUtTwPWA$zlXVXoh%0_ zNG0$yv>Q{P%6!866GYy+$@bQBdem+)Y^Hf*bHQNrv4|;*J38AAtu2f080_?81Z_4hV|U16dGwZv>B&oe zgXN0+^*TOtb6B~#j9xgHz)G2TgYiYkv_x6H@<;CzC=;8wc>cgj=b8DxKl&vOR$BDP zt>w{Hmg~g)!uW7w5b&MqeK#By;@q=~HMDqJI70^bGI$R6bXx~3{eHOvNzpJdi6^IR*3K1v zc=KU?uIv0_8_@+zWH*=L?B}Wa@J~d=bf~qomd!v!a&qGme3wz_sX14SNXAK9E@qk| zln>Mm${7GjX5Rr~5lz3zJhiBe*a_eks#GIjMG3?VgI40Cb1RRw;#K1jtl4MboOJK6 zPf8J{8Ei7gMr znq5U>$I26mVF2A4)+3$;^G&94KD4I7)uhh7PZGmzVA7bfPSl{_B1F_M7>)lcYLAD7 z`cpQi;?jN#pL}UI(PU(Ig%)Zf_Fz)4fq`aK5ZDjo8R46c3($5p#8Chp^Dt*-s7_VH z6z;oy$uslVSK2r0!GzgFqXFq!|B)jG9O?MUc_;X)afKLjdb>VnCgP4}x_|agkv>H@ zUvzn@GAVy=;Jm7+=~$|N9Je-MpC2KsMl-#SX-vZqnV6x*urY>fur|e|~ zVUbj{s!xkeT&P#KHH;~^XIvC68|4WNnYsCXf6ZY|IMranh=cMJf#y1q6K5oP6GR@c=GTKF@65&$RcZripS2DytUK*m-I-hW80>AL!K z*!je$8ycngINMN0o@%Fdj`YX8p+YAXJsu63*qok5|F*-5noW&TOi^)4p<6Ax$|f^B z53rV&9~$24ppPOEc>W(?uyn%mf-GesG+gXhGH)WV?q(jGzZYe20QNCqYb8_BXDmhz z8nUx1`vcN*?jCHuMkNfbKDlY(Uow4gS$rrw(Mac5Oe^HFZC`Q1mQ4xv;YKca_sju7 zP*n71n-8{41D^OSI&FM%Tq4}pa4u2zzOO?wsgXSlCY+F(2r!=L&o0H1)=FN#X^@Mi6rp7_!?wUY$zBJ_xSjxj=<=nW;u3z zF#^~0^(IVa#{)=GKF8*Ig!WXcykg3K+H@(udu>i zbLshBF&UpVP(MAN%g)Am6?Q1?8+~K~9u{|6lt~;s6ptl+RchP4b7o5pjN7sdHKn>s zT@m<0<-bs2KL7uRF6IyDP?nQO0G9JJ=J@t4d3pIwnG2IgCbCX9iDo98t0~X~(3_tE zIGAYz)?;b^!&Cc?UyH45j2|DFOoh1~h|nkd4|X>d7NxSC-HaadxmlmGdW#r{qW?;idFb?V2sM`Ky^C^y!q91VvIk8b=ry; zMNLmu2q11;BK>yeKp1Rqi)jYKV^&&7VdMOnWd0ddU;t`hI%L}yv)Kmgm=<4TMzNXK z@45zVeG`PfV*Wx3{MWLF8U!MT-Au)zzN~#Yr22gfK<|D_nv3M1VVH0G!Lb!S%#?Sq zohDd&-To)%PK+K?t=kDw{DfX(n3{Ee{(NC)Y;PO8%K1e=V98@S7|ILIJ6w{Jle4y- z3cPfcBk|7P;Pkl?2`Gh9<>;qjdK~fX#^b^T)X?cz4FsjM=xH2iwZqh=y%M z?-&ToicP3=vXY^4qL7U{7`6{A0TDUxgkpc$p^fdBy8@lvfyD1u(ngGC(p7FOPl-1z z73*dNcj%#1K1(MY<{UFz`6{2M#;Mz_+!Lo8@tj=^bTVXsR(HiCkb@I6CH_d=HZt$q z;?63a?z3;S>QQ*Yw;1To-Gza&e9iUs^~nm2f3F|pgzIw1K$5gh$~)kPB>s(dF}Cun z9D0EOe=vF@clF=PsM4r}tH}`J*nV_;2~2GDzm`+328OsIk;vgl-pH~YcHIue{`H?R zygRfUwW@>Fh?+)5DMaluu|_5P!dLa>_gYUXiLm&M3J|NB&8EeXS{4Yq36ezrk`DzoOL{ie z?WFFVFUfnhQ@6+4yAEUTwXSuUsWvHfh0Q>iR4e;hoKA2pZ% zrm*mAN!~5e*4&3>9(IqnECyW`3+!|7oUE{ZHo*~SOHoO9sbn{$W&K70k2&7~*k(Ex6y?`DcVxNa>Jn2;5|U3VT3UTFN&r$@Hy{@!Je@%>2KBb<)u zf-7UYi%57#NXWIzB=lmnLezF8)&RfXUIN%a)Jvoe->I57{j*<@W!^T824kBY;ebUw zJy&qXGaKch*tiWS1U&QRH9(kbwtk~**sbs=JTgk<?@0T{M;_~5~Vjkda-HjK!fyFcaRQ)7JY?%r0(o@1;a$>XOXv3Gkuo^@C>{NukL z;7{-E;PXnF4wr=Qy4;UG`oqY4K8l{niCMYT$fR&9NWZFK|7m$?cXxN0jR3wqA$+M` zaP{ctouA0%JHNk7{-U%srBf%F4B7I62&zH<-Qb5BTkQ(8c>kOD^`8KIv%l`#`9lu7 zvy52;kolj0d>Alr`;*;CKzrEUJk}Tng8}Vzexml3pF%gELtD%ghRQrGU4MC%@l)7S zE~vj%!c}ElTi*Vm41lK%{(1UAwaMV_QVdj6O7zvgUr^c*BT83#4b%fxPyGN>-!~51 z+vQE?4H$9z;r;>p+5xm2FaYrZU!C$-tzL$AApgutEM+NLwLRi;!h8NImQVk=<2PkK zul@TjAZV9H08N1U{Z7O`U+^b(Nf3J&3n+r;bQ(;(QCd!}V_LxC^go@hfcIF~cC$mg zO*$5$x-S#(Xni+)bFh2cfDHZBTWGU9!imxW9^o{i$F@gFDc1v=j@vcv^2Wd}E-rz+Hg>?S zi#I}Z8{e5{T;HaI3*AI@=&}<3Z%G8_=H>=+2Rx@M5_~pSCeFisQ zcPIcD9{9}v&P@X@|DW*oAE*Bh==cBJ{%=VE|CyHmZ%#|w!8Fp1sa*~hD+9skPUx&0 zJWI#^G>k`~s~XVfb~mq?n3zl`U{@62yIbqPuKk`9BNva|h{48=Z3u&Zc{xL%$|ZnY z!|VAaH=811h6fh*TGxNLZppnL>{7S(TW16DIyMnVC{e{M4UhbE9crsmRAF)~7AtS1 z@1qkd8@;Iba^I9$tWFc*hYl=|0F3M*7x_ow7g`^M-}deQbRq(cmWQO$ob|0tf0Jx^ z*H-pk_h0jSc15ozcCI%{s%9P$6qRcStZOPP3gxXi^}};7C@lP8#fB;o$@>LG8rhX& za^S7bm=(vf#(`s?Y5IyToD=F!=}k{Nz(pRCq|cX zNI*r!?*Trk&(1n*FLXGn7WT_otDj2jLvhX7xw$z&t6U?nVSs;*nGnN&^(puhaE}St zJ%0E^oWcZYwWG}k#9WzdgJUox>(QD}pW$G-QQrB_nBSM*>5Bhwvya0EO5JmO06go} zw8Mz_RXP-!4MhC0W33*xfbVwEq|_Usy^GQQsA_vNkn+mh)4kgzPwy=4;XnwCdt}L} z4q#Aj94^qrZslX8=!UB|ypMHQ?lfBd82SW7Rc=3+zZ8>NwBK?tj6HbV6CWA2Glk%L zxL!WBnNYWp0Dw%W(cu@m0CoQ}mSeg&aL5`N)&6gA2ZPiyqE2PR=XZ-Y7@no44XhD; zx1zjp*UR)W785qMn00bLau|2E=uO}M4E zlSel+V>mlY{Rwo+eBQmD2B}^046)08Jiy@$E17Oobw% zfV6xs+HJteu6ZRr0Y3Dw0cDi`PaJpy;EVb(D~-27yrv(BuxECwgGKg3w+6FiVR{mD zk+3U~u6rKZ+|Me6@eil~ok4IGekMNZ@``x$&xlRe0RZ$kKT93^d_BK>RPPm?j5}u! zGDj`bd^f5uqZD=YW#DjB+UnLfxu<}z^?}K|B38KdPdO0k)qS?;Sc|rDa7>#QDn?>Z&Vo|^uP-c z5LK@6bW-G|5NY{mw0Hk60F>KgAE^IQx|L#h%!OIeSj@mg5t63cUVU{zq&5;lWfGR{ zR%@;SP*~|88mxmkJBi^7oBSNJ@@IP?4uG!54<80$|8gnP+nZRJe_X|VS3I-LwdcQz zmPEf(4U1gqvcG*>45W>S6$VO$HG}Sih2hY^buhQ_Gp3~``* zjQ8YsyKh_`7;NvJxv;h~-b+{<_N`7SNqni4(7j?$TreUm7Wy&Pzxr4TQSpGcl~J4| z46fGg%d9;;lntUjwhS|&=jP_?C9?8%g;Es?!yNqz!&_X%JO!xKcZkQ>vVTuTXY!lhB1KV zi`w0~C9S5fpUzaK6c!f7QB1V`fMB=8eB^2l00%p7stKy8vpXY$azp00hx+AgPAz_U zkn{A*)7*guSN&TOl7n{ro6F6YLb?5e>+I-Ac`>?LrF`^$ypM8^YmaTek=m^@kdZ*s zgLJOHs-`UnR`-)Ad7A=ly0A^~Gu(Sm4p7@$PLoxXAyvFD_qQ6c?jJ>axsesf0Bdr% zvaB=6Dl})S6!#@372s^iD3{~cYQ*MYl>~C9(UR}@)n!kZo@02}9b`@@yAke*?PywE zU^*TW1|YKRmOjo)Qt9y1hUi60t?zB%=&5W)15`5`+NRA1sF&3{gb&gHMbk$qWVlbA z-IFZ~r{XJXK9R9R`IAG4N-fa6GODd7)sFY=DJcq|v@K$!>8$?ij}9Xao^Xxc6Pwsn zX#B8%y+n#X-T1-xlHsH9m*%lL&&4j2*vN8msXl+a)@wAeQsUR+;IW(i{LZ+XmKL12 z$msCRg6$s47vE$<{myKt=P@aNh%V=o``dgy7}vKoYEcXIgFf0f7dQR2Z}&4mTAh{A zv+hSZ90M@6plsR)3r(#?H1P+~F*+rC7VHupfy`0FV%FCMQ#Aid?{Dsn?i;P{Vm|#c z>T{Y%y>?>stsT(qSQ+Z%njx`R<7*#*u<>fH`3sk#=ej+Ua|tcyE1o})9PN2%tRV$T@G~m1sVNlHTp?@3D^`xVqOz%rs4lo}CFr8DWJ!DfZIn>7m(U#H)Ndf4>%`L<;H$+^lT@fY-*IHS_+ zG72|bXn0i0;B@8u)(B&_`f)+4r{xPHu4rXrN>`K*Zc&wJyz3w~rq@zPW1vifZ6%Vj zX-pI5Ge6e9_2X2})=QjMx?b(iwauIRs-s5Sh;mvA!>lD^Lf@T zVYOxy^A19E*6uDS`MrYU1fo)(+L*emFXULO&p1^scS$}hxTf`?Z!$0TX=>^j*9erd z$@S>seuEM#oR*c-L2t{EnWdcAw)xz1czj>55o0UP1r^{VQ&Fj;O_yl@$r(#F;tnQj zchrns|34XHW-nfzJMc3b`l;aFy!7hR;HEooJqH=#HmT*!6uM^0my*GCS;9Q)IaHlI zxyg*ZtZvxwzB0kx7i&3uUhh?L#O0z4`Eo}B#b^sXlsow>#f!sEmIcv{%W9c)3r_2Y zu50S$Yl|{~#i~?I@)H%Jt4n=6d(L+e*C(pdwZbe)Op~XN&uo29nQ7q}H_*ro^`QbX z6%DuZw!Uix@1dz5L|h6DBU#9yYu%QHQ0ftA#CDBM9s<9pWwcfObwQ+6W2RBF$ZoaVDSQHrYKSx&d5RH~|`z-lVwLa&YXVz0Ua z-v`-}Ux~XFggkE-!10o6OCc-txm2jsa7C$Z*?x{Z!s1Mo@0PGSAH-(pqhv`F!eW z3#yA$BDl$oE30$x5%-}tEWRv34?FWrS~ji$MQ{T!%R_a2;E~N)nu0VAD1tZlnsXFK zu2u1>NfGZ2nPMu^h!)dc8|qXx&Y}1rayWeRb$x0KYpL=8JV35{$tmLvi-_wg&+#vw zm!w4f<5qU*pKwvsk-N+y_r}x7OXCE zeS{G@Ijo?G|2}NS8k`?`d+KKa%=q`moKHSaL(J`RZS)m@1LlZVUUZ}d^27q z=vGiC@7QNcXQc-veD7VBkj^+dKqg1XF+ZbM0W2pZ9ol!K7OXEp7soX}m8Uo7Pyp&@ z>Egy2G9$Y`=rVmFR|)ofSz;$wg)lj~lLi3zIp*DCcf)q7A4p*MjgfXQ{rSW*|IEW1 zgLJIQJ;Tbd>xVT0yZ<)hWL8Z%toM-_Xdjj`Qqkq)*j(aE{EU%$VqsxYjoC^ND;2X1KMveAfUp4>+HdIx8wrpyUyWZR~|4r)( zHPwS6GWv}?4^W~e(eHi`5-bjkYRH&{Dm`KTQ{u39rVCS-og8{hy9M(T0ak-uxooZX z#oOg$Jv-?LId!g=U5HW5D|IJ7=kRM$m3?m>*#7lEx)k zMP5&OPL)*`CTKSj;ncxj`V@6AoRoedbR;nKA6+BTpuXbfn%Ql?Ho{Cx{gBJnaC+wq z_g6N_&aUEkU5jK7Ghd^p0@kLRHVgGXoK~gz?&QfI;e0Sc9sBm13W1g0GRW^=zc8@U zan~1Xx1j6e-o88MJxeMQijNKd*D|yIk(lhaZK#vK65xBwN zl}-3u?&MS-96GdlXYUvPTjwRnNFihMq}pzv5-xzSf5;dyP^-J}eYtpMW9|xJd}aqUBU6kQ!9+8`!Ljz0lBQPbXcMe27wOAtdl3Uhacq3 zTmD?#_CkKJEMa`SuV0<{{xKX_G_9fQ`K`7HHg1uwbp!3m*fJYr_1jD@vT^)@UrMJV z+=d#Y-*6H>qKr2SKD)iPmId`$h&6unPo&MSu4MCd!@I720U%t#n(5LW5`y6t!68PS za!F13a_qQ1=~}m$$vA|0Nq8BquRqiZxlrYmULj|2G;iU<_nSh{RFJ3odB-pW(IxsiH%ByKgh!gxVCr$|9?ZbC5#;p9z_(ycVxh}D?ip`fB$$Nj&wS=yp5nxJ259(oFJPFKwYOr8as_LJPpl>_xzkfMi?&yPm>Am zEERP%+H;)Ipk%6RAK1x%NM?P$@(SAfMbDiOAZ&V=3s>LzTI_`_Hqy2nRbT(Lf-pJ7 z-5CsXVJzo85KF1mT`3G+$Mpxfx~eA?Rt&+%S}=F-DSu{cOso_=Nl;@qta=KF zcXoPqcsQOOqOcBfhi~K$-zn>(nNnTcudb257cAiNYP~A*ONV@y!ZH$s9Z+=6(c;o1 z0&_@g#T=1Sevb3-bL-V0Ebx81F}mbUS?D0s2?cNsXxE)C0HZj4R^jg9lg2t02$;Cz z){=oR5s=^JXZX%R&qqfCaFtT8z)SrxCYe8AcC#v^#?o2%#nmT zQRa$L5`XKsys68M85ERozp4tVbJrUy8{aw}Gk*vkda@zvR7l=JTb*5aV1ExPvkbzG zw&aHS^Y-#rq>y+P1{N`O`Swq>%qb?-r_Y1zj;pownmo(OdS8!Ez*BOaSUM&`>#>9T zaRENQ-sJj{{Ls-t`8Q!vXA44yv`P~m$pQ6)Mx7`VY`sV*t9nUrt(H@=Rs`+8q2W}keoe7hIVa(k*O2a(mCd+F` zj960%O_DmTPN=?rz^a@?nh0v|t*m}IRbGuB^pkP4plFoDQCL+u(?{sP+SFoJcq8{@ zS$HYMqTUa>kUDOT&#kc0%0yrCcN`rc+EjyPI9~Pdj_=XQOlEot*Y@=tBbytBu}G1( z3h`dT2K;cV4?>%;d?FmS?X3Shfl|``-WwQFI)dQO3gx{8)M|+XvZbl1gpbC+6u*VI zPLmA^cyw$m2@!7rY68$86Lb(8%a+hjUug?N$d$3T6z;DwB_iam{nZg_Pq$1J3U6*a zQF0zro8Iy<bT@ z)D$#m-$oDe(ltX(#dEvAWHh$wc|O_F>vL?s=up|Zv($xwV+Q-`A3~3P*M?J$)2@Fy zVS|x>(B`(9kUBn%2!M2)0yv=Vmf7LRp0oYjTKxV7kS-%m4DIR_{e@` zLWuvy@7-|jc|7nPA{DYhqZ`^rfK5@m;`9?9s(85@$esmv->h3sKyoo;LM^DlfmPI^6^fmdypyDz?@^ny_qyFt0&XMz-yIC&)wYB^F zz6mA4at#ujL0FrdOCiYRCljr3*2bT$-jWklkHuPYgm3xNZO5}u071n8KCflYQE!|7 z*m~ds;Ir$TUivC+YxiJ0qQ)+E1y#PVsNG)lA%6oOa&b}r2aD>h#NiOgjXU+BdA0wD zE;iSrvK6M{p&g+k87t$R#A6h?#&ej8Rft^_I&|M;z|4rbMvB^9kRpxJr$HMw4`~Ig zCw*{9E^u(I1vJrgCcTOP=;JSw(4bTU466BZgN%5^w>iZouD2X0$92avGXFQ#I}`9_dJw`6UZ{G$te zX1jOwc52>homBOWhPQVDob;c7KIP-EC;LcDG8jL?FvQ1fw z+#Wa-7J)>i;tV={y7A zI1vMLF+>nr31IUaG4rl)OX3Z~=m0iTrwC*lOyYE&-&7vEKP?Zg<2|B@+44yN99)8X zwvV*L`i^q#QQrVy0>OY|ZEHeHOG~3PhtB-M8!KSI`$>yQ^TmW^!zvLc znQW~e!eyj@A5qhu4=?p*z)GCd)LBvo1&6xQ^61Mj(nbKLUs6rrcXYLo>N}SXeo)s^S}a#KPxBjK zeiZKjMplVgzaEWm`J^_Y;-@!P0Phtkc5NdDLgb8)7v6*Oy>51_K%nR_JF@7{v%6L<^fmR?sPkd1uX0Y3@qaiRiX<}L(j zv;I{lB)Ie9jSD77;Xwvy<=*<4(hzWDi~~)k+sU{PNg1NSnL*&LA547p>Xo%@yi$}& zj{uM-qsVrqSjGOk8EAp^zmin0Eb&hj+tZkaG23mWjoKq&YFfrGvVN3$M4GZ%S`Oio z$i>C2&e*%@xaeZ>Opbuy&ofIqdqJW13|zR=8M{K(Pys0p8!nl}DC8uCi1z#2DCKO` tHgzr1J7CxURl|1V>;I<%@V=M99s3>u~fyncd0wzXFJ~>P7$n literal 2709 zcmd^>drZ?;6vqn+I^`jx3GGc{dTM%1DTT7JJM4<90g+8p{3POQWq0mv<--^q!zj4OI?2nV1dv9)X^ZlOB zIVV3NEYNfVbOQteF%1s#kAy%B2>QrbZ>&E-eFlA^UkuJf2Kqv1JzFO9o3+S;@PiOY zeeuTSWF!6llT$%Q&Ojh$Eh}NrfvHY`KsJ0C?0+!&EMj_ySCF8x=}?aMba#t~p1O8b z8@0uy?f`-vsYVtecvQyLKF8gCX??7qJ40|QZ0xRW8GM&3yH8sAIcHcZc72<$(<0KS zCV!J#8}s_Z!mz7lTcPe=jXO#d&8LutBmr8T-ZQt0cUyB}Y@xl}(6KX`ZMeyQ*}yR0 z=YXSf-TKdUkc+Jjj@Q&{)@~+TI529yDQI<<5g750i}JCi@d`9<7?C9_6|!-dX>n5n z@P?yZ5TON^)|t~1qRPE8QPV7pl*>R<g6SKCj05oH8pRVR7=46 zOyCKLRJJETMG!3$2wIo$*z+>m;a(U)cTR#80m@!j)|(Uw0n`GW477j-Ej$F2T{SHL zM@QNy+W`yIb0WoBezB|u4wQ^DSkcWAELW%L)c_?EON0x7#r9OcJTnHI{5UD%*|TR~ zZ!r@&@t!>K?dI{qy>;BV_aMl-OZj15s@lu3&1ty}EiCNpQDQoP=f;|=W?tuP1T0=J z$#FkSYK2-^w#ca;rr< zDO?(zhmh}lm{)p77z1kufO|3nP^rq+Ze^5WdRv7aj@uaG94V`NWCeY7}! z6>uGe!6Bx5n0mn#?3=SPL?L)P0erDt;;znGG_l7NL}-T*cJy_6TQkl_^Q5fkaTml% z;y=K)wG_kz3b*m=U;;ujY>*tv%WxeDNl+t9i@{hnayEL=#EcP8owFDVCdx&OZ`8IE zmYMZM)2uk~=u)C6G*_{@YV&b9wes7^LfCYUM9kGJ;1$-27>UTL#(ttFF-_d=HT>vj zR~2e`p4FV|S^A^;nW6@E?lppVWwEYMb+?LR_24cz+EmzZw1_tc$3Ro@Yvj`@v(Gi(Kuof;An zf))%94>Qo)crT54Be^?cX#?}~8Gw}g2uSbl>U#KgU|CaFzuaQF+IeY}X@8pnu0a-^ WqyFmJl&&8oA;AG*{xsjDpZ)^GP0?)t diff --git a/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb index cf851bbdf1..b84fdc1b99 100644 --- a/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb +++ b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb @@ -5,28 +5,27 @@ "id": "8581f0e4", "metadata": {}, "source": [ - "# Speeding up the Hugging Face Gemma model generation with Cuda Graphs and THD attention with FP8 precision\n", + "# CUDA Graphs, THD Attention, and FP8 Weight Calibration\n", "\n", - "As it can be seen in the [tutorial for Llama](../te_llama/tutorial_accelerate_hf_llama_with_te.ipynb) or [tutorial for Gemma](./tutorial_accelerate_hf_gemma_with_te.ipynb), transformer models can be accelerated by using Transformer's Engine `TransformerLayer`. In this tutorial we want to present few more advanced features, namely\n", + "In tutorials such as [Llama](../te_llama/tutorial_accelerate_hf_llama_with_te.ipynb) and [Gemma](./tutorial_accelerate_hf_gemma_with_te.ipynb), we've demonstrated how transformer models can be accelerated using the Transformer Engine's `TransformerLayer`. This tutorial introduces a few more advanced features:\n", "1. THD attention layout.\n", - "2. FP8 weight calibration - for doing inference in FP8 precisions for models, which were trained in higher precisions.\n", + "2. FP8 weight calibration - enabling inference in FP8 precision for models originally trained in higher precisions.\n", "3. CUDA Graphs API.\n", + "We will explore how these features enhance the performance of the Gemma model during generation tasks.\n", "\n", - "We will compare generation time at 3 benchmarks:\n", - "- long input sequences (max 256 tokens), short generation part (max 128 tokens),\n", - "- short input sequences (max 64 tokens), long generation (max 100 tokens),\n", + "#### Benchmarking\n", "\n", - "All benchmarks above run with batch size 64 and on the dataset \"timdettmers/openassistant-guanaco\".\n", + "We'll evaluate the generation time across three benchmarks:\n", + "- Long input sequences (up to 256 tokens) with short generation (up to 128 tokens),\n", + "- Short input sequences (up to 64 tokens) with long generation (up to 1000 tokens).\n", "\n", - "
\n", + "All benchmarks are conducted with a batch size of 64 using the dataset \"timdettmers/openassistant-guanaco\".\n", "\n", + "
\n", "Note\n", " \n", - "This tutorial aims to demonstrate features of TransformerEngine mentioned above on the example of generation. It's important to note though, that NVIDIA offers other library to use for inference - namely [TensorRT](https://developer.nvidia.com/tensorrt), which should be used in such cases.\n", - "\n", - "
\n", - "\n", - "\n" + "This tutorial focuses on showcasing the mentioned features of Transformer Engine in the context of generation. It's important to note, however, that NVIDIA provides another library, [TensorRT](https://developer.nvidia.com/tensorrt), which is optimized for inference tasks and should be considered for such use cases.\n", + "
" ] }, { @@ -52,26 +51,6 @@ " - This directory contains the images used in the following tutorial." ] }, - { - "cell_type": "markdown", - "id": "84bfbe6c", - "metadata": {}, - "source": [ - "## Table of contents" - ] - }, - { - "cell_type": "markdown", - "id": "f09c29e7", - "metadata": {}, - "source": [ - "1. [Baseline] Running Hugging Face generation with Gemma model\n", - "2. [Improvement 1] Speeding up generation by using Transformer Engine THD attention.\n", - "3. [Improvement 2] Running generation of the model trained in hign precision in FP8.\n", - "4. [Improvement 3] Speeding up generation with CudaGraphs.\n", - "5. Conclusions." - ] - }, { "cell_type": "markdown", "id": "e8dfabbf", @@ -102,14 +81,12 @@ "# Import necessary packages and methods\n", "from utils import *\n", "\n", - "\n", "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", "## !!! `model_name` attr must point to the location of the model weights !!!\n", "## Weights can be downloaded from: https://llama.meta.com/llama-downloads/\n", "hyperparams.model_name = \"../../../../gemma-weights\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", "hyperparams.mixed_precision = \"bf16\"\n", "\n", - "\n", "# Init the model and accelerator wrapper\n", "model = init_baseline_model(hyperparams).cuda()\n", "model = model.to(torch.bfloat16)\n", @@ -120,25 +97,13 @@ "inputs['input_ids'] = inputs['input_ids'].cuda()\n", "inputs['attention_mask'] = inputs['attention_mask'].cuda()\n", "\n", - "start_time = time.time()\n", - "\n", - "outputs = model.generate(\n", - " **inputs,\n", - " max_new_tokens=1000\n", - ")\n", - "\n", - "end_time = time.time()\n", - "duration = end_time - start_time\n", - "\n", - "print(duration)\n", - "\n", - "# Decode the output tensor to text\n", + "outputs = model.generate(**inputs, max_new_tokens=100)\n", "generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n", - "\n", - "# Display the generated text\n", "for text in generated_texts:\n", " print(text)\n", - " print(\"=\" * 100)" + " print(\"=\" * 100)\n", + "\n", + "benchmark_generation(model)" ] }, { @@ -146,7 +111,7 @@ "id": "b3698dc6", "metadata": {}, "source": [ - "We will put these times into the table for later comparison.\n", + "We put these times into the table for later comparison.\n", "\n", "| Models | max_input_len=64, max_new_tokens=1000 | max_input_len=128, max_new_tokens=128 | \n", "|-------------------------------------------------------------|---------------------------------------|--------------------------------------|\n", @@ -158,19 +123,43 @@ "id": "2bbf3d47", "metadata": {}, "source": [ - "## [Improvement 1] Speeding up generation by using Transformer Engine THD attention\n", + "## [Improvement 1] Speeding up generation by using Transformer Engine with THD attention\n", "\n", - "Similarly to the Gemma tutorial, we substitute `GemmaDecoderLayer` with `TransformerLayer` from Transformer Engine. Since initial sequences have different lengths, we have following choices:\n", - "1. Use padding from the beginning and then use standard attention with `\"bshd\"` or `\"sbhd\"` layout.\n", - "2. Do not pad from the beginning and use THD attention.\n", + "Similarly to the Gemma tutorial, we substitute `GemmaDecoderLayer` with `TransformerLayer` from Transformer Engine. \n", "\n", - "In this tutorial we will show the second option. We illustrate THD attention idea on the two pictures below.\n", + "Input sequences can have various lengths. The most common approach is to use the padding and attention masks in such situation. We will use more straightforward method - using the THD attention layout with offests. \n", "\n", "
\n", - "\"Logo\n", - "\"Logo\n", + "\n", + "\n", + "Query layer \n", + "\"Logo\n", + "\n", + "\n", + "Key layer and value layer \n", + "\"Logo\n", + "\n", + "\n", + "cu_seqlens_q = [0, 1, 3, 7, 9, 12]
\n", + "cu_seqlens_kv = [0, 1, 3, 6, 8, 10]
\n", + "seq_offsets_q = [0, 5, 10, 15, 20, 25] * h * d
\n", + "seq_offsets_k = [0, 7, 14, 21, 28, 35] * h * d
\n", + "seq_offsets_v = [0, 7, 14, 21, 28, 35] * h * d
\n", "
\n", - "\n" + "\n", + "The class `transformer_engine.DotProductAttention` supports this format. One need to pass the following things as the arguments to the forward:\n", + "- `seq_offsets_q`, `seq_offsets_k`, `seq_offsets_v` - which represents the offsets of the beginnings of the next sequences,\n", + "- `cu_seqlens_q`, `cu_seqlens_kv` - cumulative sum of the lengths of the sequences of query and values,\n", + "- `max_seqlen_q` - maximum sequence length in query layer,\n", + "- `max_seqlen_kv` - maximum sequence length in key-value layer.\n", + "\n", + "
\n", + "\n", + "Note\n", + "Currently, the THD attention for `TransformerLayer` is supported only for inference.\n", + "
\n", + "\n", + "Let's look how using TransformerEngine with THD attention impacts the speed of generation:" ] }, { @@ -192,7 +181,6 @@ "\n", "# Init the model and accelerator wrapper\n", "model = init_te_gemma_model(hyperparams).cuda()\n", - "#accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)\n", "\n", "model = model.to(torch.bfloat16).cuda()\n", "\n", @@ -202,31 +190,17 @@ "inputs['input_ids'] = inputs['input_ids'].cuda()\n", "inputs['attention_mask'] = inputs['attention_mask'].cuda()\n", "\n", - "import time\n", - "\n", - "# PoczÄ…tek pomiaru czasu\n", - "start_time = time.time()\n", - "\n", + "# Method .generate is overriden in the file te_gemma.py - look there for the implementation.\n", "outputs = model.generate(\n", " **inputs,\n", " max_new_tokens=40\n", ")\n", - "\n", - "# Koniec pomiaru czasu\n", - "end_time = time.time()\n", - "\n", - "# Obliczamy czas trwania operacji\n", - "duration = end_time - start_time\n", - "print(f\"Generation time: {duration} seconds\")\n", - "\n", - "\n", - "# Decode the output tensor to text\n", "generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n", - "\n", - "# Display the generated text\n", "for text in generated_texts:\n", " print(text)\n", - " print(\"=\" * 100)" + " print(\"=\" * 100)\n", + "\n", + "benchmark_generation(model)" ] }, { @@ -247,7 +221,7 @@ "id": "e6b171a0", "metadata": {}, "source": [ - "## [Improvement 2] Running generation of the model trained in high precision in FP8" + "## [Improvement 2] Running generation in FP8 of the model trained in higher precision " ] }, { @@ -255,13 +229,15 @@ "id": "1a80288b", "metadata": {}, "source": [ - "Now we want to run FP8 generation with Gemma model. But it's not that simple! Since model was trained in BF16 precision, the FP8 scaling factors are not computed. Running the model with such low precision without proper scaling will lead to serious numerical divergence, which will lead to wrong output.\n", + "We are now preparing to execute FP8 generation using the Gemma model. However, this process is not straightforward. Since the model was originally trained with BF16 precision, the FP8 scaling factors have not been computed. Operating the model at such low precision without the correct scaling could result in significant numerical errors, which in turn would produce incorrect results.\n", + "\n", + "We highly recommend familiarizing yourself with the [tutorial](../../examples/fp8_primer.ipynb) on FP8 precision to understand the necessity of scaling.\n", "\n", - "##### Weight calibration\n", + "##### Weight Calibration\n", "\n", - "The wieght calibration is solution of the problem mentioned above. We will run few forward iterations on BF16 precision within context `te.fp8_autocast(enabled=False, calibration=True)`. This means that the forward pass will be done in higher precision, but we will store `amax_history`, which will be used to compute FP8 scaling factors. \n", + "To address the issue outlined above, we will implement weight calibration. This involves running several forward iterations at BF16 precision within the context `te.fp8_autocast(enabled=False, calibration=True)`. This setup allows the forward pass to operate at higher precision, while we simultaneously collect `amax_history` and other parameters related to the FP8 precision, which is essential for calculating the FP8 scaling factors.\n", "\n", - "In the code below, we initialize BF16 model, run few iterations of forward passes within mentioned context. Then, we save model - we will also use these weights in the next chapter. " + "The code below outlines the steps to initialize the BF16 model and conduct several forward iterations within the specified context. After these iterations, we save the model, and these weights will be utilized in subsequent chapters." ] }, { @@ -274,9 +250,6 @@ "# Import necessary packages and methods\n", "import transformer_engine.pytorch as te\n", "from utils import *\n", - "import accelerate\n", - "from transformer_engine.pytorch import fp8_model_init\n", - "from transformer_engine.common.recipe import Format, DelayedScaling\n", "import torch\n", "\n", "\n", @@ -284,8 +257,6 @@ "hyperparams.fuse_qkv_params = True\n", "model = init_te_gemma_model(hyperparams, fp8_model_init=False).cuda()\n", "model = model.to(torch.bfloat16)\n", - "\n", - "\n", "accelerator = Accelerator(\n", " log_with=\"wandb\",\n", " gradient_accumulation_steps=hyperparams.gradient_accumulation_steps,\n", @@ -308,7 +279,6 @@ " max_new_tokens=10\n", " )\n", " generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n", - " print(generated_texts[0][:50])\n", "print(\"calibration_finished\")\n", "\n", "print(\"scale_fwd computation started\")\n", @@ -327,7 +297,6 @@ "model_fp8.load_state_dict(model.state_dict())\n", "print(\"Weights casted\")\n", "\n", - "\n", "print(\"Saving model...\")\n", "torch.save(model_fp8.state_dict(), 'model_fp8_state_dict.pth')\n", "print(\"Model saved!\")" @@ -338,6 +307,8 @@ "id": "b6dcd135", "metadata": {}, "source": [ + "#### Generation in FP8\n", + "\n", "Now we are ready to run FP8 inference." ] }, @@ -353,14 +324,9 @@ "#restart_jupyter_notebook()\n", "import transformer_engine.pytorch as te\n", "\n", - "import os\n", "from torch.cuda.amp import autocast\n", "\n", - "\n", - "# Import necessary packages and methods\n", "from utils import *\n", - "\n", - "from transformer_engine.pytorch import fp8_model_init\n", "from transformer_engine.common.recipe import Format, DelayedScaling\n", "\n", "\n", @@ -379,12 +345,6 @@ "inputs['input_ids'] = inputs['input_ids'].cuda()\n", "inputs['attention_mask'] = inputs['attention_mask'].cuda()\n", "\n", - "import time\n", - "\n", - "\n", - "\n", - "start_time = time.time()\n", - "\n", "fp8_format = Format.HYBRID\n", "fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo=\"max\")\n", "torch.manual_seed(1234)\n", @@ -399,15 +359,12 @@ " )\n", "\n", "\n", - "end_time = time.time()\n", - "duration = end_time - start_time\n", - "\n", "generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n", - "for text in generated_texts[:12]:\n", + "for text in generated_texts[:2]:\n", " print(\"-\" * 50)\n", " print(text)\n", "\n", - "print(f\"Duration = {duration}\")\n" + "benchmark_generation(model)" ] }, { @@ -429,7 +386,7 @@ "id": "21a89d9c", "metadata": {}, "source": [ - "## [Improvement 3] Speeding up generation with CudaGraphs" + "## [Improvement 3] Speeding up generation with CUDA Graphs" ] }, { @@ -437,11 +394,26 @@ "id": "e2d53e7b", "metadata": {}, "source": [ - "The inference code is run by CPU which starts GPU kernels. When GPU kernels are fast enough, it can result in overhead caused by the CPU. That's where Cuda Graphs come in. When some series of kernels starts is repeatable, then it can be recorded and then repeated without usage of the CPU. You can read more about the Cuda Graphs [here](https://developer.nvidia.com/blog/cuda-graphs/).\n", + "The speed of the GPU is increasing at very fast pace. It turns out that sometimes kernels runtime is shorter that time it takes CPU to submit them. It can result in serious overhead as we can see at the two pictures below.\n", + "\n", + "
\n", + "\"Logo\n", + "
\n", + "Generation without CUDA Graphs\n", + "
\n", + "\n", + "\"Logo\n", + "
\n", + "Generation with CUDA Graphs\n", + "
\n", "\n", - "Pytorch supports Cuda Graphs with `torch.cuda` API. Neverthless, there are some requirements for sequence of tensor operations to be able of being captured and repeated correctly. Namely, all the operations need to be static - meaning that tensors should not \"move\" between iterations. Pytorch offers also simpler way of using cuda graphs - the method `torch.cuda.make_graphed_callables`. We can easily record every pytorch module.\n", + "CUDA Graphs were developed to address this issue. When certain kernels are executed repeatedly, this tool enables us to record and replay them without CPU involvement.\n", "\n", - "Transformer Engine from version 1.6 also `make_graphed_callables` API. In the following code I run generate method from `te_gemma.py`. This is the code responsible for making graphed part:\n", + "We recommend reading further about CUDA Graphs [here](https://developer.nvidia.com/blog/cuda-graphs/).\n", + "\n", + "PyTorch supports CUDA Graphs through the `torch.cuda` API. However, there are specific requirements for a sequence of tensor operations to be captured and replayed correctly. Specifically, all operations must be static, meaning that tensors should not change locations between iterations.\n", + "\n", + "PyTorch also provides a simpler method for utilizing CUDA Graphs: the `torch.cuda.make_graphed_callables`. This allows easy recording of any PyTorch module. Starting from version 1.5, the Transformer Engine also supports the `make_graphed_callables` API. Below is the code that executes the generate method from `te_gemma.py`, which is responsible for creating the graphed part:\n", "\n", "```\n", "graphed_generator = TeGraphed(...)\n", @@ -455,15 +427,16 @@ " fp8_enabled=True, \n", " fp8_recipe=fp8_recipe, \n", " allow_unused_input=True,\n", - " num_warmup_iters=10\n", + " num_warmup_iters=3\n", " )\n", " \n", " for i in range(max_new_tokens):\n", " next_tokens = graphed_layers(*args) if use_cuda_graphs else graphed_generator(*args)\n", " output_tokens.append(next_tokens.clone())\n", "```\n", + "If you want to use CUDA Graphs with the Transformer Engine (TE), we recommend looking into the `TeGraphed` class. This class is similar to `TEGemmaDecoderLayer`, but it includes specific functionalities required to make CUDA Graphs work effectively.\n", "\n", - "Now, let's see how big the speedup is going to be." + "Now, let's proceed to measure the speedup provided by CUDA Graphs:" ] }, { @@ -473,22 +446,13 @@ "metadata": {}, "outputs": [], "source": [ - "import os\n", - "\n", - "os.environ['CUDNN_LOGLEVEL_DBG'] = '3'\n", - "os.environ['CUDNN_LOGDEST_DBG'] = 'backlog.txt'\n", "#Restart the notebook (to flush the GPU memory)\n", "from utils import restart_jupyter_notebook\n", "#restart_jupyter_notebook()\n", - "import transformer_engine.pytorch as te\n", "\n", + "import transformer_engine.pytorch as te\n", "from torch.cuda.amp import autocast\n", - "\n", - "\n", - "# Import necessary packages and methods\n", "from utils import *\n", - "\n", - "from transformer_engine.pytorch import fp8_model_init\n", "from transformer_engine.common.recipe import Format, DelayedScaling\n", "\n", "\n", @@ -507,10 +471,6 @@ "inputs['input_ids'] = inputs['input_ids'].cuda()\n", "inputs['attention_mask'] = inputs['attention_mask'].cuda()\n", "\n", - "import time\n", - "\n", - "start_time = time.time()\n", - "\n", "fp8_format = Format.HYBRID\n", "fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo=\"max\")\n", "torch.manual_seed(1234)\n", @@ -523,16 +483,12 @@ " max_new_tokens=10,\n", " use_cuda_graphs=True\n", " )\n", - "\n", - "end_time = time.time()\n", - "duration = end_time - start_time\n", - "\n", "generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n", "for text in generated_texts[:12]:\n", " print(\"-\" * 50)\n", " print(text)\n", "\n", - "print(f\"Duration = {duration}\")\n" + "benchmark_generation(model)\n" ] }, { @@ -563,12 +519,12 @@ "id": "7bb2452d", "metadata": {}, "source": [ - "In this tutorial we showed three features of Transformer Engine:\n", - "1. Support of THD attention layout,\n", - "2. FP8 weights calibration.\n", - "3. Support of Cuda Graphs.\n", + "In this tutorial, we've explored three features of the Transformer Engine:\n", + "1. Support for the THD attention layout,\n", + "2. FP8 weights calibration,\n", + "3. Integration with CUDA Graphs.\n", "\n", - "Each one of them can be used in different context, here we showed how to use them to obtain fast inference. We remind though, that this is not the fastest possible way of doing inference - for doing do we reccommend looking at the [TensorRT](https://developer.nvidia.com/tensorrt) library from NVIDIA." + "Each of these features can be applied in various contexts, and here we demonstrated their use for achieving fast inference. However, it's important to note that this isn't the fastest possible method for performing inference. For achieving optimal speed, we recommend exploring NVIDIA's [TensorRT](https://developer.nvidia.com/tensorrt) library." ] } ], From 7259dc956aa92c9cdbe5d9be0682885a7ac709ac Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 8 May 2024 07:29:42 -0700 Subject: [PATCH 093/244] HF finetuing introcution Signed-off-by: Pawel Gadzinski --- .../tutorial_accelerate_hf_gemma_with_te.ipynb | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/docs/examples/te_gemma/tutorial_accelerate_hf_gemma_with_te.ipynb b/docs/examples/te_gemma/tutorial_accelerate_hf_gemma_with_te.ipynb index c6a236a366..e436593901 100644 --- a/docs/examples/te_gemma/tutorial_accelerate_hf_gemma_with_te.ipynb +++ b/docs/examples/te_gemma/tutorial_accelerate_hf_gemma_with_te.ipynb @@ -11,7 +11,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "In the previous [tutorial](../te_llama/tutorial_accelerate_hf_llama_with_te.ipynb) we have demonstrated how to accelerate HF Llama models using Transformer Engine. Now, we will make similar thing with Gemma model. " + "In the previous [tutorial](../te_llama/tutorial_accelerate_hf_llama_with_te.ipynb), we demonstrated how to accelerate HF Llama models using the Transformer Engine library. We replaced `GemmaDecoderLayer` with `TransformerLayer` from the Transformer Engine, achieving a **25%** speedup. Furthermore, we conducted the finetuning in FP8 precision, which yielded an additional **39%** speedup from the baseline.\n", + "\n", + "Now, we will undertake a similar enhancement for the Google's [Gemma](https://blog.google/technology/developers/gemma-open-models/) model." ] }, { @@ -41,10 +43,14 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Thr Llama and the Gemma are very similar models - both are based on Transformer Decoder architecture. The differences between them are the following:\n", - "1. The Gemma uses RMSNorm with zero centered gamma parameter, and Llama uses stardard RMSNorm.\n", - "2. The Gemma uses different head dimension than embedding dimension, but in Llama this numbers are equal.\n", - "3. The Gemma uses GeGlu activation function, the Llama uses SwiGlu." + "Thr Llama and the Gemma are very similar models - both are based on Transformer Decoder architecture. The most important architectural differences between them are the following:\n", + "\n", + "\n", + "| Feature | Llama | Gemma |\n", + "|----------------------------------------------|------------------------------------|--------------------------------------------|\n", + "| **Norm Layer** | Standard RMSNorm
$ y = \\frac{x - \\mathrm{E}[x]}{ \\sqrt{\\mathrm{Var}[x] + \\varepsilon}} * \\gamma + \\beta $ | RMSNorm with zero centered gamma parameter
$ y = \\frac{x - \\mathrm{E}[x]}{ \\sqrt{\\mathrm{Var}[x] + \\varepsilon}} * (\\textcolor{red}{1 +} \\gamma) + \\beta $ |\n", + "| **Embedding Dimension/Head Dimension** | 4096/4096 | 3072/4096 |\n", + "| **Activation Function** | SwiGlu | GeGlu |\n" ] }, { From be68a5d4fae11594b83002dd82cd98c00e8562d7 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 8 May 2024 11:05:39 -0700 Subject: [PATCH 094/244] HF finetuing introcution Signed-off-by: Pawel Gadzinski --- docs/examples/te_gemma/te_gemma.py | 11 +- ...tutorial_accelerate_hf_gemma_with_te.ipynb | 22 +- .../te_gemma/tutorial_fp8_model_init.ipynb | 0 .../tutorial_generation_gemma_with_te.ipynb | 243 +++++++----------- 4 files changed, 112 insertions(+), 164 deletions(-) create mode 100644 docs/examples/te_gemma/tutorial_fp8_model_init.ipynb diff --git a/docs/examples/te_gemma/te_gemma.py b/docs/examples/te_gemma/te_gemma.py index 3d96a97934..67522ef6d3 100644 --- a/docs/examples/te_gemma/te_gemma.py +++ b/docs/examples/te_gemma/te_gemma.py @@ -87,7 +87,7 @@ def forward(self, self_attn_mask_type=self_attn_mask_type ),) -class TeGraphed(torch.nn.Module): +class GemmaGenerator(torch.nn.Module): def __init__(self, model, lm_head, inference_params, dtype, generation_config): super().__init__() self.model = model @@ -106,7 +106,6 @@ def forward(self, hidden_states, unfinished_sequences): self_attn_mask_type='padding', attention_mask=None )[0]) - self.inference_params.seq_len.copy_(self.inference_params.seq_len + 1) @@ -286,7 +285,7 @@ def generate( ) - graphed_generator = TeGraphed( + generator = GemmaGenerator( lm_head=self.lm_head, model=self.model, inference_params=inference_params, @@ -300,8 +299,8 @@ def generate( if use_cuda_graphs: fp8_format = Format.HYBRID fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") - graphed_layers = te.pytorch.make_graphed_callables( - graphed_generator, + graphed_generator = te.pytorch.make_graphed_callables( + generator, args, fp8_enabled=True, fp8_recipe=fp8_recipe, @@ -314,7 +313,7 @@ def generate( inference_params.seq_len.copy_(lengths.to(torch.int32)) for i in range(max_new_tokens): - next_tokens = graphed_layers(*args) if use_cuda_graphs else graphed_generator(*args) + next_tokens = graphed_generator(*args) if use_cuda_graphs else generator(*args) output_tokens.append(next_tokens.clone()) result = torch.cat((input_ids, torch.stack(output_tokens).permute([1, 0])), dim=1) diff --git a/docs/examples/te_gemma/tutorial_accelerate_hf_gemma_with_te.ipynb b/docs/examples/te_gemma/tutorial_accelerate_hf_gemma_with_te.ipynb index e436593901..3dca60e093 100644 --- a/docs/examples/te_gemma/tutorial_accelerate_hf_gemma_with_te.ipynb +++ b/docs/examples/te_gemma/tutorial_accelerate_hf_gemma_with_te.ipynb @@ -59,13 +59,13 @@ "source": [ "# [Baseline] Running HF `GemmaModel` (Precision: `BF16`)\n", "\n", - "Similarly to the Llama tutorial, we begin the experiments by running baseline training in BF16 precision.\n", + "Similarly to the Llama tutorial, we begin the experiments by running baseline Hugging Face Gemma model finetuning in BF16 precision.\n", "\n", "
\n", "\n", "Note\n", " \n", - "This tutorial loads and trains a Llama 2 7B model which takes up most of the GPU memory and therefore, we need to restart the jupyter notebook each time before running the following sections. A small utility method `restart_jupyter_notebook` is defined in the accompanying `utils.py` file. This function restarts the jupyter notebook so that the GPU memory is flushed before the model is loaded again from the checkpoint in order to avoid running into OOM (Out Of Memory) errors.\n", + "This tutorial loads and trains a Gemma 7B model which takes up most of the GPU memory and therefore, we need to restart the jupyter notebook each time before running the following sections. A small utility method `restart_jupyter_notebook` is defined in the accompanying `utils.py` file. This function restarts the jupyter notebook so that the GPU memory is flushed before the model is loaded again from the checkpoint in order to avoid running into OOM (Out Of Memory) errors.\n", "\n", "If the utility doesn't work, comment this line `restart_jupyter_notebook()` in the following cell and manually restart the jupyter notebook before running the cell. Repeat the same for other sections in this tutorial.\n", "\n", @@ -89,8 +89,8 @@ "\n", "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", "## !!! `model_name` attr must point to the location of the model weights !!!\n", - "## Weights can be downloaded from: https://llama.meta.com/llama-downloads/\n", - "hyperparams.model_name = \"../../../../gemma-weights\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", + "## Weights can be downloaded from: https://huggingface.co/google/gemma-7b\n", + "hyperparams.model_name = \"../../../../gemma-weights\" # <== Add model weight location here e.g. \"/path/to/downloaded/gemma/weights\"\n", "hyperparams.mixed_precision = \"bf16\"\n", "\n", "\n", @@ -120,7 +120,7 @@ "source": [ "# [Improvement 1] Replace HF's `GemmaDecoderLayer` with TE's `TransformerLayer` (Precision: `BF16`)\n", "\n", - "Now we substitute *GemmaDecoderLayer* with highly tuned *TransformerLayer*. Let's see how this will impact the speed of the mode." + "We replace *GemmaDecoderLayer* with the highly tuned *TransformerLayer*, similarly to our approach in the [Llama tutorial](../te_llama/tutorial_accelerate_hf_llama_with_te.ipynb). Let's observe the impact this change has on the model's speed." ] }, { @@ -140,8 +140,8 @@ "\n", "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", "## !!! `model_name` attr must point to the location of the model weights !!!\n", - "## Weights can be downloaded from: https://llama.meta.com/llama-downloads/\n", - "hyperparams.model_name = \"../../../../gemma-weights\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", + "## Weights can be downloaded from: https://huggingface.co/google/gemma-7b\n", + "hyperparams.model_name = \"../../../../gemma-weights\" # <== Add model weight location here e.g. \"/path/to/downloaded/gemma/weights\"\n", "hyperparams.mixed_precision = \"bf16\"\n", "\n", "\n", @@ -192,8 +192,8 @@ "\n", "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", "## !!! `model_name` attr must point to the location of the model weights !!!\n", - "## Weights can be downloaded from: https://llama.meta.com/llama-downloads/\n", - "hyperparams.model_name = \"../../../../gemma-weights\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", + "## Weights can be downloaded from: https://huggingface.co/google/gemma-7b\n", + "hyperparams.model_name = \"../../../../gemma-weights\" # <== Add model weight location here e.g. \"/path/to/downloaded/gemma/weights\"\n", "hyperparams.mixed_precision = \"fp8\"\n", "\n", "\n", @@ -226,7 +226,7 @@ "source": [ "# Conclusion\n", "\n", - "We can see, that similar to the Llama model, using `TransformerLayer` module from Transformer Engine as a substitute for Hugging Face's `GemmaDecoderLayer` provides a speedup over Hugging Face's native Gemma implementation." + "As shown in the [Llama tutorial](../te_llama/tutorial_accelerate_hf_llama_with_te.ipynb), using the `TransformerLayer` module from Transformer Engine to replace Hugging Face's `GemmaDecoderLayer` results in a speedup compared to Hugging Face's native Gemma implementation." ] }, { @@ -235,7 +235,7 @@ "source": [ "## See more\n", "\n", - "We also prepared [tutorial](./tutorial_generation_gemma_with_te.ipynb) covering CUDA graphs and THD attention which we use to speedup Gemma generation." + "We also prepared [tutorial](./tutorial_generation_gemma_with_te.ipynb) in which we will show how to speedup the Gemma model generation using Transformer Engine." ] } ], diff --git a/docs/examples/te_gemma/tutorial_fp8_model_init.ipynb b/docs/examples/te_gemma/tutorial_fp8_model_init.ipynb new file mode 100644 index 0000000000..e69de29bb2 diff --git a/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb index b84fdc1b99..c1b93ae885 100644 --- a/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb +++ b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb @@ -5,13 +5,44 @@ "id": "8581f0e4", "metadata": {}, "source": [ - "# CUDA Graphs, THD Attention, and FP8 Weight Calibration\n", + "# Accelerating Generation of the Hugging Face Gemma Model with Transformer Engine\n", "\n", - "In tutorials such as [Llama](../te_llama/tutorial_accelerate_hf_llama_with_te.ipynb) and [Gemma](./tutorial_accelerate_hf_gemma_with_te.ipynb), we've demonstrated how transformer models can be accelerated using the Transformer Engine's `TransformerLayer`. This tutorial introduces a few more advanced features:\n", - "1. THD attention layout.\n", - "2. FP8 weight calibration - enabling inference in FP8 precision for models originally trained in higher precisions.\n", - "3. CUDA Graphs API.\n", - "We will explore how these features enhance the performance of the Gemma model during generation tasks.\n", + "Generative AI has made remarkable strides in recent years, with Large Language Models (LLMs) like ChatGPT at the forefront. These models have revolutionized how we interact with machine-generated content, providing capabilities that range from writing assistance to complex decision support. The core functionality of these models is the generation process, which involves predicting the next token in a sequence based on the preceding text. This task is critical for applications such as automated content creation, translation, and more, emphasizing the importance of efficient implementation.\n", + "\n", + "For those seeking a deeper understanding of text generation mechanisms in Transformers, it is recommended to check out the [HuggingFace generation tutorial](https://huggingface.co/docs/transformers/llm_tutorial).\n", + "\n", + "In our previous tutorials on [Llama](../te_llama/tutorial_accelerate_hf_llama_with_te.ipynb) and [Gemma](./tutorial_accelerate_hf_gemma_with_te.ipynb), we demonstrated how finetuning can be accelerated using the Transformer Engine's `TransformerLayer`. Building on this foundation, our current objective is to enhance the generation speed of the Gemma model.\n", + "\n", + "This tutorial will introduce and explain several advanced features of the Transformer Engine that contribute to this goal:\n", + "\n", + "##### 1. THD Attention Layout.\n", + "\n", + "Addressing the challenge of computing attention for sequences with varying lengths, a common method is to pad these sequences and apply an attention mask. The Transformer Engine, however, offers a more optimized approach—by specifying the lengths and offsets of the sequences, attention can be computed directly. Instead of passing the matrix and mask with the shape `[b, s, h, d]`, one can pass a matrix of the shape `[t, h, d]` along with tensors detailing sequence lengths and offsets to run the attention optimized for this case. This specific attention layout is referred to as the **THD layout**.\n", + "\n", + "
\n", + "\"\"
\n", + "Fig. 1. The sequences and the mask for standard attention layout - padding from the end.

\n", + "\"\"
\n", + "Fig. 2. The sequences and the mask for standard attention layout - padding from the beginning.

\n", + "\"\"
\n", + "Fig. 3. An attention with thd layer.

\n", + "
\n", + "\n", + "##### 2. FP8 Weight Calibration.\n", + "\n", + "Assuming that we have a model trained in FP32/BF16 precision and we wish to execute it in FP8 precision, the process isn't straightforward due to the absence of appropriate FP8 scaling factors. In this scenario, FP8 calibration becomes essential. By conducting several forward passes on sample data, we can compute the FP8 saling parameters. This calibration allows the model to operate correctly in FP8 precision.\n", + "\n", + "We highly recommend familiarizing yourself with the [tutorial](../../examples/fp8_primer.ipynb) on FP8 precision to understand the importance of proper scaling factors.\n", + "\n", + "##### 3. CUDA Graphs API.\n", + "\n", + "The speed of GPUs is increasing at a rapid pace. It turns out that sometimes the runtime of kernels is shorter than the time it takes for the CPU to submit them, which can lead to significant overhead. CUDA Graphs were developed to address this issue. When certain kernels are executed repeatedly, this tool allows us to record and replay them without CPU involvement. This becomes particularly useful in applications like text generation, where a `TransformerLayer` is run for every token that needs to be generated.\n", + "\n", + "We recommend reading further about CUDA Graphs [here](https://developer.nvidia.com/blog/cuda-graphs/).\n", + "\n", + "PyTorch exposes graphs via a raw `torch.cuda.CUDAGraphclass` and two convenience wrappers, `torch.cuda.graph` and `torch.cuda.make_graphed_callables`. More information about the cuda graphs in Pytorch can be found [here](https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/).\n", + "\n", + "Transformer Engine supports cuda graphs from version 1.5.\n", "\n", "#### Benchmarking\n", "\n", @@ -64,7 +95,7 @@ "id": "59560bff", "metadata": {}, "source": [ - "Hugging Face Transformers library offers generation API. We will treat this as our baseline." + "HuggingFace Transformers library offers generation API. We will use HuggingFace generation for the Gemma model as our baseline." ] }, { @@ -83,26 +114,13 @@ "\n", "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", "## !!! `model_name` attr must point to the location of the model weights !!!\n", - "## Weights can be downloaded from: https://llama.meta.com/llama-downloads/\n", - "hyperparams.model_name = \"../../../../gemma-weights\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", + "## Weights can be downloaded from: https://huggingface.co/google/gemma-7b\n", + "hyperparams.model_name = \"../../../../gemma-weights\" # <== Add model weight location here e.g. \"/path/to/downloaded/gemma/weights\"\n", "hyperparams.mixed_precision = \"bf16\"\n", "\n", - "# Init the model and accelerator wrapper\n", - "model = init_baseline_model(hyperparams).cuda()\n", - "model = model.to(torch.bfloat16)\n", - "\n", - "tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name)\n", - "inputs = tokenizer([\"Some random initial str \", \"Another string ... \"] * 32, return_tensors=\"pt\", padding=True)\n", - "\n", - "inputs['input_ids'] = inputs['input_ids'].cuda()\n", - "inputs['attention_mask'] = inputs['attention_mask'].cuda()\n", - "\n", - "outputs = model.generate(**inputs, max_new_tokens=100)\n", - "generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n", - "for text in generated_texts:\n", - " print(text)\n", - " print(\"=\" * 100)\n", + "model = init_te_gemma_model(hyperparams).cuda()\n", "\n", + "generate_sample_text(model)\n", "benchmark_generation(model)" ] }, @@ -174,32 +192,15 @@ "\n", "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", "## !!! `model_name` attr must point to the location of the model weights !!!\n", - "## Weights can be downloaded from: https://llama.meta.com/llama-downloads/\n", - "hyperparams.model_name = \"../../../../gemma-weights\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", + "## Weights can be downloaded from: https://huggingface.co/google/gemma-7b\n", + "hyperparams.model_name = \"../../../../gemma-weights\" # <== Add model weight location here e.g. \"/path/to/downloaded/gemma/weights\"\n", "hyperparams.mixed_precision = \"bf16\"\n", "hyperparams.fuse_qkv_params = False\n", "\n", "# Init the model and accelerator wrapper\n", - "model = init_te_gemma_model(hyperparams).cuda()\n", - "\n", - "model = model.to(torch.bfloat16).cuda()\n", - "\n", - "tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name)\n", - "inputs = tokenizer([\"I love when \", \"I \"] * 32, return_tensors=\"pt\", padding=True)\n", - "\n", - "inputs['input_ids'] = inputs['input_ids'].cuda()\n", - "inputs['attention_mask'] = inputs['attention_mask'].cuda()\n", - "\n", - "# Method .generate is overriden in the file te_gemma.py - look there for the implementation.\n", - "outputs = model.generate(\n", - " **inputs,\n", - " max_new_tokens=40\n", - ")\n", - "generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n", - "for text in generated_texts:\n", - " print(text)\n", - " print(\"=\" * 100)\n", + "model = init_te_gemma_model(hyperparams).to(torch.bfloat16).cuda()\n", "\n", + "generate_sample_text(model)\n", "benchmark_generation(model)" ] }, @@ -322,49 +323,18 @@ "#Restart the notebook (to flush the GPU memory)\n", "from utils import restart_jupyter_notebook\n", "#restart_jupyter_notebook()\n", - "import transformer_engine.pytorch as te\n", - "\n", - "from torch.cuda.amp import autocast\n", "\n", "from utils import *\n", - "from transformer_engine.common.recipe import Format, DelayedScaling\n", - "\n", "\n", "hyperparams.model_name = \"../../../../gemma-weights\"\n", "hyperparams.fuse_qkv_params = True\n", "model = init_te_gemma_model(hyperparams, fp8_model_init=True, qkv_format=\"thd\").cuda()\n", "\n", - "print(\"Loading model\")\n", - "model_state_dict = torch.load('model_fp8_state_dict.pth')\n", - "model.load_state_dict(model_state_dict)\n", - "print(\"Model loaded\")\n", - "\n", - "tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name)\n", - "inputs = tokenizer([\"Some random initial str \", \"Another string ... \"] * 32, return_tensors=\"pt\", padding=True)\n", - "\n", - "inputs['input_ids'] = inputs['input_ids'].cuda()\n", - "inputs['attention_mask'] = inputs['attention_mask'].cuda()\n", - "\n", - "fp8_format = Format.HYBRID\n", - "fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo=\"max\")\n", - "torch.manual_seed(1234)\n", - "with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):\n", - " with autocast(dtype=torch.bfloat16, cache_enabled=False):\n", - " with torch.no_grad():\n", - " model.eval()\n", - " outputs = model.generate(\n", - " **inputs,\n", - " max_new_tokens=40,\n", - " use_cuda_graphs=False\n", - " )\n", - "\n", - "\n", - "generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n", - "for text in generated_texts[:2]:\n", - " print(\"-\" * 50)\n", - " print(text)\n", + "# Load weights of the model with the proper scaling factors.\n", + "model.load_state_dict(torch.load('model_fp8_state_dict.pth'))\n", "\n", - "benchmark_generation(model)" + "generate_sample_text(model, fp8=True)\n", + "benchmark_generation(model, fp8=True)" ] }, { @@ -394,49 +364,37 @@ "id": "e2d53e7b", "metadata": {}, "source": [ - "The speed of the GPU is increasing at very fast pace. It turns out that sometimes kernels runtime is shorter that time it takes CPU to submit them. It can result in serious overhead as we can see at the two pictures below.\n", - "\n", - "
\n", - "\"Logo\n", - "
\n", - "Generation without CUDA Graphs\n", - "
\n", - "\n", - "\"Logo\n", - "
\n", - "Generation with CUDA Graphs\n", - "
\n", - "\n", - "CUDA Graphs were developed to address this issue. When certain kernels are executed repeatedly, this tool enables us to record and replay them without CPU involvement.\n", - "\n", - "We recommend reading further about CUDA Graphs [here](https://developer.nvidia.com/blog/cuda-graphs/).\n", - "\n", - "PyTorch supports CUDA Graphs through the `torch.cuda` API. However, there are specific requirements for a sequence of tensor operations to be captured and replayed correctly. Specifically, all operations must be static, meaning that tensors should not change locations between iterations.\n", - "\n", - "PyTorch also provides a simpler method for utilizing CUDA Graphs: the `torch.cuda.make_graphed_callables`. This allows easy recording of any PyTorch module. Starting from version 1.5, the Transformer Engine also supports the `make_graphed_callables` API. Below is the code that executes the generate method from `te_gemma.py`, which is responsible for creating the graphed part:\n", - "\n", + "TransformerEngine includes a function `transformer_engine.pytorch.make_graphed_callables`, which functions similarly to the corresponding feature in PyTorch. It is capable of recording any modules from the Transformer Engine. Below is a code excerpt from `te_gemma.py`:\n", "```\n", - "graphed_generator = TeGraphed(...)\n", - "(...)\n", - " if use_cuda_graphs:\n", - " fp8_format = Format.HYBRID\n", - " fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo=\"max\")\n", - " graphed_layers = te.pytorch.make_graphed_callables(\n", - " graphed_generator, \n", + " generator = GemmaGenerator(\n", + " lm_head=self.lm_head,\n", + " model=self.model, \n", + " inference_params=inference_params, \n", + " generation_config=generation_config, \n", + " dtype=hidden_states.dtype,\n", + " )\n", + "\n", + " (...)\n", + " if use_cuda_graphs:\n", + " fp8_format = Format.HYBRID\n", + " fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo=\"max\")\n", + " graphed_generator = te.pytorch.make_graphed_callables(\n", + " generator, \n", " args, \n", " fp8_enabled=True, \n", " fp8_recipe=fp8_recipe, \n", " allow_unused_input=True,\n", - " num_warmup_iters=3\n", + " num_warmup_iters=10\n", " )\n", " \n", - " for i in range(max_new_tokens):\n", - " next_tokens = graphed_layers(*args) if use_cuda_graphs else graphed_generator(*args)\n", - " output_tokens.append(next_tokens.clone())\n", + " (...)\n", + "\n", + " for i in range(max_new_tokens):\n", + " next_tokens = graphed_generator(*args) if use_cuda_graphs else generator(*args)\n", + " output_tokens.append(next_tokens.clone())\n", "```\n", - "If you want to use CUDA Graphs with the Transformer Engine (TE), we recommend looking into the `TeGraphed` class. This class is similar to `TEGemmaDecoderLayer`, but it includes specific functionalities required to make CUDA Graphs work effectively.\n", "\n", - "Now, let's proceed to measure the speedup provided by CUDA Graphs:" + "Let us now proceed to evaluate the performance improvement offered by CUDA Graphs." ] }, { @@ -450,45 +408,17 @@ "from utils import restart_jupyter_notebook\n", "#restart_jupyter_notebook()\n", "\n", - "import transformer_engine.pytorch as te\n", - "from torch.cuda.amp import autocast\n", "from utils import *\n", - "from transformer_engine.common.recipe import Format, DelayedScaling\n", - "\n", "\n", "hyperparams.model_name = \"../../../../gemma-weights\"\n", "hyperparams.fuse_qkv_params = True\n", "model = init_te_gemma_model(hyperparams, fp8_model_init=True, qkv_format=\"thd\").cuda()\n", "\n", - "print(\"Loading model\")\n", - "model_state_dict = torch.load('model_fp8_state_dict.pth')\n", - "model.load_state_dict(model_state_dict)\n", - "print(\"Model loaded\")\n", + "# Load weights of the model with the proper scaling factors.\n", + "model.load_state_dict(torch.load('model_fp8_state_dict.pth'))\n", "\n", - "tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name)\n", - "inputs = tokenizer([\"Some random initial str \", \"Another string ... \"] * 32, return_tensors=\"pt\", padding=True)\n", - "\n", - "inputs['input_ids'] = inputs['input_ids'].cuda()\n", - "inputs['attention_mask'] = inputs['attention_mask'].cuda()\n", - "\n", - "fp8_format = Format.HYBRID\n", - "fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo=\"max\")\n", - "torch.manual_seed(1234)\n", - "with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):\n", - " with autocast(dtype=torch.bfloat16, cache_enabled=False):\n", - " with torch.no_grad():\n", - " model.eval()\n", - " outputs = model.generate(\n", - " **inputs,\n", - " max_new_tokens=10,\n", - " use_cuda_graphs=True\n", - " )\n", - "generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n", - "for text in generated_texts[:12]:\n", - " print(\"-\" * 50)\n", - " print(text)\n", - "\n", - "benchmark_generation(model)\n" + "generate_sample_text(model, fp8=True, use_cuda_graphs=True)\n", + "benchmark_generation(model, fp8=True, use_cuda_graphs=True)" ] }, { @@ -506,6 +436,25 @@ "| THD attention + FP8 + Cuda Graphs with TE | - | - | " ] }, + { + "cell_type": "markdown", + "id": "a2bd87e6", + "metadata": {}, + "source": [ + "We can also see how use of graphs reduced CPU overhead. Here are two screenshots from the profiler:\n", + "\n", + "
\n", + "\"Logo\n", + "
\n", + "Generation without CUDA Graphs\n", + "
\n", + "\n", + "\"Logo\n", + "
\n", + "Generation with CUDA Graphs\n", + "
" + ] + }, { "cell_type": "markdown", "id": "c6e87275", @@ -524,7 +473,7 @@ "2. FP8 weights calibration,\n", "3. Integration with CUDA Graphs.\n", "\n", - "Each of these features can be applied in various contexts, and here we demonstrated their use for achieving fast inference. However, it's important to note that this isn't the fastest possible method for performing inference. For achieving optimal speed, we recommend exploring NVIDIA's [TensorRT](https://developer.nvidia.com/tensorrt) library." + "Each of these features can be applied in various contexts, and here we demonstrated their use for achieving fast inference. It's important to note that the fastest possible inference speeds can be achieved using NVIDIA's inference-optimized [TensorRT](https://developer.nvidia.com/tensorrt) library." ] } ], From 1bfc9b7f27f46a8123e31ec8391f4a3e6520b26f Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 8 May 2024 11:09:09 -0700 Subject: [PATCH 095/244] HF finetuing introcution Signed-off-by: Pawel Gadzinski --- .../examples/te_gemma/tutorial_generation_gemma_with_te.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb index c1b93ae885..fc3b840b61 100644 --- a/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb +++ b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb @@ -151,11 +151,11 @@ "\n", "\n", "Query layer \n", - "\"Logo\n", + "\"\"\n", "\n", "\n", "Key layer and value layer \n", - "\"Logo\n", + "\"\"\n", "\n", "\n", "cu_seqlens_q = [0, 1, 3, 7, 9, 12]
\n", From 894c6456a873b3ba1bd2ff753dbf49c3ac7569ce Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 8 May 2024 13:14:28 -0700 Subject: [PATCH 096/244] Fused attn temporary fix Signed-off-by: Pawel Gadzinski --- transformer_engine/common/fused_attn/fused_attn.cpp | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 82bc8375e4..64b8b865d1 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -135,11 +135,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( } if (((cudnn_runtime_version >= 8903 && sm_arch_ >= 80) || (cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90))) - && (max_seqlen_q % 64 == 0) - && (max_seqlen_kv % 64 == 0) && ((cudnn_runtime_version < 8907 && num_attn_heads == num_gqa_groups) || (cudnn_runtime_version >= 8907)) - && ((head_dim <= 128) && (head_dim % 8 == 0)) && ((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) || ((cudnn_runtime_version >= 8906) && (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS @@ -162,7 +159,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( && ((qkv_format == NVTE_QKV_Format::NVTE_SBHD) || (qkv_format == NVTE_QKV_Format::NVTE_THD) || (qkv_format == NVTE_QKV_Format::NVTE_BSHD))) { - flag_arb = true; + flag_arb = true; } if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) && (flag_arb == true)) { From e1e5fa8514226ceac910e8263417ddc746e7a53e Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 9 May 2024 10:05:56 -0700 Subject: [PATCH 097/244] Bug fix Signed-off-by: Pawel Gadzinski --- docs/examples/te_gemma/te_gemma.py | 20 ++++-- transformer_engine/pytorch/attention.py | 91 +++++++++++++++---------- 2 files changed, 69 insertions(+), 42 deletions(-) diff --git a/docs/examples/te_gemma/te_gemma.py b/docs/examples/te_gemma/te_gemma.py index 67522ef6d3..7b4a3baa6d 100644 --- a/docs/examples/te_gemma/te_gemma.py +++ b/docs/examples/te_gemma/te_gemma.py @@ -228,11 +228,10 @@ def _generate_context_phase( inference_params=inference_params )[0] - hidden_states = self.model.norm(hidden_states) logits = self.lm_head(hidden_states) logits = logits.float() - logits = logits[torch.arange(logits.size(0)), inference_params.seq_len - 1, :] + logits = logits[torch.arange(logits.size(0)), inference_params.incoming_seq_len - 1, :] next_tokens = torch.argmax(logits, dim=1) # Sequences, which are finished should contain padding - taken from huggingface transformers. @@ -240,11 +239,13 @@ def _generate_context_phase( output_tokens.append(next_tokens) unfinished_sequences = unfinished_sequences & ~(next_tokens == eos_token_id) + + hidden_states = self.model.embed_tokens(next_tokens).unsqueeze(1) for k, v in inference_params.key_value_memory_dict.items(): - key_layer = v[0].permute((1, 0, 2, 3)).contiguous().cuda() - value_layer = v[1].permute((1, 0, 2, 3)).contiguous().cuda() + key_layer = v[0].contiguous().cuda() + value_layer = v[1].contiguous().cuda() inference_params.key_value_memory_dict[k] = (key_layer, value_layer) return hidden_states, output_tokens @@ -271,7 +272,9 @@ def generate( # lengths is a tensor of shape [s] representing lengths of sequences. lengths = torch.sum(input_ids.ne(generation_config.pad_token_id), dim=-1).squeeze() - inference_params.seq_len = lengths.to(torch.int32).clone().cuda() + inference_params.seq_len = torch.zeros_like(lengths).to(torch.int32).clone().cuda() + inference_params.incoming_seq_len = lengths.to(torch.int32).clone().cuda() + inference_params.max_incoming_seq_len = input_ids.shape[1] TEGemmaForCausalLM._padding_to_beginning(input_ids, lengths) @@ -284,6 +287,12 @@ def generate( unfinished_sequences ) + + + inference_params.seq_len.copy_(inference_params.incoming_seq_len) + inference_params.incoming_seq_len.copy_(torch.ones_like(inference_params.incoming_seq_len)) + inference_params.max_incoming_seq_len = 1 + generator = GemmaGenerator( lm_head=self.lm_head, @@ -316,6 +325,7 @@ def generate( next_tokens = graphed_generator(*args) if use_cuda_graphs else generator(*args) output_tokens.append(next_tokens.clone()) + result = torch.cat((input_ids, torch.stack(output_tokens).permute([1, 0])), dim=1) return result diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index bfba0d5e29..74b2707485 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -3266,41 +3266,44 @@ def forward( """ inference_params.seq_len - lengths of processed sequences """ - bs = query_layer.shape[0] + batch_size = query_layer.shape[0] + tex.attention_copy( inference_key_memory, inference_params.seq_len, inference_params.incoming_seq_len, key_layer, - inference_params.max_incoming_seqence_length, + inference_params.max_incoming_seq_len, inference_params.max_sequence_length, - bs, + batch_size, self.channels) tex.attention_copy( inference_value_memory, inference_params.seq_len, inference_params.incoming_seq_len, value_layer, - inference_params.max_incoming_seqence_length, + inference_params.max_incoming_seq_len, inference_params.max_sequence_length, - bs, + batch_size, self.channels) + - max_seqlen_q = inference_params.max_incoming_seqence_length + max_seqlen_q = inference_params.max_incoming_seq_len max_seqlen_kv = inference_params.max_sequence_length - cu_seqlens_q = self.alloc(bs + 1, dtype=torch.int32, device="cuda") - cu_seqlens_kv = self.alloc(bs + 1, dtype=torch.int32, device="cuda") - seq_offsets_q = self.alloc(bs + 1, dtype=torch.int32, device="cuda") - seq_offsets_k = self.alloc(bs + 1, dtype=torch.int32, device="cuda") - seq_offsets_v = self.alloc(bs + 1, dtype=torch.int32, device="cuda") + cu_seqlens_q = self.alloc(batch_size + 1, dtype=torch.int32, device="cuda") + cu_seqlens_kv = self.alloc(batch_size + 1, dtype=torch.int32, device="cuda") + seq_offsets_q = self.alloc(batch_size + 1, dtype=torch.int32, device="cuda") + seq_offsets_k = self.alloc(batch_size + 1, dtype=torch.int32, device="cuda") + seq_offsets_v = self.alloc(batch_size + 1, dtype=torch.int32, device="cuda") - cu_seqlens_q.copy_(torch.arange(0, bs + 1, dtype=torch.int32, device="cuda")) - cu_seqlens_kv[1:].copy_(torch.cumsum(inference_params.seq_len + 1, dim=0)) + cu_seqlens_q[1:].copy_(torch.cumsum(inference_params.incoming_seq_len, dim=0)) + cu_seqlens_kv[1:].copy_(torch.cumsum(inference_params.seq_len + inference_params.incoming_seq_len, dim=0)) - seq_offsets_q.copy_(torch.arange(0, bs + 1, dtype=torch.int32, device="cuda") * self.channels * max_seqlen_q) - seq_offsets_k.copy_(torch.arange(0, bs + 1, dtype=torch.int32, device="cuda") * self.channels * max_seqlen_kv) + seq_offsets_q.copy_(torch.arange(0, batch_size + 1, dtype=torch.int32, device="cuda") * self.channels * max_seqlen_q) + seq_offsets_k.copy_(torch.arange(0, batch_size + 1, dtype=torch.int32, device="cuda") * self.channels * max_seqlen_kv) seq_offsets_v.copy_(seq_offsets_k) + query_layer = query_layer.view(-1, query_layer.shape[2], query_layer.shape[3]).to(torch.bfloat16) key_layer = inference_key_memory.view(-1, inference_key_memory.shape[2], inference_key_memory.shape[3]).to(torch.bfloat16) @@ -3601,7 +3604,6 @@ def forward( cp_global_ranks=self.cp_global_ranks, cp_stream=self.cp_stream, is_first_microbatch=is_first_microbatch) - out = self.fused_attention( query_layer, key_layer, @@ -4173,12 +4175,20 @@ def forward( if self.layer_number not in inference_params.key_value_memory_dict: inf_max_seq_len = inference_params.max_sequence_length inf_max_batch_size = inference_params.max_batch_size - inference_key_memory = self._allocate_memory( - inf_max_seq_len, inf_max_batch_size, hidden_states.dtype - ) - inference_value_memory = self._allocate_memory( - inf_max_seq_len, inf_max_batch_size, hidden_states.dtype - ) + if self.qkv_format == "thd": + inference_key_memory = self._allocate_memory( + inf_max_batch_size, inf_max_seq_len, hidden_states.dtype + ) + inference_value_memory = self._allocate_memory( + inf_max_batch_size, inf_max_seq_len, hidden_states.dtype + ) + else: + inference_key_memory = self._allocate_memory( + inf_max_seq_len, inf_max_batch_size, hidden_states.dtype + ) + inference_value_memory = self._allocate_memory( + inf_max_seq_len, inf_max_batch_size, hidden_states.dtype + ) inference_params.key_value_memory_dict[self.layer_number] = ( inference_key_memory, inference_value_memory, @@ -4333,30 +4343,37 @@ def forward( rotary_pos_emb = ((rotary_pos_emb,) * 2) if self.qkv_format == "thd" and inference_params is not None: - b, d = query_layer.shape[0], query_layer.shape[-1] + key_layer = key_layer.contiguous() + query_layer = query_layer.contiguous() + batch_size, hidden_dim = query_layer.shape[0], query_layer.shape[-1] - q_pos_emb = self.alloc((b, 1, 1, d), torch.float32, "cuda") - k_pos_emb = self.alloc((b, 1, 1, d), torch.float32, "cuda") + q_pos_emb = self.alloc((batch_size, inference_params.max_incoming_seq_len, 1, hidden_dim), torch.float32, "cuda") + k_pos_emb = self.alloc((batch_size, inference_params.max_incoming_seq_len, 1, hidden_dim), torch.float32, "cuda") q_freq, k_freq = rotary_pos_emb - + tex.get_values( - q_freq, - inference_params.seq_len + 1, - inference_params.incoming_seq_len, - q_pos_emb, - d, - b + q_freq, # [max_pos_emb, s, 1, d] + inference_params.seq_len, # [b] + inference_params.incoming_seq_len, # [b] + q_pos_emb, # [b, 1, 1, d] + inference_params.max_incoming_seq_len, + batch_size, + hidden_dim ) tex.get_values( k_freq, - inference_params.seq_len + 1, + inference_params.seq_len, inference_params.incoming_seq_len, k_pos_emb, - d, - b + inference_params.max_incoming_seq_len, + batch_size, + hidden_dim ) - query_layer.copy_(apply_rotary_pos_emb(query_layer, q_pos_emb, "bshd", fused=True)) - key_layer.copy_(apply_rotary_pos_emb(key_layer, k_pos_emb, "bshd", fused=True)) + + for i in range(batch_size): + key_layer[i,].copy_(apply_rotary_pos_emb(key_layer[i,:].unsqueeze(0), k_pos_emb[i,:].unsqueeze(1), "bshd", fused=True)[0,:]) + query_layer[i,:].copy_(apply_rotary_pos_emb(query_layer[i,:].unsqueeze(0), q_pos_emb[i,:].unsqueeze(1), "bshd", fused=True)[0,:]) + else: q_pos_emb, k_pos_emb = rotary_pos_emb From 79af381cb71d5dacbaf3fd6e43a554d27568f716 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 9 May 2024 10:21:17 -0700 Subject: [PATCH 098/244] .h file ifx Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/csrc/extensions.h | 4 +- .../pytorch/csrc/extensions/attention.cu | 66 +++++++++++-------- 2 files changed, 40 insertions(+), 30 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index f49a68cd50..916908d3ef 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -182,8 +182,8 @@ std::vector fused_attn_bwd( at::Tensor fa_prepare_fwd(at::Tensor qkvi); at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v); -void attention_copy(torch::Tensor A, torch::Tensor seq_len, torch::Tensor B, int max_seq_len, int b, int s); -void get_values(torch::Tensor A, torch::Tensor seq_len, torch::Tensor B, int d, int b); +void attention_copy(torch::Tensor A, torch::Tensor seq_len, torch::Tensor incoming_seq_len, torch::Tensor B, int max_incoming_seq_len, int max_seq_len, int b, int s); +void get_values(torch::Tensor A, torch::Tensor seq_len, torch::Tensor incoming_seq_len, torch::Tensor B, int max_incoming_seq_len, int b, int d); /*************************************************************************************************** * GEMM diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index 5637166753..9be4fd3d35 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -1610,58 +1610,68 @@ at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v) { } +// Kernel used to update KV chache when attention layout is "thd". extern "C" -__global__ void attn_copy(__nv_bfloat16* A, int* seq_len, __nv_bfloat16* B, int max_seq_len, int b, int s) { +__global__ void attention_copy_kernel( + __nv_bfloat16* cache_tensor, + int* seq_len, + int* incoming_seq_len, + __nv_bfloat16* hidden_tensor, + int max_incoming_seq_len, + int max_seq_len, + int b, + int s + ) { for(int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { - int per_block = s / blockDim.x; - int remainder = s % blockDim.x; - int copy_block_offset_begin = per_block * threadIdx.x + min(threadIdx.x, remainder); - + int to_copy = s * incoming_seq_len[batch_idx]; int offset = seq_len[batch_idx]; - __nv_bfloat16* begin_A_copy = A + max_seq_len * s * batch_idx + s * offset; - __nv_bfloat16* begin_B_copy = B + s * batch_idx; + __nv_bfloat16* begin_cache_copy = cache_tensor + max_seq_len * s * batch_idx + s * offset; + __nv_bfloat16* begin_hidden_copy = hidden_tensor + s * batch_idx * max_incoming_seq_len; - int limit = copy_block_offset_begin + per_block + (threadIdx.x < remainder ? 1 : 0); - - for(int i = copy_block_offset_begin; i < limit; i++) { - *(begin_A_copy + i) = *(begin_B_copy + i); + for(int i = threadIdx.x; i < to_copy; i += blockDim.x) { + *(begin_cache_copy + i) = *(begin_hidden_copy + i); } } } +// Kernel used in positional encoding application. extern "C" -__global__ void gv(float* src, int* seq_len, float* dst, int d, int b) { +__global__ void get_values_kernel( + float* src, + int* seq_len, + int* incoming_seq_len, + float* dst, + int max_incoming_seq_len, + int b, + int d + ) + { // src [s, 1, 1, d] // dst [b] for(int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { - int per_block = d / blockDim.x; - int remainder = d % blockDim.x; - int copy_block_offset_begin = per_block * threadIdx.x + min(threadIdx.x, remainder); - + int to_copy = d * incoming_seq_len[batch_idx]; int offset = seq_len[batch_idx]; float* begin_src_copy = src + d * offset; - float* begin_dst_copy = dst + d * batch_idx; + float* begin_dst_copy = dst + d * max_incoming_seq_len * batch_idx; - int limit = copy_block_offset_begin + per_block + (threadIdx.x < remainder ? 1 : 0); - - for(int i = copy_block_offset_begin; i < limit; i++) { + for(int i = threadIdx.x; i < to_copy; i += blockDim.x) { *(begin_dst_copy + i) = *(begin_src_copy + i); } } } - - -void attention_copy(torch::Tensor A, torch::Tensor seq_len, torch::Tensor B, int max_seq_len, int b, int s) { - attn_copy<<<16, 32, 0, at::cuda::getCurrentCUDAStream()>>>(reinterpret_cast<__nv_bfloat16*>(A.data_ptr()), +void attention_copy(torch::Tensor A, torch::Tensor seq_len, torch::Tensor incoming_seq_len, torch::Tensor B, int max_incoming_seq_len, int max_seq_len, int b, int s) { + attention_copy_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>(reinterpret_cast<__nv_bfloat16*>(A.data_ptr()), seq_len.data_ptr(), - reinterpret_cast<__nv_bfloat16*>(B.data_ptr()), max_seq_len, b, s); + incoming_seq_len.data_ptr(), + reinterpret_cast<__nv_bfloat16*>(B.data_ptr()), max_incoming_seq_len, max_seq_len, b, s); } -void get_values(torch::Tensor A, torch::Tensor seq_len, torch::Tensor B, int d, int b) { - gv<<<16, 32, 0, at::cuda::getCurrentCUDAStream()>>>(A.data_ptr(), +void get_values(torch::Tensor A, torch::Tensor seq_len, torch::Tensor incoming_seq_len, torch::Tensor B, int max_incoming_seq_len, int b, int d) { + get_values_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>(A.data_ptr(), seq_len.data_ptr(), - B.data_ptr(), d, b); + incoming_seq_len.data_ptr(), + B.data_ptr(), max_incoming_seq_len, b, d); } From ef70a25ccf31d64032998b58c747ccfa508cbe24 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 9 May 2024 10:22:22 -0700 Subject: [PATCH 099/244] generate_sample_text() add Signed-off-by: Pawel Gadzinski --- docs/examples/te_gemma/utils.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/docs/examples/te_gemma/utils.py b/docs/examples/te_gemma/utils.py index 6ccce22f9a..a52e8daaa9 100644 --- a/docs/examples/te_gemma/utils.py +++ b/docs/examples/te_gemma/utils.py @@ -194,3 +194,19 @@ def restart_jupyter_notebook(): import warnings warnings.simplefilter("ignore") torch.set_warn_always(False) + +def generate_sample_text(model): + tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name) + inputs = tokenizer(["Some random initial str ", "Another string ... "] * 32, return_tensors="pt", padding=True) + + inputs['input_ids'] = inputs['input_ids'].cuda() + inputs['attention_mask'] = inputs['attention_mask'].cuda() + + outputs = model.generate(**inputs, max_new_tokens=100) + generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True) + for text in generated_texts: + print(text) + print("=" * 100) + +def benchmark_generation(model): + pass \ No newline at end of file From 53a50fb6967978464d31f73dfecbef47bb88ed87 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 9 May 2024 15:30:40 -0700 Subject: [PATCH 100/244] Removed files Signed-off-by: Pawel Gadzinski --- docs/examples/te_gemma/eval_bf16.py | 42 ------------ docs/examples/te_gemma/eval_fp8.py | 64 ------------------- docs/examples/te_gemma/generate.py | 53 --------------- docs/examples/te_gemma/generate_baseline.py | 55 ---------------- .../examples/te_gemma/generate_cuda_graphs.py | 63 ------------------ docs/examples/te_gemma/generate_fp8.py | 63 ------------------ 6 files changed, 340 deletions(-) delete mode 100644 docs/examples/te_gemma/eval_bf16.py delete mode 100644 docs/examples/te_gemma/eval_fp8.py delete mode 100644 docs/examples/te_gemma/generate.py delete mode 100644 docs/examples/te_gemma/generate_baseline.py delete mode 100644 docs/examples/te_gemma/generate_cuda_graphs.py delete mode 100644 docs/examples/te_gemma/generate_fp8.py diff --git a/docs/examples/te_gemma/eval_bf16.py b/docs/examples/te_gemma/eval_bf16.py deleted file mode 100644 index bfeeb8fa45..0000000000 --- a/docs/examples/te_gemma/eval_bf16.py +++ /dev/null @@ -1,42 +0,0 @@ -from utils import * -import torch -from tqdm import tqdm # For progress bar - -# Default hyperparams, also defined in `utils.py` in class `Hyperparameters` -## !!! `model_name` attr must point to the location of the model weights !!! -## Weights can be downloaded from: https://llama.meta.com/llama-downloads/ -hyperparams.model_name = "../../../../gemma-weights" # <== Add model weight location here e.g. "/path/to/downloaded/llama/weights" -hyperparams.fuse_qkv_params = True - -# Init the model and accelerator wrapper -model = init_te_gemma_model(hyperparams, fp8_model_init=False).cuda() - -dataset = load_dataset(hyperparams.dataset_name, split="train") -tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name) -accelerator = Accelerator( - log_with="wandb", - gradient_accumulation_steps=hyperparams.gradient_accumulation_steps, - mixed_precision=hyperparams.mixed_precision, - kwargs_handlers=[FP8RecipeKwargs(backend="te")] - ) -train_dataloader = enumerate(get_dataloaders(accelerator, hyperparams)) - -model.eval() # Set the model to evaluation mode -total_correct = 0 -total_samples = 0 - -with torch.no_grad(): # No need to compute gradients during evaluation - for _, batch in tqdm(train_dataloader, desc="Evaluating"): - input_ids = batch["input_ids"].cuda() - - labels = input_ids[:, 1:].contiguous() - input_ids = input_ids[:, :-1].contiguous() - outputs = model(input_ids=input_ids, labels=labels, use_cache=False) - - predictions = torch.argmax(outputs.logits, dim=-1) - - total_correct += (predictions == labels).sum().item() - total_samples += labels.numel() - -accuracy = total_correct / total_samples -print(f"Accuraccy = {accuracy}") \ No newline at end of file diff --git a/docs/examples/te_gemma/eval_fp8.py b/docs/examples/te_gemma/eval_fp8.py deleted file mode 100644 index 99948c2be9..0000000000 --- a/docs/examples/te_gemma/eval_fp8.py +++ /dev/null @@ -1,64 +0,0 @@ -from utils import * -import torch -from tqdm import tqdm # For progress bar -import transformer_engine.pytorch as te - - -# Import necessary packages and methods -from utils import * -import accelerate - -from transformer_engine.pytorch import fp8_model_init -from transformer_engine.common.recipe import Format, DelayedScaling - -# Default hyperparams, also defined in `utils.py` in class `Hyperparameters` -## !!! `model_name` attr must point to the location of the model weights !!! -## Weights can be downloaded from: https://llama.meta.com/llama-downloads/ - -hyperparams.model_name = "../../../../gemma-weights" -hyperparams.fuse_qkv_params = True -model = init_te_gemma_model(hyperparams, fp8_model_init=True).cuda() - - -print("Loading model") -model_state_dict = torch.load('model_fp8_state_dict.pth') -model.load_state_dict(model_state_dict) -print("Model loaded") - - -dataset = load_dataset(hyperparams.dataset_name, split="train") -tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name) - -accelerator = Accelerator( - log_with="wandb", - gradient_accumulation_steps=hyperparams.gradient_accumulation_steps, - mixed_precision=hyperparams.mixed_precision, - kwargs_handlers=[FP8RecipeKwargs(backend="te")] - ) -train_dataloader = enumerate(get_dataloaders(accelerator, hyperparams)) - - -model.eval() # Set the model to evaluation mode -total_correct = 0 -total_samples = 0 - -fp8_format = Format.HYBRID -fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max") -with torch.no_grad(): # No need to compute gradients during evaluation - with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): - for _, batch in tqdm(train_dataloader, desc="Evaluating"): - input_ids = batch["input_ids"].cuda() - - labels = input_ids[:, 1:].contiguous() - input_ids = input_ids[:, :-1].contiguous() - outputs = model(input_ids=input_ids, labels=labels, use_cache=False) - - predictions = torch.argmax(outputs.logits, dim=-1) - - total_correct += (predictions == labels).sum().item() - total_samples += labels.numel() - -accuracy = total_correct / total_samples -print(f"Accuraccy = {accuracy}") - - diff --git a/docs/examples/te_gemma/generate.py b/docs/examples/te_gemma/generate.py deleted file mode 100644 index ae63777438..0000000000 --- a/docs/examples/te_gemma/generate.py +++ /dev/null @@ -1,53 +0,0 @@ -# Restart the notebook (to flush the GPU memory) -from utils import restart_jupyter_notebook -#restart_jupyter_notebook() - - -# Import necessary packages and methods -from utils import * -import accelerate - -# Default hyperparams, also defined in `utils.py` in class `Hyperparameters` -## !!! `model_name` attr must point to the location of the model weights !!! -## Weights can be downloaded from: https://llama.meta.com/llama-downloads/ -hyperparams.model_name = "../../../../gemma-weights" # <== Add model weight location here e.g. "/path/to/downloaded/llama/weights" -hyperparams.mixed_precision = "bf16" -hyperparams.fuse_qkv_params = False - -# Init the model and accelerator wrapper -model = init_te_gemma_model(hyperparams).cuda() -#accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams) - -model = model.to(torch.bfloat16).cuda() - -tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name) -inputs = tokenizer(["I love when ", "I "] * 32, return_tensors="pt", padding=True) - -inputs['input_ids'] = inputs['input_ids'].cuda() -inputs['attention_mask'] = inputs['attention_mask'].cuda() - -import time - -# Początek pomiaru czasu -start_time = time.time() - -outputs = model.generate( - **inputs, - max_new_tokens=40 -) - -# Koniec pomiaru czasu -end_time = time.time() - -# Obliczamy czas trwania operacji -duration = end_time - start_time -print(f"Generation time: {duration} seconds") - - -# Decode the output tensor to text -generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True) - -# Display the generated text -for text in generated_texts: - print(text) - print("=" * 100) \ No newline at end of file diff --git a/docs/examples/te_gemma/generate_baseline.py b/docs/examples/te_gemma/generate_baseline.py deleted file mode 100644 index cb6fa86bf0..0000000000 --- a/docs/examples/te_gemma/generate_baseline.py +++ /dev/null @@ -1,55 +0,0 @@ -# Restart the notebook (to flush the GPU memory) -from utils import restart_jupyter_notebook -#restart_jupyter_notebook() - - -# Import necessary packages and methods -from utils import * -import torch - - -# Default hyperparams, also defined in `utils.py` in class `Hyperparameters` -## !!! `model_name` attr must point to the location of the model weights !!! -## Weights can be downloaded from: https://llama.meta.com/llama-downloads/ -hyperparams.model_name = "../../../../gemma-weights" # <== Add model weight location here e.g. "/path/to/downloaded/llama/weights" -hyperparams.mixed_precision = "bf16" - - -# Init the model and accelerator wrapper -model = init_baseline_model(hyperparams).cuda() -model = model.to(torch.bfloat16) - -tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name) -inputs = tokenizer(["Some random initial str ", "Another string ... "] * 32, return_tensors="pt", padding=True) - -inputs['input_ids'] = inputs['input_ids'].cuda() -inputs['attention_mask'] = inputs['attention_mask'].cuda() - - -# Początek pomiaru czasu -start_time = time.time() - -import pdb -pdb.set_trace() -outputs = model.generate( - **inputs, - max_new_tokens=1000 -) - -# Koniec pomiaru czasu -end_time = time.time() - -# Obliczamy czas trwania operacji -duration = end_time - start_time - - - -print(duration) - -# Decode the output tensor to text -generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True) - -# Display the generated text -for text in generated_texts: - print(text) - print("=" * 100) \ No newline at end of file diff --git a/docs/examples/te_gemma/generate_cuda_graphs.py b/docs/examples/te_gemma/generate_cuda_graphs.py deleted file mode 100644 index 694dabfd91..0000000000 --- a/docs/examples/te_gemma/generate_cuda_graphs.py +++ /dev/null @@ -1,63 +0,0 @@ -import os - -os.environ['CUDNN_LOGLEVEL_DBG'] = '3' -os.environ['CUDNN_LOGDEST_DBG'] = 'backlog.txt' -#Restart the notebook (to flush the GPU memory) -from utils import restart_jupyter_notebook -#restart_jupyter_notebook() -import transformer_engine.pytorch as te - -from torch.cuda.amp import autocast - - -# Import necessary packages and methods -from utils import * - -from transformer_engine.pytorch import fp8_model_init -from transformer_engine.common.recipe import Format, DelayedScaling - - -hyperparams.model_name = "../../../../gemma-weights" -hyperparams.fuse_qkv_params = True -model = init_te_gemma_model(hyperparams, fp8_model_init=True, qkv_format="thd").cuda() - -print("Loading model") -model_state_dict = torch.load('model_fp8_state_dict.pth') -model.load_state_dict(model_state_dict) -print("Model loaded") - -tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name) -inputs = tokenizer(["Some random initial str ", "Another string ... "] * 32, return_tensors="pt", padding=True) - -inputs['input_ids'] = inputs['input_ids'].cuda() -inputs['attention_mask'] = inputs['attention_mask'].cuda() - -import time - - - -start_time = time.time() - -fp8_format = Format.HYBRID -fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") -torch.manual_seed(1234) -with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): - with autocast(dtype=torch.bfloat16, cache_enabled=False): - with torch.no_grad(): - model.eval() - outputs = model.generate( - **inputs, - max_new_tokens=1000, - use_cuda_graphs=True - ) - - -end_time = time.time() -duration = end_time - start_time - -generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True) -for text in generated_texts[:12]: - print("-" * 50) - print(text) - -print(f"Duration = {duration}") diff --git a/docs/examples/te_gemma/generate_fp8.py b/docs/examples/te_gemma/generate_fp8.py deleted file mode 100644 index 3ff07adf18..0000000000 --- a/docs/examples/te_gemma/generate_fp8.py +++ /dev/null @@ -1,63 +0,0 @@ -import os - -os.environ['CUDNN_LOGLEVEL_DBG'] = '3' -os.environ['CUDNN_LOGDEST_DBG'] = 'backlog.txt' -#Restart the notebook (to flush the GPU memory) -from utils import restart_jupyter_notebook -#restart_jupyter_notebook() -import transformer_engine.pytorch as te - -from torch.cuda.amp import autocast - - -# Import necessary packages and methods -from utils import * - -from transformer_engine.pytorch import fp8_model_init -from transformer_engine.common.recipe import Format, DelayedScaling - - -hyperparams.model_name = "../../../../gemma-weights" -hyperparams.fuse_qkv_params = True -model = init_te_gemma_model(hyperparams, fp8_model_init=True, qkv_format="thd").cuda() - -print("Loading model") -model_state_dict = torch.load('model_fp8_state_dict.pth') -model.load_state_dict(model_state_dict) -print("Model loaded") - -tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name) -inputs = tokenizer(["Some random initial str ", "Another string ... "] * 32, return_tensors="pt", padding=True) - -inputs['input_ids'] = inputs['input_ids'].cuda() -inputs['attention_mask'] = inputs['attention_mask'].cuda() - -import time - - - -start_time = time.time() - -fp8_format = Format.HYBRID -fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") -torch.manual_seed(1234) -with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): - with autocast(dtype=torch.bfloat16, cache_enabled=False): - with torch.no_grad(): - model.eval() - outputs = model.generate( - **inputs, - max_new_tokens=1000, - use_cuda_graphs=False - ) - - -end_time = time.time() -duration = end_time - start_time - -generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True) -for text in generated_texts[:12]: - print("-" * 50) - print(text) - -print(f"Duration = {duration}") From 9dbbdd453fdc2ecea1b93d977ed9c867a60936a4 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 9 May 2024 15:32:09 -0700 Subject: [PATCH 101/244] Removed files Signed-off-by: Pawel Gadzinski --- docs/examples/te_gemma/generate_convert.py | 61 ---------------------- 1 file changed, 61 deletions(-) delete mode 100644 docs/examples/te_gemma/generate_convert.py diff --git a/docs/examples/te_gemma/generate_convert.py b/docs/examples/te_gemma/generate_convert.py deleted file mode 100644 index 3bd9250b7d..0000000000 --- a/docs/examples/te_gemma/generate_convert.py +++ /dev/null @@ -1,61 +0,0 @@ -# Import necessary packages and methods -import transformer_engine.pytorch as te -from utils import * -import accelerate -from transformer_engine.pytorch import fp8_model_init -from transformer_engine.common.recipe import Format, DelayedScaling -import torch - - -hyperparams.model_name = "../../../../gemma-weights" -hyperparams.fuse_qkv_params = True -model = init_te_gemma_model(hyperparams, fp8_model_init=False).cuda() -model = model.to(torch.bfloat16) - - -accelerator = Accelerator( - log_with="wandb", - gradient_accumulation_steps=hyperparams.gradient_accumulation_steps, - mixed_precision=hyperparams.mixed_precision, - kwargs_handlers=[FP8RecipeKwargs(backend="te")] - ) -train_dataloader = get_dataloaders(accelerator, hyperparams) - -tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name) - -print("Calibration started") -with te.fp8_autocast(enabled=False, calibrating=True): - model.train() - train_dataloader = enumerate(train_dataloader) - - for i in range(100): - step, batch = next(train_dataloader) - batch["input_ids"] = batch["input_ids"].cuda() - outputs = model.generate( - **batch, - max_new_tokens=10 - ) - generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True) - print(generated_texts[0][:50]) -print("calibration_finished") - -print("scale_fwd computation started") -with te.fp8_autocast(enabled=True): - for i in range(10): - step, batch = next(train_dataloader) - batch["input_ids"] = batch["input_ids"].cuda() - outputs = model.generate( - **batch, - max_new_tokens=1 - ) -print("scale_fwd_computation ended") - -print("Casting weights...") -model_fp8 = init_te_gemma_model(hyperparams, fp8_model_init=True).cuda() -model_fp8.load_state_dict(model.state_dict()) -print("Weights casted") - - -print("Saving model...") -torch.save(model_fp8.state_dict(), 'model_fp8_state_dict.pth') -print("Model saved!") \ No newline at end of file From 2e3bebda5348b5bab3d0eac5fffa2c5e7893a00b Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 9 May 2024 15:32:42 -0700 Subject: [PATCH 102/244] Removed files Signed-off-by: Pawel Gadzinski --- docs/examples/te_gemma/tutorial_fp8_model_init.ipynb | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 docs/examples/te_gemma/tutorial_fp8_model_init.ipynb diff --git a/docs/examples/te_gemma/tutorial_fp8_model_init.ipynb b/docs/examples/te_gemma/tutorial_fp8_model_init.ipynb deleted file mode 100644 index e69de29bb2..0000000000 From b12416b9813b2629280698d3441e99f98417d10c Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 9 May 2024 15:34:37 -0700 Subject: [PATCH 103/244] whitespace fix Signed-off-by: Pawel Gadzinski --- tests/pytorch/test_torch_save_load.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/pytorch/test_torch_save_load.py b/tests/pytorch/test_torch_save_load.py index e29a986dd5..85ec7685b3 100644 --- a/tests/pytorch/test_torch_save_load.py +++ b/tests/pytorch/test_torch_save_load.py @@ -65,7 +65,6 @@ def __init__(self, precision, use_bias): self.inp_type = tex.DType.kFloat8E4M3 self.weights_type = tex.DType.kFloat8E4M3 self.outp_type = precision - def forward(self, inp, weight): inp_fp8 = cast_to_fp8( From 306b94b406f518612f22692f961a6911131d5a02 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 9 May 2024 15:40:37 -0700 Subject: [PATCH 104/244] Attention pictures Signed-off-by: Pawel Gadzinski --- docs/examples/te_gemma/media/atn1.png | Bin 0 -> 4602 bytes docs/examples/te_gemma/media/atn2.png | Bin 0 -> 4561 bytes docs/examples/te_gemma/media/atn3.png | Bin 0 -> 2487 bytes 3 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 docs/examples/te_gemma/media/atn1.png create mode 100644 docs/examples/te_gemma/media/atn2.png create mode 100644 docs/examples/te_gemma/media/atn3.png diff --git a/docs/examples/te_gemma/media/atn1.png b/docs/examples/te_gemma/media/atn1.png new file mode 100644 index 0000000000000000000000000000000000000000..4c3f5e2fa5a2d56dfed8137e407cfd671405f9e1 GIT binary patch literal 4602 zcmeHL`CC(08m6TktbmL<9g3jTg3!v=2@*=!T0s$k#3>50t85}mG-1mE(Q%McL?9F} zgoII=WE_MD0b*F|0vNVbq9P^)*@*!HNf1JS%(*(nA>FB&;zwLM6AKvU`F1WmhJ?wjO8~72o#e4&v zO;V4XP7OrIrDmMP#OlPSpp#?mu+f;<*c9w}bgEpp(?dsR=XGE2uTExGE{FoIN9bjB zkFVn6Hk`ITvYxzj?64X6%ZuK&?|^m33%#aFa%03S}$!cfeA!(OyDbaD9y`I1N0yeApH|eYw93^!@Gkz)_tO-^7CU z)5kE-{^6VnXt&oJfHvo6Gyl`eXZoFM6ZtNeCcK$URz-!*(m!P1mmpa})%Bn~`}PgC zj%_X)7#JLK&k5#_Kgz{!ZKzfFb|$M4$xJk0cQ`|$fkEIR&R7&j#uzk^LC zEb^>7w^rDW2t>t$2i=Qu?hp?<3v+fZr>3-g7S^1UfJfHY+SyI7El={M%fetV*pRC7 zg^9Vj;N{EHb*+mCXJ?ISLn(;IU|?X-mDW0xjg5_oscB!lhmt~6Rhvg}C*kW<7N@#o zP!Ay;mEE`$HE69qN_OFMzN{CrvV0Q`@Aj*gC$ zPPZyaaU^Wy*|%k($1zs?3*=mA-5lSb5-w?SL7IvPU?-GPq!c&OxLUk6S774e21p<% zBqCyfmxslrEb-h_Mf%(*eSLlD3O*zxBmuvjeLIp}USf@k!{ZUA@XU|u%pKa64l`LS z`P%hMlO3*HyDh(H`;YbtwiF@=bW6XX@dvU+_nwFu`#DMMgctlK@#%I zJC+y*JK*UbG!8jrFXt0CWl@42?fN86B$B)UK2RzudLJ2iR^GaGi@1obN(5{N=-F*m zjK#gaq8H9#e6!>nq&lnz0$?*RbE^;_J)K%N~GYrogHq7h7a? zZL{ujh@&e#Ngg0>$GM_`u(3A}7kt1CI9w1kBs@ix;mSThLt49LEcSO~XE9-;{M>5C0KW2&!{Jb<6@Kzg z0v{7-6MEqwzur>cx8uc98NA}ykY9%a6~+JI+7;$-WA9_-ZVK@dOuJmAEe>QMp}d%x znZc;$o~E{r5*HHf#o2djYh%xy`=mY6p{AxLKrntPeeTCw0A7IBoZT6x=$R9Fl^$6L zBX`w_?6;gMXb2{g33GL&s?};U58{0&{~8;=%gf74dUCF6r~9 z2kGu22t!nApF)ALk7#LxFl3z>7iXn(RPysIu>rN`?=afTh&8_Z5;Ofw2M)BW6gkyV zOG#Ze^<~#R9qD{yD{AD>L1;k2+a!E@P(y-{*)I^VWtU9zrz2X^V|$zcjd758(n#WkqpmcRLhfsSfY%ABqf9cZ zM55dNSaX6jL#{+Fs%r=eASy>x6+dwow{vp=#2XSm0fK~5z`0SIZ-Q}45A?D!GYbH9 z663oV~d;>|lCC1Itpt%#-?QkzIFGy(u9s-IFp{uKF zVrJIAG8`qh3a&!Z-6;PeDObYVd1&*143CM8jUCm-^mzRpNUdMKdiXyv{kK)y)q;-o zP_p(oB9L}8zH!ihNOQy?#2=x3EnZomd_OE>C=2&;^VglqRdo4uuIJiWB>f8rDt*NDpe{ydtI6&zF!X+jX%t}d-v|t%lc>+?2JEz ztFveKgA|ozi@VR^K#7O$kPeJoO7Kyf477PAeKYaAlImU_5C)%KXVsyBeDg>8$4R>x zVui|1nkr73>v_(o2{8-v2S+(5${@A@Jj&ZzmkFh*m94>ZUgqm3gu63K7sfh0J2$2s zl=B<`+?|mIR2HYDr2@#r>o;ydzLu4jbFj06X0GYBKvY5S9v>fX17)8enLBD#ra1(p zludA5Ols;ONacY3;^JcHc#kL?0gA3Y&3(druAkyasjN(zV1%J@I2U7w$o}pE(^0$Z z_alkRUE%F`C+E@BKoGxp&>PH-958ei7lASsV)aP8lwoUY8v|k>;-Z18HikQP%mDSG zsTRA;poGhlJ%A;llDPutXb{$B^cWowBlK=ia4>dJ%meH>oyI0afUw2`0h^eZ$XjR@ zwMDHaq@>tGo5#blSP7B0fMaP3)!jdR4A?rU&JO@U|0O>V{Lc{IG8W`qLkfIu83Ipq Oe2*Z#>kgm!%Rd3Ym~jyR literal 0 HcmV?d00001 diff --git a/docs/examples/te_gemma/media/atn2.png b/docs/examples/te_gemma/media/atn2.png new file mode 100644 index 0000000000000000000000000000000000000000..7e9c471511285a917b8d0a8014ac84e7f7c5f7be GIT binary patch literal 4561 zcmdUzdsvcp8^>+!tu|M#9de~fE7H2GY)YQdbuhyeYcwLkX9dtL9JujAr+o*(?4=XbcD`~KYD z`?-Ds>A(KnE$>=bSgc1JJMx8v#cHmF#VV(DYrsE8JMuSx)7ykYh|}x9Cu3doHSl|L zvfr8HQ;F#0v~&0v3rs>HE@m&`0zM`tfe@RRtX$3Xv9Q>90deHu>C0sbLGU$H8l!iP zJwFs^RNF7tG^}4&^}=xt^u%`^FMbkydvL#H(*1+)KWMoY`SJPP$9CszTcv#6a`H{; zmydASm0_4i_h;Cfz5us z{j>F;!|(hSaJ=#VyCJqc*RQl1!G+=Wj1z;n(XORC5$O!nAUB!aBU0%d~opa&;G@V+tp9bdPfF1ub4HOYrNIIvp;ZwO8wx_xftiUY8N1!!b!o zpK)akxe_mhf-W9bMlEKh&I*gY5}1{WBu{B~A!_Kd_i&9#!QFS-ua0ROi`mLjt8PlgK1~dwR zz^>+$QB}yC?8vcf*3Y&Q^a?s1j+8E>GlZGi@n(}zPfs1M&gpH>HcUVWGcTWdYlVd3 zio`0Z7uMTO${DD7JX8l2@XM>OP#suj$)emtE%9YA$Dni*=eWUi#^rc@K8$A12)8B0 zOiWTKX_6dk)cah>wkSP+Of?u~4(q2N?CXXzY$E-37@uqOr@d) zO}MZ#XWnprigmZOM>HjhXWRCcu+Y4x7yR27s+=Sv{*ZA_V|Z)8jIi0o%GYpd!&cAV zM%Pi2mlkB+3w?)mVJ1vmdGeK$nj`sZ$y1H?VX>s5+a7l(1&H}erR6gMx)U(oMbbSaQ70*+;RH?Is=gh%Ne5pZFs6!+t&_ZZ zc7-_ev={&C-`_O#%F0MehTB>zyP?%q-gm}Z(||E+Y(4PME6=uFwSo*Am-S|oBg@?H zR{G~P)i?Fpt>^$*fib&cmo$=Lb8{F?8$#0pYq+<1CtQ#I@wh)aA3gfm)IEwTXVjK5 zw1h4C&;vo7E)lM0k6d5yAg(0D{~Q*VA+ph08kUWzN+IS&pgkEXC&enKU2ik`qtBae zplNEGy_u)oC`iCsN9{XLx)xL_d_h6MpZ>?gRNyXrH+(iCVyNnP;izvjMgW4bfM$2V zM2V(K!dj*jP%Hw4sAqsSx+r>ge?S?%^l4Ebh8xk0-w83kxhg3Cey?+JptJ3U`O)yi zn1tbfwZ2wK2Ad@X;VP-G$-~+S6;CQd5&>15DM*)3-V;_PbX;;D8W;#D&o=7f1x-%? z(IhN4v6*PKb?qgCg-&MAFN2d@L+em1FJ&7|g4LF6`oOq2H$gPa_0HG$F7F1R#h;dV z;Q&Ym27b8ZGW>)tEV;xc;nzSg+o$*T*1@m#_mP8r%v_l7`Y?F%*;>dD zkBf2NY z@$z>o{iJiu{TDDQ2DU{4lnfE% z?)pr<2IS*=-QC@Psh}&*el$-!RO}LlLJhyD4W(x&Le4KOspZX==gtF=5f7e$f#EcQ zM2sh8%1f+6nnNqc(z-r9^MH6UFA>UUs*x385h8@D0%5E*^E7`_0S(HZ-^=Fsbv_tV zzVpbc==V!Sr)rRr1sPYST@E?h4VK(9kyKDvm}yxX>sneBSJF1uV`FzDz9ZgsFgSaO zYgig%riwO2wj`UGQxjYqp&H$^FXQ=wk5=5D$Z6z*;~7 z8D?rxX~WB_u+<0@aw%#S5C)D5?%c$uVS;_;i`mftKD5#(t&qCp zpR}YYK}3cp1r6KM=y6C2WDm4I-j$YSx|Wn*t@zXvtr4*iA6!MT780}gttDYb<;#iH zajvB?o)n|WbVtiV%N>HafI+brm?hOo6^-u%(Xt2!?D7;*BRx-R@}}*{2rQPqi3=#^ zkurU{JZnkTbo7r0RaI4Ox~Wo--vC~QDPDzWA(Ra*#*#{I6L<7`A&}AVhdO=a;F2-& z`0IuzB${GX>>yB0!(|@ioSd8{f7!j8)=pFWMr;@mPN2i=>(kR0k5-;UQamLGyUeu* zsHj|Y-A-9|w|b=lPybWS)YjJO#+6xxV^P_Rzi`rwmjpNhb(8CVEc_2)GeHNcup+Th zLl0*x_~zv1Jsy)cdLy*)hTfM-{ranqc<}|*aU7*~E6LA>G$vQul1dxWlp@2cwZ45o zLkTsz;_voCvb{knPpofh8rAHgaFtqsVKl8+o^rBoY_Ox_(D$A9~SiE+zWpWE+YSS8eeT?|r*s5?K;W6jr1Ap(Ec**}M)PZrf5vNp{}O&C a%agg84kr%7KLWe3K=>h#P!E0e?SBAShj~N* literal 0 HcmV?d00001 diff --git a/docs/examples/te_gemma/media/atn3.png b/docs/examples/te_gemma/media/atn3.png new file mode 100644 index 0000000000000000000000000000000000000000..16d707719e9664245fb8772d419bbdb8cbba77be GIT binary patch literal 2487 zcmd^>e^k<07QpGaN89NKJ6oG$l+HG&ZJM+iDv45!rdZM>PNt|?`J=R86af*Fnd$Uv z8ZsqxQNbF+hEf@&R3v2DQe)J#Fwr4oRfG)DK~xkF_RH-4xo6ML-}^n^^X~h8-+lMJ zd(XX}_pX12iQe$R)(=1+&<1o&WIPDu(hKx)Yh8ghlJ1@dh~E(pq7&8vPx0D}YrwZh zLDY!?EFr6aaw8t+t;sWp>QKARevG!WnbSBA75WR z7b}uTB;OvYbDJo<`fMDgR+6pzrAulL8ja?$@#?uYCBwnl&Nwt}wgP>2Rn-F;W@?tx zBbUqBTy7H$!#h-W40r)s5cvg!uzr2*5pRFiOwA|0dm+S zl}ERfqWXVUbEuvvWeM zUk~cd@#}AFOAY<`Uy3seY&<#OD%i{E zRgomk_0m%pWBNbRkeIkQxZQFasuU}e)G0^Bs+S81LbWnYTG?o5ET|7-hX>-!HL;7< z)$@3KR!0y=G%5l-$l2?eNua%%;Il!Eb;_SHe!dy>5N2_e2AngN zk81ESx$8_5SeUkt9YhlFhqT(P2>bFR((=VJNB!kRG3>^+5kN@5THmo?Ll!hM8J zV^&(>_~i~o!u!8pUYBG~N8sopmaRmUyQ`I{ftB0vOGL8>mSsVH)U>1|=oY5s2)g}O=b6$~bVSGUv3Ui0%lkBJ zGd(uhF&JyqMa|PlrTmd8D4qET(;9DU9|z~Sq3sBn+IK}<)%F(b&OAdW`3oFIJ$3O- zqCO4F*DC}%+lq-g&8wX6h7g>EAxVRIl`NkcE7CVQOr(B0_##m0Ux5CXg8$8g_&=Ud zzI1^81F4Rc5*xS(NipqCXxM`GJ-3?em&B*Vbv5}$Nn`XmejbYO`l7C^vf+5Ym87k7 z&$g0RKP`2Woi!qp$((rVg-QF;DYu5{ZRIOK$~Cmn89A!+-tKx&tYjFLJoCAZ^kJlg&=#~shfPK)F(^EGx;OA2D-a?>8M&ir+ z?~x>|0>D?5Ol12pOF?PW+0?S{d!X*BCs?Ve!Q zOK;orCKEfGKHnT=)Z3>}%XMCwZY^VT<4J3@lo(97kTvAN6A$)3j}%I;sYgf9^tM`2 zm*@xv9`=2ChF*KzN9z+P=bW?Alyl}FaZq(nDT@0^^i5}m`uX`KO$+XYrtGyp-T&8s z7GVpM{|w7cs{w#H<$CErmYWbE`C=cDNOZE@Vkuf$S_(v=YT0Zy5U~deedd`npKvg~ zD!ly0E)mIkJi=-b9)V%9Pq&QMhKH@roS(fXH)sa$e(Ic`PPk=Y7>@K~re2I^(O+(f zDzhI#U3z=~K&fR@Z3fcQmnP#+W2u_5XlFnHnhQEDyHxl#W35nT7d>fj{~KkDs)L*) z-Z@d?STzg9Vyk1TIn^lht#z`7oy|RG+tB;^b$oDe@ci>lM>W5UKLZwiyQo(L!QrgQ zp3>~1j?uf#9UWD`5s9?U2$ND$8l0p|lT<0&2)cK#>KIZ4WaHJ9i5sxMkPyYP89oYr zVK5kK5JVqeOwIRz^6>C*OKWRIS67$6zrQnzI2Oik_JSueH;w8=sllcGr4p*?;3lPgpO4((AoCoZ@pan c2XxroMyRD1rR~U7z#<48g^6rBnEv%Y0soto*8l(j literal 0 HcmV?d00001 From b8f25fd5df8f145de1c63b32c8d955759bfaca64 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 10 May 2024 14:22:02 -0700 Subject: [PATCH 105/244] temp fix Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/module/_common.py | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index ab6455649c..32885d51e2 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -112,7 +112,8 @@ def forward( dim: int, *tensors: Tuple[torch.Tensor, ...], ) -> torch.Tensor: - + #print("rrr") + #import pdb; pdb.set_trace() # Check first tensor if not tensors: raise ValueError("Attempted to concatenate 0 tensors") @@ -154,24 +155,15 @@ def forward( strides = tensors[0].stride() data_ptr_stride = strides[dim] * tensors[0].element_size() data_ptr = tensors[0].data_ptr() + tensors[0].size(dim) * data_ptr_stride + import pdb + pdb.set_trace() for tensor in tensors[1:]: - if ( - tensor.dtype != dtype - or tensor.device != device - or tensor.stride() != strides - or tensor.data_ptr() != data_ptr - ): + if True: return torch.cat(tensors, dim=dim) data_ptr += tensor.size(dim) * data_ptr_stride - # No-op concatenation out = tensors[0].new() - out.set_( - tensors[0].untyped_storage(), - tensors[0].storage_offset(), - out_shape, - strides, - ) + out.set_(tensors[0].untyped_storage(),tensors[0].storage_offset(),out_shape,strides,) out.requires_grad = any(tensor.requires_grad for tensor in tensors) return out From 394f736368da32f67b17710f9de9574544f9a875 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 10 May 2024 14:28:16 -0700 Subject: [PATCH 106/244] temp fix Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/module/_common.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index 32885d51e2..871e9d0bbf 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -154,9 +154,7 @@ def forward( device = tensors[0].device strides = tensors[0].stride() data_ptr_stride = strides[dim] * tensors[0].element_size() - data_ptr = tensors[0].data_ptr() + tensors[0].size(dim) * data_ptr_stride - import pdb - pdb.set_trace() + data_ptr = tensors[0].data_ptr() + tensors[0].size(dim) * for tensor in tensors[1:]: if True: return torch.cat(tensors, dim=dim) From eb689ce08c526d921c5fca286a82bb95835bb09d Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 10 May 2024 14:29:05 -0700 Subject: [PATCH 107/244] temp fix Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/module/_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index 871e9d0bbf..0037f84315 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -154,7 +154,7 @@ def forward( device = tensors[0].device strides = tensors[0].stride() data_ptr_stride = strides[dim] * tensors[0].element_size() - data_ptr = tensors[0].data_ptr() + tensors[0].size(dim) * + data_ptr = tensors[0].data_ptr() + tensors[0].size(dim) * data_ptr_stride for tensor in tensors[1:]: if True: return torch.cat(tensors, dim=dim) From 036ed5a4658a5f0c7321e85de135002d4f6d5719 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 10 May 2024 14:38:48 -0700 Subject: [PATCH 108/244] zero centered gamma Signed-off-by: Pawel Gadzinski --- docs/examples/te_gemma/te_gemma.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/docs/examples/te_gemma/te_gemma.py b/docs/examples/te_gemma/te_gemma.py index 7b4a3baa6d..a05e256e79 100644 --- a/docs/examples/te_gemma/te_gemma.py +++ b/docs/examples/te_gemma/te_gemma.py @@ -62,7 +62,8 @@ def __init__(self, config, layer_idx, *args, **kwargs): attn_input_format=config.qkv_format, num_gqa_groups=config.num_key_value_heads, attention_hidden_size=4096, - layer_number=(layer_idx+1) + layer_number=(layer_idx+1), + zero_centered_gamma=True ) te_rope = RotaryPositionEmbedding(256) self.te_rope_emb = te_rope(max_seq_len=config.max_position_embeddings).cuda() @@ -287,8 +288,6 @@ def generate( unfinished_sequences ) - - inference_params.seq_len.copy_(inference_params.incoming_seq_len) inference_params.incoming_seq_len.copy_(torch.ones_like(inference_params.incoming_seq_len)) inference_params.max_incoming_seq_len = 1 @@ -346,8 +345,7 @@ def replace_params(hf_state_dict, te_state_dict, config, fp8_init=False): # When loading weights into models with less number of layers, skip the # copy if the corresponding layer doesn't exist in HF model if layer_prefix + 'input_layernorm.weight' in hf_state_dict: - te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.layer_norm_weight'].copy_(1 + hf_state_dict[layer_prefix + 'input_layernorm.weight']) - + te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.layer_norm_weight'].copy_(hf_state_dict[layer_prefix + 'input_layernorm.weight']) if fp8_init: dst = te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.weight'] @@ -380,7 +378,7 @@ def replace_params(hf_state_dict, te_state_dict, config, fp8_init=False): if layer_prefix + 'self_attn.k_proj.weight' in hf_state_dict: te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.key_weight'].copy_(hf_state_dict[layer_prefix + 'self_attn.k_proj.weight']) - + if layer_prefix + 'self_attn.v_proj.weight' in hf_state_dict: te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.value_weight'].copy_(hf_state_dict[layer_prefix + 'self_attn.v_proj.weight']) @@ -389,7 +387,7 @@ def replace_params(hf_state_dict, te_state_dict, config, fp8_init=False): te_state_dict[layer_prefix + 'self_attention.proj.weight'].copy_(hf_state_dict[layer_prefix + 'self_attn.o_proj.weight']) if layer_prefix + 'post_attention_layernorm.weight' in hf_state_dict: - te_state_dict[layer_prefix + 'layernorm_mlp.layer_norm_weight'].copy_(1 + hf_state_dict[layer_prefix + 'post_attention_layernorm.weight']) + te_state_dict[layer_prefix + 'layernorm_mlp.layer_norm_weight'].copy_(hf_state_dict[layer_prefix + 'post_attention_layernorm.weight']) if layer_prefix + 'mlp.gate_proj.weight' in hf_state_dict: te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[:GATE_PROJ_SIZE].copy_(hf_state_dict[layer_prefix + 'mlp.gate_proj.weight']) From 9c7880cd3fc092509b4a9ad6079da900df47f22a Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 10 May 2024 15:05:10 -0700 Subject: [PATCH 109/244] refactor of replace_params() Signed-off-by: Pawel Gadzinski --- docs/examples/te_gemma/te_gemma.py | 91 ++++++++++-------------------- 1 file changed, 29 insertions(+), 62 deletions(-) diff --git a/docs/examples/te_gemma/te_gemma.py b/docs/examples/te_gemma/te_gemma.py index a05e256e79..f67c96100e 100644 --- a/docs/examples/te_gemma/te_gemma.py +++ b/docs/examples/te_gemma/te_gemma.py @@ -19,9 +19,8 @@ from transformer_engine.common.recipe import Format, DelayedScaling import transformers -from transformers.models.gemma.modeling_gemma import GemmaModel, GemmaForCausalLM, GemmaRMSNorm, GemmaConfig +from transformers.models.gemma.modeling_gemma import GemmaForCausalLM, GemmaConfig from transformers.modeling_utils import _add_variant, load_state_dict, _load_state_dict_into_model -from transformers.utils import WEIGHTS_INDEX_NAME from transformers.utils.hub import get_checkpoint_shard_files @contextmanager @@ -266,7 +265,7 @@ def generate( unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) # inference_params object is a cache, where keys and values of previous tokens are stored - inference_params = te.pytorch.InferenceParams( + inference_params = InferenceParams( max_batch_size=batch_size, max_sequence_length=input_ids.shape[1] + max_new_tokens ) @@ -324,7 +323,6 @@ def generate( next_tokens = graphed_generator(*args) if use_cuda_graphs else generator(*args) output_tokens.append(next_tokens.clone()) - result = torch.cat((input_ids, torch.stack(output_tokens).permute([1, 0])), dim=1) return result @@ -337,67 +335,36 @@ def replace_params(hf_state_dict, te_state_dict, config, fp8_init=False): m = re.match(layer_prefix_pat, param_key) if m is not None: all_layer_prefixes.add(m.group()) - - GATE_PROJ_SIZE=24576 - - for layer_prefix in all_layer_prefixes: - # When loading weights into models with less number of layers, skip the - # copy if the corresponding layer doesn't exist in HF model - if layer_prefix + 'input_layernorm.weight' in hf_state_dict: - te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.layer_norm_weight'].copy_(hf_state_dict[layer_prefix + 'input_layernorm.weight']) + def copy_from_ht_to_te(te_name, hf_name, start=None, end=None): + # When loading weights into models with less number of layers, skip the + # copy if the corresponding layer doesn't exist in HF model + if layer_prefix + hf_name in hf_state_dict: + te_state_dict[layer_prefix + te_name].data[start:end].copy_(hf_state_dict[layer_prefix + hf_name]) + + copy_from_ht_to_te('self_attention.layernorm_qkv.layer_norm_weight', 'input_layernorm.weight') + copy_from_ht_to_te('self_attention.proj.weight', 'self_attn.o_proj.weight') + copy_from_ht_to_te('layernorm_mlp.layer_norm_weight', 'post_attention_layernorm.weight') + copy_from_ht_to_te('layernorm_mlp.fc2_weight', 'mlp.down_proj.weight') + copy_from_ht_to_te('layernorm_mlp.fc1_weight', 'mlp.gate_proj.weight', end=config.intermediate_size) + copy_from_ht_to_te('layernorm_mlp.fc1_weight', 'mlp.up_proj.weight', start=config.intermediate_size) + if fp8_init: dst = te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.weight'] - - if layer_prefix + 'self_attn.q_proj.weight' in hf_state_dict: - q = hf_state_dict[layer_prefix + 'self_attn.q_proj.weight'] - for head_nr in range(config.num_attention_heads): - dst_offset = head_nr * config.head_dim * 3 - # copy query - dst[dst_offset:(dst_offset + config.head_dim), :] = \ - q[(head_nr * config.head_dim):(head_nr * config.head_dim + config.head_dim), :] - - if layer_prefix + 'self_attn.k_proj.weight' in hf_state_dict: - k = hf_state_dict[layer_prefix + 'self_attn.k_proj.weight'] - for head_nr in range(config.num_attention_heads): - dst_offset = head_nr * config.head_dim * 3 - # copy query - dst[( dst_offset + config.head_dim):(dst_offset + 2 * config.head_dim), :] = \ - k[(head_nr * config.head_dim):(head_nr * config.head_dim + config.head_dim), :] - - if layer_prefix + 'self_attn.v_proj.weight' in hf_state_dict: - v = hf_state_dict[layer_prefix + 'self_attn.v_proj.weight'] - for head_nr in range(config.num_attention_heads): - dst_offset = head_nr * config.head_dim * 3 - dst[(dst_offset + 2 * config.head_dim):(dst_offset + 3 * config.head_dim), :] = \ - v[(head_nr * config.head_dim):(head_nr * config.head_dim + config.head_dim), :] + def copy_interleave(hf_name, x): + if layer_prefix + hf_name in hf_state_dict: + q = hf_state_dict[layer_prefix + hf_name] + for head_nr in range(config.num_attention_heads): + dst_offset = head_nr * config.head_dim * 3 + # copy query + dst[( dst_offset + x * config.head_dim):(dst_offset + (x + 1) * config.head_dim), :] = \ + q[(head_nr * config.head_dim):(head_nr * config.head_dim + config.head_dim), :] + copy_interleave('self_attn.q_proj.weight', 0) + copy_interleave('self_attn.k_proj.weight', 1) + copy_interleave('self_attn.v_proj.weight', 2) else: - - if layer_prefix + 'self_attn.q_proj.weight' in hf_state_dict: - te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.query_weight'].copy_(hf_state_dict[layer_prefix + 'self_attn.q_proj.weight']) - - if layer_prefix + 'self_attn.k_proj.weight' in hf_state_dict: - te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.key_weight'].copy_(hf_state_dict[layer_prefix + 'self_attn.k_proj.weight']) - - - if layer_prefix + 'self_attn.v_proj.weight' in hf_state_dict: - te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.value_weight'].copy_(hf_state_dict[layer_prefix + 'self_attn.v_proj.weight']) - - if layer_prefix + 'self_attn.o_proj.weight' in hf_state_dict: - te_state_dict[layer_prefix + 'self_attention.proj.weight'].copy_(hf_state_dict[layer_prefix + 'self_attn.o_proj.weight']) - - if layer_prefix + 'post_attention_layernorm.weight' in hf_state_dict: - te_state_dict[layer_prefix + 'layernorm_mlp.layer_norm_weight'].copy_(hf_state_dict[layer_prefix + 'post_attention_layernorm.weight']) - - if layer_prefix + 'mlp.gate_proj.weight' in hf_state_dict: - te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[:GATE_PROJ_SIZE].copy_(hf_state_dict[layer_prefix + 'mlp.gate_proj.weight']) - - if layer_prefix + 'mlp.up_proj.weight' in hf_state_dict: - te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[GATE_PROJ_SIZE:].copy_(hf_state_dict[layer_prefix + 'mlp.up_proj.weight']) - - if layer_prefix + 'mlp.down_proj.weight' in hf_state_dict: - te_state_dict[layer_prefix + 'layernorm_mlp.fc2_weight'].copy_(hf_state_dict[layer_prefix + 'mlp.down_proj.weight'].data[:]) - - + copy_from_ht_to_te('self_attention.layernorm_qkv.query_weight', 'self_attn.q_proj.weight') + copy_from_ht_to_te('self_attention.layernorm_qkv.key_weight', 'self_attn.k_proj.weight') + copy_from_ht_to_te('self_attention.layernorm_qkv.value_weight', 'self_attn.v_proj.weight') return all_layer_prefixes \ No newline at end of file From b05cfa62a83141948c8ecfa22f04952722ae8fdb Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 10 May 2024 15:36:04 -0700 Subject: [PATCH 110/244] refactor of replace_params() Signed-off-by: Pawel Gadzinski --- docs/examples/te_gemma/te_gemma.py | 66 ++++++++++++++++++------------ 1 file changed, 39 insertions(+), 27 deletions(-) diff --git a/docs/examples/te_gemma/te_gemma.py b/docs/examples/te_gemma/te_gemma.py index f67c96100e..41ded00405 100644 --- a/docs/examples/te_gemma/te_gemma.py +++ b/docs/examples/te_gemma/te_gemma.py @@ -150,7 +150,6 @@ def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, fp8 config.qkv_format = qkv_format with fp8_model_init(fp8_init): vanilla_model = cls(config) - is_local = os.path.isdir(pretrained_model_name_or_path) subfolder = "" variant = None if os.path.isfile( @@ -162,7 +161,7 @@ def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, fp8 ) is_sharded = True - resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( + resolved_archive_file, _ = get_checkpoint_shard_files( pretrained_model_name_or_path, archive_file, ) @@ -172,17 +171,16 @@ def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, fp8 assert not isinstance(resolved_archive_file, list) resolved_archive_file = [resolved_archive_file] + total_dict = {} for shard_file in resolved_archive_file: state_dict = load_state_dict(shard_file) - replace_params(state_dict, vanilla_model.state_dict(), config, fp8_init=config.fuse_qkv_params) - _load_state_dict_into_model(vanilla_model, state_dict, start_prefix="") - - # Force mem release. Taken from huggingface code - del state_dict - gc.collect() - - + total_dict = total_dict | state_dict + replace_params(total_dict, vanilla_model.state_dict(), config, qkv_fused_and_interleaved=config.fuse_qkv_params) + _load_state_dict_into_model(vanilla_model, total_dict, start_prefix="") # Copy parameters like embedding. + # Force mem release. Taken from huggingface code + del total_dict + gc.collect() return vanilla_model @staticmethod @@ -326,21 +324,29 @@ def generate( result = torch.cat((input_ids, torch.stack(output_tokens).permute([1, 0])), dim=1) return result - -def replace_params(hf_state_dict, te_state_dict, config, fp8_init=False): - # collect all layer prefixes to update +def _get_all_layer_prefixes_to_update(hf_state_dict): + """ + There are many parameters in hf_state_dict, whose name start with model.layers.[number]. + This function extracts all strings like "model.layers.[number]." that are starting strings of keys in hf_state_dict. + """ all_layer_prefixes = set() for param_key in hf_state_dict.keys(): layer_prefix_pat = 'model.layers.\d+.' m = re.match(layer_prefix_pat, param_key) if m is not None: all_layer_prefixes.add(m.group()) + return all_layer_prefixes + +def replace_params(hf_state_dict, te_state_dict, config, qkv_fused_and_interleaved=False): + """ + Replaces params from TE TransformerLayer state_dict with corresponding parameters + from HuggingFace GemmaModel state_dict. + """ + all_layer_prefixes : List[str] = _get_all_layer_prefixes_to_update(hf_state_dict) + for layer_prefix in all_layer_prefixes: def copy_from_ht_to_te(te_name, hf_name, start=None, end=None): - # When loading weights into models with less number of layers, skip the - # copy if the corresponding layer doesn't exist in HF model - if layer_prefix + hf_name in hf_state_dict: - te_state_dict[layer_prefix + te_name].data[start:end].copy_(hf_state_dict[layer_prefix + hf_name]) + te_state_dict[layer_prefix + te_name].data[start:end].copy_(hf_state_dict[layer_prefix + hf_name]) copy_from_ht_to_te('self_attention.layernorm_qkv.layer_norm_weight', 'input_layernorm.weight') copy_from_ht_to_te('self_attention.proj.weight', 'self_attn.o_proj.weight') @@ -349,16 +355,22 @@ def copy_from_ht_to_te(te_name, hf_name, start=None, end=None): copy_from_ht_to_te('layernorm_mlp.fc1_weight', 'mlp.gate_proj.weight', end=config.intermediate_size) copy_from_ht_to_te('layernorm_mlp.fc1_weight', 'mlp.up_proj.weight', start=config.intermediate_size) - if fp8_init: - dst = te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.weight'] - def copy_interleave(hf_name, x): - if layer_prefix + hf_name in hf_state_dict: - q = hf_state_dict[layer_prefix + hf_name] - for head_nr in range(config.num_attention_heads): - dst_offset = head_nr * config.head_dim * 3 - # copy query - dst[( dst_offset + x * config.head_dim):(dst_offset + (x + 1) * config.head_dim), :] = \ - q[(head_nr * config.head_dim):(head_nr * config.head_dim + config.head_dim), :] + if qkv_fused_and_interleaved: + """ + When qkv_fused_and_interleaved=True, key, query and value layers are on one tensor + in TE TransformerLayer. Moreover they are interleaved within each head. + Let q_i, k_i and v_i be query, key and value layers for i-th head respectively. + Then TE stores weight tensor in the form: + [q1 k1 v1 q2 k2 v2 ...] + This is done to maximally optimize performance time. + """ + te_qkv_layer = te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.weight'] + def copy_interleave(hf_name, idx): + src = hf_state_dict[layer_prefix + hf_name] + for head_nr in range(config.num_attention_heads): + dst_offset = head_nr * config.head_dim * 3 + te_qkv_layer[(dst_offset + idx * config.head_dim):(dst_offset + (idx + 1) * config.head_dim), :] = \ + src[(head_nr * config.head_dim):(head_nr * config.head_dim + config.head_dim), :] copy_interleave('self_attn.q_proj.weight', 0) copy_interleave('self_attn.k_proj.weight', 1) copy_interleave('self_attn.v_proj.weight', 2) From ee698e7f2100befe2a33316be8b2675bd7c45841 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Mon, 13 May 2024 09:56:57 -0700 Subject: [PATCH 111/244] Minor refactors of te_gemma.py Signed-off-by: Pawel Gadzinski --- docs/examples/te_gemma/te_gemma.py | 98 ++++++++++++++---------------- 1 file changed, 45 insertions(+), 53 deletions(-) diff --git a/docs/examples/te_gemma/te_gemma.py b/docs/examples/te_gemma/te_gemma.py index 41ded00405..4040998f00 100644 --- a/docs/examples/te_gemma/te_gemma.py +++ b/docs/examples/te_gemma/te_gemma.py @@ -69,11 +69,9 @@ def __init__(self, config, layer_idx, *args, **kwargs): def forward(self, hidden_states, - *args, attention_mask, inference_params=None, - self_attn_mask_type='causal', - **kwargs): + self_attn_mask_type='causal'): """ Custom forward to make sure we only pass relevant arguments to the forward pass of the `TransformerLayer`. Also, make sure the output @@ -87,33 +85,48 @@ def forward(self, self_attn_mask_type=self_attn_mask_type ),) -class GemmaGenerator(torch.nn.Module): - def __init__(self, model, lm_head, inference_params, dtype, generation_config): +class StaticGemma(torch.nn.Module): + def __init__(self, model, inference_params, dtype, mask, lm_head): super().__init__() self.model = model self.inference_params = inference_params - self.normalizer = torch.tensor(self.model.config.hidden_size**0.5, dtype=dtype) - self.generation_config = generation_config + self.normalizer = torch.tensor(self.model.config.hidden_size**0.5, dtype=dtype) + self.mask = mask self.lm_head = lm_head + + def forward(self, hidden_states): - def forward(self, hidden_states, unfinished_sequences): hidden_states.data[:] = hidden_states.data[:] * self.normalizer - for decoder_layer in self.model.layers: hidden_states.copy_(decoder_layer( - hidden_states, - inference_params=self.inference_params, - self_attn_mask_type='padding', - attention_mask=None - )[0]) - - self.inference_params.seq_len.copy_(self.inference_params.seq_len + 1) + hidden_states, + attention_mask=None, + self_attn_mask_type=self.mask, + inference_params=self.inference_params + )[0]) hidden_states.copy_(self.model.norm(hidden_states)) logits = self.lm_head(hidden_states) logits = logits.float() + return logits + + +class GemmaGenerator(torch.nn.Module): + def __init__(self, model, lm_head, inference_params, dtype, generation_config): + super().__init__() + self.model = model + self.inference_params = inference_params + self.normalizer = torch.tensor(self.model.config.hidden_size**0.5, dtype=dtype) + self.generation_config = generation_config + self.lm_head = lm_head + self.gemma_layers = StaticGemma(model, inference_params, dtype, 'padding', lm_head) + + def forward(self, hidden_states, unfinished_sequences): + logits = self.gemma_layers(hidden_states) logits = logits[:, -1, :] - next_tokens = torch.argmax(logits, dim=-1) + next_tokens = torch.argmax(logits, dim=1) + + self.inference_params.seq_len.copy_(self.inference_params.seq_len + 1) # Sequences, which are finished should contain padding - taken from huggingface transformers. next_tokens = next_tokens * unfinished_sequences + self.generation_config.pad_token_id * (1 - unfinished_sequences) @@ -134,7 +147,6 @@ class is monkey-patched with `TEGemmaDecoderLayer` class before def __new__(cls, config: GemmaConfig): with replace_decoder(te_decoder_cls=TEGemmaDecoderLayer): - # trzeba wstawis layer number do tego czegos w jakis sposob gemma_for_causal_lm = GemmaForCausalLM(config) gemma_for_causal_lm.generate = TEGemmaForCausalLM.generate.__get__(gemma_for_causal_lm, GemmaForCausalLM) @@ -145,7 +157,8 @@ def __new__(cls, config: GemmaConfig): def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, fp8_init=False, qkv_format="bshd", **kwargs): """ Custom method adapted from `from_pretrained` method in HuggingFace - Transformers repo: https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579 + Transformers repo: + https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579 """ config.qkv_format = qkv_format with fp8_model_init(fp8_init): @@ -184,7 +197,7 @@ def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, fp8 return vanilla_model @staticmethod - def _padding_to_beginning(inputs, lengths): + def _padding_to_end(inputs, lengths): """ Gets the tensor with sequence padded from the beginning and return tensor padded from its end. @@ -206,47 +219,24 @@ def _padding_to_beginning(inputs, lengths): def _generate_context_phase( self, + gemma_layers, input_ids, inference_params, pad_token_id, eos_token_id, unfinished_sequences ): - hidden_states = self.model.embed_tokens(input_ids) - normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) - - output_tokens = [] - hidden_states = hidden_states * normalizer - for decoder_layer in self.model.layers: - hidden_states = decoder_layer( - hidden_states, - attention_mask=None, - self_attn_mask_type="padding_causal", - inference_params=inference_params - )[0] - - hidden_states = self.model.norm(hidden_states) - logits = self.lm_head(hidden_states) - logits = logits.float() + logits = gemma_layers(hidden_states) logits = logits[torch.arange(logits.size(0)), inference_params.incoming_seq_len - 1, :] next_tokens = torch.argmax(logits, dim=1) # Sequences, which are finished should contain padding - taken from huggingface transformers. next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) - output_tokens.append(next_tokens) unfinished_sequences = unfinished_sequences & ~(next_tokens == eos_token_id) - - hidden_states = self.model.embed_tokens(next_tokens).unsqueeze(1) - - for k, v in inference_params.key_value_memory_dict.items(): - key_layer = v[0].contiguous().cuda() - value_layer = v[1].contiguous().cuda() - inference_params.key_value_memory_dict[k] = (key_layer, value_layer) - - return hidden_states, output_tokens + return hidden_states, [next_tokens] @torch.no_grad() @@ -254,18 +244,18 @@ def generate( self, input_ids: Optional[torch.Tensor] = None, generation_config: Optional[GenerationConfig] = None, - max_new_tokens = 0, - use_cuda_graphs = False, + max_new_tokens: int = 0, + use_cuda_graphs: bool = False, **kwargs, ): - batch_size, _ = input_ids.shape + batch_size, max_input_sequence_len = input_ids.shape generation_config, _ = self._prepare_generation_config(generation_config, **kwargs) unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) - # inference_params object is a cache, where keys and values of previous tokens are stored + # InferenceParams is a cache, where keys and values of previous tokens are stored. inference_params = InferenceParams( max_batch_size=batch_size, - max_sequence_length=input_ids.shape[1] + max_new_tokens + max_sequence_length=max_input_sequence_len + max_new_tokens ) # lengths is a tensor of shape [s] representing lengths of sequences. @@ -274,10 +264,13 @@ def generate( inference_params.incoming_seq_len = lengths.to(torch.int32).clone().cuda() inference_params.max_incoming_seq_len = input_ids.shape[1] - TEGemmaForCausalLM._padding_to_beginning(input_ids, lengths) + TEGemmaForCausalLM._padding_to_end(input_ids, lengths) + + context_phase_layers = StaticGemma(self.model, inference_params, torch.float32, 'padding_causal', self.lm_head) hidden_states, output_tokens = TEGemmaForCausalLM._generate_context_phase( self, + context_phase_layers, input_ids, inference_params, generation_config.pad_token_id, @@ -289,7 +282,6 @@ def generate( inference_params.incoming_seq_len.copy_(torch.ones_like(inference_params.incoming_seq_len)) inference_params.max_incoming_seq_len = 1 - generator = GemmaGenerator( lm_head=self.lm_head, model=self.model, From 9ec603a7c73fe5058bc4f4c81ad0dbe4f50fdf49 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Mon, 13 May 2024 17:58:00 -0700 Subject: [PATCH 112/244] Refactored te_gemma.py Signed-off-by: Pawel Gadzinski --- docs/examples/te_gemma/media/atn1.png | Bin 4602 -> 0 bytes docs/examples/te_gemma/media/atn2.png | Bin 4561 -> 0 bytes docs/examples/te_gemma/media/atn3.png | Bin 2487 -> 0 bytes docs/examples/te_gemma/media/pic1.png | Bin 19382 -> 0 bytes docs/examples/te_gemma/media/pic2.png | Bin 25116 -> 0 bytes docs/examples/te_gemma/te_gemma.py | 399 ++++++++---------- .../te_gemma/te_gemma_loading_weights.py | 106 +++++ transformer_engine/pytorch/attention.py | 16 +- 8 files changed, 287 insertions(+), 234 deletions(-) delete mode 100644 docs/examples/te_gemma/media/atn1.png delete mode 100644 docs/examples/te_gemma/media/atn2.png delete mode 100644 docs/examples/te_gemma/media/atn3.png delete mode 100644 docs/examples/te_gemma/media/pic1.png delete mode 100644 docs/examples/te_gemma/media/pic2.png create mode 100644 docs/examples/te_gemma/te_gemma_loading_weights.py diff --git a/docs/examples/te_gemma/media/atn1.png b/docs/examples/te_gemma/media/atn1.png deleted file mode 100644 index 4c3f5e2fa5a2d56dfed8137e407cfd671405f9e1..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4602 zcmeHL`CC(08m6TktbmL<9g3jTg3!v=2@*=!T0s$k#3>50t85}mG-1mE(Q%McL?9F} zgoII=WE_MD0b*F|0vNVbq9P^)*@*!HNf1JS%(*(nA>FB&;zwLM6AKvU`F1WmhJ?wjO8~72o#e4&v zO;V4XP7OrIrDmMP#OlPSpp#?mu+f;<*c9w}bgEpp(?dsR=XGE2uTExGE{FoIN9bjB zkFVn6Hk`ITvYxzj?64X6%ZuK&?|^m33%#aFa%03S}$!cfeA!(OyDbaD9y`I1N0yeApH|eYw93^!@Gkz)_tO-^7CU z)5kE-{^6VnXt&oJfHvo6Gyl`eXZoFM6ZtNeCcK$URz-!*(m!P1mmpa})%Bn~`}PgC zj%_X)7#JLK&k5#_Kgz{!ZKzfFb|$M4$xJk0cQ`|$fkEIR&R7&j#uzk^LC zEb^>7w^rDW2t>t$2i=Qu?hp?<3v+fZr>3-g7S^1UfJfHY+SyI7El={M%fetV*pRC7 zg^9Vj;N{EHb*+mCXJ?ISLn(;IU|?X-mDW0xjg5_oscB!lhmt~6Rhvg}C*kW<7N@#o zP!Ay;mEE`$HE69qN_OFMzN{CrvV0Q`@Aj*gC$ zPPZyaaU^Wy*|%k($1zs?3*=mA-5lSb5-w?SL7IvPU?-GPq!c&OxLUk6S774e21p<% zBqCyfmxslrEb-h_Mf%(*eSLlD3O*zxBmuvjeLIp}USf@k!{ZUA@XU|u%pKa64l`LS z`P%hMlO3*HyDh(H`;YbtwiF@=bW6XX@dvU+_nwFu`#DMMgctlK@#%I zJC+y*JK*UbG!8jrFXt0CWl@42?fN86B$B)UK2RzudLJ2iR^GaGi@1obN(5{N=-F*m zjK#gaq8H9#e6!>nq&lnz0$?*RbE^;_J)K%N~GYrogHq7h7a? zZL{ujh@&e#Ngg0>$GM_`u(3A}7kt1CI9w1kBs@ix;mSThLt49LEcSO~XE9-;{M>5C0KW2&!{Jb<6@Kzg z0v{7-6MEqwzur>cx8uc98NA}ykY9%a6~+JI+7;$-WA9_-ZVK@dOuJmAEe>QMp}d%x znZc;$o~E{r5*HHf#o2djYh%xy`=mY6p{AxLKrntPeeTCw0A7IBoZT6x=$R9Fl^$6L zBX`w_?6;gMXb2{g33GL&s?};U58{0&{~8;=%gf74dUCF6r~9 z2kGu22t!nApF)ALk7#LxFl3z>7iXn(RPysIu>rN`?=afTh&8_Z5;Ofw2M)BW6gkyV zOG#Ze^<~#R9qD{yD{AD>L1;k2+a!E@P(y-{*)I^VWtU9zrz2X^V|$zcjd758(n#WkqpmcRLhfsSfY%ABqf9cZ zM55dNSaX6jL#{+Fs%r=eASy>x6+dwow{vp=#2XSm0fK~5z`0SIZ-Q}45A?D!GYbH9 z663oV~d;>|lCC1Itpt%#-?QkzIFGy(u9s-IFp{uKF zVrJIAG8`qh3a&!Z-6;PeDObYVd1&*143CM8jUCm-^mzRpNUdMKdiXyv{kK)y)q;-o zP_p(oB9L}8zH!ihNOQy?#2=x3EnZomd_OE>C=2&;^VglqRdo4uuIJiWB>f8rDt*NDpe{ydtI6&zF!X+jX%t}d-v|t%lc>+?2JEz ztFveKgA|ozi@VR^K#7O$kPeJoO7Kyf477PAeKYaAlImU_5C)%KXVsyBeDg>8$4R>x zVui|1nkr73>v_(o2{8-v2S+(5${@A@Jj&ZzmkFh*m94>ZUgqm3gu63K7sfh0J2$2s zl=B<`+?|mIR2HYDr2@#r>o;ydzLu4jbFj06X0GYBKvY5S9v>fX17)8enLBD#ra1(p zludA5Ols;ONacY3;^JcHc#kL?0gA3Y&3(druAkyasjN(zV1%J@I2U7w$o}pE(^0$Z z_alkRUE%F`C+E@BKoGxp&>PH-958ei7lASsV)aP8lwoUY8v|k>;-Z18HikQP%mDSG zsTRA;poGhlJ%A;llDPutXb{$B^cWowBlK=ia4>dJ%meH>oyI0afUw2`0h^eZ$XjR@ zwMDHaq@>tGo5#blSP7B0fMaP3)!jdR4A?rU&JO@U|0O>V{Lc{IG8W`qLkfIu83Ipq Oe2*Z#>kgm!%Rd3Ym~jyR diff --git a/docs/examples/te_gemma/media/atn2.png b/docs/examples/te_gemma/media/atn2.png deleted file mode 100644 index 7e9c471511285a917b8d0a8014ac84e7f7c5f7be..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4561 zcmdUzdsvcp8^>+!tu|M#9de~fE7H2GY)YQdbuhyeYcwLkX9dtL9JujAr+o*(?4=XbcD`~KYD z`?-Ds>A(KnE$>=bSgc1JJMx8v#cHmF#VV(DYrsE8JMuSx)7ykYh|}x9Cu3doHSl|L zvfr8HQ;F#0v~&0v3rs>HE@m&`0zM`tfe@RRtX$3Xv9Q>90deHu>C0sbLGU$H8l!iP zJwFs^RNF7tG^}4&^}=xt^u%`^FMbkydvL#H(*1+)KWMoY`SJPP$9CszTcv#6a`H{; zmydASm0_4i_h;Cfz5us z{j>F;!|(hSaJ=#VyCJqc*RQl1!G+=Wj1z;n(XORC5$O!nAUB!aBU0%d~opa&;G@V+tp9bdPfF1ub4HOYrNIIvp;ZwO8wx_xftiUY8N1!!b!o zpK)akxe_mhf-W9bMlEKh&I*gY5}1{WBu{B~A!_Kd_i&9#!QFS-ua0ROi`mLjt8PlgK1~dwR zz^>+$QB}yC?8vcf*3Y&Q^a?s1j+8E>GlZGi@n(}zPfs1M&gpH>HcUVWGcTWdYlVd3 zio`0Z7uMTO${DD7JX8l2@XM>OP#suj$)emtE%9YA$Dni*=eWUi#^rc@K8$A12)8B0 zOiWTKX_6dk)cah>wkSP+Of?u~4(q2N?CXXzY$E-37@uqOr@d) zO}MZ#XWnprigmZOM>HjhXWRCcu+Y4x7yR27s+=Sv{*ZA_V|Z)8jIi0o%GYpd!&cAV zM%Pi2mlkB+3w?)mVJ1vmdGeK$nj`sZ$y1H?VX>s5+a7l(1&H}erR6gMx)U(oMbbSaQ70*+;RH?Is=gh%Ne5pZFs6!+t&_ZZ zc7-_ev={&C-`_O#%F0MehTB>zyP?%q-gm}Z(||E+Y(4PME6=uFwSo*Am-S|oBg@?H zR{G~P)i?Fpt>^$*fib&cmo$=Lb8{F?8$#0pYq+<1CtQ#I@wh)aA3gfm)IEwTXVjK5 zw1h4C&;vo7E)lM0k6d5yAg(0D{~Q*VA+ph08kUWzN+IS&pgkEXC&enKU2ik`qtBae zplNEGy_u)oC`iCsN9{XLx)xL_d_h6MpZ>?gRNyXrH+(iCVyNnP;izvjMgW4bfM$2V zM2V(K!dj*jP%Hw4sAqsSx+r>ge?S?%^l4Ebh8xk0-w83kxhg3Cey?+JptJ3U`O)yi zn1tbfwZ2wK2Ad@X;VP-G$-~+S6;CQd5&>15DM*)3-V;_PbX;;D8W;#D&o=7f1x-%? z(IhN4v6*PKb?qgCg-&MAFN2d@L+em1FJ&7|g4LF6`oOq2H$gPa_0HG$F7F1R#h;dV z;Q&Ym27b8ZGW>)tEV;xc;nzSg+o$*T*1@m#_mP8r%v_l7`Y?F%*;>dD zkBf2NY z@$z>o{iJiu{TDDQ2DU{4lnfE% z?)pr<2IS*=-QC@Psh}&*el$-!RO}LlLJhyD4W(x&Le4KOspZX==gtF=5f7e$f#EcQ zM2sh8%1f+6nnNqc(z-r9^MH6UFA>UUs*x385h8@D0%5E*^E7`_0S(HZ-^=Fsbv_tV zzVpbc==V!Sr)rRr1sPYST@E?h4VK(9kyKDvm}yxX>sneBSJF1uV`FzDz9ZgsFgSaO zYgig%riwO2wj`UGQxjYqp&H$^FXQ=wk5=5D$Z6z*;~7 z8D?rxX~WB_u+<0@aw%#S5C)D5?%c$uVS;_;i`mftKD5#(t&qCp zpR}YYK}3cp1r6KM=y6C2WDm4I-j$YSx|Wn*t@zXvtr4*iA6!MT780}gttDYb<;#iH zajvB?o)n|WbVtiV%N>HafI+brm?hOo6^-u%(Xt2!?D7;*BRx-R@}}*{2rQPqi3=#^ zkurU{JZnkTbo7r0RaI4Ox~Wo--vC~QDPDzWA(Ra*#*#{I6L<7`A&}AVhdO=a;F2-& z`0IuzB${GX>>yB0!(|@ioSd8{f7!j8)=pFWMr;@mPN2i=>(kR0k5-;UQamLGyUeu* zsHj|Y-A-9|w|b=lPybWS)YjJO#+6xxV^P_Rzi`rwmjpNhb(8CVEc_2)GeHNcup+Th zLl0*x_~zv1Jsy)cdLy*)hTfM-{ranqc<}|*aU7*~E6LA>G$vQul1dxWlp@2cwZ45o zLkTsz;_voCvb{knPpofh8rAHgaFtqsVKl8+o^rBoY_Ox_(D$A9~SiE+zWpWE+YSS8eeT?|r*s5?K;W6jr1Ap(Ec**}M)PZrf5vNp{}O&C a%agg84kr%7KLWe3K=>h#P!E0e?SBAShj~N* diff --git a/docs/examples/te_gemma/media/atn3.png b/docs/examples/te_gemma/media/atn3.png deleted file mode 100644 index 16d707719e9664245fb8772d419bbdb8cbba77be..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 2487 zcmd^>e^k<07QpGaN89NKJ6oG$l+HG&ZJM+iDv45!rdZM>PNt|?`J=R86af*Fnd$Uv z8ZsqxQNbF+hEf@&R3v2DQe)J#Fwr4oRfG)DK~xkF_RH-4xo6ML-}^n^^X~h8-+lMJ zd(XX}_pX12iQe$R)(=1+&<1o&WIPDu(hKx)Yh8ghlJ1@dh~E(pq7&8vPx0D}YrwZh zLDY!?EFr6aaw8t+t;sWp>QKARevG!WnbSBA75WR z7b}uTB;OvYbDJo<`fMDgR+6pzrAulL8ja?$@#?uYCBwnl&Nwt}wgP>2Rn-F;W@?tx zBbUqBTy7H$!#h-W40r)s5cvg!uzr2*5pRFiOwA|0dm+S zl}ERfqWXVUbEuvvWeM zUk~cd@#}AFOAY<`Uy3seY&<#OD%i{E zRgomk_0m%pWBNbRkeIkQxZQFasuU}e)G0^Bs+S81LbWnYTG?o5ET|7-hX>-!HL;7< z)$@3KR!0y=G%5l-$l2?eNua%%;Il!Eb;_SHe!dy>5N2_e2AngN zk81ESx$8_5SeUkt9YhlFhqT(P2>bFR((=VJNB!kRG3>^+5kN@5THmo?Ll!hM8J zV^&(>_~i~o!u!8pUYBG~N8sopmaRmUyQ`I{ftB0vOGL8>mSsVH)U>1|=oY5s2)g}O=b6$~bVSGUv3Ui0%lkBJ zGd(uhF&JyqMa|PlrTmd8D4qET(;9DU9|z~Sq3sBn+IK}<)%F(b&OAdW`3oFIJ$3O- zqCO4F*DC}%+lq-g&8wX6h7g>EAxVRIl`NkcE7CVQOr(B0_##m0Ux5CXg8$8g_&=Ud zzI1^81F4Rc5*xS(NipqCXxM`GJ-3?em&B*Vbv5}$Nn`XmejbYO`l7C^vf+5Ym87k7 z&$g0RKP`2Woi!qp$((rVg-QF;DYu5{ZRIOK$~Cmn89A!+-tKx&tYjFLJoCAZ^kJlg&=#~shfPK)F(^EGx;OA2D-a?>8M&ir+ z?~x>|0>D?5Ol12pOF?PW+0?S{d!X*BCs?Ve!Q zOK;orCKEfGKHnT=)Z3>}%XMCwZY^VT<4J3@lo(97kTvAN6A$)3j}%I;sYgf9^tM`2 zm*@xv9`=2ChF*KzN9z+P=bW?Alyl}FaZq(nDT@0^^i5}m`uX`KO$+XYrtGyp-T&8s z7GVpM{|w7cs{w#H<$CErmYWbE`C=cDNOZE@Vkuf$S_(v=YT0Zy5U~deedd`npKvg~ zD!ly0E)mIkJi=-b9)V%9Pq&QMhKH@roS(fXH)sa$e(Ic`PPk=Y7>@K~re2I^(O+(f zDzhI#U3z=~K&fR@Z3fcQmnP#+W2u_5XlFnHnhQEDyHxl#W35nT7d>fj{~KkDs)L*) z-Z@d?STzg9Vyk1TIn^lht#z`7oy|RG+tB;^b$oDe@ci>lM>W5UKLZwiyQo(L!QrgQ zp3>~1j?uf#9UWD`5s9?U2$ND$8l0p|lT<0&2)cK#>KIZ4WaHJ9i5sxMkPyYP89oYr zVK5kK5JVqeOwIRz^6>C*OKWRIS67$6zrQnzI2Oik_JSueH;w8=sllcGr4p*?;3lPgpO4((AoCoZ@pan c2XxroMyRD1rR~U7z#<48g^6rBnEv%Y0soto*8l(j diff --git a/docs/examples/te_gemma/media/pic1.png b/docs/examples/te_gemma/media/pic1.png deleted file mode 100644 index 7c639fab31e8d71c619f8c5cf776d8964a5eb514..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 19382 zcmeI4cT`i^+wY@{;*0{LC`iqWsEATT>D59RMMMOogA$Q$=nzN<%81ejmEIJjH$$(9 z2m~QO4ANT=2!t9)Ac2G=_u%|~?|s)@cdh&0Kki-kt~Gx^NW$4??{oHk_Orj=&puK2 z?i%v_BKivm1me5(hu(b<=qG*v+7ATcZ{Pnp&;>1e3<4P$-O{^m9_+9-h4gmVMo2RgY&jpru6U^4 zGVD29p0x7u_%rL6v-UWzI+SbbnMlvXu164oC8tEs*gu}c{$}g+##o|l{m{{>c!|`L ziCG`7G+ut7@kmzX_Xh*FqzD;yIHih{^2eUPs2Dt2Js^L#?5-^8kNj>UjgdL1%vknhFX(+qVDZUuV_! zKYRTXFYx=?>zVyGf1D8Cf4TL1|DfM){d3U&y$fMfWoOC@x=Wbv(7UrhOmS$?N_M!n z!~;^xQJVvSgxt$ipEk}?8}qx-}Z1zJ^c5L@Z^zw6yHWjV=o|RJiVxSrE>a)000pWh5 zddIl$6qrbj*CS`JhD^xJ*!?wEB?>u#u1t!eK3lOmrq!fC*Cq%?v`Z=ae~_22D36IW zsnJ^*>NDlioti|>JBq%_e%VaR5sCkvNGK{VjU$z&L)g-)`ndFTE1Ns@mjxB6!9o7h z_0n|q+1Xhm51O-R+kKEFy2;GPe*Tp!aS^s+(cn92C_b=Q!l@Pc`0-Q}pDt0- zYb28y@|HoXXXq0l5sK91KVtR5?3B#uG%Q>fue7-)=G6Ip_{&qMY{(PU4bnLjJtirn z>zdF+L45meEMLd8E+u-3xElS;HMXG1*~kBb!r}NqLz=P|fzw&%su7kLFC!s_;Mvt^ zd96y_CL?_t_B4IA@+1hzl`U$#cB3kNdrKex{{3l@s3IIO@98$yb*x1vofw2oc>V$= zr)!WiDm<02os2{p;xJGT945zWM9FZIWVJ{qii`^F9W$Rj<-%WcGn+ACo zFoIs4`4X(6qm$0J9pECd>k%Lph#$q{odxXFxmC%%lP1!Xgt1r`c2Mo(08)S}zKORh zh7-LB>2k0iX3pH(J*{I(9Zj{FwtQK|TTzKm*QwQ<(A7$hOL^k!uO+s;_Jv{LJdVY$ z<434*Z!z3ydRcKWx2c%exCz&d74N05x%Z_%kl?alHBei0rES%^a11a^jiuB;LRooI ztZMquH=crl@HO%^!{Hiv@M*9v|0VH0{z)6r*yI9eZyXADyUZz)UK#>YRiChi9BhEryy>&j48$-c#IIC`8=We%By zu(nlHyOcx)dCfF>!ry=ID*jSW!0xq*QZ%gmKS;EcT2a&w&*dYmtnra$<dRV4VNL(u=z|m6trDgx|KTb_3 z=boQv2#WvCF2y^upk8Wp4v8cu@p2`-DjwVyCN9Z#V4;33++Bg*^S{S`C3@z(?)n>6 z-N}~nVw??Frd21K8m;di~EuNl*EL^6vDyb^QG~=BlgC$BbXJ-50DE zM6gID<)VJvGW@IcR*t9A%JjYTr|}2IMSPSry=~ooom8UK$!Mkv{9+XGeemJl=tK3Y z$O+css6wO2#K(nUo^H1;R-G}l8g?`C_HTsX#b=&KSnGn-u!xYPc<-?|YTL#igsjU= z_4%oR*%0W(`3p%Byj2lFD$Z-f438DkM8#;oX|IrG8;t>fc$q1Ni#$p(sv(bpb zw*Y-Dw{u=_B99c4lwaIfE@$%~x%8H@Wu1X-)M!ig+MkI_1R^qvCo=SJNX1G>RkGgD@JKiu=wsIvAj8& zAPT({3GX`)?)PE_WVIE-s}Qs8$Di!0T3gZ* zATVQ7`|wBsW#QUa*)E3BV zyq{)+?B`m3|FkKZTl{@cCyoLOt1x+=V(=}Fm4~o*yI?Ea7qk{t*Lb)@SoYGz#+;s(hW<>wsa-^VCk7cqS!uyL_=+Qd{4M;HDs>($ooD7z#Pqf61D(B0GrFj!d^~#$jS?PYs!$_Zuio_|o-~dt z{HX6v-M(1j(laB_1Yu^ripdLn=2F<(BEC5D<=||7`YF`x$so@2g;0(2th^xXIk54F z63MXl#*jhZ!G4VcplP+|JP!%B8xLv5m(>-8+D+*C#QNr3+Il~EZ;#JscerA9=v%da zUmbb68Aa{hdjkv|NJ>{hhLi?4o1($FTDl;p2_bG@jf{Y_6GDj+J~weBAM2@TtGW*= z@|;oQU0=Qqt7XqWU3ejmrgGMsJ@>ZQzY!g1$q`*I4t>KqZ5Sf-wHX9vk!Dic)`0U- z9nD9NuuhM?vH2%4NF{?O_Bn5ixEnVrOj5tM2Yh?fld-mO&0l!vb)*?n!d?_NZDS4$ zxaTvVnLDB9z}Jx51r5tUCR>-QQZe478scapv<r-8fSW|c^dtM6>kTWW1Z+j7J zpFoP}-T6!(ZFpxkga6Rv3fp+NMI3s)qXd(U1WyRod)kTBZm&S_CN3gJFb~q*$H^NH zk8%8ecfPVr85hwsa<-g{Wq{+pOGmqRQ=Jv(KL4fK>BBg#cICv)>I!_f^L%+<(l!pq zWL6{a(N{e@ACnu{O|#|DUf5;OC+quK_6HuCriIzrKb8x9BcPl4UU%g}fg@RJ`U$%0 zhGd^*hy#K#Tri7D$nUzM^Vc>;4M$r)tWmJo{9cny#_qPXD0pP0nR~b7kS#}u%NqBJ{V1y`Vd>AHCD3g2oV2W#6g)q;~7|EaI2ci)Pyj^G(>#X zTfCuAS(F%1r{k6t`-~Nc<&@gyHo|Ayu6Yv2tOz|Y0p#WG!fpXDd7v_l#CjFE`~N|5*5bVn`MW!!2D5db}QUUISC zA7(sQ=@wO2xt6f#u%3`*DZKgV{ThMY5o;tgZGoKZJL1UgW4h-(U6^9!h{O?cnF~cx zO-vEOm!DOAwH)QjHN!CRXm~Wk1E__}el6TK@Dq?wbtNDg9G4Kg%ecA~b$Z9Z@>uUx zY&x78Kle1UVQ)E$P@PB3e1UqY(pkV&3my>U>u*Usy<-=;B`Kfo2MdMNP}ZWnI>W>8 z@4w&9|4{IIdWGzABWKZ(avAAnH#}-H==YKzI~L;o>Bdl{Tb9p*&U>-wZX7MrD!^I^ zf5e+EQ4!hj6R6*DzediNnP}Lm+ctq3Wd>WFj5-gQ-<&?3w&ELM88VmbLGQ94uBZE4 z?Z_tthX(mWi)s#|zY1lUWV6Zd;KOY*2YLS#RxwzKtJWrelpjXAhrH8GhTGiuJP%xL zl(wQ&^eYnhY5bDpLY$t7skT{ig)O*VTuhG=$H~bWM~lE*%_gZkPOJ{} z@98QkC%54kodMOk#J4n&n8ch@7ND8T-_$_C+UE@*^Y={g_VnppNJL^zU$EVbj7U_g zYCV>0KZWw@ixTc1>BP1mA-j2N&ql+HJysFuz0cM|jN0r&0!LEpts3EZ-#^c#Aa<3# zP=3aQU?=Dnr=>{YQ^%Y=ioI6br5m}iCPxS~Asvf-tu3%E8T69dc#Pw^(EdCAZKUgz zsomYzEATSP$^IiEvDu-Kd3Aw4+7|CWd}qeHX4ov2uxKmzxqO z+Q1shZ2=)B6jQI-MYk7=fgmo|QOFLT#?B^Xgaw|G1IU3{_BL~ zv~YRp?KrpyMh!109e>PlM7Z7%UXxF@deyDQbT{pb!)T100QyMfwvcL+_^Y*?wO44% ziuPwp@DskAfG6?sZ3|1Ko`c3n%6hyO4bwT=)ONSw_ux#>h+^yw7 z+}wS8%Ft2EYH4MyKhS}wrj2k|Rsoiotx3BQr?Z(+B9zUCp~G46$0$1wIo zm5`d9#MhPrPjZD7I z5>l0sty}KS>U8Y}IgjQh9UQNH*Mhx6ynS`AZqQaIr;Js%AXt!4_Qc+!&X2=oT*gSpiojexJZMk~bkm{8~Nr6nr z6BDG>HPR(y!91+$b>KQ-fkhn++84+$kIxaTrzv7WneI)IQc^J+Xf0825xN zQ0E+{ELP1MX_AfG@!piwZmu3K4!D#8&p@~b)m65@d8^*1#DG&h(rC`E$c=eo)@xn% zH{GlNF5dEPppsa~N9EE>$v$%<3CNu2y;OW*vExYMHbx^o0GeupO)g_~I-LMkvDF{8 za`uq@pv8t~4Uq2f(d?1sVI}=3+e@uoEh?fnbx+5m$gen_wxl;JFGmd5u-Wjk`(Wb~ z$)Hp4F+w8xbg)Y0Dmh`5K42odrEQbT>O`P68?N%4ciNc<&7F&BhETqvUReO^ji`bm zWY<6{+~+b~Tjzua=A&jEG8e4|d3Yrf<6~x5+gQ5uwn5%zeL23uR*W`s35R06E@;^K zv$H~$wx!XpL}bQu$`m4XCnfR{?qummZdL?k1&6fnadr~8YJrq#a|bMYw{8VjKQ(W{ z(3{xKQnVTomjhSCjOXu$n8ahhm@so01XK~GRw^*;5oAzd{ytkuYCr8FY;IS;?LBhC zrYs0>XT%^{OT0CKlJMR&+BPu2fYnr7n~Lv6RLiENT|GhFwdi;W zAj(;VX_dRjCMT$Lj>8|f{yrA{Q=w^&5>aSr zBg5~TT|-{=Q$hmTTlWYy4Ji&G#YI^U%M|GIE0ZussCesRh$nO2ISB{9P2qY&W!YOi zC>Ia2LA%RBp}G!Y89ga$z4Ye8pcVokF}sY;cl>coAig7*5PvEFMStC76*3!M{@C}- z`P(iH3igsX7xg*Lq_T+AO9M#fa?_lARw_K~GSe`=941@ZswVnis29CXOOK6Od^tB0 zyIWplE5>)p1bNdKkm_bOT+Q!(brh=I<}-D?3JIef;&mBMM4i^rI+Da7M^Dg zqY0_<=eq>x99=^RYSIBa@V*?t-)Nci;3@o)`VCVQ8Iork#l;- z6b;TKheT2-OZ{p|sS4Inp{3vm&Z zSdVK<5zS?1Xy!E*kmq}uOW6&3N3~IK3rJbHt*L+i#_}Iqy6WPBmXuhY9Ip!+2YwW_ z-&)^<|LSp$BXWO<5AL=q1r_zyAlqnVUF;Z*vVWD#@SAvK+;{154J!c?-Bz8 zgV`(!++{M1wRZ5QnHDLvqWf7{S(};|nkMPfObTYtnQ#>JX@x)_9H)dansMO%Bfo~z zI$z6L4yxm|xqGxta$W+q9=%F#!$L{B%Wp9F#K;^A#=Z<8)rCbvgyuZ(nI_ENZ2IH{vQ<+5*6>=dL{t>t)gKz+IRPhq0arSgPEZ}@y3kq z2lGxx?zUsf&FA+YZp?wX4Z)Lyq^B5GZ;nqm~tJ65m0j*gzjr6?#k z90;cU%rBHU@J-VDeje=L;{6ZL@C4Q;bOnygWmFj;+cl~BXFkg`@*YFyLY_3 zv#G$L;i~aniU$|3wp~9(UR#snJ(5q2aoo!#Yp2~j0MxWsSX@a}HH9bQJ`A|iK;vZj zpGRDz?$Pim?GaGKimlJaOj~YH`rrR7B;G>HqQn9>D5?EmLhE56pxC74)zw)0=D2Ni!CvgBZR{ujCA2fDs0jZEHJ&a>gY$a5=0P zArNTjF>RrbFf?RQ(;qLL1HUp{W^GmjDgb&&6r`&^I*ux|{Aydz%(0rX^Iuy4}L2_9n3h3&-`}!`rujTjIYP`a!m>AitJ?>&6r-rxUqze&caSXNI)&HIfWptS>y{?OdqT+~PM$6DYTJ+#Y`6u$P@<|5IJj<=U+#I-e(fwNOzNdUGX$Q~U)fKLN!P>N3}LwfKaW!HrPt-xwM(DfG7rzOzyGE$NqgS^s$3|Cs#$&(jWVXA+&Q zO>yiu_XC_GM;t<#3jA-d>cw+DC8l4NFPm9!^~gAnnPU#k1;v`y}Rhi(~l3 z@xA^Gn}S3h>95) z1b=@)efM4|_+67-=vwXO7pVk2OG``7>6S)vzKz!QpcyJ;dzp$>00hzqASNY*{zIVb zhu?otoPU_4|7|#{W}mp}f3%-(KwMofIC&AS2kgFV+0%ugxTPa8Nz~zbbnQSviMYD& z+1uRau>A!QhVF5N>wZ;TT?E4%s&pDEvDmc5Y}*2P(gzGiSH6jy8_ZlPjLE*ba>;S` zw&zq+t*gGc4!mIxY7#EY$UH^76KXc@|Oj_EIlT30@bEI|`GD6A3FEylw&eO+i&3PJNRTTN$=pnxs==AS1g2 zBy&?cZ*+Ix>?m+an3Jt8$@|d($oodCoQ)=tT=wxS*xbTxu*C;m1x!jxa?9$afO`S* zUvLvtDuX06^u*br;CVz0aMC1gSfk7=bEPtY7YzS6U^o=KTwLe`tUXXL+CtCEKU|*e zmwfab0LA@-#{E|r!I*D}QK_k4aEZBnD_5+?lk=}w`(L}hpN`2Mgb+e$Bp=vZs8jdr zb3ms%g>uq&SeLj^@aX$t2>#SAy`_deO1FD3o(ywi)2$Tyc0KNE_}<=6);^v2sPw>Y zT)aY1Y4H2f&KT_0r_A=#>2}zI;yhl6vEXklD-l z_`=nk?%U0mG-$J>OmtXg$^Zu{Gmb8ytc=j8wp}*plC_l}>V2lzhChLBVei;Xdrjf4 zRASIOmIsFBzrMdoVq<_%le^g4%$N?ngoH%*q65uODFz_{Vs zC(WV!Pw;WzK532gxCGslZOYmgn+K$CLxV2m0{qT%N$R~T93VzJI$cptK;@Bloe>LP13X= z4%(?}EluZ38=IswH)eacGgQV$5|u~pcs0k%>Yjc@`1B-TW0eRG+18fRpTmb0nz z#05^*`V9}D--lth`o?(hCuwQ=ZkXE2((b}b5{|r8gt@>8*T8q9SQ^+CdJLFCepf4^ zDzrUQXuBSO{F%RmW#oNt=qFykRP`d3r;c}Xll;Waw)%}{n>{ntN{GGyfyXy1h1fDu zQ)AflWKRbE3ApIN56CRdWze&?$uN2yw(TNehLe=DX2j?-WcCTQGL%+AI<=`VWy@8C znME(t0h5D{ojp7){S&VtZEXVXtOLpkh>+X|>&h5fJo6)!qgao?C%i}t|NLX?vUOeE zZEcE%UDZ6Aq6y_Ic5Q70O08vq^-#L@Mm21iY$uP_KXvp+mgunUI3zv-x?XeK;|x}^>G$-flVVbx zM$#Fv=*s%8X;gN@(m)J^jfjX^>0fc0%gQ3$FNp^nP1V3YMf_C{YPV)?W5fLWh2eE_ z&j`YelrMNkkFp35s0jgBZ=G#==z^LuvPaL+-%}!iZdj*3S0?7~rc^t5a<`y&8%yV9 z^^XVazR@K|(_dn9-TaJ{?C$v&ls%1!pSWn0RO1g*LgXr^M=#i&caT}@!;bEfc?co5 zs9$>GI;ItBfEMRlOoMZZN|o+z?M=J3J_&_klqb>Bfevp{<%!t1`apxOuH`l7d3f9;q_)2>MzKB0Nz0=*Gn~G)N;~|ldk*Y`gt!N5pU_qf)OTEV zq6H}%t|aGjcIcz5$ELa&=XrDejGQ+_2q&^~JZ}GaxG{=DW-rI8%@Gn|?QtB535c(u zp_09ofA14#HmZ#=Iv24kq^TJxu$>$cPwGCE>E>2uUte5=f3#HP|H(g>PqnjPUmuf?1?oabT*HBE5c-EgMCLma?fM4exo>udCI$x{Sw-!Ob_!N(pVJF^$3s@K82Atjot+1%f1- z?|dkDw0>4uhDWA#TklIxppq}tA_QgYqpvrVcq?K3l66y)99{d3zKX_4 z0ngN+A~sfiP(W5uV`_&B2((GRJYw|H5Jl9`K#)-KQ66Bae12JC(S`{ zd*NkrR_5ubHt?4IzhB^Gd$Im#TA4Zg=R2p$1O)tBDt zvskAcl(tXT!nO~rJ=CeQ>YQoll7-K=sZhvFY?wGkDLqj9Rh@uPpJ1$(KW+=b_W&A` zUFns{x)cAO_$^WwVUIMEAtdxeDFl^RA$x;tpLeGmcA}&}U)56VPv9Pu#Thf`dfdle zO7M>|Hk8%qsT^C|J1}~j4j2342RwT0tvX~5?aJ({6EVK{k;c|CzJT4T%3qvZzX=M<>cP1S(Eio~i z$$W9#Lw8Jatv=gn*Yn%w^%!rVAA#p$vH*+G}(~d6i49HB7A0C@)BD$B=g$YD5#~a5?)R0 z(M*gNVRwG)6+bq25t}^fpygY=(6m0vE*`@lNL8B@Zodkh55BI0Ftv~UyPK$oOg?P+ zefcgt^E)jxZ@b4fw|?(?7{p&|_*6M1lwCiesl5B!_P0y9v_k#rd(FNJw1Wmt{^-JE zvc3@AQNd7wb2q||rWp3IMDs}1K!UpwyhZ$GcdVZm!ct}R@0>d-B~WiT0bC3B@!XYc z;1!!2)I}m(rJ3dcM4)0Kv;(4v|DlfEv%qGjHR)*A=gY`z$_u8xnrf!+a zWGV76;9Un%L|PlC1|n+-HiJey^NGTE`YATFA`#OhGY$sJ+xRs+geg6p64tI4ZtFor zpu%%(8g6Ex+~%fdbq>W5HrKKUnOdyJN(MP@>Kts6G9SU;*%ZI6t7|hT#Jit&-qSa5 zDpDJ2@u%+8@fo99T$7Bq3@M0jIfv_2R9oW>S#g~+i@Y&y3nEl{f$P{6?K4xfyof@7 zZSI$wsvYz$Vaw6Wbu-dOW#h0WbrOtjj)>3W-_}Qc0drqeX}VE*e6gvrt;03;tz;VP zdcL)(@QSXo#I7tn>z6mXSKyxbG`^d6ft4VT+fN#F35VG*-)E1~^h_~TQZ^%5mqG9_ zwSpYmuP1~nWlL_oQB+c!I~IA!TGYr=c<4i(kN@nSCgr$fLs}s;3Y&W>h!Qt_$g$$` zX>Eph?5Wx={d0?GX}X??+i&P!jZ!S0zZ=?Bkq>Hdctfb4+hK|N6NfJ|T96JQTeHs6 z_6Mu?((6VW3tY7lDfKzR+lNwvTQ1dMT2-dYc%6Uwd)7o(kP+x2yxW-&4EicqXVG27 zfH7iYf;P|wXBsp5xJ6jC0dOm^B8`m=Qi$tl%MEE^y6(7|%C&YG=?4^aC!P+=m3M!V zvAJ9~>bGd_|Miq67w&!CMdFTDvWe}4kQ^;iXfeHQw6Pxg$mK!H6&uc3Oo6jkZ;M#R zk|?Y}S!GH)V5$d8o&>^S+uOa+?V8spHHo{jMzSqncu*FV|Kl6;Dd2s(_J! zy<$JS5?LoX6fJw)mAjXHTIat1xmc}1p>h)L*V`jSL7mvK{`0|>W@W^CgS1Ce=&r?( z8sg()lS3j#7=mE^9X_m~K_Ncu0-uxWVBeLmcA_4YlMnvn+YQ|-5_cANq4H`I*kghe z%w=acrO9ClRz+nmZ04vJVY5SP(M}7aI{s8L4#0H&Jl9A-EeS4$MhbhMV0Q$I2PK6q zCnmdm8J()v>vHT9mY8Sd6>1))L2_w8n0cJ(NhNRT^rWy{O|@^K`<|)?Iz~i}(p_g$ z)4q=?s;q?J8I$Q=1a?-PZ}8=0lZkzV>ZdaR4314Edn~~v7~^U@SA<^Rfn|iXTadns z1Kff*l0dsZ&Xv=ju-rzlLma|ZtZtxOoOh=7uycTD-gI(99-@nYl)3eY)iUK*C*NnR1FZYtHR;>k-&NuZSy zJh^Me);jm_PCs_e5N2QenaBn<+Z21E>nvQmeGMSUtbWeW-+rr&Ef0 zF$>GdWsL4O;%GR`D-YSH2nFrn7 z4Sc84kv+^F^P8L74a+LmN|ETTztt7lw6yC`uRv^fo6{Z3sx2#NL4F%v_ayHaKTzTH za>}F}hp`jz~Tm z9DgYQj%Q%FZ2c8+Izar)r*Eol^!B?peH>8Uw9F#ZSpV$dv&T)&ZQGC~DBZC|CRg7EbkH14Ar8s#Y?%Mnw9 zwLb|G7@gawB^)Bjq{t&Zmz_OD*o%5^{m+z@@+BH>FjbY42Ol=8|1FMPjqF5 z1g)Y&*5Aw)-VtAq(hho>FeM<+Wo3*|LNv&T>Y59`X>TQP^WLZKe3e77vgxiG>s2rR z=fKyOMPUg-j#nxC1_#&%F*W+8(v~Bck`1qA@8ez&JNQTa2tB)W0ic^Y%x{+MLg9jz!=?2u@W_srm+aU#wYC=VJ#sN zHO*Ro6<--dnkYZNJYs*U&vMvs4cAU!?Ql*<7c@9qF2+Ju`=kI@1+e$;dWX-vYXVlQ z88h!ne&)Z6pQl9-bgwwnB-N>p@*8Clxt!CYqR~6>eO;XhYt&Vc zgIQZFzgzZSAbG)d>jBw$f!hNp+ne0*_4i=E&Rphv;q99O!lI*fBl$dFUP(P{_Z~U_eAMqEe;v`AhB^a(G!tg(wrm zvY9^b&^(4S!WG5rkrd{}Ju2gnbUeqJU{t%qVPL>`ZH66uZw{OwfFR!lH2(QvTWFjv zGfy|V3yf?bt5*X%gWv!X)R8P_JIi5FDl;qKe)$O(2Xaz~IHRTeFA%`_`k{t<<3zU2 z-kdH60<_J#(bu_U^89^#R>x8|%{l3DUo8rDXXlBy6JXDHfq5Y3# z@&Q{}be_f^0eas*8`A&TWdG082L9__Z}Q{D+ySA9pgJvm!QW9KZCB^0axMJoV(4sW--x`-Hrv@D zOiH2ttww5qX3|@x5yGr8B~z{M?V#3^bg%ScM(P!!l8cih7wRAx|HC-6s`IqYuk4Yv zzEYc5f+up;h*+EgMjW}pJH$!|Ph)XAkWF`6lbO2s^XI>b2ggS9@BiI$9LWFWN6L4UO08?zLp18wj-{r<Kf%9QKr@ zo#|4Z`@fNgw+a@cEigYI3jf=CBq1C%4)-5!J8&5c?A`@J;Xlve_~(Hn|L@w*|H^>= zl@%nAd&N{OnMG&Vu-(>{nt5l(;)oKz>kj{3`H0b|-W@?0))YV*U%Ch^c04LuRU~2a ziR9(wXn@NWZZ(hj*ZRkPnvoBg@&6m!mJS6NVjR-3U@~|vZZkfpCC^oYBEa1cm}4Nh z3;-CE15%SeXWsSv%>S#nxoX>W3qf)7`m3kJ-^$uIDRI7oIqv~V8Q`c#F-Rl;GDaDl z0*K3BR>x~gi?XsMGy_j6zZ>6||K9xqXMW3xXS6J8>#q#Lm(6mb1>iG00K};BH^j*0 z5V@?}KzGg*h83jGS7h)nwy~Q3?so69Z@VHhdw+A0z9p(-E#D)>qDAkO@5_4k^i@1Z z|AJ4>M%^~Kzw`B_j(JVEI}jKFQuF>Vd;n;9L)KqoX76vtvHGuT4@o^q&w1Y6*>}z@ z7No$te>C>ypA>B&$w6got|{+3z1L2qQ#=p4h=I(0oZ!Xnwom(Mq)_>sJ&cs}H2R1* z@OTKIeDkJTPBOox8kc~J?*6FmeIU{L;f>$RzkQ%BD&^8!=7Vw@iyF|6a%ykCZ;Nsf zR+|8+CC&h6FG>9YSnU60C3B$x$mej|`J5And0zgf>qS#BF>MytX;0t}$Ewy#Oi8jBzQ6E?>20&+?!t!2-_!G=m* z{E<_VfD~Fp9#}W4l1SfQ?GyK#Z31EYyv7B#*i53KWE3!0?Eop~+4^)GD--* z9Frh3o@uG6uO^;0%K`HmI+yp~d@yi30{R2J?@`O4h)bM4%&zLc*x^)W&9mh#AUp0X ztwf6&CCf&Ua4A5qe6zW^NoAb6^k5%{ju%%ykzXIbFcv7UXQu5=E3vV8U|Gc1kLhi@$Pm5T=AICVciG9G~ z!o+`W|Co;d*=g~2IPpAn-?u+maJYJQD*2DeVv;6ej8s=>Cxa@=>C=aNZ?&tzb^~u7T#j|7MI_|B!j8t9jm=OM%w>N7gR*CHcO&Joj%&cXDn3*;^XSpi*gV&nch2SNkK zV^&UT6c!_5n@K_>2b-p7VH!$O;M@9%#{KtiTU8j^em52nxSw~jaX+vK_Zxwc0zZEQ zUNZb~yqbAf^vCh2!vDT~|F)VO91hR(HY`DPqt`d|le5aZSnB+HuFm(DTl_YpqPW82 zlNdBPlirYF^PAlKb2hWTMc9o?V=lEUy&HRce53kz0p)b}=o_vFTP3KeUHe;$m?bY!_g&Elw}3rsfFen_GMVI!>ALkh_a)rlfk=&7}C(?%}phlLsTwT3tm9`$ouC?f>rQ$Erom85X4 zOT7u6i~YkX22_^jYa{~Ew~nxsA-^>eam8X#2}1!01a=J)rTaKW{$3fgfP5!d6IO8| zeYA4Mw7xx?6DQ0WJW6g_XjvoJii(mB*=2<{Qa7#&2^;sQG!eEB zX!CdOtYKI02pPc=^-KNM0>W+_t0NR@)o}bF#3X9=Yyz3LW(*6<6f#nmQR*m3N*@@k zv?~kyK2By-6TF=0?4Y%GmPrTASr^oe;!X#RD6ORw! zFE^N;9IZL+gSK$SPiX6JgoG{NdvJ1bFR^h{X}b2NJ3d2HF774hV_f>eD+sq5 zb=j%h71FR>7w&rus}mmd_Sja^vC7p|hwE@j1>$FC z4TaB&ibyGFh1u zdja(r^j=OL(W!ZwJ>^Zjtt~s$?9LwgTox^l+U72NUWBuU;;4}*Y5irnG~JMF6||_% zF-J?U%bvmI0UJsx(SyCTXdI;>TE66aZ}rm$S6|J%S9ilx>uH^2F79QHNSY+e55QHJvj-<6-_ofUx&eCC#s3H9_I@?9kFreE;0~y9PBh=J#)rs ztVtVBU!8>_Ob(7n=c9ZfG>(R#z}bYuqDMnQ!eKe=zs0amjr2QRy+~)~|GF`R(09}g zh$K9}iNRdvwnY&9-+to3g*A|v$H_>c!jb_Uv|!mx-ot2aMW}-y66l{6l}0xZ9tbKI zZO`ovC|9mZ&n$F0g!Y^R@7_V=Xo=p7zwL(-j3T%Xc?EMFg)3!Sl{~^`h#0cxq2`ro zB}lO%x+9I)QD&5}dNV9I2d!I@dGY;f3e;2aN!Iv=E`~RC$!~c>(LcW)cWKwTE#1jJ zmgk#;_KA`y#-Eab5M>JwWuQHNzu_7|{KVPXf|tM?M<)Wqt>8Jk8QtI*Q}8_l@dE9U z$aFW-=A&@qhFZ&o#1aUhnE2#kk$8VRZ`mtuemj z7{u?h=J(%<#qJ~-3(@;hHZJ=Losd>?<*F0cKB3hl$a zfaHoC*T0UhOaM;hUN}^Bx|Uk2d1Y6NlDp#j4w~?v)ui@HWF3jVnV@fxQ~h_ zRvzt$YBy2sDvyh_hZUWB?3S@R&@#3!{&pmwC**@=5=;-tM-CBw3gXN;H2-QPI;YO zh|({~u5Y{=JpoUU|H`GwSj!mMd*Oxt0CjrnuFzUKEVNyogwCvNJvZl1QtqZ$Fx+$< zb(aR*JU<0?_}vcDoHrCVNZ0rco_AOiR+khUI@M;PP&8)5f+w$~A~7EGJIBpKj!F9) zJqaycu~>ZbIDB%Se_wT`JjQTTGMm+(AtVE+9>Ie#{(j2kMw(8kYC|0;ucN)dHV}ul z=nh~mli+RI>V;lg;dh_ze7l*meMmMMH)JtZUgr7kwf7O30Xr|1p=QA{?nhML6C&o3 zWmuw-mG5O85Vu47Y|(+(<)y+PUuV`K7rExy*N%%U7 zlc9Dm(`j=Km59yhFyhDouDCV~vg&z-@jG?QZVI)N6%SreDgR!gv&zSx#n!ogip0^? z1un!2#_5lPIP0|FUqNY47g9G{V4133U-cEre!DA{O|D(;_wl7*i6Sjm?{fNfs?Bci z7RlSR?*F=p?+7gqnCW9->k(C`#}<}~YTjpjv0m5YpR|9UNvt+_B^IB1qKs-Oo%k~O zMbSB54jvBkz+paya;d4Iho=pK3YNMaTpA+aTizK9Ew`F}rb*sE_A>f>8H8f^htrs~ zSGLB5Gf-Pks0D~-c$wtM8E`Ie{p#J7|E=>Qb^q&b9UIUfxD9lNA!`_QcK4Kd^gF%3 z6W*-&b$%k=<^EALe*(sGmoDXzRmq)v0>nxGvq!6Er?UCM9DE_v_1;-!Cn2uPS zE6~$XZT{v{qc=4nolsAQNrAAT=4%k5f#awSj}Nt3$IIg6O>vo1-nvBWAF%s*YXLB8 zKa%FX)q%90BFEI!QttQah28QfoObTL&E<^CiE>%H9XSpU!;8pSZSghsNR+Fb(Xoc6 zjHB&TJsM;E!ggblxCz>=?PX=-=NAS<9CT39FH-sqFG0l$Y3o{jr5Prdb_Rn&tkI`C z2-7_sOOoXq?56&Q17TIj(u=-HvQE?89X7_SDZk(-3yL^PloKzzy_%NfG||Dr&A6+^c=2Gw}McFbl1@=HQK3s*G|HWGrVw-N1&3`r@~Z z4hlRvEIRrxqPFs70q5<_J;A(eqGJhjTXQ zyrs}|R>fvP5)HPH#O;#3X`1V~TESSqF)N+6EV0p`!XOkH)D%1S*O0Y)tum0LZAf6C zZ&mOg&8}Im@YiU@egsbGb6(H4{k5h34vWuf^crHG6|q9cziacm?g&jbJ@e!?7pt24 z_s)3@^y_kJ)BAK}MD0N|7Mb8v+8~u=Re4sm^fY+Fy|B+ec``K?SnNt zc(`|I5$YQbrRB+ar|hkoxLuvp-gUwWRS^*(+I(=?^0UXAM!eaui^1i^^tNPV zba#wAyG`vhw zs6Zx+x*B$aji{IEi;PKLBfmi83v;>8NLBXDA|`uwt3e(kh3O>CnBz)iw$MsX-R+(< zr}g+!=qno45YOxBZGN0w`F)a6H91OA%x!5)! zVscV*XYMXz?}A;(=%{3tjcT=6SNKj8k#KjG?)_#c3PZz!5amR@>h{ zSr*#;AX=1y;RGd6j!=kMiT zR5mt&{cw@yE^>K!xUBSOuy#=IKoxgadTqan>qb`w)5}(L@mzs^`yfDx813J zNOd@j9Z#m3^y0l1#z(6rFRvO*N(daD*jLh&aG8Oe1}wm)76IdK?!1+jJ%f%y^+eur z(R5)RK_k+u(udv`kb~?(Tr?Q_C20D(fh=6N&cydugI@gB9NqAtM(7*P%WD}7# z|K^?l<%O+2&8+sK+c1b5VrR2KedA+2>P7=$Z7h-YzQyD`Hs-QLt=4__Z=iJ0a0kh+riXB0mP0u#Cdqc|1aIhLK`!;wzwfxHVh(WHOo7b`r()P0N-Yru=!2j3+ z4;(gtgw~jx0^9BEZ=;yjz$dy6c87cXQR82yf6l1#+&N0I&g1+|R~AL8?f@Wg4Np1@1qtVmIY? zNE^3-PMjO9 zykKpWknI#z`kUdbu(zG$=C>g_d;#??;*226~o zP9v!|-q}SB(>Wp~hy2G}FH1*!&ybCt>fNLl`JW3}lRYPcI<+^Yu0S#MsBK$y78$Wc zE}#l4!|14((Y)YtwewY#pSHJBF0lp&;cc3s`Y1_I9r+Blj*hI8wWAm&gdhi3=yksO zFL#ZyHb>QWUWq{-iU%OSH>veo72q-KL+`5L{O;OA=0Chhv9`Ccn>KQlnz*rtT{HT6Z& z;NjW>Tw|x%!B`tcc=vbjCj*Q?7oXvknI0v@m1a%X8dktaOuSC`*U^_wNa}b!nL(tS zvz4ihYdVQ8R?uIxKSGmMXl5mi4v8CB)I8+P-a6h3(ll-Jt7U!uTxO*k@h46Q93N@K zCTx;e@In*L$kCxigXNkR2_vr7cB)((h@};AFX?O>E%4zXopOQu(Ljb^R2c^ z>$*ha|%W?A)KDbvVb*Tl6wp(HhU&Y}v{IUf<6#T-%AsS3D!NFdz&QTXS~19ijV z2G{qwdPJhj5z6zARKD5^VJegkc+^ax{C5~m7FAE1;@kA!@TOWA3QlfqVBzc#&;BMI znf-YG1d77yzyOJSVlW~Cj+QHCM_|0(+`I)61}hD@S|;BJ*@2Z6R(-l#diw=x=O$@1 zvazN@#a-S~gnIQ9&mrGx{Q;X1iR^B?p&m=A#v}}h&sXcB@K@f%#3b#T7NyajXmo4i$gl~pMYFgkR zZ8Dg}^<9m&dB1nk&M*AE2$WOWRXM*2D90WaebsPhZA{}ys(gBhuVM5d|77YCpCXa% zS8maFaWrg`QnmJ=xK!G~+PGo5CKt_W9R8g2k%4kF!6%yReKVxSW%eFc!DlX(7Z-XB zm#fT&O)nZr#DL0Y90d85mBug!-O`E6=I&X2*)v$*eJw4k4U7yB$gI~`+HHuycwiK` z6c^3T7@tr?ZjTMmTC%A~IfE!u1ig<}pXV|*mo%b&>&&qdG?L>I7@&x@+;`5{E@HvEb{H* zLz*9`=v-D@dVwNmmPFQb8s}{kL zh))-@U7~qSY-CAfXiC*W7!i&*nE2R|hT!Tg4_12hUV(a)t_ZpLb%E~L>h@5fSyF@qj$-L4%TORT;XZ1FB;+ST z27=tDN&AV+g=O4KRCtt2`%}h5c*0zNr`oy{RO!v^;||07iAAg|6&@)&NH@Zt9a|$^ zDFbdv{ZQw@BuJlb%qGsfqMuJWmsqFi!{-(bTc#TEAs=zXG)8nX?|Bk*R8sof z*byK9LED<5&n@o~Z&Kc)T1>;mLnBd~9o!H>UTnyNRZ|sk-K1`Y-Spc>!h9dxZajgg z+O8ncKFBoCPnpqp3vq-(Q7Y|xueu6k6xIuc`a|CzBE!aGo+sUOBA;+PZEEI$fB(^G zr1KX6SEUnjf8INJIK9(AGx=QiUoqZYA2yi|GLh9q4ZQ38^*`4c;j7F4dhcnuyjZRw zO+h-Xxr~Bpi(8ah5u}^W&2xux!{!u`xx0aFrp%^1q{x+?zV<=X`>oy-FS0To72;Gr z+Fy77cD`!@V}_IyxtMX}udEU>@2`>iq$~pJyt9&>&>uYHjbU#dLzWS->fB}|w;x+| zYOqj0zwW?Yn@Ha}1fu*!vBC=#mU-~7+W2J7OpUi`2i>NPrB9tIUqAJ8Bac6=s)4BDVAM(OA_mwAJx-JR3&_XY{A! zR&rjko*iPTuT9DlE@5d`<~g_QKCeTX}xTB!7!3yUt=T%TX&KfU24v^XF+M2XlL z46TqJ%H4$vTe&P8>H*R#93}ph0Utvq$VDeCML$kV>dhq4CwFsyuLIJ{y~w)&zcKIz z2>*hEZ~LMHChCQCz=;&PR;C1ztS`PF=^vM(;OI7Hm!5liKH#HI|B;6dNT1b_NJcf` zw8+?iOjhp*fx}WVYg1oy#;!p-rO{xSvxg3pO>SKB48h!MUG;3ngJVZkiKau~tne+b zbz>C*n?x>FdQ_U<5^XYraLy)>GKRnnsHj>9qe792j++y7s;agj%r4N>*aji@j_?MhCE!a&+d~HKpgm{Qb-6U04Ln-IAXB~LGvx6(uU!l_ zvDhDnt3?8tnA_qry+vY31QCodJgSl$V%)ZEfk`1Cc+{M(i~7fmfHDJC&1tle_T=q~Zoi>KuvfNX2jhdo++BZ)u z*Q#JkQtFc9+xkH0y8a`~6tZOk4qR8EpbX0ca`Xo%HS%E#dUD1uDiKO_7tnehZ&H~( z^t<28^6q&jX~#$hC$5yam6Lu5Pzm_GLDW^+2~y)Tq5EM}W@P_}J98;hjempdy4P4- zP#p^Mg&Bw&W!!T!3;Z)^IWl65XsQ4N8fg>v5dtU0aY&*(`_J5iJxe=h>haB?|5907Fk~|G5g~Q}<$bw|SQB%Ad24 zqmp~m$vP&qsZrm}j9fhtlU|!h!aRJPFwwqq(fu({>b6`RwMUtTwPWA$zlXVXoh%0_ zNG0$yv>Q{P%6!866GYy+$@bQBdem+)Y^Hf*bHQNrv4|;*J38AAtu2f080_?81Z_4hV|U16dGwZv>B&oe zgXN0+^*TOtb6B~#j9xgHz)G2TgYiYkv_x6H@<;CzC=;8wc>cgj=b8DxKl&vOR$BDP zt>w{Hmg~g)!uW7w5b&MqeK#By;@q=~HMDqJI70^bGI$R6bXx~3{eHOvNzpJdi6^IR*3K1v zc=KU?uIv0_8_@+zWH*=L?B}Wa@J~d=bf~qomd!v!a&qGme3wz_sX14SNXAK9E@qk| zln>Mm${7GjX5Rr~5lz3zJhiBe*a_eks#GIjMG3?VgI40Cb1RRw;#K1jtl4MboOJK6 zPf8J{8Ei7gMr znq5U>$I26mVF2A4)+3$;^G&94KD4I7)uhh7PZGmzVA7bfPSl{_B1F_M7>)lcYLAD7 z`cpQi;?jN#pL}UI(PU(Ig%)Zf_Fz)4fq`aK5ZDjo8R46c3($5p#8Chp^Dt*-s7_VH z6z;oy$uslVSK2r0!GzgFqXFq!|B)jG9O?MUc_;X)afKLjdb>VnCgP4}x_|agkv>H@ zUvzn@GAVy=;Jm7+=~$|N9Je-MpC2KsMl-#SX-vZqnV6x*urY>fur|e|~ zVUbj{s!xkeT&P#KHH;~^XIvC68|4WNnYsCXf6ZY|IMranh=cMJf#y1q6K5oP6GR@c=GTKF@65&$RcZripS2DytUK*m-I-hW80>AL!K z*!je$8ycngINMN0o@%Fdj`YX8p+YAXJsu63*qok5|F*-5noW&TOi^)4p<6Ax$|f^B z53rV&9~$24ppPOEc>W(?uyn%mf-GesG+gXhGH)WV?q(jGzZYe20QNCqYb8_BXDmhz z8nUx1`vcN*?jCHuMkNfbKDlY(Uow4gS$rrw(Mac5Oe^HFZC`Q1mQ4xv;YKca_sju7 zP*n71n-8{41D^OSI&FM%Tq4}pa4u2zzOO?wsgXSlCY+F(2r!=L&o0H1)=FN#X^@Mi6rp7_!?wUY$zBJ_xSjxj=<=nW;u3z zF#^~0^(IVa#{)=GKF8*Ig!WXcykg3K+H@(udu>i zbLshBF&UpVP(MAN%g)Am6?Q1?8+~K~9u{|6lt~;s6ptl+RchP4b7o5pjN7sdHKn>s zT@m<0<-bs2KL7uRF6IyDP?nQO0G9JJ=J@t4d3pIwnG2IgCbCX9iDo98t0~X~(3_tE zIGAYz)?;b^!&Cc?UyH45j2|DFOoh1~h|nkd4|X>d7NxSC-HaadxmlmGdW#r{qW?;idFb?V2sM`Ky^C^y!q91VvIk8b=ry; zMNLmu2q11;BK>yeKp1Rqi)jYKV^&&7VdMOnWd0ddU;t`hI%L}yv)Kmgm=<4TMzNXK z@45zVeG`PfV*Wx3{MWLF8U!MT-Au)zzN~#Yr22gfK<|D_nv3M1VVH0G!Lb!S%#?Sq zohDd&-To)%PK+K?t=kDw{DfX(n3{Ee{(NC)Y;PO8%K1e=V98@S7|ILIJ6w{Jle4y- z3cPfcBk|7P;Pkl?2`Gh9<>;qjdK~fX#^b^T)X?cz4FsjM=xH2iwZqh=y%M z?-&ToicP3=vXY^4qL7U{7`6{A0TDUxgkpc$p^fdBy8@lvfyD1u(ngGC(p7FOPl-1z z73*dNcj%#1K1(MY<{UFz`6{2M#;Mz_+!Lo8@tj=^bTVXsR(HiCkb@I6CH_d=HZt$q z;?63a?z3;S>QQ*Yw;1To-Gza&e9iUs^~nm2f3F|pgzIw1K$5gh$~)kPB>s(dF}Cun z9D0EOe=vF@clF=PsM4r}tH}`J*nV_;2~2GDzm`+328OsIk;vgl-pH~YcHIue{`H?R zygRfUwW@>Fh?+)5DMaluu|_5P!dLa>_gYUXiLm&M3J|NB&8EeXS{4Yq36ezrk`DzoOL{ie z?WFFVFUfnhQ@6+4yAEUTwXSuUsWvHfh0Q>iR4e;hoKA2pZ% zrm*mAN!~5e*4&3>9(IqnECyW`3+!|7oUE{ZHo*~SOHoO9sbn{$W&K70k2&7~*k(Ex6y?`DcVxNa>Jn2;5|U3VT3UTFN&r$@Hy{@!Je@%>2KBb<)u zf-7UYi%57#NXWIzB=lmnLezF8)&RfXUIN%a)Jvoe->I57{j*<@W!^T824kBY;ebUw zJy&qXGaKch*tiWS1U&QRH9(kbwtk~**sbs=JTgk<?@0T{M;_~5~Vjkda-HjK!fyFcaRQ)7JY?%r0(o@1;a$>XOXv3Gkuo^@C>{NukL z;7{-E;PXnF4wr=Qy4;UG`oqY4K8l{niCMYT$fR&9NWZFK|7m$?cXxN0jR3wqA$+M` zaP{ctouA0%JHNk7{-U%srBf%F4B7I62&zH<-Qb5BTkQ(8c>kOD^`8KIv%l`#`9lu7 zvy52;kolj0d>Alr`;*;CKzrEUJk}Tng8}Vzexml3pF%gELtD%ghRQrGU4MC%@l)7S zE~vj%!c}ElTi*Vm41lK%{(1UAwaMV_QVdj6O7zvgUr^c*BT83#4b%fxPyGN>-!~51 z+vQE?4H$9z;r;>p+5xm2FaYrZU!C$-tzL$AApgutEM+NLwLRi;!h8NImQVk=<2PkK zul@TjAZV9H08N1U{Z7O`U+^b(Nf3J&3n+r;bQ(;(QCd!}V_LxC^go@hfcIF~cC$mg zO*$5$x-S#(Xni+)bFh2cfDHZBTWGU9!imxW9^o{i$F@gFDc1v=j@vcv^2Wd}E-rz+Hg>?S zi#I}Z8{e5{T;HaI3*AI@=&}<3Z%G8_=H>=+2Rx@M5_~pSCeFisQ zcPIcD9{9}v&P@X@|DW*oAE*Bh==cBJ{%=VE|CyHmZ%#|w!8Fp1sa*~hD+9skPUx&0 zJWI#^G>k`~s~XVfb~mq?n3zl`U{@62yIbqPuKk`9BNva|h{48=Z3u&Zc{xL%$|ZnY z!|VAaH=811h6fh*TGxNLZppnL>{7S(TW16DIyMnVC{e{M4UhbE9crsmRAF)~7AtS1 z@1qkd8@;Iba^I9$tWFc*hYl=|0F3M*7x_ow7g`^M-}deQbRq(cmWQO$ob|0tf0Jx^ z*H-pk_h0jSc15ozcCI%{s%9P$6qRcStZOPP3gxXi^}};7C@lP8#fB;o$@>LG8rhX& za^S7bm=(vf#(`s?Y5IyToD=F!=}k{Nz(pRCq|cX zNI*r!?*Trk&(1n*FLXGn7WT_otDj2jLvhX7xw$z&t6U?nVSs;*nGnN&^(puhaE}St zJ%0E^oWcZYwWG}k#9WzdgJUox>(QD}pW$G-QQrB_nBSM*>5Bhwvya0EO5JmO06go} zw8Mz_RXP-!4MhC0W33*xfbVwEq|_Usy^GQQsA_vNkn+mh)4kgzPwy=4;XnwCdt}L} z4q#Aj94^qrZslX8=!UB|ypMHQ?lfBd82SW7Rc=3+zZ8>NwBK?tj6HbV6CWA2Glk%L zxL!WBnNYWp0Dw%W(cu@m0CoQ}mSeg&aL5`N)&6gA2ZPiyqE2PR=XZ-Y7@no44XhD; zx1zjp*UR)W785qMn00bLau|2E=uO}M4E zlSel+V>mlY{Rwo+eBQmD2B}^046)08Jiy@$E17Oobw% zfV6xs+HJteu6ZRr0Y3Dw0cDi`PaJpy;EVb(D~-27yrv(BuxECwgGKg3w+6FiVR{mD zk+3U~u6rKZ+|Me6@eil~ok4IGekMNZ@``x$&xlRe0RZ$kKT93^d_BK>RPPm?j5}u! zGDj`bd^f5uqZD=YW#DjB+UnLfxu<}z^?}K|B38KdPdO0k)qS?;Sc|rDa7>#QDn?>Z&Vo|^uP-c z5LK@6bW-G|5NY{mw0Hk60F>KgAE^IQx|L#h%!OIeSj@mg5t63cUVU{zq&5;lWfGR{ zR%@;SP*~|88mxmkJBi^7oBSNJ@@IP?4uG!54<80$|8gnP+nZRJe_X|VS3I-LwdcQz zmPEf(4U1gqvcG*>45W>S6$VO$HG}Sih2hY^buhQ_Gp3~``* zjQ8YsyKh_`7;NvJxv;h~-b+{<_N`7SNqni4(7j?$TreUm7Wy&Pzxr4TQSpGcl~J4| z46fGg%d9;;lntUjwhS|&=jP_?C9?8%g;Es?!yNqz!&_X%JO!xKcZkQ>vVTuTXY!lhB1KV zi`w0~C9S5fpUzaK6c!f7QB1V`fMB=8eB^2l00%p7stKy8vpXY$azp00hx+AgPAz_U zkn{A*)7*guSN&TOl7n{ro6F6YLb?5e>+I-Ac`>?LrF`^$ypM8^YmaTek=m^@kdZ*s zgLJOHs-`UnR`-)Ad7A=ly0A^~Gu(Sm4p7@$PLoxXAyvFD_qQ6c?jJ>axsesf0Bdr% zvaB=6Dl})S6!#@372s^iD3{~cYQ*MYl>~C9(UR}@)n!kZo@02}9b`@@yAke*?PywE zU^*TW1|YKRmOjo)Qt9y1hUi60t?zB%=&5W)15`5`+NRA1sF&3{gb&gHMbk$qWVlbA z-IFZ~r{XJXK9R9R`IAG4N-fa6GODd7)sFY=DJcq|v@K$!>8$?ij}9Xao^Xxc6Pwsn zX#B8%y+n#X-T1-xlHsH9m*%lL&&4j2*vN8msXl+a)@wAeQsUR+;IW(i{LZ+XmKL12 z$msCRg6$s47vE$<{myKt=P@aNh%V=o``dgy7}vKoYEcXIgFf0f7dQR2Z}&4mTAh{A zv+hSZ90M@6plsR)3r(#?H1P+~F*+rC7VHupfy`0FV%FCMQ#Aid?{Dsn?i;P{Vm|#c z>T{Y%y>?>stsT(qSQ+Z%njx`R<7*#*u<>fH`3sk#=ej+Ua|tcyE1o})9PN2%tRV$T@G~m1sVNlHTp?@3D^`xVqOz%rs4lo}CFr8DWJ!DfZIn>7m(U#H)Ndf4>%`L<;H$+^lT@fY-*IHS_+ zG72|bXn0i0;B@8u)(B&_`f)+4r{xPHu4rXrN>`K*Zc&wJyz3w~rq@zPW1vifZ6%Vj zX-pI5Ge6e9_2X2})=QjMx?b(iwauIRs-s5Sh;mvA!>lD^Lf@T zVYOxy^A19E*6uDS`MrYU1fo)(+L*emFXULO&p1^scS$}hxTf`?Z!$0TX=>^j*9erd z$@S>seuEM#oR*c-L2t{EnWdcAw)xz1czj>55o0UP1r^{VQ&Fj;O_yl@$r(#F;tnQj zchrns|34XHW-nfzJMc3b`l;aFy!7hR;HEooJqH=#HmT*!6uM^0my*GCS;9Q)IaHlI zxyg*ZtZvxwzB0kx7i&3uUhh?L#O0z4`Eo}B#b^sXlsow>#f!sEmIcv{%W9c)3r_2Y zu50S$Yl|{~#i~?I@)H%Jt4n=6d(L+e*C(pdwZbe)Op~XN&uo29nQ7q}H_*ro^`QbX z6%DuZw!Uix@1dz5L|h6DBU#9yYu%QHQ0ftA#CDBM9s<9pWwcfObwQ+6W2RBF$ZoaVDSQHrYKSx&d5RH~|`z-lVwLa&YXVz0Ua z-v`-}Ux~XFggkE-!10o6OCc-txm2jsa7C$Z*?x{Z!s1Mo@0PGSAH-(pqhv`F!eW z3#yA$BDl$oE30$x5%-}tEWRv34?FWrS~ji$MQ{T!%R_a2;E~N)nu0VAD1tZlnsXFK zu2u1>NfGZ2nPMu^h!)dc8|qXx&Y}1rayWeRb$x0KYpL=8JV35{$tmLvi-_wg&+#vw zm!w4f<5qU*pKwvsk-N+y_r}x7OXCE zeS{G@Ijo?G|2}NS8k`?`d+KKa%=q`moKHSaL(J`RZS)m@1LlZVUUZ}d^27q z=vGiC@7QNcXQc-veD7VBkj^+dKqg1XF+ZbM0W2pZ9ol!K7OXEp7soX}m8Uo7Pyp&@ z>Egy2G9$Y`=rVmFR|)ofSz;$wg)lj~lLi3zIp*DCcf)q7A4p*MjgfXQ{rSW*|IEW1 zgLJIQJ;Tbd>xVT0yZ<)hWL8Z%toM-_Xdjj`Qqkq)*j(aE{EU%$VqsxYjoC^ND;2X1KMveAfUp4>+HdIx8wrpyUyWZR~|4r)( zHPwS6GWv}?4^W~e(eHi`5-bjkYRH&{Dm`KTQ{u39rVCS-og8{hy9M(T0ak-uxooZX z#oOg$Jv-?LId!g=U5HW5D|IJ7=kRM$m3?m>*#7lEx)k zMP5&OPL)*`CTKSj;ncxj`V@6AoRoedbR;nKA6+BTpuXbfn%Ql?Ho{Cx{gBJnaC+wq z_g6N_&aUEkU5jK7Ghd^p0@kLRHVgGXoK~gz?&QfI;e0Sc9sBm13W1g0GRW^=zc8@U zan~1Xx1j6e-o88MJxeMQijNKd*D|yIk(lhaZK#vK65xBwN zl}-3u?&MS-96GdlXYUvPTjwRnNFihMq}pzv5-xzSf5;dyP^-J}eYtpMW9|xJd}aqUBU6kQ!9+8`!Ljz0lBQPbXcMe27wOAtdl3Uhacq3 zTmD?#_CkKJEMa`SuV0<{{xKX_G_9fQ`K`7HHg1uwbp!3m*fJYr_1jD@vT^)@UrMJV z+=d#Y-*6H>qKr2SKD)iPmId`$h&6unPo&MSu4MCd!@I720U%t#n(5LW5`y6t!68PS za!F13a_qQ1=~}m$$vA|0Nq8BquRqiZxlrYmULj|2G;iU<_nSh{RFJ3odB-pW(IxsiH%ByKgh!gxVCr$|9?ZbC5#;p9z_(ycVxh}D?ip`fB$$Nj&wS=yp5nxJ259(oFJPFKwYOr8as_LJPpl>_xzkfMi?&yPm>Am zEERP%+H;)Ipk%6RAK1x%NM?P$@(SAfMbDiOAZ&V=3s>LzTI_`_Hqy2nRbT(Lf-pJ7 z-5CsXVJzo85KF1mT`3G+$Mpxfx~eA?Rt&+%S}=F-DSu{cOso_=Nl;@qta=KF zcXoPqcsQOOqOcBfhi~K$-zn>(nNnTcudb257cAiNYP~A*ONV@y!ZH$s9Z+=6(c;o1 z0&_@g#T=1Sevb3-bL-V0Ebx81F}mbUS?D0s2?cNsXxE)C0HZj4R^jg9lg2t02$;Cz z){=oR5s=^JXZX%R&qqfCaFtT8z)SrxCYe8AcC#v^#?o2%#nmT zQRa$L5`XKsys68M85ERozp4tVbJrUy8{aw}Gk*vkda@zvR7l=JTb*5aV1ExPvkbzG zw&aHS^Y-#rq>y+P1{N`O`Swq>%qb?-r_Y1zj;pownmo(OdS8!Ez*BOaSUM&`>#>9T zaRENQ-sJj{{Ls-t`8Q!vXA44yv`P~m$pQ6)Mx7`VY`sV*t9nUrt(H@=Rs`+8q2W}keoe7hIVa(k*O2a(mCd+F` zj960%O_DmTPN=?rz^a@?nh0v|t*m}IRbGuB^pkP4plFoDQCL+u(?{sP+SFoJcq8{@ zS$HYMqTUa>kUDOT&#kc0%0yrCcN`rc+EjyPI9~Pdj_=XQOlEot*Y@=tBbytBu}G1( z3h`dT2K;cV4?>%;d?FmS?X3Shfl|``-WwQFI)dQO3gx{8)M|+XvZbl1gpbC+6u*VI zPLmA^cyw$m2@!7rY68$86Lb(8%a+hjUug?N$d$3T6z;DwB_iam{nZg_Pq$1J3U6*a zQF0zro8Iy<bT@ z)D$#m-$oDe(ltX(#dEvAWHh$wc|O_F>vL?s=up|Zv($xwV+Q-`A3~3P*M?J$)2@Fy zVS|x>(B`(9kUBn%2!M2)0yv=Vmf7LRp0oYjTKxV7kS-%m4DIR_{e@` zLWuvy@7-|jc|7nPA{DYhqZ`^rfK5@m;`9?9s(85@$esmv->h3sKyoo;LM^DlfmPI^6^fmdypyDz?@^ny_qyFt0&XMz-yIC&)wYB^F zz6mA4at#ujL0FrdOCiYRCljr3*2bT$-jWklkHuPYgm3xNZO5}u071n8KCflYQE!|7 z*m~ds;Ir$TUivC+YxiJ0qQ)+E1y#PVsNG)lA%6oOa&b}r2aD>h#NiOgjXU+BdA0wD zE;iSrvK6M{p&g+k87t$R#A6h?#&ej8Rft^_I&|M;z|4rbMvB^9kRpxJr$HMw4`~Ig zCw*{9E^u(I1vJrgCcTOP=;JSw(4bTU466BZgN%5^w>iZouD2X0$92avGXFQ#I}`9_dJw`6UZ{G$te zX1jOwc52>homBOWhPQVDob;c7KIP-EC;LcDG8jL?FvQ1fw z+#Wa-7J)>i;tV={y7A zI1vMLF+>nr31IUaG4rl)OX3Z~=m0iTrwC*lOyYE&-&7vEKP?Zg<2|B@+44yN99)8X zwvV*L`i^q#QQrVy0>OY|ZEHeHOG~3PhtB-M8!KSI`$>yQ^TmW^!zvLc znQW~e!eyj@A5qhu4=?p*z)GCd)LBvo1&6xQ^61Mj(nbKLUs6rrcXYLo>N}SXeo)s^S}a#KPxBjK zeiZKjMplVgzaEWm`J^_Y;-@!P0Phtkc5NdDLgb8)7v6*Oy>51_K%nR_JF@7{v%6L<^fmR?sPkd1uX0Y3@qaiRiX<}L(j zv;I{lB)Ie9jSD77;Xwvy<=*<4(hzWDi~~)k+sU{PNg1NSnL*&LA547p>Xo%@yi$}& zj{uM-qsVrqSjGOk8EAp^zmin0Eb&hj+tZkaG23mWjoKq&YFfrGvVN3$M4GZ%S`Oio z$i>C2&e*%@xaeZ>Opbuy&ofIqdqJW13|zR=8M{K(Pys0p8!nl}DC8uCi1z#2DCKO` tHgzr1J7CxURl|1V>;I<%@V=M99s3>u~fyncd0wzXFJ~>P7$n diff --git a/docs/examples/te_gemma/te_gemma.py b/docs/examples/te_gemma/te_gemma.py index 4040998f00..44a8a8eb1c 100644 --- a/docs/examples/te_gemma/te_gemma.py +++ b/docs/examples/te_gemma/te_gemma.py @@ -2,9 +2,6 @@ # # See LICENSE for license information. -import os -import re -import gc from contextlib import contextmanager from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union @@ -15,26 +12,10 @@ import torch import transformer_engine as te from transformer_engine.pytorch.attention import InferenceParams, RotaryPositionEmbedding -from transformer_engine.pytorch.fp8 import fp8_model_init from transformer_engine.common.recipe import Format, DelayedScaling import transformers -from transformers.models.gemma.modeling_gemma import GemmaForCausalLM, GemmaConfig -from transformers.modeling_utils import _add_variant, load_state_dict, _load_state_dict_into_model -from transformers.utils.hub import get_checkpoint_shard_files - -@contextmanager -def replace_decoder(te_decoder_cls): - """ - Replace `GemmaDecoderLayer` with custom `TEGemmaDecoderLayer`. - """ - original_gemma_decoder_cls = transformers.models.gemma.modeling_gemma.GemmaDecoderLayer - transformers.models.gemma.modeling_gemma.GemmaDecoderLayer = te_decoder_cls - try: - yield - finally: - transformers.models.gemma.modeling_gemma.GemmaDecoderLayer = original_gemma_decoder_cls - +from transformers.models.gemma.modeling_gemma import GemmaForCausalLM, GemmaConfig, GemmaModel class TEGemmaDecoderLayer(te.pytorch.TransformerLayer): """ @@ -46,7 +27,7 @@ class TEGemmaDecoderLayer(te.pytorch.TransformerLayer): args: positional args (for compatibility with `GemmaDecoderLayer`) kwargs: keyword args (for compatibility with `GemmaDecoderLayer`) """ - def __init__(self, config, layer_idx, *args, **kwargs): + def __init__(self, config : GemmaConfig, layer_idx : int, *args, **kwargs): super().__init__( hidden_size=config.hidden_size, ffn_hidden_size=config.intermediate_size, @@ -61,81 +42,96 @@ def __init__(self, config, layer_idx, *args, **kwargs): attn_input_format=config.qkv_format, num_gqa_groups=config.num_key_value_heads, attention_hidden_size=4096, - layer_number=(layer_idx+1), + layer_number=(layer_idx+1), # Layer numbers in TE starts from 1, not from 0 like in the HF. zero_centered_gamma=True ) - te_rope = RotaryPositionEmbedding(256) - self.te_rope_emb = te_rope(max_seq_len=config.max_position_embeddings).cuda() + self.te_rope_emb = RotaryPositionEmbedding(256)(max_seq_len=config.max_position_embeddings).cuda() - def forward(self, - hidden_states, - attention_mask, - inference_params=None, - self_attn_mask_type='causal'): - """ - Custom forward to make sure we only pass relevant arguments to the - forward pass of the `TransformerLayer`. Also, make sure the output - format matches the output of the HF's `GemmaDecoderLayer`. - """ - return (super().forward( - hidden_states, - attention_mask=attention_mask, - rotary_pos_emb=self.te_rope_emb, - inference_params=inference_params, - self_attn_mask_type=self_attn_mask_type - ),) - -class StaticGemma(torch.nn.Module): - def __init__(self, model, inference_params, dtype, mask, lm_head): + def forward(self, *args, **kwargs): # We need to pass positional encoding. + return super().forward(*args, rotary_pos_emb=self.te_rope_emb, **kwargs) + + +class StaticGemmaModel(torch.nn.Module): + """ + StaticGemma is based of HF GemmaModel class. + It is adjusted to work properly with CUDA Graphs. + """ + def __init__( + self, + model : GemmaModel, + dtype : torch.dtype, + mask : torch.Tensor, + lm_head : torch.nn.Module, + inference_params : InferenceParams + ): super().__init__() self.model = model - self.inference_params = inference_params self.normalizer = torch.tensor(self.model.config.hidden_size**0.5, dtype=dtype) self.mask = mask self.lm_head = lm_head + self.inference_params = inference_params - def forward(self, hidden_states): - - hidden_states.data[:] = hidden_states.data[:] * self.normalizer + def forward(self, hidden_states : torch.Tensor): + hidden_states.data[:] = hidden_states.data[:] * self.normalizer # static operation - for CUDA graphs for decoder_layer in self.model.layers: hidden_states.copy_(decoder_layer( hidden_states, attention_mask=None, self_attn_mask_type=self.mask, inference_params=self.inference_params - )[0]) + )[0]) # static copy - for CUDA graphs - hidden_states.copy_(self.model.norm(hidden_states)) + hidden_states.copy_(self.model.norm(hidden_states)) # static copy - for CUDA graphs logits = self.lm_head(hidden_states) logits = logits.float() return logits class GemmaGenerator(torch.nn.Module): - def __init__(self, model, lm_head, inference_params, dtype, generation_config): + """ + GemmaGenerator gets one layer of embeddins, + makes forward pass and returns next tokens. + """ + def __init__(self, model : GemmaModel, lm_head: torch.nn.Module, inference_params : InferenceParams, dtype : torch.dtype): super().__init__() self.model = model + self.gemma_layers = StaticGemmaModel(model, dtype, 'padding', lm_head, inference_params) self.inference_params = inference_params - self.normalizer = torch.tensor(self.model.config.hidden_size**0.5, dtype=dtype) - self.generation_config = generation_config - self.lm_head = lm_head - self.gemma_layers = StaticGemma(model, inference_params, dtype, 'padding', lm_head) - def forward(self, hidden_states, unfinished_sequences): + def forward(self, hidden_states : torch.Tensor): logits = self.gemma_layers(hidden_states) + + assert logits.shape[0] == hidden_states.shape[0] # b + # logits.shape[1] = number of tokens + assert logits.shape[2] == hidden_states.shape[2] # hidden_dim logits = logits[:, -1, :] next_tokens = torch.argmax(logits, dim=1) - self.inference_params.seq_len.copy_(self.inference_params.seq_len + 1) - - # Sequences, which are finished should contain padding - taken from huggingface transformers. - next_tokens = next_tokens * unfinished_sequences + self.generation_config.pad_token_id * (1 - unfinished_sequences) - unfinished_sequences.copy_(unfinished_sequences & ~(next_tokens == self.generation_config.eos_token_id)) hidden_states.copy_(self.model.embed_tokens(next_tokens).unsqueeze(1)) + # self.inference_params contains for example kv_cache + # This needs to be called before every pass, + # to update the information of sequence lengths. + # Here we increase sequence offsets by one, + # because we generated one token for every sequence. + self.inference_params.set_before_new_input(hidden_states, offsets_change="+1") + return next_tokens -class TEGemmaForCausalLM: +@contextmanager +def replace_decoder(te_decoder_cls): + """ + Replace `GemmaDecoderLayer` with custom `TEGemmaDecoderLayer`. + """ + original_gemma_decoder_cls = transformers.models.gemma.modeling_gemma.GemmaDecoderLayer + transformers.models.gemma.modeling_gemma.GemmaDecoderLayer = te_decoder_cls + try: + yield + finally: + transformers.models.gemma.modeling_gemma.GemmaDecoderLayer = original_gemma_decoder_cls + + +class TEGemmaForCausalLM(GemmaForCausalLM): """ Causal LM created with `GemmaModel`. The underlying `GemmaDecoderLayer` class is monkey-patched with `TEGemmaDecoderLayer` class before @@ -145,57 +141,28 @@ class is monkey-patched with `TEGemmaDecoderLayer` class before config: GemmaConfig """ - def __new__(cls, config: GemmaConfig): + def __init__(self, config: GemmaConfig): with replace_decoder(te_decoder_cls=TEGemmaDecoderLayer): - gemma_for_causal_lm = GemmaForCausalLM(config) - - gemma_for_causal_lm.generate = TEGemmaForCausalLM.generate.__get__(gemma_for_causal_lm, GemmaForCausalLM) - - return gemma_for_causal_lm - - @classmethod - def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, fp8_init=False, qkv_format="bshd", **kwargs): - """ - Custom method adapted from `from_pretrained` method in HuggingFace - Transformers repo: - https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579 - """ - config.qkv_format = qkv_format - with fp8_model_init(fp8_init): - vanilla_model = cls(config) - subfolder = "" - variant = None - if os.path.isfile( - os.path.join(pretrained_model_name_or_path, subfolder, _add_variant("model.safetensors.index.json", variant)) - ): - # Load from a sharded PyTorch checkpoint - archive_file = os.path.join( - pretrained_model_name_or_path, subfolder, _add_variant("model.safetensors.index.json", variant) - ) - is_sharded = True - - resolved_archive_file, _ = get_checkpoint_shard_files( - pretrained_model_name_or_path, - archive_file, + super().__init__(config) + self.hidden_states = None + + # This function is overriden in TeGEmmaForCausalLMCudaGraphs. + @torch.no_grad() + def _model_generation_phase(self, hidden_states : torch.Tensor, inference_params : InferenceParams=None): + generator = GemmaGenerator( + lm_head=self.lm_head, + model=self.model, + inference_params=inference_params, + dtype=hidden_states.dtype, ) + return generator(hidden_states,) + + # This function is overriden in TeGEmmaForCausalLMCudaGraphs. + @torch.no_grad() + def _model_context_phase(self, hidden_states : torch.Tensor, inference_params : InferenceParams=None): + layers = StaticGemmaModel(self.model, torch.float32, 'padding_causal', self.lm_head, inference_params) + return layers(hidden_states) - # If the checkpoint is not sharded, it's a trivial sharding case - if not is_sharded: - assert not isinstance(resolved_archive_file, list) - resolved_archive_file = [resolved_archive_file] - - total_dict = {} - for shard_file in resolved_archive_file: - state_dict = load_state_dict(shard_file) - total_dict = total_dict | state_dict - replace_params(total_dict, vanilla_model.state_dict(), config, qkv_fused_and_interleaved=config.fuse_qkv_params) - _load_state_dict_into_model(vanilla_model, total_dict, start_prefix="") # Copy parameters like embedding. - - # Force mem release. Taken from huggingface code - del total_dict - gc.collect() - return vanilla_model - @staticmethod def _padding_to_end(inputs, lengths): """ @@ -217,158 +184,130 @@ def _padding_to_end(inputs, lengths): new_input_ids[i,lengths[i]:] = inputs[i, 0:(max_seq_len-lengths[i])] inputs.copy_(new_input_ids) + # This function is overriden in TeGEmmaForCausalLMCudaGraphs. + def _create_hidden_states_buffer(self, input_ids : torch.Tensor): + return torch.empty_like(input_ids, device="cuda", dtype=torch.float32) + + # This function is overriden in TeGEmmaForCausalLMCudaGraphs. + def _create_inference_params(self, max_batch_size : int, max_sequence_length : int): + return InferenceParams(max_batch_size, max_sequence_length) + def _generate_context_phase( self, - gemma_layers, - input_ids, - inference_params, - pad_token_id, - eos_token_id, - unfinished_sequences + input_ids : torch.Tensor, + inference_params : InferenceParams ): - hidden_states = self.model.embed_tokens(input_ids) - logits = gemma_layers(hidden_states) + hidden_states = self._create_hidden_states_buffer(input_ids) + hidden_states.data[:] = self.model.embed_tokens(input_ids) + + logits = self._model_context_phase(self.hidden_states, inference_params) + + # We choose logits coresponding with last token in each sequence, + # which have various lengths - they are stored in (inference_params.incoming_seq_len - 1) Tensor. logits = logits[torch.arange(logits.size(0)), inference_params.incoming_seq_len - 1, :] next_tokens = torch.argmax(logits, dim=1) - # Sequences, which are finished should contain padding - taken from huggingface transformers. - next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + # self.hidden_states have shape [b, s, hd]. + # We return hidden state for the last token - output has shape [b, 1, hd] + self.hidden_states.data[:, 0, :] = self.model.embed_tokens(next_tokens) + return self.hidden_states[:, 0, :].unsqueeze(1), [next_tokens] - unfinished_sequences = unfinished_sequences & ~(next_tokens == eos_token_id) - hidden_states = self.model.embed_tokens(next_tokens).unsqueeze(1) - return hidden_states, [next_tokens] - - @torch.no_grad() def generate( self, input_ids: Optional[torch.Tensor] = None, - generation_config: Optional[GenerationConfig] = None, - max_new_tokens: int = 0, - use_cuda_graphs: bool = False, - **kwargs, + pad_token_id: int = 0, + max_new_tokens: int = 0 ): batch_size, max_input_sequence_len = input_ids.shape - generation_config, _ = self._prepare_generation_config(generation_config, **kwargs) - unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) + lengths = torch.sum(input_ids.ne(pad_token_id), dim=-1).squeeze() # [s] # InferenceParams is a cache, where keys and values of previous tokens are stored. - inference_params = InferenceParams( + # Moreover it stores length of both already generated and input sequences. + inference_params = self._create_inference_params( max_batch_size=batch_size, max_sequence_length=max_input_sequence_len + max_new_tokens ) - # lengths is a tensor of shape [s] representing lengths of sequences. - lengths = torch.sum(input_ids.ne(generation_config.pad_token_id), dim=-1).squeeze() - inference_params.seq_len = torch.zeros_like(lengths).to(torch.int32).clone().cuda() - inference_params.incoming_seq_len = lengths.to(torch.int32).clone().cuda() - inference_params.max_incoming_seq_len = input_ids.shape[1] - - TEGemmaForCausalLM._padding_to_end(input_ids, lengths) + # We need to update offsets before every forward pass to make cache work properly. + inference_params.set_before_new_input(input_ids, padding_token=pad_token_id, offsets_change="all_zero") - context_phase_layers = StaticGemma(self.model, inference_params, torch.float32, 'padding_causal', self.lm_head) - + # Context phase + TEGemmaForCausalLM._padding_to_end(input_ids, lengths) hidden_states, output_tokens = TEGemmaForCausalLM._generate_context_phase( self, - context_phase_layers, input_ids, - inference_params, - generation_config.pad_token_id, - generation_config.eos_token_id, - unfinished_sequences - ) - - inference_params.seq_len.copy_(inference_params.incoming_seq_len) - inference_params.incoming_seq_len.copy_(torch.ones_like(inference_params.incoming_seq_len)) - inference_params.max_incoming_seq_len = 1 - - generator = GemmaGenerator( - lm_head=self.lm_head, - model=self.model, - inference_params=inference_params, - generation_config=generation_config, - dtype=hidden_states.dtype, + self.inference_params ) - args = (hidden_states, unfinished_sequences) - - saved_args = [arg.clone() for arg in args] # Warmup iterations of graph will change the arguments, we want to revert that. - if use_cuda_graphs: - fp8_format = Format.HYBRID - fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") - graphed_generator = te.pytorch.make_graphed_callables( - generator, - args, - fp8_enabled=True, - fp8_recipe=fp8_recipe, - allow_unused_input=True, - num_warmup_iters=10 - ) - - for i in range(len(saved_args)): - args[i].copy_(saved_args[i]) - inference_params.seq_len.copy_(lengths.to(torch.int32)) - - for i in range(max_new_tokens): - next_tokens = graphed_generator(*args) if use_cuda_graphs else generator(*args) + # Generation phase. + self.inference_params.set_before_new_input(hidden_states, offsets_change=None) + for _ in range(max_new_tokens): + next_tokens = self._model_generation_phase(hidden_states, self.inference_params) output_tokens.append(next_tokens.clone()) result = torch.cat((input_ids, torch.stack(output_tokens).permute([1, 0])), dim=1) return result -def _get_all_layer_prefixes_to_update(hf_state_dict): +class TEGemmaForCausalLMCudaGraphs(TEGemmaForCausalLM): """ - There are many parameters in hf_state_dict, whose name start with model.layers.[number]. - This function extracts all strings like "model.layers.[number]." that are starting strings of keys in hf_state_dict. + TEGemmaForCausalLMCudaGraphs is the version of the class TEGemmaForCausalLM using CUDA Graphs to speed it up. + We need to make one trade-off. Namely, batch_size, max_seq_len and max_context_seq_len need to be static. + It is necessary to run generation with the same value of these variables that we recorded graph on. """ - all_layer_prefixes = set() - for param_key in hf_state_dict.keys(): - layer_prefix_pat = 'model.layers.\d+.' - m = re.match(layer_prefix_pat, param_key) - if m is not None: - all_layer_prefixes.add(m.group()) - return all_layer_prefixes - -def replace_params(hf_state_dict, te_state_dict, config, qkv_fused_and_interleaved=False): + def __init__(self, config : GemmaConfig, batch_size : int, max_seq_len : int, max_context_seq_len : int): + super.__init(config) + + # Preparation of the static buffers. + self.batch_size = batch_size + self.max_seq_len = max_seq_len + self.hidden_states_buffer = torch.empty((batch_size, max_context_seq_len, self.config.hidden_dim)).cuda() + self.inference_params = InferenceParams(max_batch_size=batch_size, max_sequence_length=max_seq_len) + + # Here "the trick" happens. We override methods from TEGemmaForCausalLM + # with their recorded version. After invocation of each of them, + # captured graph will be replayed with minimal usage of CPU, + # what will lead to huge speedup. + self._model_generation_phase = self.record_graph(super()._model_generation_phase) + self._model_context_phase = self.record_graph(super()._model_context_phase) + """ - Replaces params from TE TransformerLayer state_dict with corresponding parameters - from HuggingFace GemmaModel state_dict. + Functions _create_hidden_states_buffer and _create_inference_params from base class are overriden + to make hidden_states and inference_params static + - not changing their position in memory between every invocation. """ - all_layer_prefixes : List[str] = _get_all_layer_prefixes_to_update(hf_state_dict) + def _create_hidden_states_buffer(self, *args): + return self.hidden_states_buffer + + def _create_inference_params(self, *args): + return self.inference_params + + @torch.no_grad() + def record_graph(self, function): + # function is invoked on argument (self.hidden_states,) and all kernels are recorded. + # record_graph() returns captured function, which can be run later with minimal use of th CPU. + fp8_format = Format.HYBRID + fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") + graphed_function = te.pytorch.make_graphed_callables( + function, + (self.hidden_states,), + fp8_enabled=True, + fp8_recipe=fp8_recipe, + allow_unused_input=True, + num_warmup_iters=3 + ) + return graphed_function - for layer_prefix in all_layer_prefixes: - def copy_from_ht_to_te(te_name, hf_name, start=None, end=None): - te_state_dict[layer_prefix + te_name].data[start:end].copy_(hf_state_dict[layer_prefix + hf_name]) - - copy_from_ht_to_te('self_attention.layernorm_qkv.layer_norm_weight', 'input_layernorm.weight') - copy_from_ht_to_te('self_attention.proj.weight', 'self_attn.o_proj.weight') - copy_from_ht_to_te('layernorm_mlp.layer_norm_weight', 'post_attention_layernorm.weight') - copy_from_ht_to_te('layernorm_mlp.fc2_weight', 'mlp.down_proj.weight') - copy_from_ht_to_te('layernorm_mlp.fc1_weight', 'mlp.gate_proj.weight', end=config.intermediate_size) - copy_from_ht_to_te('layernorm_mlp.fc1_weight', 'mlp.up_proj.weight', start=config.intermediate_size) - - if qkv_fused_and_interleaved: - """ - When qkv_fused_and_interleaved=True, key, query and value layers are on one tensor - in TE TransformerLayer. Moreover they are interleaved within each head. - Let q_i, k_i and v_i be query, key and value layers for i-th head respectively. - Then TE stores weight tensor in the form: - [q1 k1 v1 q2 k2 v2 ...] - This is done to maximally optimize performance time. - """ - te_qkv_layer = te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.weight'] - def copy_interleave(hf_name, idx): - src = hf_state_dict[layer_prefix + hf_name] - for head_nr in range(config.num_attention_heads): - dst_offset = head_nr * config.head_dim * 3 - te_qkv_layer[(dst_offset + idx * config.head_dim):(dst_offset + (idx + 1) * config.head_dim), :] = \ - src[(head_nr * config.head_dim):(head_nr * config.head_dim + config.head_dim), :] - copy_interleave('self_attn.q_proj.weight', 0) - copy_interleave('self_attn.k_proj.weight', 1) - copy_interleave('self_attn.v_proj.weight', 2) - else: - copy_from_ht_to_te('self_attention.layernorm_qkv.query_weight', 'self_attn.q_proj.weight') - copy_from_ht_to_te('self_attention.layernorm_qkv.key_weight', 'self_attn.k_proj.weight') - copy_from_ht_to_te('self_attention.layernorm_qkv.value_weight', 'self_attn.v_proj.weight') - - return all_layer_prefixes \ No newline at end of file + @torch.no_grad() + def generate( + self, + input_ids: Optional[torch.Tensor] = None, + *args, + **kwargs, + ): + assert self.batch_size == input_ids.shape[0], \ + f"Input_ids shape {input_ids.shape} does not match batch_size={self.batch_size} of recorded graphs" + assert self.max_seq_len == input_ids.shape[1], \ + f"Input_ids shape {input_ids.shape} does not match max_seq_len={self.max_seq_len} of recorded graphs" + + super().generate(input_ids, *args, **kwargs) \ No newline at end of file diff --git a/docs/examples/te_gemma/te_gemma_loading_weights.py b/docs/examples/te_gemma/te_gemma_loading_weights.py new file mode 100644 index 0000000000..582e0136e7 --- /dev/null +++ b/docs/examples/te_gemma/te_gemma_loading_weights.py @@ -0,0 +1,106 @@ +import os +import re +import gc +from contextlib import contextmanager + +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union + +from transformers.generation import * +from transformers.generation.utils import * + +from transformer_engine.pytorch.fp8 import fp8_model_init + +from transformers.modeling_utils import _add_variant, load_state_dict, _load_state_dict_into_model +from transformers.utils.hub import get_checkpoint_shard_files + + +def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, fp8_init=False, qkv_format="bshd", **kwargs): + """ + Custom method adapted from `from_pretrained` method in HuggingFace + Transformers repo: + https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579 + """ + config.qkv_format = qkv_format + with fp8_model_init(fp8_init): + vanilla_model = cls(config) + variant = None + if os.path.isfile( + os.path.join(pretrained_model_name_or_path, _add_variant("model.safetensors.index.json", variant)) + ): + # Load from a sharded PyTorch checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, _add_variant("model.safetensors.index.json", variant) + ) + + resolved_archive_file, _ = get_checkpoint_shard_files( + pretrained_model_name_or_path, + archive_file, + ) + total_dict = {} + for shard_file in resolved_archive_file: + state_dict = load_state_dict(shard_file) + total_dict = total_dict | state_dict + replace_params(total_dict, vanilla_model.state_dict(), config, qkv_fused_and_interleaved=config.fuse_qkv_params) + _load_state_dict_into_model(vanilla_model, total_dict, start_prefix="") # Copy parameters like embedding. + + # Force mem release. Taken from huggingface code + del total_dict + gc.collect() + return vanilla_model + +def _get_all_layer_prefixes_to_update(hf_state_dict): + """ + There are many parameters in hf_state_dict, whose name start with "model.layers.[number]." + This function extracts all strings like "model.layers.[number]." that are starting strings of keys in hf_state_dict. + """ + all_layer_prefixes = set() + for param_key in hf_state_dict.keys(): + layer_prefix_pat = 'model.layers.\d+.' + m = re.match(layer_prefix_pat, param_key) + if m is not None: + all_layer_prefixes.add(m.group()) + return all_layer_prefixes + +def replace_params(hf_state_dict, te_state_dict, config, qkv_fused_and_interleaved=False): + """ + Replaces params from TE TransformerLayer state_dict with corresponding parameters + from HuggingFace GemmaModel state_dict. + """ + all_layer_prefixes : List[str] = _get_all_layer_prefixes_to_update(hf_state_dict) + + for layer_prefix in all_layer_prefixes: + def copy_from_ht_to_te(te_name, hf_name, start=None, end=None): + te_state_dict[layer_prefix + te_name].data[start:end].copy_(hf_state_dict[layer_prefix + hf_name]) + + copy_from_ht_to_te('self_attention.layernorm_qkv.layer_norm_weight', 'input_layernorm.weight') + copy_from_ht_to_te('self_attention.proj.weight', 'self_attn.o_proj.weight') + copy_from_ht_to_te('layernorm_mlp.layer_norm_weight', 'post_attention_layernorm.weight') + copy_from_ht_to_te('layernorm_mlp.fc2_weight', 'mlp.down_proj.weight') + copy_from_ht_to_te('layernorm_mlp.fc1_weight', 'mlp.gate_proj.weight', end=config.intermediate_size) + copy_from_ht_to_te('layernorm_mlp.fc1_weight', 'mlp.up_proj.weight', start=config.intermediate_size) + + if qkv_fused_and_interleaved: + """ + When qkv_fused_and_interleaved=True, key, query and value layers are on one tensor + in TE TransformerLayer. Moreover they are interleaved within each head. + Let q_i, k_i and v_i be query, key and value layers for i-th head respectively. + Then TE stores weight tensor in the form: + [q1 k1 v1 q2 k2 v2 ...] + This is done to maximally optimize performance time. + """ + te_qkv_layer = te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.weight'] + def copy_interleave(hf_name, idx): + src = hf_state_dict[layer_prefix + hf_name] + for head_nr in range(config.num_attention_heads): + dst_offset = head_nr * config.head_dim * 3 + te_qkv_layer[(dst_offset + idx * config.head_dim):(dst_offset + (idx + 1) * config.head_dim), :] = \ + src[(head_nr * config.head_dim):(head_nr * config.head_dim + config.head_dim), :] + copy_interleave('self_attn.q_proj.weight', 0) + copy_interleave('self_attn.k_proj.weight', 1) + copy_interleave('self_attn.v_proj.weight', 2) + else: + copy_from_ht_to_te('self_attention.layernorm_qkv.query_weight', 'self_attn.q_proj.weight') + copy_from_ht_to_te('self_attention.layernorm_qkv.key_weight', 'self_attn.k_proj.weight') + copy_from_ht_to_te('self_attention.layernorm_qkv.value_weight', 'self_attn.v_proj.weight') + + return all_layer_prefixes \ No newline at end of file diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 74b2707485..438017e06d 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -148,6 +148,15 @@ def swap_key_value_dict(self, batch_indices): new_inference_key_memory, new_inference_value_memory, ) + + def set_before_new_input(self, new_input, pad_token_id=None, offsets_change): + assert offsets_change in ["all_zero", "+1", None] + + lengths = torch.sum(new_input.ne(pad_token_id), dim=-1).squeeze() + self.seq_len = torch.zeros_like(lengths).to(torch.int32).clone().cuda() + self.incoming_seq_len = lengths.to(torch.int32).clone().cuda() + self.max_incoming_seq_len = new_input.shape[1] + @torch.no_grad() def get_alibi( @@ -2321,7 +2330,7 @@ def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seql seq_offsets_q, seq_offsets_k, seq_offsets_v, q, k, v, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, rng_gen, fused_attention_backend, - use_FAv2_bwd, fp8, fp8_meta, tp_size, tp_group): + use_FAv2_bwd, fp8, fp8_meta): if fp8: if _NVTE_DEBUG: @@ -3196,6 +3205,7 @@ def forward( q_size = query_layer.shape[1] key_layer = key_layer.contiguous() value_layer = value_layer.contiguous() + @@ -3268,7 +3278,6 @@ def forward( """ batch_size = query_layer.shape[0] - tex.attention_copy( inference_key_memory, inference_params.seq_len, @@ -3287,7 +3296,6 @@ def forward( inference_params.max_sequence_length, batch_size, self.channels) - max_seqlen_q = inference_params.max_incoming_seq_len max_seqlen_kv = inference_params.max_sequence_length @@ -3304,7 +3312,6 @@ def forward( seq_offsets_k.copy_(torch.arange(0, batch_size + 1, dtype=torch.int32, device="cuda") * self.channels * max_seqlen_kv) seq_offsets_v.copy_(seq_offsets_k) - query_layer = query_layer.view(-1, query_layer.shape[2], query_layer.shape[3]).to(torch.bfloat16) key_layer = inference_key_memory.view(-1, inference_key_memory.shape[2], inference_key_memory.shape[3]).to(torch.bfloat16) value_layer = inference_value_memory.view(-1, inference_value_memory.shape[2], inference_value_memory.shape[3]).to(torch.bfloat16) @@ -4370,6 +4377,7 @@ def forward( hidden_dim ) + for i in range(batch_size): key_layer[i,].copy_(apply_rotary_pos_emb(key_layer[i,:].unsqueeze(0), k_pos_emb[i,:].unsqueeze(1), "bshd", fused=True)[0,:]) query_layer[i,:].copy_(apply_rotary_pos_emb(query_layer[i,:].unsqueeze(0), q_pos_emb[i,:].unsqueeze(1), "bshd", fused=True)[0,:]) From 942a2dbc38e44f27f254ceb3fae5c44bda570a85 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Mon, 13 May 2024 18:19:04 -0700 Subject: [PATCH 113/244] Minor chenges Signed-off-by: Pawel Gadzinski --- docs/examples/te_gemma/te_gemma.py | 2 +- docs/examples/te_gemma/te_gemma_loading_weights.py | 11 ++++++++--- .../te_gemma/tutorial_generation_gemma_with_te.ipynb | 10 +++++----- docs/examples/te_gemma/utils.py | 12 ++++++++---- transformer_engine/pytorch/transformer.py | 1 + 5 files changed, 23 insertions(+), 13 deletions(-) diff --git a/docs/examples/te_gemma/te_gemma.py b/docs/examples/te_gemma/te_gemma.py index 44a8a8eb1c..e8442057a2 100644 --- a/docs/examples/te_gemma/te_gemma.py +++ b/docs/examples/te_gemma/te_gemma.py @@ -4,7 +4,7 @@ from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Optional from transformers.generation import * from transformers.generation.utils import * diff --git a/docs/examples/te_gemma/te_gemma_loading_weights.py b/docs/examples/te_gemma/te_gemma_loading_weights.py index 582e0136e7..bf93e9fe3f 100644 --- a/docs/examples/te_gemma/te_gemma_loading_weights.py +++ b/docs/examples/te_gemma/te_gemma_loading_weights.py @@ -1,9 +1,8 @@ import os import re import gc -from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union +from typing import List from transformers.generation import * from transformers.generation.utils import * @@ -13,8 +12,14 @@ from transformers.modeling_utils import _add_variant, load_state_dict, _load_state_dict_into_model from transformers.utils.hub import get_checkpoint_shard_files +""" + This file contains logic of mapping the HuggingFace GemmaModel parameters + with TransformerEngine TransformerLayer. When we have initialized Transformer models + both with HF and with TE, we can copy parameters from the first to the second. +""" -def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, fp8_init=False, qkv_format="bshd", **kwargs): + +def from_pretrained_local(cls, pretrained_model_name_or_path, config, fp8_init=False, qkv_format="bshd"): """ Custom method adapted from `from_pretrained` method in HuggingFace Transformers repo: diff --git a/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb index fc3b840b61..542c00e3e0 100644 --- a/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb +++ b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb @@ -20,11 +20,11 @@ "Addressing the challenge of computing attention for sequences with varying lengths, a common method is to pad these sequences and apply an attention mask. The Transformer Engine, however, offers a more optimized approach—by specifying the lengths and offsets of the sequences, attention can be computed directly. Instead of passing the matrix and mask with the shape `[b, s, h, d]`, one can pass a matrix of the shape `[t, h, d]` along with tensors detailing sequence lengths and offsets to run the attention optimized for this case. This specific attention layout is referred to as the **THD layout**.\n", "\n", "
\n", - "\"\"
\n", + "\"\"
\n", "Fig. 1. The sequences and the mask for standard attention layout - padding from the end.

\n", - "\"\"
\n", + "\"\"
\n", "Fig. 2. The sequences and the mask for standard attention layout - padding from the beginning.

\n", - "\"\"
\n", + "\"\"
\n", "Fig. 3. An attention with thd layer.

\n", "
\n", "\n", @@ -151,11 +151,11 @@ "\n", "\n", "Query layer \n", - "\"\"\n", + "\"\"\n", "\n", "\n", "Key layer and value layer \n", - "\"\"\n", + "\"\"\n", "\n", "\n", "cu_seqlens_q = [0, 1, 3, 7, 9, 12]
\n", diff --git a/docs/examples/te_gemma/utils.py b/docs/examples/te_gemma/utils.py index a52e8daaa9..e052357187 100644 --- a/docs/examples/te_gemma/utils.py +++ b/docs/examples/te_gemma/utils.py @@ -6,6 +6,8 @@ import sys import IPython +from te_gemma_loading_weights import from_pretrained_local + import torch from torch.optim import AdamW from torch.utils.data import DataLoader @@ -19,7 +21,7 @@ class HyperParameters: def __init__(self): self.mixed_precision = "bf16" - #self.model_name = "" # <== Add model weight location here + self.model_name = "" # <== Add model weight location here self.dataset_name = "timdettmers/openassistant-guanaco" self.dataset_text_field = "text" self.learning_rate = 1.41e-5 @@ -87,13 +89,15 @@ def init_baseline_model(hyperparams): return model -def init_te_gemma_model(hyperparams, fp8_model_init=False, qkv_format="thd"): +def init_te_gemma_model(hyperparams, fp8_model_init=False, qkv_format="thd", cuda_graphs=False): # Init the model - from te_gemma import TEGemmaForCausalLM + from te_gemma import TEGemmaForCausalLM, TEGemmaForCausalLMCudaGraphs + cls = TEGemmaForCausalLMCudaGraphs if cuda_graphs else TEGemmaForCausalLM config = AutoConfig.from_pretrained(hyperparams.model_name) config._attn_implementation = "flash_attention_2" config.fuse_qkv_params = hyperparams.fuse_qkv_params - model = TEGemmaForCausalLM.from_pretrained_local( + model = from_pretrained_local( + cls, hyperparams.model_name, config=config, torch_dtype=torch.bfloat16, diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 2219154903..1c07ac725e 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -618,6 +618,7 @@ def forward( hidden_states = cast_if_needed( hidden_states, torch.get_autocast_gpu_dtype() ) + # Self attention. self_attention_outputs = self.self_attention( From 2048c6e4693775953c2bc6308dac47a46f8e7806 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 14 May 2024 19:03:36 -0700 Subject: [PATCH 114/244] Te gemma Signed-off-by: Pawel Gadzinski --- docs/examples/te_gemma/te_gemma.py | 146 +++++++++++++++++------------ 1 file changed, 84 insertions(+), 62 deletions(-) diff --git a/docs/examples/te_gemma/te_gemma.py b/docs/examples/te_gemma/te_gemma.py index e8442057a2..3f52a1149d 100644 --- a/docs/examples/te_gemma/te_gemma.py +++ b/docs/examples/te_gemma/te_gemma.py @@ -13,10 +13,13 @@ import transformer_engine as te from transformer_engine.pytorch.attention import InferenceParams, RotaryPositionEmbedding from transformer_engine.common.recipe import Format, DelayedScaling +from torch.cuda.amp import autocast import transformers from transformers.models.gemma.modeling_gemma import GemmaForCausalLM, GemmaConfig, GemmaModel +import torch.nn.functional as F + class TEGemmaDecoderLayer(te.pytorch.TransformerLayer): """ Wrapper class over TE's `TransformerLayer`. This makes the wrapper very @@ -50,7 +53,6 @@ def __init__(self, config : GemmaConfig, layer_idx : int, *args, **kwargs): def forward(self, *args, **kwargs): # We need to pass positional encoding. return super().forward(*args, rotary_pos_emb=self.te_rope_emb, **kwargs) - class StaticGemmaModel(torch.nn.Module): """ StaticGemma is based of HF GemmaModel class. @@ -62,24 +64,25 @@ def __init__( dtype : torch.dtype, mask : torch.Tensor, lm_head : torch.nn.Module, - inference_params : InferenceParams ): super().__init__() self.model = model self.normalizer = torch.tensor(self.model.config.hidden_size**0.5, dtype=dtype) self.mask = mask self.lm_head = lm_head + + def set_inference_params(self, inference_params): self.inference_params = inference_params def forward(self, hidden_states : torch.Tensor): hidden_states.data[:] = hidden_states.data[:] * self.normalizer # static operation - for CUDA graphs for decoder_layer in self.model.layers: - hidden_states.copy_(decoder_layer( + hidden_states.data[:] = decoder_layer( hidden_states, attention_mask=None, self_attn_mask_type=self.mask, inference_params=self.inference_params - )[0]) # static copy - for CUDA graphs + ) # static copy - for CUDA graphs hidden_states.copy_(self.model.norm(hidden_states)) # static copy - for CUDA graphs logits = self.lm_head(hidden_states) @@ -92,18 +95,21 @@ class GemmaGenerator(torch.nn.Module): GemmaGenerator gets one layer of embeddins, makes forward pass and returns next tokens. """ - def __init__(self, model : GemmaModel, lm_head: torch.nn.Module, inference_params : InferenceParams, dtype : torch.dtype): + def __init__(self, model : GemmaModel, lm_head: torch.nn.Module, dtype : torch.dtype): super().__init__() self.model = model - self.gemma_layers = StaticGemmaModel(model, dtype, 'padding', lm_head, inference_params) + self.gemma_layers = StaticGemmaModel(model, dtype, 'padding', lm_head) + + def set_inference_params(self, inference_params): self.inference_params = inference_params + self.gemma_layers.set_inference_params(inference_params) def forward(self, hidden_states : torch.Tensor): logits = self.gemma_layers(hidden_states) assert logits.shape[0] == hidden_states.shape[0] # b - # logits.shape[1] = number of tokens - assert logits.shape[2] == hidden_states.shape[2] # hidden_dim + assert logits.shape[1] == hidden_states.shape[1] # seq_len + # logits.shape[2] = number of tokens logits = logits[:, -1, :] next_tokens = torch.argmax(logits, dim=1) @@ -114,7 +120,7 @@ def forward(self, hidden_states : torch.Tensor): # to update the information of sequence lengths. # Here we increase sequence offsets by one, # because we generated one token for every sequence. - self.inference_params.set_before_new_input(hidden_states, offsets_change="+1") + self.inference_params.set_before_new_input(next_tokens.unsqueeze(1)) return next_tokens @@ -142,27 +148,17 @@ class is monkey-patched with `TEGemmaDecoderLayer` class before """ def __init__(self, config: GemmaConfig): + assert config.qkv_format == "thd", "Generation using other qkv_layouts than thd is not provided in this tutorial" with replace_decoder(te_decoder_cls=TEGemmaDecoderLayer): super().__init__(config) - self.hidden_states = None - - # This function is overriden in TeGEmmaForCausalLMCudaGraphs. - @torch.no_grad() - def _model_generation_phase(self, hidden_states : torch.Tensor, inference_params : InferenceParams=None): - generator = GemmaGenerator( + self.hidden_size = config.hidden_size + self._model_generation_phase = GemmaGenerator( lm_head=self.lm_head, model=self.model, - inference_params=inference_params, - dtype=hidden_states.dtype, + dtype=torch.float32, ) - return generator(hidden_states,) - - # This function is overriden in TeGEmmaForCausalLMCudaGraphs. - @torch.no_grad() - def _model_context_phase(self, hidden_states : torch.Tensor, inference_params : InferenceParams=None): - layers = StaticGemmaModel(self.model, torch.float32, 'padding_causal', self.lm_head, inference_params) - return layers(hidden_states) - + self._model_context_phase = StaticGemmaModel(self.model, torch.float32, 'padding_causal', self.lm_head) + @staticmethod def _padding_to_end(inputs, lengths): """ @@ -186,11 +182,11 @@ def _padding_to_end(inputs, lengths): # This function is overriden in TeGEmmaForCausalLMCudaGraphs. def _create_hidden_states_buffer(self, input_ids : torch.Tensor): - return torch.empty_like(input_ids, device="cuda", dtype=torch.float32) + return torch.empty((input_ids.shape[0], input_ids.shape[1], self.hidden_size), device="cuda", dtype=torch.float32) # This function is overriden in TeGEmmaForCausalLMCudaGraphs. def _create_inference_params(self, max_batch_size : int, max_sequence_length : int): - return InferenceParams(max_batch_size, max_sequence_length) + return InferenceParams(max_batch_size, max_sequence_length, qkv_format="thd") def _generate_context_phase( self, @@ -200,7 +196,11 @@ def _generate_context_phase( hidden_states = self._create_hidden_states_buffer(input_ids) hidden_states.data[:] = self.model.embed_tokens(input_ids) - logits = self._model_context_phase(self.hidden_states, inference_params) + # We need to update offsets before every forward pass to make cache work properly. + inference_params.set_before_new_input(input_ids, pad_token_id=0, offsets_change="all_zero") + self._model_context_phase = self.record_graph(self._model_context_phase, hidden_states) + hidden_states.data[:] = self.model.embed_tokens(input_ids) + logits = self._model_context_phase(hidden_states) # We choose logits coresponding with last token in each sequence, # which have various lengths - they are stored in (inference_params.incoming_seq_len - 1) Tensor. @@ -209,18 +209,26 @@ def _generate_context_phase( # self.hidden_states have shape [b, s, hd]. # We return hidden state for the last token - output has shape [b, 1, hd] - self.hidden_states.data[:, 0, :] = self.model.embed_tokens(next_tokens) - return self.hidden_states[:, 0, :].unsqueeze(1), [next_tokens] + hidden_states.data[:, 0, :] = self.model.embed_tokens(next_tokens) + output = hidden_states.view(-1)[:hidden_states.shape[0] * hidden_states.shape[2]] + output.copy_(hidden_states.data[:, 0, :].reshape(-1)) + output = output.view((hidden_states.shape[0], 1, hidden_states.shape[2])) + return output, next_tokens + + def _get_max_input_seq_len(self, input_ids): + return input_ids.shape[1] @torch.no_grad() def generate( self, input_ids: Optional[torch.Tensor] = None, pad_token_id: int = 0, - max_new_tokens: int = 0 + max_new_tokens: int = 0, + *args, **kwargs ): - batch_size, max_input_sequence_len = input_ids.shape + batch_size, max_input_sequence_len = input_ids.shape[0], self._get_max_input_seq_len(input_ids) lengths = torch.sum(input_ids.ne(pad_token_id), dim=-1).squeeze() # [s] + input_ids = F.pad(input_ids, (max_input_sequence_len - input_ids.shape[1], 0), "constant", 0) # InferenceParams is a cache, where keys and values of previous tokens are stored. # Moreover it stores length of both already generated and input sequences. @@ -229,21 +237,22 @@ def generate( max_sequence_length=max_input_sequence_len + max_new_tokens ) - # We need to update offsets before every forward pass to make cache work properly. - inference_params.set_before_new_input(input_ids, padding_token=pad_token_id, offsets_change="all_zero") + self._model_context_phase.set_inference_params(inference_params) + self._model_generation_phase.set_inference_params(inference_params) # Context phase TEGemmaForCausalLM._padding_to_end(input_ids, lengths) - hidden_states, output_tokens = TEGemmaForCausalLM._generate_context_phase( + hidden_states, next_tokens = TEGemmaForCausalLM._generate_context_phase( self, input_ids, - self.inference_params + inference_params ) # Generation phase. - self.inference_params.set_before_new_input(hidden_states, offsets_change=None) + inference_params.set_before_new_input(next_tokens.unsqueeze(1), offsets_change=None) + output_tokens = [next_tokens] for _ in range(max_new_tokens): - next_tokens = self._model_generation_phase(hidden_states, self.inference_params) + next_tokens = self._model_generation_phase(hidden_states) output_tokens.append(next_tokens.clone()) result = torch.cat((input_ids, torch.stack(output_tokens).permute([1, 0])), dim=1) @@ -255,47 +264,60 @@ class TEGemmaForCausalLMCudaGraphs(TEGemmaForCausalLM): We need to make one trade-off. Namely, batch_size, max_seq_len and max_context_seq_len need to be static. It is necessary to run generation with the same value of these variables that we recorded graph on. """ - def __init__(self, config : GemmaConfig, batch_size : int, max_seq_len : int, max_context_seq_len : int): - super.__init(config) - + def __init__(self, config : GemmaConfig): + super().__init__(config) + self.to("cuda") # Preparation of the static buffers. - self.batch_size = batch_size - self.max_seq_len = max_seq_len - self.hidden_states_buffer = torch.empty((batch_size, max_context_seq_len, self.config.hidden_dim)).cuda() - self.inference_params = InferenceParams(max_batch_size=batch_size, max_sequence_length=max_seq_len) + self.config = config + self.hidden_states_buffer = torch.empty( + (config.cuda_graphs_static_batch_size, config.cuda_graphs_static_max_context_len, config.hidden_size)).cuda() + self.generation_buffer = self.hidden_states_buffer.view(-1)[:config.cuda_graphs_static_batch_size*config.hidden_size].view((config.cuda_graphs_static_batch_size, 1, config.hidden_size)) + self.inference_params = InferenceParams( + max_batch_size=config.cuda_graphs_static_batch_size, max_sequence_length=config.cuda_graphs_static_max_seq_len, qkv_format="thd") + + + self._model_generation_phase.set_inference_params(self.inference_params) + self._model_context_phase.set_inference_params(self.inference_params) # Here "the trick" happens. We override methods from TEGemmaForCausalLM # with their recorded version. After invocation of each of them, # captured graph will be replayed with minimal usage of CPU, # what will lead to huge speedup. - self._model_generation_phase = self.record_graph(super()._model_generation_phase) - self._model_context_phase = self.record_graph(super()._model_context_phase) + #self.inference_params.set_before_new_input(torch.ones((config.cuda_graphs_static_batch_size, config.cuda_graphs_static_max_context_len)), pad_token_id=0, offsets_change="all_zero") + #self._model_context_phase = self.record_graph(self._model_context_phase, self.hidden_states_buffer) + + #self.inference_params.set_before_new_input(torch.ones((config.cuda_graphs_static_batch_size, 1)), offsets_change="all_zero") + #self._model_generation_phase = self.record_graph(self._model_generation_phase, self.generation_buffer) """ Functions _create_hidden_states_buffer and _create_inference_params from base class are overriden to make hidden_states and inference_params static - not changing their position in memory between every invocation. """ - def _create_hidden_states_buffer(self, *args): + def _create_hidden_states_buffer(self, *args, **kwargs): return self.hidden_states_buffer - def _create_inference_params(self, *args): + def _create_inference_params(self, *args, **kwargs): return self.inference_params + + def _get_max_input_seq_len(self, _): + return self.config.cuda_graphs_static_max_context_len @torch.no_grad() - def record_graph(self, function): + def record_graph(self, function, input_tensor): # function is invoked on argument (self.hidden_states,) and all kernels are recorded. # record_graph() returns captured function, which can be run later with minimal use of th CPU. fp8_format = Format.HYBRID fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") - graphed_function = te.pytorch.make_graphed_callables( - function, - (self.hidden_states,), - fp8_enabled=True, - fp8_recipe=fp8_recipe, - allow_unused_input=True, - num_warmup_iters=3 - ) + with autocast(dtype=torch.bfloat16, cache_enabled=False): + graphed_function = te.pytorch.make_graphed_callables( + function, + (input_tensor,), + fp8_enabled=True, + fp8_recipe=fp8_recipe, + allow_unused_input=True, + num_warmup_iters=3 + ) return graphed_function @torch.no_grad() @@ -305,9 +327,9 @@ def generate( *args, **kwargs, ): - assert self.batch_size == input_ids.shape[0], \ + assert self.config.cuda_graphs_static_batch_size == input_ids.shape[0], \ f"Input_ids shape {input_ids.shape} does not match batch_size={self.batch_size} of recorded graphs" - assert self.max_seq_len == input_ids.shape[1], \ - f"Input_ids shape {input_ids.shape} does not match max_seq_len={self.max_seq_len} of recorded graphs" + assert self.config.cuda_graphs_static_max_context_len >= input_ids.shape[1], \ + f"Input_ids shape {input_ids.shape} is greater than max_seq_len={self.max_seq_len} of recorded graphs" - super().generate(input_ids, *args, **kwargs) \ No newline at end of file + return super().generate(input_ids, *args, **kwargs) \ No newline at end of file From 167631dc32dcc463805d8231ed597c1ee907412c Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 14 May 2024 19:04:00 -0700 Subject: [PATCH 115/244] Attention Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/attention.py | 33 ++++++++++++++++--------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 438017e06d..11ed01e023 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -116,13 +116,17 @@ class InferenceParams: # pylint: disable=too-few-public-methods maximum sequence length during inference. """ - def __init__(self, max_batch_size, max_sequence_length): + def __init__(self, max_batch_size, max_sequence_length, qkv_format="bshd"): self.max_sequence_length = max_sequence_length self.max_batch_size = max_batch_size self.sequence_len_offset = 0 self.batch_size_offset = 0 self.key_value_memory_dict = {} - self.seq_len=torch.tensor((1000)) + + + if qkv_format == "thd": + self.seq_len = torch.empty((max_batch_size,), device="cuda", dtype=torch.int32) + self.incoming_seq_len = torch.empty((max_batch_size,), device="cuda", dtype=torch.int32) def swap_key_value_dict(self, batch_indices): """ @@ -149,14 +153,19 @@ def swap_key_value_dict(self, batch_indices): new_inference_value_memory, ) - def set_before_new_input(self, new_input, pad_token_id=None, offsets_change): - assert offsets_change in ["all_zero", "+1", None] + def set_before_new_input(self, new_input, offsets_change=None, pad_token_id=None): + assert offsets_change in ["all_zero", None] - lengths = torch.sum(new_input.ne(pad_token_id), dim=-1).squeeze() - self.seq_len = torch.zeros_like(lengths).to(torch.int32).clone().cuda() - self.incoming_seq_len = lengths.to(torch.int32).clone().cuda() + self.seq_len.copy_(self.seq_len + self.incoming_seq_len) + if pad_token_id is not None: + self.incoming_seq_len.copy_(torch.sum(new_input.ne(pad_token_id), dim=-1, dtype=torch.int32).squeeze()) + else: + self.incoming_seq_len.copy_(torch.ones(new_input.shape[0], device="cuda") * new_input.shape[1]) self.max_incoming_seq_len = new_input.shape[1] + if offsets_change == "all_zero": + self.seq_len.copy_(torch.zeros_like(self.seq_len)) + @torch.no_grad() def get_alibi( @@ -4350,13 +4359,16 @@ def forward( rotary_pos_emb = ((rotary_pos_emb,) * 2) if self.qkv_format == "thd" and inference_params is not None: - key_layer = key_layer.contiguous() - query_layer = query_layer.contiguous() + key_layer = key_layer.contiguous() # verify if needed + query_layer = query_layer.contiguous() # verify if needed batch_size, hidden_dim = query_layer.shape[0], query_layer.shape[-1] q_pos_emb = self.alloc((batch_size, inference_params.max_incoming_seq_len, 1, hidden_dim), torch.float32, "cuda") k_pos_emb = self.alloc((batch_size, inference_params.max_incoming_seq_len, 1, hidden_dim), torch.float32, "cuda") q_freq, k_freq = rotary_pos_emb + + # inference_params.pick_freqs(q_freq, q_pos_emb) + # inference_params.pick_freqs(k_freq, k_pos_emb) tex.get_values( q_freq, # [max_pos_emb, s, 1, d] @@ -4376,8 +4388,7 @@ def forward( batch_size, hidden_dim ) - - + for i in range(batch_size): key_layer[i,].copy_(apply_rotary_pos_emb(key_layer[i,:].unsqueeze(0), k_pos_emb[i,:].unsqueeze(1), "bshd", fused=True)[0,:]) query_layer[i,:].copy_(apply_rotary_pos_emb(query_layer[i,:].unsqueeze(0), q_pos_emb[i,:].unsqueeze(1), "bshd", fused=True)[0,:]) From 8db2699dde0516d49db1610a5ac8881c15e967cb Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 15 May 2024 11:03:24 -0700 Subject: [PATCH 116/244] attention.py refactor Signed-off-by: Pawel Gadzinski --- docs/examples/te_gemma/te_gemma.py | 41 ++++-- transformer_engine/pytorch/attention.py | 186 +++++++++++++++--------- 2 files changed, 145 insertions(+), 82 deletions(-) diff --git a/docs/examples/te_gemma/te_gemma.py b/docs/examples/te_gemma/te_gemma.py index 3f52a1149d..e1c041d585 100644 --- a/docs/examples/te_gemma/te_gemma.py +++ b/docs/examples/te_gemma/te_gemma.py @@ -120,7 +120,7 @@ def forward(self, hidden_states : torch.Tensor): # to update the information of sequence lengths. # Here we increase sequence offsets by one, # because we generated one token for every sequence. - self.inference_params.set_before_new_input(next_tokens.unsqueeze(1)) + self.inference_params.thd_setup_before_new_input(next_tokens.unsqueeze(1)) return next_tokens @@ -187,6 +187,20 @@ def _create_hidden_states_buffer(self, input_ids : torch.Tensor): # This function is overriden in TeGEmmaForCausalLMCudaGraphs. def _create_inference_params(self, max_batch_size : int, max_sequence_length : int): return InferenceParams(max_batch_size, max_sequence_length, qkv_format="thd") + + # The buffer for generation is some part (beginning) of hidden states buffer. + # This function returns pointer to it and also copies there data if provided. + def _get_generation_buffer(self, hidden_states_buffer, data_to_copy=None): + # hidden_states_buffer has shape [b, s, hd] + # generation_buffer will have shape [b, 1, hd] + # Notice that "generation_buffer = hidden_states_buffer[:, 0, :].unsqueeze(1)" + # will return uncontiguous buffer, which we want to avoid. + output = hidden_states_buffer.view(-1)[:hidden_states_buffer.shape[0] * hidden_states_buffer.shape[2]] + if data_to_copy is not None: + output.copy_(data_to_copy.reshape(-1)) + generation_buffer = output.view((hidden_states_buffer.shape[0], 1, hidden_states_buffer.shape[2])) + return generation_buffer + def _generate_context_phase( self, @@ -197,8 +211,8 @@ def _generate_context_phase( hidden_states.data[:] = self.model.embed_tokens(input_ids) # We need to update offsets before every forward pass to make cache work properly. - inference_params.set_before_new_input(input_ids, pad_token_id=0, offsets_change="all_zero") - self._model_context_phase = self.record_graph(self._model_context_phase, hidden_states) + inference_params.thd_setup_before_new_input(input_ids, pad_token_id=0, reset=True) + #self._model_context_phase = self.record_graph(self._model_context_phase, hidden_states) hidden_states.data[:] = self.model.embed_tokens(input_ids) logits = self._model_context_phase(hidden_states) @@ -209,11 +223,8 @@ def _generate_context_phase( # self.hidden_states have shape [b, s, hd]. # We return hidden state for the last token - output has shape [b, 1, hd] - hidden_states.data[:, 0, :] = self.model.embed_tokens(next_tokens) - output = hidden_states.view(-1)[:hidden_states.shape[0] * hidden_states.shape[2]] - output.copy_(hidden_states.data[:, 0, :].reshape(-1)) - output = output.view((hidden_states.shape[0], 1, hidden_states.shape[2])) - return output, next_tokens + hidden_states = self._get_generation_buffer(hidden_states, self.model.embed_tokens(next_tokens)) + return hidden_states, next_tokens def _get_max_input_seq_len(self, input_ids): return input_ids.shape[1] @@ -249,7 +260,7 @@ def generate( ) # Generation phase. - inference_params.set_before_new_input(next_tokens.unsqueeze(1), offsets_change=None) + inference_params.thd_setup_before_new_input(next_tokens.unsqueeze(1)) output_tokens = [next_tokens] for _ in range(max_new_tokens): next_tokens = self._model_generation_phase(hidden_states) @@ -271,7 +282,7 @@ def __init__(self, config : GemmaConfig): self.config = config self.hidden_states_buffer = torch.empty( (config.cuda_graphs_static_batch_size, config.cuda_graphs_static_max_context_len, config.hidden_size)).cuda() - self.generation_buffer = self.hidden_states_buffer.view(-1)[:config.cuda_graphs_static_batch_size*config.hidden_size].view((config.cuda_graphs_static_batch_size, 1, config.hidden_size)) + self.generation_buffer = self._get_generation_buffer(self.hidden_states_buffer) # in fact part of the buffer for hidden_states self.inference_params = InferenceParams( max_batch_size=config.cuda_graphs_static_batch_size, max_sequence_length=config.cuda_graphs_static_max_seq_len, qkv_format="thd") @@ -283,11 +294,13 @@ def __init__(self, config : GemmaConfig): # with their recorded version. After invocation of each of them, # captured graph will be replayed with minimal usage of CPU, # what will lead to huge speedup. - #self.inference_params.set_before_new_input(torch.ones((config.cuda_graphs_static_batch_size, config.cuda_graphs_static_max_context_len)), pad_token_id=0, offsets_change="all_zero") - #self._model_context_phase = self.record_graph(self._model_context_phase, self.hidden_states_buffer) + input_shape = (config.cuda_graphs_static_batch_size, config.cuda_graphs_static_max_context_len) + self.inference_params.thd_setup_before_new_input(torch.ones(input_shape), pad_token_id=0, reset=True) + self._model_context_phase = self.record_graph(self._model_context_phase, self.hidden_states_buffer) # CUDA Graphs recording - #self.inference_params.set_before_new_input(torch.ones((config.cuda_graphs_static_batch_size, 1)), offsets_change="all_zero") - #self._model_generation_phase = self.record_graph(self._model_generation_phase, self.generation_buffer) + input_shape = torch.ones((config.cuda_graphs_static_batch_size, 1)) + self.inference_params.thd_setup_before_new_input(input_shape, reset=True) + self._model_generation_phase = self.record_graph(self._model_generation_phase, self.generation_buffer) # CUDA Graphs recording """ Functions _create_hidden_states_buffer and _create_inference_params from base class are overriden diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 11ed01e023..661970c893 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -119,14 +119,15 @@ class InferenceParams: # pylint: disable=too-few-public-methods def __init__(self, max_batch_size, max_sequence_length, qkv_format="bshd"): self.max_sequence_length = max_sequence_length self.max_batch_size = max_batch_size - self.sequence_len_offset = 0 - self.batch_size_offset = 0 self.key_value_memory_dict = {} + self.qkv_format = qkv_format - if qkv_format == "thd": self.seq_len = torch.empty((max_batch_size,), device="cuda", dtype=torch.int32) self.incoming_seq_len = torch.empty((max_batch_size,), device="cuda", dtype=torch.int32) + else: + self.sequence_len_offset = 0 + self.batch_size_offset = 0 def swap_key_value_dict(self, batch_indices): """ @@ -153,8 +154,28 @@ def swap_key_value_dict(self, batch_indices): new_inference_value_memory, ) - def set_before_new_input(self, new_input, offsets_change=None, pad_token_id=None): - assert offsets_change in ["all_zero", None] + + def thd_setup_before_new_input(self, new_input, reset=False, pad_token_id=None): + """ + After every context/generation phase, the parameters representing + for example sequence lengths and incmoing sequence lengths, + need to be updated. This function does exactly that. + + + Parameters + ---------- + new_input: torch.Tensor + Tensor with token_ids (not embeddings!) on which we want to do next forward pass. + reset: int + If reset=True, all previous sequence lengths will be set to 0. + It is supposed to be used after last generation phase to + allow inference_params to be reused. + pad_token_id: int + Value of padding token - used to compute sequence_lengths. If pad_token_id=None, + we assume that all new_input sequence lengths + are equal to the corresponding dimension of new_input. + """ + assert self.qkv_format == "thd" self.seq_len.copy_(self.seq_len + self.incoming_seq_len) if pad_token_id is not None: @@ -163,8 +184,79 @@ def set_before_new_input(self, new_input, offsets_change=None, pad_token_id=None self.incoming_seq_len.copy_(torch.ones(new_input.shape[0], device="cuda") * new_input.shape[1]) self.max_incoming_seq_len = new_input.shape[1] - if offsets_change == "all_zero": + if reset: self.seq_len.copy_(torch.zeros_like(self.seq_len)) + + def save_new_key_and_value_layer(self, layer_number, key_layer, value_layer): + """ + Saves key_layer and value_layer in the cache. + """ + (inference_key_memory, inference_value_memory, + ) = self.key_value_memory_dict[layer_number] + if self.qkv_format == "thd": + batch_size = key_layer.shape[0] + channels = inference_key_memory.shape[2] * inference_key_memory.shape[3] # h * d + tex.attention_copy( + inference_key_memory, + self.seq_len, + self.incoming_seq_len, + key_layer, + self.max_incoming_seq_len, + self.max_sequence_length, + batch_size, + channels) + + tex.attention_copy( + inference_value_memory, + self.seq_len, + self.incoming_seq_len, + value_layer, + self.max_incoming_seq_len, + self.max_sequence_length, + batch_size, + channels) + else: + assert self.qkv_format in ["bshd", "sbhd"], "Attention format not supported by the inference." + batch_start = self.batch_size_offset + batch_end = batch_start + key_layer.size(1) + assert batch_end <= inference_key_memory.size(1) + + sequence_start = self.sequence_len_offset + sequence_end = sequence_start + key_layer.size(0) + assert sequence_end <= inference_key_memory.size(0) + + # Copy keys and values into KV-cache + inference_key_memory[ + sequence_start:sequence_end, batch_start:batch_end, ...] = key_layer + inference_value_memory[ + sequence_start:sequence_end, batch_start:batch_end, ...] = value_layer + key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...] + value_layer = inference_value_memory[:sequence_end, batch_start:batch_end, ...] + return key_layer, value_layer + + def pick_freqs(self, freq, pos_emb_buffer): + """ + Parameters + ---------- + freq: torch.Tensor [max_pos_emb, 1, 1, d] + Tensor with frequencies used in rotarty positional encoding application. + pos_emb_buffer: torch.Tensor [b, max_incoming_seq_len, 1, d] + Buffer for positional embedding frequencies for each sequence in batch. + + If self.incoming_seq_len contains numbers [s1, s2, ...], then + pos_emb_buffer[0, :] = freq[s1:(s1 + max_incoming_seq_len), 1, 1, d]. + """ + batch_size, _, _ , hidden_dim = pos_emb_buffer.shape + tex.get_values( + freq, + self.seq_len, + self.incoming_seq_len, + pos_emb_buffer, + self.max_incoming_seq_len, + batch_size, + hidden_dim + ) + @torch.no_grad() @@ -3215,9 +3307,6 @@ def forward( key_layer = key_layer.contiguous() value_layer = value_layer.contiguous() - - - assert ( query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda ), 'DotProductAttention only supports CUDA tensors.' @@ -3266,46 +3355,15 @@ def forward( ) = inference_params.key_value_memory_dict[self.layer_number] if qkv_format in ["bshd", "sbhd"]: - batch_start = inference_params.batch_size_offset - batch_end = batch_start + key_layer.size(1) - assert batch_end <= inference_key_memory.size(1) - - sequence_start = inference_params.sequence_len_offset - sequence_end = sequence_start + key_layer.size(0) - assert sequence_end <= inference_key_memory.size(0) - - # Copy keys and values into KV-cache - inference_key_memory[ - sequence_start:sequence_end, batch_start:batch_end, ...] = key_layer - inference_value_memory[ - sequence_start:sequence_end, batch_start:batch_end, ...] = value_layer - key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...] - value_layer = inference_value_memory[:sequence_end, batch_start:batch_end, ...] + key_layer, value_layer = inference_params.save_new_key_and_value_layer(self.layer_number, key_layer, value_layer) elif qkv_format == "thd": + + inference_params.save_new_key_and_value_layer(self.layer_number, key_layer, value_layer) + """ - inference_params.seq_len - lengths of processed sequences + We compute parameters needed by the THD attention with offsets. """ batch_size = query_layer.shape[0] - - tex.attention_copy( - inference_key_memory, - inference_params.seq_len, - inference_params.incoming_seq_len, - key_layer, - inference_params.max_incoming_seq_len, - inference_params.max_sequence_length, - batch_size, - self.channels) - tex.attention_copy( - inference_value_memory, - inference_params.seq_len, - inference_params.incoming_seq_len, - value_layer, - inference_params.max_incoming_seq_len, - inference_params.max_sequence_length, - batch_size, - self.channels) - max_seqlen_q = inference_params.max_incoming_seq_len max_seqlen_kv = inference_params.max_sequence_length cu_seqlens_q = self.alloc(batch_size + 1, dtype=torch.int32, device="cuda") @@ -3321,6 +3379,7 @@ def forward( seq_offsets_k.copy_(torch.arange(0, batch_size + 1, dtype=torch.int32, device="cuda") * self.channels * max_seqlen_kv) seq_offsets_v.copy_(seq_offsets_k) + # qkv layers are reshaped to the format [t, h, d] query_layer = query_layer.view(-1, query_layer.shape[2], query_layer.shape[3]).to(torch.bfloat16) key_layer = inference_key_memory.view(-1, inference_key_memory.shape[2], inference_key_memory.shape[3]).to(torch.bfloat16) value_layer = inference_value_memory.view(-1, inference_value_memory.shape[2], inference_value_memory.shape[3]).to(torch.bfloat16) @@ -4359,36 +4418,27 @@ def forward( rotary_pos_emb = ((rotary_pos_emb,) * 2) if self.qkv_format == "thd" and inference_params is not None: - key_layer = key_layer.contiguous() # verify if needed - query_layer = query_layer.contiguous() # verify if needed + # For thd attention incoming tokens can be on different positions, + # so we need to copy different positional encoding freqency + # for every sequence in a batch. + # + # For example if sequence lengths in context phase are: 2 and 5 (batch size=2), + # in first generation phase key_layer have shape [2, 1, d]. + # key_layer[0, :] corresponds to the token with position 3 = 2 + 1, + # and key_layer [1, :] corresponds to the token with position 6 = 5 + 1. + key_layer = key_layer.contiguous() + query_layer = query_layer.contiguous() batch_size, hidden_dim = query_layer.shape[0], query_layer.shape[-1] q_pos_emb = self.alloc((batch_size, inference_params.max_incoming_seq_len, 1, hidden_dim), torch.float32, "cuda") k_pos_emb = self.alloc((batch_size, inference_params.max_incoming_seq_len, 1, hidden_dim), torch.float32, "cuda") q_freq, k_freq = rotary_pos_emb - # inference_params.pick_freqs(q_freq, q_pos_emb) - # inference_params.pick_freqs(k_freq, k_pos_emb) - - tex.get_values( - q_freq, # [max_pos_emb, s, 1, d] - inference_params.seq_len, # [b] - inference_params.incoming_seq_len, # [b] - q_pos_emb, # [b, 1, 1, d] - inference_params.max_incoming_seq_len, - batch_size, - hidden_dim - ) - tex.get_values( - k_freq, - inference_params.seq_len, - inference_params.incoming_seq_len, - k_pos_emb, - inference_params.max_incoming_seq_len, - batch_size, - hidden_dim - ) + # inference_params object is aware of the positions of incoming tokens. + inference_params.pick_freqs(q_freq, q_pos_emb) + inference_params.pick_freqs(k_freq, k_pos_emb) + # We need to apply different positional encoding for each element of the batch. for i in range(batch_size): key_layer[i,].copy_(apply_rotary_pos_emb(key_layer[i,:].unsqueeze(0), k_pos_emb[i,:].unsqueeze(1), "bshd", fused=True)[0,:]) query_layer[i,:].copy_(apply_rotary_pos_emb(query_layer[i,:].unsqueeze(0), q_pos_emb[i,:].unsqueeze(1), "bshd", fused=True)[0,:]) From a0e35dc3a5015b9dad49839e5de6648485a83d22 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 15 May 2024 11:04:36 -0700 Subject: [PATCH 117/244] attention.py refactor Signed-off-by: Pawel Gadzinski --- .../te_gemma/te_gemma_loading_weights.py | 49 ++++--- docs/examples/te_gemma/utils.py | 136 +++++++++++------- 2 files changed, 111 insertions(+), 74 deletions(-) diff --git a/docs/examples/te_gemma/te_gemma_loading_weights.py b/docs/examples/te_gemma/te_gemma_loading_weights.py index bf93e9fe3f..772f58320d 100644 --- a/docs/examples/te_gemma/te_gemma_loading_weights.py +++ b/docs/examples/te_gemma/te_gemma_loading_weights.py @@ -1,6 +1,7 @@ import os import re import gc +import torch from typing import List @@ -9,7 +10,7 @@ from transformer_engine.pytorch.fp8 import fp8_model_init -from transformers.modeling_utils import _add_variant, load_state_dict, _load_state_dict_into_model +from transformers.modeling_utils import load_state_dict, _load_state_dict_into_model from transformers.utils.hub import get_checkpoint_shard_files """ @@ -18,29 +19,14 @@ both with HF and with TE, we can copy parameters from the first to the second. """ - -def from_pretrained_local(cls, pretrained_model_name_or_path, config, fp8_init=False, qkv_format="bshd"): - """ - Custom method adapted from `from_pretrained` method in HuggingFace - Transformers repo: - https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579 - """ - config.qkv_format = qkv_format - with fp8_model_init(fp8_init): - vanilla_model = cls(config) - variant = None - if os.path.isfile( - os.path.join(pretrained_model_name_or_path, _add_variant("model.safetensors.index.json", variant)) - ): - # Load from a sharded PyTorch checkpoint - archive_file = os.path.join( - pretrained_model_name_or_path, _add_variant("model.safetensors.index.json", variant) - ) - - resolved_archive_file, _ = get_checkpoint_shard_files( - pretrained_model_name_or_path, - archive_file, +def _load_fp8_weights(vanilla_model, hyperparams): + vanilla_model.load_state_dict( + torch.load(hyperparams.fp8_model_weights_filename) ) + +def _load_standard_weights(vanilla_model, config): + archive_file = os.path.join(config.model_name, "model.safetensors.index.json") + resolved_archive_file, _ = get_checkpoint_shard_files(config.model_name, archive_file) total_dict = {} for shard_file in resolved_archive_file: state_dict = load_state_dict(shard_file) @@ -51,6 +37,23 @@ def from_pretrained_local(cls, pretrained_model_name_or_path, config, fp8_init=F # Force mem release. Taken from huggingface code del total_dict gc.collect() + + +def load_te_model(cls, config): + """ + Custom method adapted from `from_pretrained` method in HuggingFace + Transformers repo: + https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579 + """ + with fp8_model_init(config.fp8_model_init): + # there we need only to create model + vanilla_model = cls(config) + if config.fp8_model_init: + if config.fp8_model_weights_filename is not None: + _load_fp8_weights(vanilla_model, config) + else: + _load_standard_weights(vanilla_model, config) + return vanilla_model def _get_all_layer_prefixes_to_update(hf_state_dict): diff --git a/docs/examples/te_gemma/utils.py b/docs/examples/te_gemma/utils.py index e052357187..b316247640 100644 --- a/docs/examples/te_gemma/utils.py +++ b/docs/examples/te_gemma/utils.py @@ -6,7 +6,7 @@ import sys import IPython -from te_gemma_loading_weights import from_pretrained_local +from te_gemma_loading_weights import load_te_model import torch from torch.optim import AdamW @@ -18,10 +18,25 @@ from accelerate import Accelerator from accelerate.utils.dataclasses import FP8RecipeKwargs + +from te_gemma import TEGemmaForCausalLM, TEGemmaForCausalLMCudaGraphs + class HyperParameters: def __init__(self): self.mixed_precision = "bf16" - self.model_name = "" # <== Add model weight location here + self.model_name = None + + # Weights in fp8 + self.fp8_model_weights_filename = None + self.fp8_model_init = False + + # Cuda graphs + self.generation_cuda_graphs = False + self.cuda_graphs_static_batch_size = 16 + self.cuda_graphs_static_max_seq_len = 256 + self.cuda_graphs_static_max_context_len = 16 + + # Finetuning settings. self.dataset_name = "timdettmers/openassistant-guanaco" self.dataset_text_field = "text" self.learning_rate = 1.41e-5 @@ -30,16 +45,16 @@ def __init__(self): self.gradient_accumulation_steps = 1 self.num_warmup_steps=5 self.num_training_steps=10 + + # QKV format. self.fuse_qkv_params=False + self.qkv_format = "bshd" - hyperparams = HyperParameters() def get_dataloaders(accelerator:Accelerator, hyperparams): dataset = load_dataset(hyperparams.dataset_name, split="train") tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name) - if getattr(tokenizer, "pad_token", None) is None: - tokenizer.pad_token = tokenizer.eos_token def tokenize(element): outputs = tokenizer( @@ -85,29 +100,19 @@ def init_baseline_model(hyperparams): config=config, torch_dtype=torch.bfloat16, ) - # Needed for the cases when using TEGemmaForCausalLM. So adding here for 1:1 comparison - return model -def init_te_gemma_model(hyperparams, fp8_model_init=False, qkv_format="thd", cuda_graphs=False): - # Init the model - from te_gemma import TEGemmaForCausalLM, TEGemmaForCausalLMCudaGraphs - cls = TEGemmaForCausalLMCudaGraphs if cuda_graphs else TEGemmaForCausalLM +def init_te_gemma_model(hyperparams): + cls = TEGemmaForCausalLMCudaGraphs if hyperparams.generation_cuda_graphs else TEGemmaForCausalLM config = AutoConfig.from_pretrained(hyperparams.model_name) config._attn_implementation = "flash_attention_2" - config.fuse_qkv_params = hyperparams.fuse_qkv_params - model = from_pretrained_local( - cls, - hyperparams.model_name, - config=config, - torch_dtype=torch.bfloat16, - fp8_init=fp8_model_init, - qkv_format=qkv_format - ) - # Needed for the cases when using TEGemmaForCausalLM - + # Adding all params from the hyperparams to the config to make the code simpler. + for key, value in hyperparams.__dict__.items(): + setattr(config, key, value) + model = load_te_model(cls, config) return model + def wrap_with_accelerator(model, hyperparams): # Create FP8 kwarg handler if required fp8_kwarg_handler = [FP8RecipeKwargs(backend="te")] if hyperparams.mixed_precision == "fp8" else None @@ -137,21 +142,22 @@ def wrap_with_accelerator(model, hyperparams): def finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler): model.train() - total_loss = 0 optimizer.zero_grad() train_dataloader = enumerate(train_dataloader) - # Warmup iters - for _ in range(hyperparams.num_warmup_steps): - step, batch = next(train_dataloader) - with accelerator.accumulate(model): - outputs = model(**batch) - loss = outputs.loss - total_loss += loss.detach().float() - accelerator.backward(loss) - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad() + def run_iters(num_iters): + for _ in range(num_iters): + _, batch = next(train_dataloader) + with accelerator.accumulate(model): + outputs = model(**batch) + loss = outputs.loss + total_loss += loss.detach().float() + accelerator.backward(loss) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + run_iters(hyperparams.num_warmup_steps) # Warmup iters # Get the timers ready start = torch.cuda.Event(enable_timing=True) @@ -159,22 +165,15 @@ def finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, torch.cuda.synchronize() start.record() - # Training iters - for _ in range(hyperparams.num_training_steps): - step, batch = next(train_dataloader) - with accelerator.accumulate(model): - outputs = model(**batch) - loss = outputs.loss - total_loss += loss.detach().float() - accelerator.backward(loss) - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad() + run_iters(hyperparams.num_training_steps) # Training iters torch.cuda.synchronize() end.record() accelerator.end_training() - print(f"{hyperparams.num_training_steps} finetuning steps complete!\nAverage time taken per step: {(start.elapsed_time(end)/hyperparams.num_training_steps):.0f} milliseconds") + print(f"""{hyperparams.num_training_steps} finetuning steps complete!\n + Average time taken per step: + {(start.elapsed_time(end)/hyperparams.num_training_steps):.0f} + milliseconds""") def restart_jupyter_notebook(): # Try restarting the Jupyter kernel @@ -199,18 +198,53 @@ def restart_jupyter_notebook(): warnings.simplefilter("ignore") torch.set_warn_always(False) -def generate_sample_text(model): + +def run_forward_pass(model, hyperparams, num_iters): + """ + It runs num_iters forward passes with sample data. + """ + model.train() + train_dataloader = enumerate(train_dataloader) + + for _ in range(num_iters): + _, batch = next(train_dataloader) + batch["input_ids"] = batch["input_ids"].cuda() + model.generate( + **batch, + max_new_tokens=10 + ) + +""" + Benchmarking and example generation functions. +""" + +def print_sample_of_generated_texts(model): tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name) - inputs = tokenizer(["Some random initial str ", "Another string ... "] * 32, return_tensors="pt", padding=True) + inputs = tokenizer(["Another string ... ", "I "] * 32, return_tensors="pt", padding=True) inputs['input_ids'] = inputs['input_ids'].cuda() inputs['attention_mask'] = inputs['attention_mask'].cuda() outputs = model.generate(**inputs, max_new_tokens=100) generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True) - for text in generated_texts: + for text in generated_texts[:2]: print(text) print("=" * 100) -def benchmark_generation(model): - pass \ No newline at end of file +def benchmark_generation(model, tokenizer, context_length, max_new_tokens): + inputs = tokenizer(["a" * context_length] * context_length, return_tensors="pt", padding=True) + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + start.record() + + model.generate( + inputs['input_ids'].cuda(), + max_new_tokens = 256 + ) + torch.cuda.synchronize() + end.record() + + print(f"Benchmark with context_length={context_length} and max_new_tokens={max_new_tokens} took {start.elapsed_time(end)} ms.") + print(f"Peak GPU memoty usage: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB") From 40ce474d373e943b59c1e49a82874865c090ac4d Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 15 May 2024 14:19:01 -0700 Subject: [PATCH 118/244] fp8_model_init tutorial Signed-off-by: Pawel Gadzinski --- .../te_gemma/media/fp8_model_init.png | Bin 0 -> 43656 bytes .../tutorial_generation_gemma_with_te.ipynb | 264 ++++++++++++------ 2 files changed, 171 insertions(+), 93 deletions(-) create mode 100644 docs/examples/te_gemma/media/fp8_model_init.png diff --git a/docs/examples/te_gemma/media/fp8_model_init.png b/docs/examples/te_gemma/media/fp8_model_init.png new file mode 100644 index 0000000000000000000000000000000000000000..c8f9a0b416e80637688d418c7c9875532df7a43d GIT binary patch literal 43656 zcmdqJWmuJK)Gj(rR792vA|PPUskEeufRj#XQ0Z=v#zYB0Qo0-Fq&rlk1tcepbc4jC zV~=;O^?m!h&N=(+eeLV~*w^8YQ+Vea;~8<^_ZZLnQ9(|U=mNzB6beOze)3oeg*uat zLYs2NHxp%wAu;9$A9gg zEC?^oJ=6C-sjPs6@B`Jur)OAi-hAy4uKVbo!xgD>w`fziB7g2nj>g2IV~Zk#+(mKq zjf7Tfowcp=%i5r`Ik+05r5^f6l`T9ybcI2OrDAa>#F@B^sb!>;_kb^_(SZZg>r5GCs6iD{(dwa9QDb!j}$1{UA^Kfp*C ze}sgDVn$A*P(i2uS*V2X8=IRBPFG?5Fa7VVXhx9?eMn(6d_7hCpEb@XLe|}e`hV_0 zEA3V&h0D2`#ZrW1?4h*6f4-ig;57IUq6PPdLNOB8?5_wO?o0*-rvwH4B^xVH+&{Fp zsCS%|l0u8_QkEA^r51J#m7?^UX!?9dGylaqxj3Or_?wHXYnEYaxa)js#EsEwp~Mj! zdVO%T;c)zY`?Q6=T&*;$%T&{4*=XLjvRa$*N9E<^Ve{tUGtT8zRYosfB&Es4W_|f` zy}Y7gVLC#uuOWc4*k!{A-L+gghS^`K#H#3p*i>BMcU%ZflZjNCSXRkY5|8)Tifq$o zd7ATUqTX*=bF8;lRW6t86#MyE1MI)dHKu^4-*w*EzDqL@#SKtg z_NgxRnVPrmz0DAu^oP{Q|!cv z6YnsB`VAMC$7*YzEVs2DZjWK%My^$8KYsjHWN+?{PK~=pFgKj^5Rdg@>ocdPr#EX2 zXB#Xwk9hIoMQBjaO?i2FxS*uvU{Pnn!#5-(Bq}+Ur%*B3jExu1%nuf4$9ZgNL^o2q z&fQ8)Pai6-iHL}BcXubbaU&<3O)E>U)-y{rJA4Ei9~6`T)}Q~CDCXD{FgI#&N$CSF7*fP#l}CBkPzl!JpqKU!S)SZsc% z+&4zFLg}e-E3f5XhPdCg3Y+oQh=-7} zJk9Vt+*=r{^~$VfALqV7&&8#TMx&*qq?A-tLV2x*GroR(Ap7O+!&|p*ox!+`e?QX} z%_q+w7t@5-mXh*SQc`k&-&qb6h=2R`&GlfdvA4{2@?vYtn~#_!&}n;&&sEjBkc;*50IoKr+AwMFsxk)}np8PPE?u-w0|LB_8AYh|Re z_~qJ*5!c0B3N9o6;gQ(mUY_W-y>TC!&5aJx)))DqFWm&4S9)HbAqoyY_{Q7|S1MOUTwGj8 z%T`*#SnwmxW6cJf)2{6M{o_aw)wq|xC!$R2OqR62eKDbnxp`O5!J$axXh)w))Kds& zpP88{B`eF*e*F6NYi`pnS|0PhoSY*Txk9^XsnoQzEXhD>vtJ3X3Zp~P>7>K^hgW+t zLr$JLMOAP-X@7Ex?cO~__`>JV-tK{+p&<+&3Gh1|3!~MQR^`*z(7`@FK5_2r=;}Y; z3E}Q>v+-USSY9T6{xX%p#e=tmm&pG5OJS~AO*ac-ZjSTJ;jd^JdYx)l)sG)PHt65G zcMpOsD(>aPdHm7Wp)l?}`IyyQgPEI8pRxg4U;!MuE zW;YEYKFaz*HH4(c&bSY6q8So8H*T=<@$n%+LB^pQ8tO>rG~E*B8yJY2Y7QavAU-HPQ*Kvpl00D12fprbbl`G0Npl?J=O>+Lx%TG|qmuPq z^O~BPT@hz+A&)=b2n6hBWb#X5yTR<6%?$C^w{>)MdUG_mAbZHJt*ybI{eK!2a&mDM zJIpJUmz9}JHGNK9A4H*MGt{uU`F$)FB%f<*y=qzwjEtChcvQg%EkA#Th8D-`RXXL~ zt6-|tPr=sgweUA}SV_*d#|BV$sbRD`U!5W#xq9`EUwsmU1EYX|rq}WQ*z%z9sRtTs zczjrVyk34E9`5(gZUO&sB}C5X%Sj@gN+;RisfmfCt>q!auvwbLSxU;vj(eM??|gh( zEqE8dzavDJK-Yauyx6Qa>-w=YQUkx7y>x?->3Ex;jt=6GEjZM1RGqqRK3fgTEFdsg zChD}LPDIK3^A=WzT!ew~cxjhiaFXZF?b}5b1DxotHoh?nHU2|z2D5r!(qLLqQBmRb zraQbHIE;=Kl-h3Ncn|uJ<`iPl0|i)2qfCc5+8E4RYOe#^Vz^0}lEay3>uVK1Q>7y| z_g8Dpd$Uv#3qcz8q4CU+2%wl2P~?e}i4t}zxh&)6=4L)zX1jpH#$<;OY9ITTgruZ7{7+(Tadj07XHwUxaLCu#S{f*f;Q}7N=tQDqy)xiq}| zX~Ang-<4V(J_0LYOV?LlFWw<~Ty>U+B6P9EKO#a6V@GV<5!Y6Dv8lkY)&E(sR2Z`> zzE`cc*c_*Q%PCnRV6%k<+y6GD_Haup@wj#g)eEP9(!}Dh`&uPeQ}qYA2DzVHHlZmMj*El-7S-;%1rVibzHH7I%3Cq* zO-*^_dc)Lv;`Otq8PyfNL;(h+Z_IW`2fu#vW+Rtjd1B%P7x@HO!9t>uglwEUCJg`! zFW33ifx^afOoi~wtz{!%mPh6x$l}tMwGTw&z`Tv0sCZim&YqpA6xP$zOK$iS%JVZZ zVXp;pm#@EnC@l|wg|A)L#TC=#sG^@+4^)>^u~llLA3X2|@RS7JI3S#e*1be6WYR;B z4f%xHW3$se0#eT}F=F8q+sOt%ak@z>+^B>f09`HKvud(li`Eac=xe;XOC(lm^H!=Tmn%?*nkwSOLeD zmFYI4LeITDp^A(8oZY$Io}C|5}Dpnkg46&_Eg=qh^EAtRx7-LI)-lus{F&`E%mt=H`puOB|Vi4$Z#5J)e4|%6hD3qzrKu#Lv&LkCj+z z6cUN?b&Zt%?1VMBcPUQ%<%*@_m$#e?%QvU9UrMF_5FC>SAEi6z-2u}cHqbJy7#7#P_Q&aB&RM?yh6qRZbm~9)K87MTycS&^umb zT3NwhYmWpvaIMt6A#rgSxK~q1J{1o0!B?(anOa`X1#ir}Xe~ zrREfb4FD_v^uUVlBIKw_fJ4y1Na=^dtvo+JKUiSsS7JG2RFn^fxNz~J3gmjn#oW>v zH_ti$4zwbyVE`px``j?E*OB|~?(Y0TZ?-1XJEzsrToyIVlNN!Jx&HiQtIB2G=t+NA zKJY@cR;g7l6ajzMB$tgDglLS+HbA)?EHVwbL?dEh(b(1|>*?uv;nJn2l}^h&wO(G8 zt;%nGe2~x<@H{9(9%*_y6(ZMh?$^UNXNi;`t@!O&8OpL&+fGVA30ws5U*WVIH9Rt+ znybmfW7_o?Tumun#6v2z#@4iDwK8UTt}|I;HyIWaNsUTR=ezKWIP;D;VP8K#Wk8nz zgL^U*?;v(K+{3pzD|H^|+C$0GEVFrznFgY1JL%+eiOPR0#bT(`8V4RhhvqTs$t(Zx zCH|J!`|DR3Fpz=je262`(~T1sr_2Y*z+36m#&0Rn$g8$Gv+4t%yea9Tc#Kh?_Q(FtY(qcZB*upHq;>#nIxU$uZJp&YK zN~O}t@=i=UMlZ6hu1;)pbQB>mp*%E?p93JUT#4v=!i*~GBy~D69y#9>H>QPdSha7* zFZb4XxWU&hX|3Du{COU|FZt;Pla0&b^1HWhLtSF^ot>SdDO&r2LqZ&(E+F&>Xim2@ zM|H~Y-@iF6<#`wFaiO6k0I|eZtGDCsE%0jrU%PVkDsp(PMxjK9(1yQek;!LsxAvVF zk}y9%KOte^0K#E#sfw3tNr1tH_2qIc00(uK+s_tx9=gg#a*E4FaVyHm_@~LoV<71b zuu{^{U?K3$q1I9GS;c(+{(aq2%*ZJ18i#HU`0Nx#YS73y8!zwE{q0p;SE@9R)iA$) zQuRRzA9$nXXw|@sE#MKz5#kaONO5wV`Yh?o-7x^MlC@fUEP{Jd=@<_5Lf}!u~Wh1Y%PzrMa(A>oxZ$lcnKuoOm@?{winG^`p&RgYyjsyKQ9{V$t zzFLdGyWBlI!a_q+e*Ad6JWwcE?TTl1qCA&At`Zp;$-6DBbo){83w?b{&nm(w;X&O# zK}S8kF?TFIWYXl>qS8|5IKj@~yp`tODwhqVk&1({>LKWom_U&~%6ETfQ4v;9)GU3IESQ$Ay=Q;9{g#uPQ$yZ-B*$ zEI$r0uNmG=xg2bxOIiM5CgjNJ3~ z{8d^sQb{lC5>X5F^AoVn#q93k7ZLbHN=XT|!YmU2={g>-3O?|?u8x$9jDgAg@!D{? zJ-%IVl_Z;_Q9n>P4azGZv*x*POz1$N^$(qey?es>E6Y*A!AX%^#>_wkAmOV<^IEDD z8p-O^dg@%he7Rwa>(;G@t5qAHMqOrOwaIv5*6uT1TGaOitF*-lYcVk~nGKcXQt+73 z0_Yb#KImo~ij0cN3KTt3x2{@yKk5N6zQO_h@Zm$`IPgKowFz+`-eWt|Y`RtIkQ6v) zj0%kM)wv-BE{|3Zz07f19t5-yVur)yf^&!ix?CNtMmQ;OylbrL38A5(?hIx7dy81h z*Yl8r-o1aH1>^|mSntQH3?x84A*)T_|K^~3ja}PB`C8J`T+M~9RKL!Kj0{F_i^g+Z z>EPhAac-C%!BgR+i>ujy0^zINZd(?SZAM^o8SvE*;9wZ-@+3eu*QsvYzI}TW5a_}; z0-B=z;=xjD?D?zobC7J2;^5y1&g_{yY1H270rifTtaE9Zdv0bX4Vbt|cbaT)3asW| za+3|<4Epx9wByDS67NM0HB7PV=Np=E1&8Z<5{+5SziN1wecOVmk#F15*IZmw-RCJ0}sMP`<5u zFKMy*pY!tW-iaA4v5Y=TOpVA7GnKiJX@_zM6WkdWb+Z|@pU-+Bfe#r558Ohk(~71={NHkVz%&j9{Sr_kc=+NLCj=8M8Ct-0;55)7eZS={jhvJ3-0N$ zHDJ64g_ahrsHiAlH}xdHg?`kbTQ<1Bj8Ds_+hcFDnXA9lsu9c%KAsL~dI9i!NAZDm z?O_^6?SYcgQ*{W6u?66dMH&s{@gWm=M_yk4V9vc+(PL9Ca(k!!a1K3orff3&&qyxB z8w02WpFV&7oOkFxls1Ijfe#oO8JWSYf%qVU!5zUVEnd%mz9VV&fZC6`F0y$XA8vyZ77SX9D^0F$ zwWe~q9MAjf+wt^Vtsqr-j5L~YO12e1kJlLJMfjOUyHgqi*M{jK2; zw38X;lPwJ92rp=9MZ05_NdQKR6sP{-8j1MjymJ2}f!t`D&T z;b=5mJD(Q4e$j)j8MvUlFg{dqaS#-Q2tCi{;$g6<|HdD1cyl1KKVyS+qGnx zR`oCPP1Yvr8x8@1nQqQ?x|HNYDR7;MH1_9hXlxuhJv|+XALFPjY6V~!WNdl`XbFN9 z%e(+j-IldbztI;TKhqXX#w1pVxfT>2o|ci3@uKuiT(?aW2#=t$>NFg}LigjvKY&#L z?A`>Gq^O9yzP^6oZ5X8WPOuN*jXwQANUR=5yR!|4kZYpR=k1y0_^SH?csj1mPV*-rd_Y%vYTa^L0t<`|D z86FX|CMh|&KE2~THDewsl)^&;1B2_;g(*3l%DVx<6{|J-Do|As#{r%3*ZOp8fUeyA z-Mzg94zFG5{F!!?w{OenMlGUbM)S(*xv#z3S{ca(s%Sp*Gji9e6qeK#=H z?RrzfQ-s~N`cLoOk??2Cvjon8pz_PtR)@=#14Z^S6j%FS*aA)(>98^3{mRIb)ke&u}%3OwYphPPgzHvK~AoOZepc#qnvX=AMt3noM{?)g#g zM<$(m`+;YHpuQBdU1l#-j@}GKF<`0OZL*YU{A+xCtQdFbAx>nx_H zr&DqrYJxLcf=s96!0ka+PxXtr^Yw!ZE}dIiK4GO3 zQg9QL#BNYobASC(@X?uOGw-WE-d-KU0wn>=i$n|}iMJJime<=z{Xj4sqyY2ZUtiTzyfSB_rF{$y0+1sR4k$yfjEUPCcy4cF zlXq}%(32!i2unW)&*jn=TXJdU2B{nruIgVK;)_v*TXOX3QP8INY`ps19Exe_`D#14 zm@jReu;Pq4p&y{zVOt5tIpb7PsLMM``!m>b^%!br<^+yR|Kf!MenB;QdqP4%{rxzUd03$9WK=DC zf@Z?5dHiF6fB>$Kbyu;6_WMCI21G~A{65P2b_R*nPGg`8}81=K;aB7nH+0lApNE97$C{q964-_5fKJ-EyHi9w5L;^CvX-?S5oDHx#DL zDblWbsc1j*QxQ_715^R?xnGG+y+tCfn;8IsYF<>Q;s1&*2ok21Yq^j`$!&tjU$;6~ z@_>N5LTzGw#2;L6~9297~zO+vH%tq`WX1g1^3ZCp|ZT7U#(1~j(vs{gl1%%x2 zI384oS4%d7t~HEqgAAlc8bvENo>@0XL-LB&tp{vaKK^|kL_3lSf5-CZs-v1zIv$>0 zqsMOTa<8|}UA)NssZw8G-*4pFchgYj#YE{X6Yx=b`8W%l9Jdw}9{=Cml{t+7bvMCm zjJIz;LOOzwn;n<>?+I`9>x-LpAo4q+GCs5ig(uE!rQF?EuifeQ*AtUYiVL!24zFWD z-T;*rVRW!Oh_Z_$3W&_budhxaE4`7k=8&8hm#-0wv|vv9oOOBmat>UiulD%R8m%GN z3e*yjiVO$kP+tS4A!Ap7MmN&x28RJ%>G^Q^&iZm(`}?~8m<5;zIYhm>lIn3}%p1)EO+I}UGl*zVRGNh% zAV1)cr-i2h6-yW3b2r5IM(j@ZT_h}BU7?NILm%_$!-u^4lA#5zZbm*XK+N;crx&^@ zOKa3^5(>t%vTap;DwM?cQ+jLoeCj(9H8tFUwk4hBQh>O)QXP(;!z>B3y|Z<8R)b&o zJLbAc&F1V{cDSjjo5XmGKMd}JkjtUy>KiIk?B6JL>@nGm@xp<)L+V1=M@GJpMP$wK zeO@X9T&F(jVoC~EVntb52K;?}ds_{bu&3J18I~*mv7n^2we^`ZXP6inpFlzH26l|E zC)0Zpe6%x33}R!C5$C%DbTVO5*EtCwPjr9j({ZNFJ zo;^zdA*wf5i*LlPmCbc~#5oKC63XymHnxVD+Z_}M=N&nm?tSO>txf(_(3G@wbiyJd znV7O!FF*%<_=|71pcwP+TO+7m2p0vb50+SF1I{Aj5A8AFoDY!~x%YZraHTByYvNqB z+cpx&Gwrb~Xr$Q&d{%9u{4Oe#YVooMsbUzjs;?uvamU^L*CewC)*!k{{GOYfOa@SN ze{3TXa^M^|X>TZ#G70E);EsoY?mK}g`}U3^DZV+FF4ztSd3(H&p&iP07J#n|P{?OA zN6G1gX=e#Kw(|H=idQ~_g@+q$#Bhl+)T@Sm%}?=?TvYn&j3cVG*;QllveEr9FSk$i z7u-g=7}du83%Kc2<64SlhPO%6Hc4(lsz7liGt^e`kx>#?4&tGBh01=x2c9Yntd0uyR1n9w; zxTO-4&ZJ+7FWX-nij-GY4mk!*%+F^65Tp>ESkK-*k77IqngVov=TwBa zsk-_XK!KnbY_=f%m>j}LNM28oEWbR+1Kmz-qr!pfw?e4o2+o`tD9VJMy8sYnOhtG5 zg>>uFTW7HYWbaakU&VnD`Cd=FAL?5z7Y-z&bxV2ZnRA>5#xEexe1=Vu`%R#AUy_J) zc3(l$@b$TL#qrHWT*tZe$%zRPDymp(*%Z)7syEsNB@XX`<|Z2_RQ{ycpBVJ;VMHsQ zm>VcezbvQ4dJW{opkp#sbXLY)BeS7L5!>b){NM~_O#(U8|2PI!2#(#;)-fsAVB zz0G+}Q%aNyEwpTTi!A{Knc`ggPJ{vH2j#0L1s!s`)^m%oAE4Mgcp*ZlK{!is6{^i+O zW96}R=wTb~xf(LfK1THw{wS*deDTr<9CIw$Nrs)=pZI@R%} zMo^4cr}TQXxra!EPSzx%#ddVYT%|OM*aznoAyY04kTCdu*x6D!p zS0dyt*CM#++OZ|4u6?18$(N_o>MoqG}_TO%XHsOP7Qj>P;i5f$c&V`lAi$raAJ$l z5l}9ma@3+0VVI6}4EzK5tiM>Uy;d3`IbaCTYoW^usM_zWuz9%Rm1a)>rr#+)C>eWn zN-~eWec!h@RrNuU69#gw!hF=esT-aIbp70v3}#F$r5Oqf99Y^p%hg5 z$do$bMXhnKDuMmo?__SjcNV3>3k^ix=V-uhI*S?+dUm$9Oo{~Ia2JgH{OZBbC%p$) zCEB&yMfOVMh#SO1H*_KSzMGPfkyX-rU@h3S%M?tkN1;g>Er;%VXygJ|#ktSkAJWtu zXo&AHg}|~WOQ3o}7Bv~hLSs=D+y7l4969a%>h%}G7jA)m%OD@eS``)W40iuNu{1_yC*na*V;of0h%@ccQ z-`=hb#|zNS3!(V~?3OW=E1aGrvh*QvzBI_#Gt50!`RYY(+g8RUE zbKeN)6m$OSeS)qc%92VhqE~edIK1UZg(|jc-M8&eBvjWMQxr*Cm1B;o#{BPBssAD343rWi!Cwn(l<`qKzG`8f268qcDplRg}B_1LP18A$p3s`X+J9mf)ETzJp1acJ?0i@=pAgh*G^7O!qC=XYm*{0khm@S^R!z>+&9`9mf=*UFe2pgKCaSvwGRe|EX<^4 zLCO%@% zng*aMvjX1dfOB9Bq%}mn5tgD9efHw@1Q737n3$Tc7id9ODf)Ts*51KJ$HwwuG}Jk; z9|)#&I5}Ew?w;{+3+U5;r$o=%_Ys_Q+gfVg5YC0bJlZO)opyDH*%)MC2A~iFQ}$E~ zC;+_$hSK1pOfYAXK@_+Aihw$g+~_pwfd+-BXBDD>UDgJ4*##uL2f`QOg%}J5QB+_A zij{fPqVPv?o4$xSM$k3PSme!D z-*?~nqjWw8_~+F4_${El067PQ;v!&Cvp|>k4G1V^5>xN8sel0$3G5GG!Df(Hp|czZ zO-62zfFV6(Jsxrn2h$)e_PvE{EYht4T>uTt5}Iw@aD$do8fQ@aaIi;Gkhi zCE~6QrT#N5GVV20VyO&j7AeeRfTOX1OM}7gjnaT7AREK42m^uTCGdAaM;N=8srr;)iDiP;*MJd%Qf5E36B4!0oG6;D-xw}YZGo0K*`=_w>&8g4F{-%px#Um?RN|fY3Vn9>VtD0@7Q~;8nYM;?kgVwz#2L{bUk&E#``lBS1=+S?R~>#nmMqJF ze1?!qM6?hW7oQe<_yR}){yQPZR8$dLbRmL0(INEB(C9~UD9Fj|x>aAST?#-FHJ$kJ zare_-K}Tf##(_8GAu?g+!!q0!k_R&S2|7)4Z*T9;!ORV+*kx!as6qD!k4yx?)Dy_s z5(#l3Kvkgm=>WafL8^qOPoWRqoh0rDZ4d5_ID58z&yq`20!@%h0CmFz9-=3~?2{GF z+`##6Xn|cpuz||k?y_4B;IvR|YG5`h84@Nqx-yK>sXs;X*gb~YV+1!?6i43%QV zT6EOC%3G6Ce4?YHtE#J~S68hM#+c1w4Po|F_~1_+1-D5Muk|P*lre!aZE9+Yur6SN zz4`i*L=@Z^k8o#Cp+fZ_^*~StzbB@_fYQPOq#s;H0UGSk6HqNSQ$2I`Y!u;`;#96nbk+9kU`p%Y}?p2~eo84hT({ zZ7R|ZRQq>e`l{w>l}iW&_(Uzt%#@weAiO<1Id$zsQwcnO^5akN90=u`d*=(<7^R^W zl(tQvlq9bAGH=R!IQQ?~f-o2k_VB-?gGi|Dv4Vs;#^A^e%aA!E7{7%7=N^QFE?|8) zD(B(n!^dllG^z@p10YD}K`ag~kf0`_5-9p%dGkL;^u49#fOM$B68+tjgJCnVnpPwE zB1Xfod8Po)sDH2ipX2n?O`t47%V$?<%nI1(h1>zc#=FsN2qet3E`>(987RG();JA- z6PrzU0_VUbf&r!l(@FOfo2!B7)@D>LdbA_f7)TRA%LDaW#Z^}NpLu+5sjtAL2V_wS zpDaeJG+=6y9ls2XH4Sj_)EXTU9Zk(B7~f7-8e6d}9eJAg-4qRL8~}k^1XH<48v-K*v-~V~rP?&{)+#B=3wvnj^ z|8A4`gSaN6*MB-F>VL9HhM6a%d1?vaLrnkf$YCyqx-R&)HKhJ`_HKYd|G(MjGyLx^ zB04gdoPoU@{FUU4XF;b!w1|N=SS(H$3=gY+Ve7=D@EGRfFPuLQAEk)?-h!fVqDuJC zGS*LEcF8E=i>?PA(sOUi-(R}|?$ifVCT8R!NjF5QV~6lIoICaP-`kKfP+i?!sw{bF zgH9NJVN>QeJ7q7a`W>GDU2eC!fJ(T6hM6QMee) z7o#P8s&4OoB@^Ro<;M28uNbj*D~Z(KxxfJ?R)=Cn-YV6k{#SQGB`9%4%CuX-2B5Gj|8BWaLiS62z5icUes*Q)EUN6@3e7VmxHu}eWz8DZXKt25W!3&#XJF#Otb32mOc@(8(po4_ zfJ#_rHoInX{Ktped~cZ3igf8AKhf{^fA+D9z07&luw?ea+Y1}kCO?-a>Cn`1)O)VG zvu5O7-#cYbqfkzH&hev(_HG{Ko63@WkeNd7;I3dYk;`e;(&6k`BCtx?U$=my-6Sk~ zrD#v_;N8T+0!)oofYl%eyfxcVQ!`n$YPW){Qg$Nf%y8Mf7+I4)K% z4D+jOU!qJ8sI`6yr}Fg+;OhGrP@c4Qhepxo{fL9K$rs+7k*S=SQz-A_h{7C#eFxYs z6(6JVK-(gQKEO*!CmP?CEBt;waEx~$?b4g#g;|~k0bt8fdrON8T`4}U#q-#}?LVaiKi0CQ5v*WlT?1#=-xo<-W!kW=C_(bC;%GFCfCF(kK7tI-Mv}ku%?6 z{wkla;IwI5l!#ioe4MRrG7D3SE2=Ef!K&*HP9JV=VOV(kEqdRHMgg88?&b4>tUVRg z#gbeH+?UI)v%T9E;gmd1Zv`9|8POGPOFDZUI8nwm&7^~@-LN(z19peW3HbyPTl2(juTD)9xyoIV*MYfKs12WQS+V!{m!#YX7p z0*X23tH3KMG`hF9?B!&>7{Mq0d~rAte}8&1_ta7*F)HJWSuci3O-)UQ-}&21)`_Nj zh6V=w`KnFC(h|maV91%FH<_1Q#BrZ!khAY+w(Uhm#@an z+GA_C<^MK)o85@3V{hPME#G@vj8?I(NR7xkej_YJd&q*k>-HE9@&%posz~)%LXCi( zZLMv*hXABH`QnA%^^tgD%6lqx5*Nd05k;Ssj6_>Lp7rPnEHWMXBkoFARB%?sLq}RP zq1VA}+?LZpd3jMZvEZ*Z)Pr}1FdC?}D!AFlvdYXury7x@zd7%tXZ9=TM7tG$55|*Q z^Ni-O?E|U&p-B;OF;G`eEOU`p^{%xr z|3dzw2x)m~HON4Nh5Fc2nn)sT8THsp!)x`V=`s)$0lp1Y+oC28I>S%HlYT#^mSnw}~hJnA+%<~osBYu>_z&D5{6 z@(uKhfIfpiZd5J4r_ z%$tjo6It!Kw5>NiiO=~)`W89%p0C;IQ}&6uykOVqgOQMI5ley$#JjHxYJ4_MqD4c~ zAot{uIi91|^q@)7!G1fjqI_Q=SFiA!qV*jVN^bV2$6l&m#OBT^`7FuonuQ*ImE^hx zvoXx6U_0VT!}v=m?~TK?AqhwGQ7xA6!>sW)D^JIqPT_xz1kQgi_X|#uUF6M%FYZ0> zlcLqt7-*qU^t1+MWGzqyE5)7)k|aiOayU4n=_IhV48a)_L3$+IoG z_yLEEr$#A$&(F$C^a;(8xo%{QGbt5RaUaNR-AI_D3neu3P0KyCkKgd-4&}HMl069) zsun8p+CFjS46eK3ur|BG%Mm%%@koL47)CpT>T)#CR-zle{|nl;<}WsM)c&aCV4*$s z+TlT%=g3}27xe7DPes^OZ*|qT;v20gi>GkbMXQ^as(JN&)2fAZ6py`Ln=P-c$+yHj z@AGXFHCz4^mZib+V4K+M39cL43bWwBo6}WTp2d*u&3%k2g=YtqE54ksRD!$h zjEG0WBG>fHLfV*fZr3MKzQvimZ>~pcWDxFlY^@?xu&>SD-dS1cMX;0AP-?1qgh|~| z@jC1V*$6aVo)s1r4vUFNs;=BJ?=6Yv`90$~nC}>oCLfziMT+=I73n zY){=gOf23<+e7x74f|NL?pKZrTnVMrlw9cM<`ToNa9p)P)r{vt?!O7CgzFbi3 z&CyYY7}P9#_L)EMiK^{C>|L76h0MdVmf5&@ zAR6C^Kc=YsW9uCrD+#fUAsrMp!*xp6F4 zLYC5OdR?`FvzlO2Y3PzAEjqEfvbhl_nI3pmj*JYFCcpL@iEq3Nf%w-8%ri>xj97i7{h`-AMG)Uo92?}6>ws@jrPGxzI{)PjG^gg&zRB;PqY63;8;niV--eYI%RY3k-gBfwM{v|r z1oD`5{VMa=Qc3YVcC4;Bnxc%hODkfN$|xJ$jgx!#kJf5*PsXm%ymVY}=z zg-ff6ylgyMW- zMSIF6nyPCzgD^`AD;L+gN6@1{biA;)Ve=Sc36mj+?_=Y!auP8`E>@#CXLzx<&`ixt zAhbQ*m$St06kf%5(}*^~4*o^l!Sh1x$+Q3iPQ!j}@@j{*VA%ZKTV^T-uX$jFCv2=o zqCxfWEWWqZx+cL|jkRY+v}9pv{z>#S0tysr9>*Px#mLVlx3@yn6fgF)+Yd;{Qp^Zg zm61>-FV5rW&_vPPJqnf0#~v-s-4y-X`iEtE4EuK{qW7Hk zALtA^;+n6f8x(66Z*=HX*c>Ij7E{(W_%0=q)p%RVT>bcRh5LK~M!P2I7R^PqM9D`a z8X~t3Yj2@Y1_~0|-MLS*DO9YJ6vZ0nTEbo;LKiNcmq}~OB6D=uPL=@m)#Zx}y~oTPRcdE|-AxO*|8&aGy}3sY+$gSE5Xuf+({1dKz% z1GsHBNvH&D)cLGvnXSIdPzl?~;3-5xj>i6a@!kB20_9mOnJ7qtj3kq1O~^ zyd>usGfGpFOzi5Ju@^?kd?Z!1PB5)Jt4p={_#GqNFt2FOACC8PZ{lc}xAL>rf39RU z&7G5t6?}-n91NUB{ieTznFm~$oaX=Q_x5|UK{76bVs#bnoQwt;N+u5MQf7NI3>~9p zpA`lNp>#|mQ`Tf-D<1NVpX53iCYZ5{=QP~7{&<>+Y;;+#Hm)Uidj*2&_l3cu+oD z$EPo|efMc8j?`;9i=x+pUOHG@+>flUs62#3vrM7Flj9F@2<+D+r#k6vua)=gRn5V= zv?Jx8lB%j|>N$0>ePUdnY%R-5W%GO}`-Vp85(Ar7o)XB_A*FnT_eNS_P7k)8?Q?Gj zkaJ0t+fCcq+(zA%@F^8!yc=IjQ9Jnc*6d_doXFO)NP=m6hYD;DP)(B|Po6YzWz#HD zfS}J>hWw%evw95Z6i~3_G@Cr*1UlOh#x_h%F=-M?~^L0 zo>X?dDf?gF7*qfE`*VX7b(XmbSa9%*? zJ#S<@Bmgc}MpW#b^J;<{m9AD{!6X~W_3ioWzAavUaII0!CzcrHos+aCP1L(_h`A&< z#@9VN*HtflNstDGdLr6$vf!+5H!3DB(3)Lbn0!8OLzY4~jPMEyMf~vFqceDuj}w>f zt0^f3pWWJ3k6_nzYT5<5x|eh2p1SH%_8l{$DLHaKk=whWO35;lpXHj^=Qmip z?M-Vs-PP;Mx}I{x=lnS?x9bV1)YGTg5|$22O}6N+0&y0VtY?pXCj*ZRVsta`NmjJO z3M%;*d;@E0PBZeBE)a% z;)rR{75y9T<3GAU z?=)hvqq8TlKgR5C4mI(m1WDYH^ z4sZ0YP_INC9!3PG>}Jo*wOK9GESQ6f5t9{r|MZ*S_P&_;yl{4!>^B2COZ(2#bh$V= z+ixy4XpiZiVamZhkoM~JKH@mMyk5XNtYY5Ozfg=W)%~2EjWiFLRRcOlNA{QYEqS9g zi%em{hN88|FMe<>-e%C94S z$x9qAvzvX*A?Z|Pf1yyr5$+FNmIFtIbsRzcC)s~vU7j{(;kGwL-uM;*uS@;EKqhFbhx$)iVS{%XSyw-(Nj@MMy6<0)N|@>-3AoveS4GR=lhe%K7lQu z1H5uMIsLg!ES%QQqQvV$SK}vYzT3A!3%}BHl9L*Io-_46S8eZ6HG#vO@%Ih>*Ex&r+nl@_|;z(m;_jPDF>_tMkJg_=+`8VBn9D50^2G11vD0SvX5q8Ff@zlrYlb{MB4|4-;(}8~1dl4P@8Bf>4;m!d zZcZZRRW=(g_*_xDhDq{M`}N+jc5Z{|&h$#05RGbWpqqK~L8@zoHe0b*<{M<4&{BJD z{Wc-jho^srfYq}Us z@5>4D0P>@EnkAO((n`T8Q@fejJvljYmX-6}50m}fK~eqsFyoAY5Jm!2B)5;5N)Lwu zkBz))7JEua_Wjw%tti^mQ+q#OCsT_y@VtS|h z77R4qJ&EM6xc~m9vE*59Z#T}v>KiNLrzEE%*_KtUV(SCmQU<47m6Tuj8u#XG`mS|* ztOx^|GkW&`To8T?bV<@r40pGwS{loKv9t-^?}kpQrN(h(NzU=e93%0M-Z0TlNQEqE zQtwpioL>=4!^RSrn<1w+B39ltP-^A#7uqK(Sp8p7j zH=aUkJK)m>c4v2YK$HY7`C!k+-kwVUbknjkkh!;0U#?9^NbcZZ4-N~{{YrS-)3<-$ zuV%wS&kZ!;?;rjf|73aik9QrU_y5bUAB2a0`C@_;iUL~jQsgMs0LbmAhJ^p}cC~;0 zg8bFTo(E_f97SNv8|lb%lyzz8zKH|L0b%;5cwV@-n}o1AsX!xln9tzBQ@`H1bMwB) z*-$~}U&2iK@6o!b^?{4Y$ti--Z6LH6IQaXtKqGSQ99;GX5PdW{T*hTAh}0l$A1M^@ z(DZ*`kp1r=vc{K;;p{qDdR~V}50z7(MGU$siSTw9O2nq`VArb*l*h0(BQ61{3TUs# z8@0#${r-W(#6*ckv@?Vd3`9sYKH-C9c2HuduX$>!wMT2Tb`qS*wEWb)S6y8)NSQ)Z`j<3!@(Oh$6>=g{sm71dd9P zZUdy30FkaD0--4_5UK@Kklv*CmQbXGUQ|SSCq%kR4}=~NLb*G7zB~8-XYS0+jAuBM zl=pp}{p`K=T5H3~jE?~34bUP4ey;*9C0H|(9=|?t4U`dc^?RG%mmkQa(BATQ9WZ5b z=oAA+4A2YEHzCLQ-Udo=?q-wswIl81ccl*Me=ggN$fl-yO3g9}c4$XKaNN3c9d7aT zXEAb{y~BR5gZ*0e}0FMpkO@W!G7x0_>yY5+<&j00Iv?(jP!w{{-a=y^IZ?k z;>>j7#6=K^9nk*a1<4ET(>MV)c4-UKDn-i6pK>>A;7zG2Z#S&EdyQ+WO;*0f1ab6{QEglQ&Zms`MUql z^tF{v)+9l|H6D>A`BM)~EB^WDy_FO-?wX{f1&`p$5H%_s~19_wS1e~Ax;NQh1+t+dVzbk~IQBQ8(r2G9+G9%mnR>6Wf#QSOI z>7QUhH=bu&X96??aY=<(aOH^9Dq3lGy*|r?9z7L_kUm!^#VmIFzA^g z5MB&wF!=7P3p9Pr?_Kh91A|^1 z4*3%hn>2Iva4UcPA`uh_fk9zERDW%`FR=WVXgB=?eKuF~7}4;d?@{J>9y1B_k0P#S zC??|nVtfG+Z`+-HA%9qW`|yHYgqeNAi#N38Onba`yq~iR6my!4UdXg|J{xiA00ziV zI@PYv(#-z+4xMWgt6qF9IG(_7bT{tZ#b*vydZJ^gQQp_pgCj5J`z$H*mAX$aX}cC` zQWm}`;1lyN@NIg2{jP-&nQdKk`q_7W^2i!8;4BG?ZtRDZ!ALcSXF5EyJ*pWPzhoXu z-Ey(R)fx|GRF*wOqkB0&-NH~Bruz!p1J9fUWNfpRS z%XaiBCa7?EI=0n0Yv1QDS_Ep%E~O5ZngwIUz0`JVE7rDJg>S4uyzqIQrTX?~tjse@ zELm2fh8#RpW|PitUmC_`T$PyYTxX;8+Jr+TOcHOKPp;aS^(~bVah^Dg`lYIyAc4N6 zHc083lQ~SnJ$uQvdOPHI6sDu|Yl^)7nx3gGX{t7^VU|vggQ+P=F-+%zH#@^ZuU0@* zL2os;0PX#W*$<`6G^)%No^0S!%BTz#vg3$%dvH|9a@f-mKi3Ldod_!C*JfXJDF*k0 zSx#WkK3;RNIGUFmC^MPa4u){Kl~)0xr5r}m4b}}P+p?^mvh}~rmMP~#Cit}viFn;N zeTmqw-w1AkK$+u9GyAAS^JPHw-2>mV>yq?Emo+rd`nk;;426@NV10lEdyK3p2GUGh z#nIR@*GQ6HfaneLep<-%xgrE6L_q%z2M!f*$KC}U;57r9G%(v>w?WcyEE;OFS;sT8 z8*JWIu6uT2>9M+r+C-vZ_U#(Iq5&INMx>EKJV|d_jpzGi;2cGP++qLz{qKVI^O^x( z1JrmXL4$~&BCH)KJgPoQPQ-KllCx#6PFRFE+UH3iTl(2S7vxa>4!06PAd={lXn*d! zs0!^=yLkrf*!?pr5u*&B?B5#<@$C>|`qC5`)66b}Ghf4C{Q?73@H^l7GaRl~0ZE}` zf&Xl=zV^y+=squ5X7Thgeud~ruJ-t<`<%QUIijmPd__+0OYQbl0u9iJ2RIaTP zSh%c2vzCP9#3seeb|B`9U%Ph`=M8?sy1v=r*@{73E~Waq_xYW?-rP!+ zq3M)I3dz12~MPW3H)r2si5TA!w##Le-4TVZ}iki zxI1&1Ii=Kzwf;!~@Z(97CFRg-HT%x6nhe0@Xi1DoQrp%p5cAZV zto?13&IG7tI=zKe9)DL%BiMaZ{MW=De`L&(FMxo3LUt$pYQElbYNjBXSvJjt&5(r++O9WZ{JYIXcUBoG%h`M= zqOFgEt2wckDNQ6Bc{nZ2T>rnLZ8(#owtCJh$!JtrvT`=D5lF1MN$R$JS%|nbe?Qf> zKZ%*o-uLj0s?p;1ngu~SlulVJGa+w1IW$E5eLpQb0L+AT&puZCxpTDd{7=SxGwKWn zhyDMVTGM{|f6K!DBxfHl$|h|`Zr49zbE%b~GFp#@Mg9k!c+U@W5|jVlS)Ueq9WTts zrvxi?HpKA<$DdF0SlRK65X2(fyCAqk*L(sE(KEfv*WU$&PL?zOzW6u$Pb2o22$;$j zE@YWxQ?b32*5zdsRr?NQwof7Kv@x!Nq?4CTiyQrNXmAU_-kEzs>Z;7);!8C^)M(aK zVyXvPyi|j~Uv5v5yvG&&*W;%@o?sGYO_gsY&eDFGX&@i?D>)7Gnacp`ATNmykCy}dKc%_s$Q8Ga$)boeb4J&Z?HOVt}PLsu>AqE9{dgbgTPo3 z6*dEL!aCj&Zhi3DZdc(T9>ktf*t#Hw8*PQ z`umwot$TZ742pg#GHluO$OIc_+>Y?8B`STo<>u_th@$U3e)mtV@en@1+_E+hSH7*B zSn2qYVeEeV+Xga6Q$6Fezj9b2KT^S5NvHejpLzkE4l68A!OT$?kv8>`#o;hfx;e^)yIF~Wf9V7v|}%R?#TMI?@^|E zg>^S~t0i52mfR_#CBtE)+D|f*p7n56LnL0GKDaX7T^K$re&hEMLq~VaENnKPDtN}) ze;O=OR@Rov&D`VKql7BpWL>SmFU6k;pp;S%1;{6&V+LDM#iP$SSYQ4w?&z}Mj{Z2^ z$;d8|zkkzFC9KE|Go$b*%yTT%jMp5$!gC5LVboqq8mKwerw=o7Af9`;ag9{l-&5`0 zUFtq{y#00Gf2pK_wfAUh2F{B{WR2ew>`kg#$9}r5%dON9aJpwjR1N@g;F(!2Wj8yZ z(Zzd1^8?b(-XwM9%PT!5`#t2xT|%2VXfWCC$D8vAer6&AM0Qo37V z$g7bQ$)&b~d9I#R^I)%RpxBKx?5iT_v{wGJz~nlYx88qutryijb6#z0Me5)yq~69Z zrwFF2iXV|$&Xt2~>@PB~r4PiWWyiMb!d>6?bmW^|WJd_gni|4B#q~t8%r-3}6lx*+ z`c4;Wu6!Bj*(qAiIvKYNMX`ozWJ4a| z`-DeHnwq0GW%M!%cQ+0M2uu0nx@nP*?lFUX2`X3IK(tiORF)c`W2?SfDBa+=$T}cu zJ^WoHc;nu^eR(nBd<~`3C}g2rzEMiN`cB2qTbCeKy^QcEc@q|7S0uMMl##IIVild* zPBit1U*QxYn?*pe@HP1@LU!ZV7Wc~ zbHC_{{-Wh4H_rZ7uSN6=zeoLsoo{)pEf^{&e)Rs6a5w#gGK)*8YMp}+zer?(jJYE5 z>7B60*J6&T@2B%IGp_cCr26fB?}@~%wAwn+PyYe50c+jZO&hG)%kI9aT{vJ97GFg- zKRavX^p~^`bXWkizUfDk^aezaQ$`r&&a-jf`C0ovG0_q%hG^2A?yEf|;CD%L&HX+mI|~D!gs6-8tv(c6V>pCG^6Qno28xrSm7w)hgHGI%lbncwxQJ)eZxD?d)@#Yj==lcv^fIWoX-+)>-DUo7-h=>lajO~jz{^Su1xqW4=|Y0u18X$% zIe`B5H+Ua4$)nB z`Do%4dnt>!-8s0uK3AIgvM2zwod8)rO*{lTim)(uHAQ*(^YQdezi8+d=raINoaKV@ z7%Z1XBo93`h1>VLBjoei2XK4U#@1F9(By(!UV>+>BS!52qaC2;Ewkf-b#$g_Y2Wv; z$vJ34xc{?U&H-5Gqe=rHXsloxvAFG>9q<#fbK<&yxV-6FjwN~z&YW|-c}h{$tnd-- z=J}a*Skc93g9R1%1;+G1r%Ck+R0*_2+Xz4Y=zSMLM~of-22oe^Q=vpX&jvAi`zR;{ zyFpokuuYOe=Y#mp+B*?yc;MxChj#zIpun)b&g(t)C_u75950&0t+acJs^3xxKgTPY}gV+TR5&7i?0gbH$yq*Chdc@%#J85iL#4 z5SlC`yxFOdX1wkA0FK-5Z8QtU^73+>A*ix*pm50iwhn%}J)!MN5FDP;Oxi=ZCIvL_ zKSA!=NZUjU#}i<)oCb9kJkrL2Ir@fQ7n{C$2%)^YG|<}{PJ@;|($t=?=+MT$h8|zL z`gyQf(|rF^4e`AXMBDyb12ql+8mv;C$+z=N#SHaugqh29bmzbv|Lg9d)&9|5@ucGK zR{@aH1uqGj=mJzFRi$Re{L`P~;mIM@rb*ZM^>U%Kh=0O(IdcuZR~GR0ZD_j4BHR7q z&k|bkE+0~MUe8VX*)HJ7ev~?_>w-4W5im_TK5d*Ae^jOhxRJ^N!LZNM>8b|wg#2H} z!=Bha3wn7gsHvBhz6Wwc)R;8cI1J{6s!6ktFRN#TGrBYOAH|!fI-0ig&Y`}1obYXD zuItm38CxVZS>DY2u=RS=UuX(kB<6C~8A#O3L9x6`O2926&Xi~HYgebIal(#pIs+C} z;J8gKVOO1n0J1P;$5{X%7yx?=%6?KQH^>hT9%brF20yR~Rx zFk2@lEIp#$k+#|ZK!eV!HdCgp7z<}X{OJa(!&ROJ!4<`|2+E8kMq0$ccxVH9^v|$} zWqqm_K)iu+_IhM`3g!Mi3DNwW)Z+_skMBBNxKj>)#Wo~1J9$VWNBg`t-p)vA4k@Sw z*Cs&@J-BP%ktmx&H0;w<<>L^|X3nSJ&af~+OLKT?on5`U5a4t4sq^RgU?Wo<6-*?= z{d*5>t!|KM{Sf@zHwC#xgRB;Undo}&5dggKWmn_cVYLJMViZnmkYm>l_Tm&pHC2UiA(7(lRl>Z%eRBzNUE2ir)B3eVht|F7 zvzuTEivhApi%bJHQO%FFVPyvf!CQ(qg0IZ4#b~c3g95Ge`P-+wr?YyyX-#cu(-<97 z>IUn*;vN4*BwxwXi=viWxnMD?GhJc8e@aVIo&}_bWtDaL`#;@kF};w4l0zcgUC7QK zar@#Clh<}J*V+Uch-gza)nam~AdA!Puoa%%r8rRJo@066ieTGyO}DC&!BnsLTKCHG zQOzBNJg2R`tHu=`%Ma^L)2Uzjfq6me#;UQ1jUlPgiMHmZD*JuR?kXq@ju~A;aBUU4tN0u6crL#@J~Smi}sv_`Zy}>3fu*L5jGEmth^p;AKXQ;M3I6x%KBR zi{H|O_r%rg7VKiz$y`iF1yGvWbQ>o@3I&n+GDu>p3%?woW#;;K;9Ub);sid9aQ0A{ z&pefdn-FRt)+sXft+FCx$F@{I?^+zf#Y8+tEoU&M_uHYzx+(IXLE{_qC=RWjBofdo zX2@J;mf3eeLG{U@P&)mX0UNzF#FhL*H>5IGVdL*9fcNKVu)}y8itKpzv*fwMbbFp$ zshvh4xpS;+EJf7ETi^oP%eacUAYi(aKf0if%W68w@R<6+j{{sr8k)1;@3kBKt9zn5 zcy2hxkwmCS^3vbKZ6DZPQ{P)44N(r%NF`UH2)K(%0E3fyCp`t@tM0K6NE1@q15jf* zd7lctxPs=l7v$s6E7eU)dwh>(k4454!L6CR9$5 zTQpGfZXm)=iQLUO&kx47h?-2?u;=X=}hVj|Nx z5OCE^7oS-)swe0M2lE1Ad5-Pi`EUPqlhVwF2_PY)NejVn`RmMOv~Goixo)=RrNx_lg-t0=+Ylz7bS__#RS2w8I zZfYCF9{T(kcd3Dl?YD&WN69jB=ght5VN<0V4>&N1c#aA??6@z-5n9h&u8J+LGjiK4{)zV5g*f^^!C>;oFL<#~Jr+M`oga{ajU0 zJaFwG9FxHc9es($-q}w!U!a!fPto)LdQ-W^{scfeb^Z%oroi)R3}P*|AHsW01n#pwqQ&)`g zGs$D!UlnibNLFXRG*RmBLoc~T3Z8W&itYHRYTtEoh*LQ$Y}t*z-bbeRG0RzJC`Slo zU0g)>y)}Eyv@y~N$8?`Rw<7^(=BP%sd+9H0&~iffz|Au~3mJZ-b%j=)q18iMbC}Nz zN{U6Rf# zB|patAj|C3xsdMTFWbT*ZxPA+(bSiIN`vBzORlcag+2`QW{?_pBYGxIK%`#|-*Q8& z96lB2J(hzJ72pF51FEqZUko8qCN69DTSeoW=!lG8o6}mm`!D9y+D(qLBoDhCy^J2+ zPfJfF*Un|XxF}x&X>3?}g14QX4Zd=Pu}uItOS-t;dx--@){6^r?_8DWu2mRGIHtRe zELWTFoBFRU1b~`W&WBXy;jA=&zcYQ%7~+05vTqmH8ZC6W^8GX$9J>_8RfKK9=G|M) zo_WvL7e`>hjgEL@&E_yzJ)tCQWNTKLNq*jlR4Fs1r|1eson^k}Q_aNqoMWd;peZWb zBC(F-bGNAt+)~oOKY|Z@8R3BNhrA!kf$lnr%WJOZ>CCMDyLTI|8C1-8tVk}um16F^kOIjUC_jj1)+cj z{)j z11LSgbA{>DxQG2HG-aV$P|e(v?X=eX)I~*sb0N9jci&^%S})lxr=S0(nSH-l={j39 zx;b7@?F&mm~MwqIMkP+ zv9gG&3fd|0)-WR!=%{M1RQGWli%T7Ld-3iaJzZFs_2J&>D?N)+yFE+0BV2In+56f< zk?2$JdF!O5!`u7;=w6kz=gN2hfP z9c-N!R^P&?7E#naVK#|!q=csjeragdpxmX-!F9T4Mq>-LTpd&CMaJF^cr_cCe_+00 z0&37Lrw3?A{buRom!OhJ> z$~ig8e8VC0NhfR8DMd!5Fi8qC)?b$-U9!kVcfr=#usVVCvMsFPsC6%+>0r_{AEjjH z**3wNf79|rd&2fx4^Ic>uu)fw^iU5f^E0`qTaJ8QSY?qsLtKnCo<|;S8o?( zci1MWQC8j4Esv-r?FH<(K|e8MbCF&32(h5KQT*oO@pEa{LYt2Ucy5mxR=WPEt~IVj z>~mYOYA-rS0^O5Bqo6LF$7~yq=pG-rqo6_fPzab9$55q?(9}eoE`tF&GC;zOhhR-S zb^7~s+Os#z?((er!!BoEZG&NK*`rd}$&zJrnJ+FL#5>jB29fJ6vY}H5TFAbp z|LP=b4@ceJTcb81gY~&bAOy7@e083*j?vgzYE3;&hH4!b%~GOA?I4RR;|e9w>(!$k z6=|cMRl%|e|ElKO<@?v2EHDp1&dhONe+&(Etd?OQ^EHL~ug6`!ZJg{_>=%DQt7yQM zr{qaHGYo?PvnK=aF4TGO;6Y_#Zv#P>rGGcsKXPuS-Is$j9}|w*+{~H~UpK?MR}~Bo z%aoAI7*Jigd1rD5aPhKUvA;GI2iI2l1C#LDZmUdvJ-((m#2hAer=FuFq0EQJL+qxZ z_)fO)6CvdBMWe0kjX4sL%;YKN_R!x^`2BU=IX$7CR|R@}a7+!Zv>G_Ok=R1{k5SE(+oYizHyU3N!TA8&IsxhLDXU(Y5nM6TQjHV<2LaX%F(`8orf$n$He*6 z5BTg;@a#iq?0TBIglD~0^q-a!jORg%gIibgs?Su*tnC<@dlyS1T#R)x=XKRdofUzm z#VHG~EGtvtZ20>*Sw5J-o4?+EonbAD0IRnW}g!*9!OVIIgQ9bHOq*hcPmutw#&8Z z=GiYU){RX#k$bg5W%Acw~dq3&t*!xISOcu4l_$)LQdwJYB-M;TXZxN zI~50NifE!=+ffhsyfMp2QuL!y@jZmlm=S8VAC%&%^eu^Mwxbw%8tR0{5xKa%>HDVg zkQ1cXZO(ihQH9l7{iUv2i;ZiZTTW_8vO9@nnC+DHY=-e#5A@5^Z~_P=7Pf2!65R!6Q<+TqzeJv?P{PfWmY zy{$$z`g+QZ1i$Y?Q|Jzs2*oV*_#!#s(_5YKVu#n+BDmz4a3Lh1rL#o_-~xt+MoCNF*Bq!-k_+_3(=>jm^dCB4>O zc^cLxdw0xEEnw%YPJdzJ)<~3YiMc+_w!zEAH3ODS^z`)F`9`PL*`r_*35i6I5(Yu; zP0!W6DDXZ5SeQ1d1dOP%41=~0-a%~dwTBn@bi_;UcJA)?o%fwRI^6-m+$d~IdDS-1(Wguh^-8YfY-W`*6afNou zuyz}gw(0<}o|*Zvvp3PP*EvQH$)!^Boj`~s=ZAKLZqH9GWLk}>YZrbOAvSGIPSZxpw;GEUF-txN?mwQI@2wxdAYx5Z;u;$={!ql)2_`(l6Z(t|EK|pEjZnw&xsBqog2y2f@{f_44<`77kjH@Mf&WTx$ zv~W%;BXY%5-sfAxbq+LE*q@qz&p7Ad2>+|G!|)o zdsYio`tR4**1)+i0-$bk(tDx%A+MWv9 zJc1PswD_Ir)yIDU64?!eX>O?Kle~BJVU#Qq@@1$L{L8fttSu-e&Aycso?VH=VYs;N zg)6QhBI&gh6sbfzj(sp-Ck;juxw)B-vfo@}j7x=rO%Hd7$w_p#XyxoNg8NI}?BQ@g znh()gwWt!=9t)kHl>OX&MBpD8Ga40Wt z&-^gMl-ykw89GDH5K2H1Orlog113HG+Yd7? z^%WVTMRWCTTSp7)R|(wfFs^kqit9Pyiq2$IHND@k=P!1jHSzK>O}1h0 zdnvwm@R$K+T4obzt)soj=-yJ_d*_bqxqK*@`J(NGJU0l^c{`rI4ijxDKRAN=s_oCr z4C3dYQvFc5w(vgnJ*)V`OrGEv?Decp{@lmUD9l|5IF} z;$mDGAE*FhOFVAbJ9)Ne+mmHRDt#7USP*TjY!eLff^L}}xM$wJy-<65t9~lH)lf-M zd<3ty6C1fWpb+M8_?OjZ-o*^0Nd3t6!2h>sgPcEp>cFAJc-g@Z)CErJLRKIbY<9h* z=cegM5xVj;!?Qzbtvg0>>22tdTT>3jqfS)MFYWGc+mdw#cq3XHP19aew5rO+!eK7% z#MPIPBdtTV9S2$KZJg*V*5*3nN?ax}V5=*r|3Sx_Je&qtZq`7l>)S{CNZH#5sK>jZ zV2XSybIVNe9dg`2n6!nTqZek%>;b+_a_t=_XejL99fL7wjD{j3o60j*5ZN3e#98)r zU9EKBHUVqGElQSJOO#C9{?&9=1ua0mDxH3W@dPGfS;S8|PkcWS$RVoa%+eRyC^Mz~ zSNcS13AARQnQ4NV%h()$o}e@wPFQR{B&rY^MG!FzU$;cF%c`a)}=yL^&;c z^-T7g|KS4AvC}kKd4`*5w~xM_@wTodpaq^-x3A!tsa7lTT?_1XTejnRlfoeyIB&yD zquL4)moKh6)eC1mEB;kyamgbDeFD={zWb}q^2WW9no1kHP^mTg6t&v$2 zv+PFu!>ZE7D6Wdi3VyroV|W{)k&2f=^22f2{-$6#B2eZKY^?*stj1==12oHZ+FUf+ z#5Lf4=R;QY8+uawVryqIXTRg~5Tu;6kr=dbtd=kaKoW*Ked62K=!izwp*iWAe%iS_ zuG*1(1V37*z5#*(^ZDiHv;^YBscCDtSGslDTC{@`IxSsUr>2I8Q7_E893>e#Mf0-; zu1}H3$f~AB@o2#RkPo>zYRsq)%??|Cta_|v~=-Q4(Dbb0RCN?Vk)(}Q2-!Gwaw z*VxU+#%_~M`51^c#M#w6pm*MSj(!#I83S(-%C$#2|pZ+N;pgD&X#4TDx1z z%(fBZS1qT?DYt91yUz=&RTewCu>3^Qn>XY zSiPzgW-)7Y<9eRST#EHF@^?m5=b?4*Ha%`Jd6`}yrkYMa?zkr{U#J-VC*Q-| z^IsbjcHC(}M)l6OnMA^7@+1dVkUw_!`m*Md)5AfK~--R8e%LOs(SJV?FSx2KG*T3owM%!2MOWWZZBs zQS2d%H-a7k@s1QDd%o0&;<+DQ3c3w?IJ7Be0lT1EgxpK|^=G1`R`9Lt9%>d{m|-z6Hrd&^E+f z#-j2vA~f>`H$2ZuF@;ZVD5B1t;`(CoS(;pgXN0ofZrm5&dM)33MeeZlMs8+XLd2u0(H+eF5ypVa1p_K|8ilr& zu5VnD22jqq&ad@QBPqK68U7rPr2=+Ew5j4agCB2?2k4j6xB5tHf@ysJkVJ}vn7&z>q6f%u4mcX>lt+s8&9g5zNtwDm>cIUYON$878HGCPD0jq;ww5a`2`f8pO}IK%HO?nTx`U+kB<% zlQ9Sh$H^{`TI7A6HB$^~&C+~6A-b{Sf%op**J=jC9b^B+gptItW&|R|ni%;$BK>1B z-N8c#X95=ra`Y@^^6qcPnb`Dy?CH23oF!f44(h>|^g?^plfM9v)GgB@J8d4DpX~@4 z0apz$h)2SnBG+ELv_STXfaU?7OkCNbR88KIWjWB zbk?G7z}vO*AN8@x4s$WO^g5~N0Xm;Pf9i%G&%4C*y}3TU*Np7l!mD1oS?u=1#XACQ z2B@BI_PRp3xypxrS(96wN&P9&nH^MxjMl$qCXsPDb9vb>W%RXWW2AM%#G>9#uA4cEkZ`7`(C)_{=dvcu4mEzEowKL`7zg8e32UM*MH z5J zmHl~579nx{Khj@wa^jwm5(TMYtQ!_)xLi>JC z5RYIIC33&Y8I@?s4K``4TitGykY@%jxEs^l*kPa|ZJeO7_^n;*ab^nX+kO0P&(rMR zQ?^cYW}9}FN`64U&S;NpJX? zYXvV+>4xd5$_fM9oAab8?7sR`kq7s6{K<$x9k=>?(T&{B#FhxYG#DH>*^yvKvo6ln zFFxkE!#W^y+C=tV05Tte`&t%lpkK7Y==Ru=7^dr;tya9H*Lc)*QdiSw&y(ux;-dGs zVeRNQ3t^+o5$f^ThC$ulyiP>|xL6-^bfyYBW|8V^?Xz9#^|qFJn^*Y1;}E>Tq}P0V zy?dg#^tWRL+E@2_qSkj_t97kjjI`ZqPUOs*oMx?Xu%7I3tHewKGkoF$Yp{46{_&@c zmKcFxf@AL>yN}uLD;UXeCDP8>ZmtUq&P%|%wSEh@a zsfO+lxN5b?nI?mjki$_I+Q!`#>KXWUl=gjT z6$R|$FDeG*Hh~^5t>+pKst9`W@^F*d(SDaL-YTzn<@?DpM@kbgOHIqnznX8f|C&Wk z3NZm$B~HQP3S)#;t)@w>(=E}CeB%4G~^T*AZu9*||a5&O=8z3{>M>6-pj zf2?d0GV*{URw>YeGN9-Qo$G@taeW=tm7)mLC}Zha|9jeVev`0S)GXB5+_TqRonH02 z^rWgGYQ(eAl@xJnl8te#lf=l@f%(ZV-FKLoDVfX)XOONwuTj2R*(yVl2E~t3bgK*x z+*U3VvJEeADHpvj`tC|T5WM4vm7*hP%f2|+XLL)y`tsOQu1|MZul7&y^+p#M5w_Zg z`YJq+HjIcEW@?*OVAQ;~H-kxf!697u1=b83{o7nhF%i?&-hWcPJ(= zombZMZA~d@dV3;Dg;P3{Fj1!ZF;X3P2OUWw}@N310EtS)3?8i&?G z3WB2_q4Txzr=!jtNIb;h*`>sJ$OYl6EK1q@;ugEz3xggc?!$W~SPNcq9v_?2BU2gU zk}CvUhK6IyDrAbUMO&D(?=2g~5eKVQITEBCd!}3shQ_=9Q6ExDx#X*JjS1nS%?ae* z?j66ml?;K|9*zgX)D0UVF7UI{gGF`9_V-h231`znP205!!=7Wx939(}D0#ZnuY#r; z&;DUv%W^`7y^5tr?mTW(X$1*{S>h@JS+gxGBDSyr(!98ySskPFTqy`3q^4;EGHi%PrV zHLkny{d=`Kc$U(V%A)0mg9n}(FS*0GW0g|15hI+i7nEGM!nNQMQ(vf z(YboXheVl9p1uHzB~MmDd!`ypZamFKP;--Z3 zll8cvzR>=4SCLn}R$W<+rj$9+CI|2C?S_@kv)N zbTgm6At_X6&yKZ~lvKIv2TNS{*3Rp~@XwvajwX!uP88H4YLtcb+0hOTJ(-j^UGbmt zk$kMFXPiyDf4o$5y1tOv68)wN${NB4=N3nuNxSsOSb zLP8$s4~|D6EX>Ts8e2PwjSlmhJ=~Fe2p*hE!w$XvNaxR~OmF)y8H+5Q^yVElMd!y{LFS7)oRPdqXn%Y< zUWTqQvj5ri@aECVr^`jLPBXUW(buSdS%V%|46*|@nP-Ej;j7Rg^61Bek6L(m=6+4R zVwe9;>gl=Gv+a6tY|l6=sOS| zM`l8QZ_BJ(4J{7`OLjF{ci!XUblU2&#de-izkPL%J8hEj0{o`4XT0h~CyJjBdOWD{ z3*5YCdW-`nzY23-t9FvuNJDu^lC?XP{foQCy4Lcj;q139@tf;??2sp`C5a%$vw{h! z_*L`T^95k8)zUzamlIMS%|q|D43K)Vdd;@~@@bqNjtL)JnU=)6rA5al*m{eKqPlMd zns|{@w(?D?3iX5yp&IKn3=%Y1s-a!0c1|j1*Nnz%>k(VY`S;1c5FC+6pK%Q>(5%k5?n@)`n!?^6=qzxQaEA2ACG4-Y4JTJWkz zbNJbZ{;YJmsAr+hQO0_d|E%4S_^R4dwx>*{S68$`Q+LJ6B07#Ti`e2!r`t8F2XalP z)OBC-6pWT7mXsd;{+|8GaA^8uac$9g^jEL#TD3g@_QdLX7uYSdw-OHD@P~eWOM&H} z*absl!Dbs%K zIjFIzjP)&P%$iw#OQF2p-%+>#rQ)+f#L)39os35-JTlygd{-7d+}7&1xd+{s-f(kW zo|?wNk!(9;?im!(L$Kalb!{0K&{qr^F%e%%VsjtxkxY1{hs~j7d&5o7qs$ftq;*Q} z`uIovt-0V~65Qu~&ll^^8m3!p@eVBd6m@b8cAOxixtifT+?*hgCG9xX6~*SBw)QFMk6e9;7Yi}ESSF>Ng)PkDQ-7_KQvTML?>3YN55Mz5B`{p=J(Ox|?2%JT`uOm8bm1_K_ zj1&fnD(04+bH-lt!fvk~G_X6RpwUyvSz6_mDD9=}>wmgLQm<0gN3H5}=jqbzdFU(v zg*kmM?yq^S;$|)-J*3u-`%Wz3f=*#X)p3q*W^UED!b|y%-~N;mDz+nTxq>xU?ie*T zbS8Z{p^Noe6wajGiDnCjt5z+;O8i<_*UB|CxF!lx&E+fCCbJ>8&=E|nR@k;L;{QvI zH`rz&su87gyqDK3tEtZN1e8S8YhN8Z&mbp16p>^s(`f=}$l{;|y5djc5lniGZ&e!eyHNt!J1RzR|~V8e|~!yXE<+ct)kqFJkxL6XK^TaPDN>j2l#c;%Yk={qBJHzcN=CKAq#qS z-X1xk*QrEq6MzUs@SJ7Zk#;WL2eL#kN_y*G%;rm$gfWNaibVo=hK++fj_qABk zB+Vh)Wl`>N1%@28`1P1>2P-?K)Zvs{bWV=Bwb)Fwm01LLK%Cl`X)x>MY1J5=E4j~6 zB6Hat}BuJXM%$ht&Zio?i-nxRQ^ z^1yQ07GqrYZavR^JjeFJ97?BsY^uM^u*LMRIdQu*foP%hE&j~kaj{2oVU65>tDU(| z=cU+4WZGTs8u7KGYFjqRFOrHAP7G%C;;@D?w+V4IDJOpjx0EH9wa6m*2`cvp?~a8B zBTXW6`ftk#8}BUHpKMAtvFy67$Vl$}T!9d5beN_TwN*QGu3JveQsSzy6zeF|r}<|7 zvwHFstE1jLNZL(I-|hM5p7QamMlUfFar-ZPTvyx@V!nSU8vNCm1pj#O3tie_I^Q~x z&<-fYEX*P-A_S1?73Bt|IiZ~bNu{;}51nnOa{U|6vy6j`DwykqGy<_fYA!QvuMzC<%`Yy$HRj0H>VsLFB69676)xv!WSXS2BRc9wqaRX@B< z5Y=_MH}?pq_B@!v>Ml3hFEag{8_!6}dRO!p5df2Jino1Jqf`=?)n@Rx?beMN+lt9) z?1_t##R+mZhc$9~%thB$b2eJoocH_u==$_ysr8LsFRMMJY zu5f5o2PouggnkAS%%aPyUag=ocKd+VVV`+9wubjx-XT4L?@@GKqG8&*AI1^P)Ye>V ze)v*rvVz!F3Vi4cK}=6UG76nuMCO#XCB-3*;MH#VTZGk zlW;{EN$_3VIfv(t%I>X9QO<0Si=WgnO9-`_ z#A+s)KEnzz+^mffxVbwY(o&ENXNd+4HaI~_NyDYHB1Q8pD0+rZ?vJxratTfBSdfE#*y&9)-6^VJXToW|sCN}ie$(vl_4G{I$4_(=6DskH z{EKmX5?jw=19c|p15Ifh!`P|5IE~o&isth!8ZTZ99!l3zr)oScPMy1JTuBZzoyriK z1c#|*sryTV&M5CW?<~vVZ1yJ|X=UGx0?p}6w^{1Ns^eb?Y;7fCa@&eO1QS}`i~@G@ zCXcJ0=e}J!u1r{n!8)^8c&EGShju&C?yQ&R7J0ofOR)d(VuVLy|07aXA}wKAQ7I|K z0&_EXdN9VU?8N>W3NyVwH{oO0ul8_u-sVyN{_))*6{oQ$yw3@9uE<97@K~}{%*aRP zs7v%-^RqCA9!K&l#`wc2w3RY#Q~mmS`t_NK)8O}O)6R*>Eg&rv_9a^c?glb211sg` z@ToLUTq6%0@a*352k0u<9T%Z;u+P`zsGI|5VI5ygbD~}{;7m^T$N@%$j!S$ioy8l% zF)5RR<3?!c%$Alr2hz;4OOmPjyh>4Q&-hGI((Wmj2OV)1$I>(lO1o2PbMRxfJ^S=# z$^|Sr58!ABUK){${aQm=w}Je)jdsq;xkK46qaeP*PDAokBx&jyd3*wKPFEIpSdQFi zaVseo=#7;P3J!6cipw=0wjw%av`R)U3H0{%rXP56-iUxAo4%YNWm!~4uo-T>ew_wM z(i2q`BT|&a#KcLuz)+u8YVbYZk6cLF3XQIw9GGp$W~RibQ0_L>yCgW(zUu2J^vMyJ z_Lk>%`2^kiioh7}#l5_KdtHL&JxO)fEr!V&oAy3a#|8;>vz|`3SDApb^8sT*0P}(1 z+0AG$Q+53*W1zFzK1##F&~J%x&^ML`<^DB|Mw1xFiV6AThf(!;0t5rykMOFIK4#*B z7;SHC1DA~Bdvn0K zlm#Umh%GO?)UX3QKCr5n2)|S28Q+oy%uAQt3x$#xprfMdbr#ZBPdK z;^`0LY{n-^Byv(091{(vq?^@s5VWYR8;x!(WPQJ}@)cm1wy@WnaW2;!@)z=KD_uuT$ZAq~ zQY2PcxWE>LOR;VFqT*Tmt(+a!>~JX*8;7*TrNC{kT-QBGRs_y9Gbn)s3dxkUzG4%% zY7=V=cIo$9IimF9Q*~Kq*;>@oh@B3?GyKzpx?4{8 z=E4ZsEay|D+|MU7-Rm=Za63dziGi8CT9!{fdKZ6Dn<^$SESR4sCZe%P>ivUPgU zzAvW6AQEE#Ti|eGTOVB0y)b##KpK)e3g*$~@}}FomdIRHUev)ZNIx95S%W%yl}{Lr z8x?ZR{X!t}~Tay=y(vWym)w+EM=5_c@YqGjBd&p?m20>Ir z9W5&H_zZlFQv;@*l@{?nFqjU^;!O}&L<(SR0Fwd`vjB%+wO%aPP<(>&`7 zaz>(OFhbWi%s?uE2ckFVFmPFrtmahh4Zg9i25+c+cs>!ZQnSk%gd>{!>9~sHNt+1ehgI^f) z5F8+lz!F52@U945wH?4LenC{~&i|JyLhUdBk&cSOeWzCS7lI&`{o57j^zWsZcJ$1K?#51xl_tt;5drEs+ zjaQxk$Rd9*d}`mtBG|g(_u=b6)(M!O5whQ;oD(_^z*i2smtij3lChSToVeFOvkK@n zUU_`t27tx~3nctUWPSAX08|73!{kj2GX`GZ0vxQ4>^8#70Am6cErd z0HNkL8$PAXfB+Pr$hO0-Dh|wPKpQteOG5Syl(!ajadrqXRAaa78vbOzg&RLxC*gY1 zvNY(MfI|D&DA^SezUsz*Wd3PDK`md)(*8Uw=nWQsSIY*k_@%`Tm;W2fv1^I<$H!tN TP4<*LWGwWw4NvB2;cosDVUXfi literal 0 HcmV?d00001 diff --git a/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb index 542c00e3e0..b023d6ed04 100644 --- a/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb +++ b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb @@ -28,13 +28,7 @@ "Fig. 3. An attention with thd layer.

\n", "\n", "\n", - "##### 2. FP8 Weight Calibration.\n", - "\n", - "Assuming that we have a model trained in FP32/BF16 precision and we wish to execute it in FP8 precision, the process isn't straightforward due to the absence of appropriate FP8 scaling factors. In this scenario, FP8 calibration becomes essential. By conducting several forward passes on sample data, we can compute the FP8 saling parameters. This calibration allows the model to operate correctly in FP8 precision.\n", - "\n", - "We highly recommend familiarizing yourself with the [tutorial](../../examples/fp8_primer.ipynb) on FP8 precision to understand the importance of proper scaling factors.\n", - "\n", - "##### 3. CUDA Graphs API.\n", + "##### 2. CUDA Graphs API.\n", "\n", "The speed of GPUs is increasing at a rapid pace. It turns out that sometimes the runtime of kernels is shorter than the time it takes for the CPU to submit them, which can lead to significant overhead. CUDA Graphs were developed to address this issue. When certain kernels are executed repeatedly, this tool allows us to record and replay them without CPU involvement. This becomes particularly useful in applications like text generation, where a `TransformerLayer` is run for every token that needs to be generated.\n", "\n", @@ -44,6 +38,18 @@ "\n", "Transformer Engine supports cuda graphs from version 1.5.\n", "\n", + "##### 3. FP8 Weight Calibration.\n", + "\n", + "Assuming that we have a model trained in FP32/BF16 precision and we wish to execute it in FP8 precision, the process isn't straightforward due to the absence of appropriate FP8 scaling factors. In this scenario, FP8 calibration becomes essential. By conducting several forward passes on sample data, we can compute the FP8 saling parameters. This calibration allows the model to operate correctly in FP8 precision.\n", + "\n", + "We highly recommend familiarizing yourself with the [tutorial](../../examples/fp8_primer.ipynb) on FP8 precision to understand the importance of proper scaling factors.\n", + "\n", + "### 4. FP8 Model Weights.\n", + "\n", + "The typical approach is to store weights in higher precision and then cast them to fp8 before operations. This is especially useful during training, as it allows us to store some values in high precision to avoid performance drops. However, for inference, this level of precision is not necessary.\n", + "\n", + "The TransformerEngine offers a feature called `fp8_model_init`, which enables the creation of models that store only the fp8 copy of the weights. This helps reduce memory consumption, which can then be utilized to increase the batch size, leading to a speedup in generation.\n", + "\n", "#### Benchmarking\n", "\n", "We'll evaluate the generation time across three benchmarks:\n", @@ -76,9 +82,11 @@ "\n", "1. `te_gemma.py`\n", " - This file contains the code to load a Hugging Face Gemma checkpoint in Transformer Engine's `TransformerLayer` instead of Hugging Face's `GemmaDecoderLayer`. It does also contain code for generation with THD attention and weight calibration.\n", - "2. `utils.py`\n", + "2. `te_gemma_loading_weights.py`\n", + " - This file contains logic of mapping the parameters from `GemmaDecoderLayer` into the `TransformerLayer`.\n", + "3. `utils.py`\n", " - This file contains the code related to dataloading, hyperparameters, setting up model/optimizers/accelerator, model training and other miscellaneous tasks like restarting the jupyter notebook from within the cell. \n", - "3. `media/`\n", + "4. `media/`\n", " - This directory contains the images used in the following tutorial." ] }, @@ -120,7 +128,7 @@ "\n", "model = init_te_gemma_model(hyperparams).cuda()\n", "\n", - "generate_sample_text(model)\n", + "print_sample_of_generated_texts(model)\n", "benchmark_generation(model)" ] }, @@ -200,7 +208,7 @@ "# Init the model and accelerator wrapper\n", "model = init_te_gemma_model(hyperparams).to(torch.bfloat16).cuda()\n", "\n", - "generate_sample_text(model)\n", + "print_sample_of_generated_texts(model)\n", "benchmark_generation(model)" ] }, @@ -217,12 +225,122 @@ "| THD attention with TE | - | - | " ] }, + { + "cell_type": "markdown", + "id": "21a89d9c", + "metadata": {}, + "source": [ + "## [Improvement 2] Speeding up generation with CUDA Graphs" + ] + }, + { + "cell_type": "markdown", + "id": "e2d53e7b", + "metadata": {}, + "source": [ + "TransformerEngine includes a function `transformer_engine.pytorch.make_graphed_callables`, which functions similarly to the corresponding feature in PyTorch. It is capable of recording any modules from the Transformer Engine. Below is a code excerpt from `te_gemma.py`:\n", + "```\n", + " generator = GemmaGenerator(\n", + " lm_head=self.lm_head,\n", + " model=self.model, \n", + " inference_params=inference_params, \n", + " generation_config=generation_config, \n", + " dtype=hidden_states.dtype,\n", + " )\n", + "\n", + " (...)\n", + " if use_cuda_graphs:\n", + " fp8_format = Format.HYBRID\n", + " fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo=\"max\")\n", + " graphed_generator = te.pytorch.make_graphed_callables(\n", + " generator, \n", + " args, \n", + " fp8_enabled=True, \n", + " fp8_recipe=fp8_recipe, \n", + " allow_unused_input=True,\n", + " num_warmup_iters=10\n", + " )\n", + " \n", + " (...)\n", + "\n", + " for i in range(max_new_tokens):\n", + " next_tokens = graphed_generator(*args) if use_cuda_graphs else generator(*args)\n", + " output_tokens.append(next_tokens.clone())\n", + "```\n", + "\n", + "Let us now proceed to evaluate the performance improvement offered by CUDA Graphs." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "31a3a8a3", + "metadata": {}, + "outputs": [], + "source": [ + "#Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "#restart_jupyter_notebook()\n", + "\n", + "from utils import *\n", + "\n", + "hyperparams.model_name = \"../../../../gemma-weights\"\n", + "hyperparams.fuse_qkv_params = True\n", + "hyperparams.qkv_format = \"thd\"\n", + "\n", + "hyperparams.generation_cuda_graphs = True\n", + "hyperparams.cuda_graphs_static_batch_size = 64\n", + "hyperparams.cuda_graphs_static_max_context_len=6\n", + "hyperparams.cuda_graphs_static_max_context_len=100\n", + "model = init_te_gemma_model(hyperparams).cuda()\n", + "\n", + "# Load weights of the model with the proper scaling factors.\n", + "model.load_state_dict(torch.load('model_fp8_state_dict.pth'))\n", + "\n", + "print_sample_of_generated_texts(model)\n", + "benchmark_generation(model)" + ] + }, + { + "cell_type": "markdown", + "id": "53bb430f", + "metadata": {}, + "source": [ + "We finally obtained the **??%** speedup.\n", + "\n", + "| Models | max_input_len=64, max_new_tokens=1000 | max_input_len=128, max_new_tokens=128 | \n", + "|-------------------------------------------------------------|---------------------------------------|--------------------------------------|\n", + "| HF (baseline) | - | - |\n", + "| THD attention with TE | - | - | \n", + "| THD attention + FP8 with TE | - | - | \n", + "| THD attention + FP8 + Cuda Graphs with TE | - | - | " + ] + }, + { + "cell_type": "markdown", + "id": "a2bd87e6", + "metadata": {}, + "source": [ + "We can also see how use of graphs reduced CPU overhead. Here are two screenshots from the profiler:\n", + "\n", + "
\n", + "\"Logo\n", + "
\n", + "Generation without CUDA Graphs\n", + "
\n", + "\n", + "\"Logo\n", + "
\n", + "Generation with CUDA Graphs\n", + "
" + ] + }, { "cell_type": "markdown", "id": "e6b171a0", "metadata": {}, "source": [ - "## [Improvement 2] Running generation in FP8 of the model trained in higher precision " + "## [Improvement 3] Running generation in FP8 of the model trained in higher precision " ] }, { @@ -326,15 +444,24 @@ "\n", "from utils import *\n", "\n", + "from utils import *\n", + "\n", "hyperparams.model_name = \"../../../../gemma-weights\"\n", "hyperparams.fuse_qkv_params = True\n", - "model = init_te_gemma_model(hyperparams, fp8_model_init=True, qkv_format=\"thd\").cuda()\n", + "hyperparams.qkv_format = \"thd\"\n", "\n", - "# Load weights of the model with the proper scaling factors.\n", - "model.load_state_dict(torch.load('model_fp8_state_dict.pth'))\n", + "hyperparams.generation_cuda_graphs = True\n", + "hyperparams.cuda_graphs_static_batch_size = 64\n", + "hyperparams.cuda_graphs_static_max_context_len=6\n", + "hyperparams.cuda_graphs_static_max_context_len=100\n", "\n", - "generate_sample_text(model, fp8=True)\n", - "benchmark_generation(model, fp8=True)" + "hyperparams.fp = True\n", + "hyperparams.fp8_model_weights_filename = \"model_fp8_state_dict.pth\"\n", + "hyperparams.fp8_model_init = False\n", + "model = init_te_gemma_model(hyperparams).cuda()\n", + "\n", + "print_sample_of_generated_texts(model)\n", + "benchmark_generation(model)" ] }, { @@ -353,106 +480,56 @@ }, { "cell_type": "markdown", - "id": "21a89d9c", + "id": "8d3945e3", "metadata": {}, "source": [ - "## [Improvement 3] Speeding up generation with CUDA Graphs" + "## [Improvement 4] Reducing memory usage with the fp_model_init()" ] }, { "cell_type": "markdown", - "id": "e2d53e7b", + "id": "2dd0cba9", "metadata": {}, - "source": [ - "TransformerEngine includes a function `transformer_engine.pytorch.make_graphed_callables`, which functions similarly to the corresponding feature in PyTorch. It is capable of recording any modules from the Transformer Engine. Below is a code excerpt from `te_gemma.py`:\n", - "```\n", - " generator = GemmaGenerator(\n", - " lm_head=self.lm_head,\n", - " model=self.model, \n", - " inference_params=inference_params, \n", - " generation_config=generation_config, \n", - " dtype=hidden_states.dtype,\n", - " )\n", - "\n", - " (...)\n", - " if use_cuda_graphs:\n", - " fp8_format = Format.HYBRID\n", - " fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo=\"max\")\n", - " graphed_generator = te.pytorch.make_graphed_callables(\n", - " generator, \n", - " args, \n", - " fp8_enabled=True, \n", - " fp8_recipe=fp8_recipe, \n", - " allow_unused_input=True,\n", - " num_warmup_iters=10\n", - " )\n", - " \n", - " (...)\n", - "\n", - " for i in range(max_new_tokens):\n", - " next_tokens = graphed_generator(*args) if use_cuda_graphs else generator(*args)\n", - " output_tokens.append(next_tokens.clone())\n", - "```\n", - "\n", - "Let us now proceed to evaluate the performance improvement offered by CUDA Graphs." - ] + "source": [] }, { "cell_type": "code", "execution_count": null, - "id": "31a3a8a3", + "id": "96264b9c", "metadata": {}, "outputs": [], "source": [ "#Restart the notebook (to flush the GPU memory)\n", "from utils import restart_jupyter_notebook\n", - "#restart_jupyter_notebook()\n", + "restart_jupyter_notebook()\n", + "\n", + "from utils import *\n", "\n", "from utils import *\n", "\n", "hyperparams.model_name = \"../../../../gemma-weights\"\n", "hyperparams.fuse_qkv_params = True\n", - "model = init_te_gemma_model(hyperparams, fp8_model_init=True, qkv_format=\"thd\").cuda()\n", + "hyperparams.qkv_format = \"thd\"\n", "\n", - "# Load weights of the model with the proper scaling factors.\n", - "model.load_state_dict(torch.load('model_fp8_state_dict.pth'))\n", + "hyperparams.generation_cuda_graphs = True\n", + "hyperparams.cuda_graphs_static_batch_size = 64\n", + "hyperparams.cuda_graphs_static_max_context_len=128\n", + "hyperparams.cuda_graphs_static_max_context_len=1024\n", "\n", - "generate_sample_text(model, fp8=True, use_cuda_graphs=True)\n", - "benchmark_generation(model, fp8=True, use_cuda_graphs=True)" - ] - }, - { - "cell_type": "markdown", - "id": "53bb430f", - "metadata": {}, - "source": [ - "We finally obtained the **??%** speedup.\n", + "hyperparams.fp = True\n", + "hyperparams.fp8_model_weights_filename = \"model_fp8_state_dict.pth\"\n", + "hyperparams.fp8_model_init = True\n", + "model = init_te_gemma_model(hyperparams).cuda()\n", "\n", - "| Models | max_input_len=64, max_new_tokens=1000 | max_input_len=128, max_new_tokens=128 | \n", - "|-------------------------------------------------------------|---------------------------------------|--------------------------------------|\n", - "| HF (baseline) | - | - |\n", - "| THD attention with TE | - | - | \n", - "| THD attention + FP8 with TE | - | - | \n", - "| THD attention + FP8 + Cuda Graphs with TE | - | - | " - ] - }, - { - "cell_type": "markdown", - "id": "a2bd87e6", - "metadata": {}, - "source": [ - "We can also see how use of graphs reduced CPU overhead. Here are two screenshots from the profiler:\n", + "print_sample_of_generated_texts(model)\n", + "benchmark_generation(model, 64, 128, 1024)\n", "\n", - "
\n", - "\"Logo\n", - "
\n", - "Generation without CUDA Graphs\n", - "
\n", + "hyperparams.generation_cuda_graphs = True\n", + "hyperparams.cuda_graphs_static_batch_size = 64\n", + "hyperparams.cuda_graphs_static_max_context_len=256\n", + "hyperparams.cuda_graphs_static_max_context_len=128\n", "\n", - "\"Logo\n", - "
\n", - "Generation with CUDA Graphs\n", - "
" + "benchmark_generation(model, 64, 256, 128)" ] }, { @@ -470,8 +547,9 @@ "source": [ "In this tutorial, we've explored three features of the Transformer Engine:\n", "1. Support for the THD attention layout,\n", - "2. FP8 weights calibration,\n", - "3. Integration with CUDA Graphs.\n", + "2. Integration with CUDA Graphs,\n", + "3. FP8 weights calibration,\n", + "4. Models containing only FP8 version of their parameters.\n", "\n", "Each of these features can be applied in various contexts, and here we demonstrated their use for achieving fast inference. It's important to note that the fastest possible inference speeds can be achieved using NVIDIA's inference-optimized [TensorRT](https://developer.nvidia.com/tensorrt) library." ] From ae64bdfd1b818c5ae58752e534ff69b6795a9bdd Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 15 May 2024 14:20:45 -0700 Subject: [PATCH 119/244] images Signed-off-by: Pawel Gadzinski --- .../te_gemma/media/bshd_attention_1.png | Bin 0 -> 4602 bytes .../te_gemma/media/bshd_attention_2.png | Bin 0 -> 4561 bytes docs/examples/te_gemma/media/thd_attention.png | Bin 0 -> 2487 bytes .../te_gemma/media/thd_dimensions_1.png | Bin 0 -> 19382 bytes .../te_gemma/media/thd_dimensions_2.png | Bin 0 -> 25116 bytes 5 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 docs/examples/te_gemma/media/bshd_attention_1.png create mode 100644 docs/examples/te_gemma/media/bshd_attention_2.png create mode 100644 docs/examples/te_gemma/media/thd_attention.png create mode 100644 docs/examples/te_gemma/media/thd_dimensions_1.png create mode 100644 docs/examples/te_gemma/media/thd_dimensions_2.png diff --git a/docs/examples/te_gemma/media/bshd_attention_1.png b/docs/examples/te_gemma/media/bshd_attention_1.png new file mode 100644 index 0000000000000000000000000000000000000000..4c3f5e2fa5a2d56dfed8137e407cfd671405f9e1 GIT binary patch literal 4602 zcmeHL`CC(08m6TktbmL<9g3jTg3!v=2@*=!T0s$k#3>50t85}mG-1mE(Q%McL?9F} zgoII=WE_MD0b*F|0vNVbq9P^)*@*!HNf1JS%(*(nA>FB&;zwLM6AKvU`F1WmhJ?wjO8~72o#e4&v zO;V4XP7OrIrDmMP#OlPSpp#?mu+f;<*c9w}bgEpp(?dsR=XGE2uTExGE{FoIN9bjB zkFVn6Hk`ITvYxzj?64X6%ZuK&?|^m33%#aFa%03S}$!cfeA!(OyDbaD9y`I1N0yeApH|eYw93^!@Gkz)_tO-^7CU z)5kE-{^6VnXt&oJfHvo6Gyl`eXZoFM6ZtNeCcK$URz-!*(m!P1mmpa})%Bn~`}PgC zj%_X)7#JLK&k5#_Kgz{!ZKzfFb|$M4$xJk0cQ`|$fkEIR&R7&j#uzk^LC zEb^>7w^rDW2t>t$2i=Qu?hp?<3v+fZr>3-g7S^1UfJfHY+SyI7El={M%fetV*pRC7 zg^9Vj;N{EHb*+mCXJ?ISLn(;IU|?X-mDW0xjg5_oscB!lhmt~6Rhvg}C*kW<7N@#o zP!Ay;mEE`$HE69qN_OFMzN{CrvV0Q`@Aj*gC$ zPPZyaaU^Wy*|%k($1zs?3*=mA-5lSb5-w?SL7IvPU?-GPq!c&OxLUk6S774e21p<% zBqCyfmxslrEb-h_Mf%(*eSLlD3O*zxBmuvjeLIp}USf@k!{ZUA@XU|u%pKa64l`LS z`P%hMlO3*HyDh(H`;YbtwiF@=bW6XX@dvU+_nwFu`#DMMgctlK@#%I zJC+y*JK*UbG!8jrFXt0CWl@42?fN86B$B)UK2RzudLJ2iR^GaGi@1obN(5{N=-F*m zjK#gaq8H9#e6!>nq&lnz0$?*RbE^;_J)K%N~GYrogHq7h7a? zZL{ujh@&e#Ngg0>$GM_`u(3A}7kt1CI9w1kBs@ix;mSThLt49LEcSO~XE9-;{M>5C0KW2&!{Jb<6@Kzg z0v{7-6MEqwzur>cx8uc98NA}ykY9%a6~+JI+7;$-WA9_-ZVK@dOuJmAEe>QMp}d%x znZc;$o~E{r5*HHf#o2djYh%xy`=mY6p{AxLKrntPeeTCw0A7IBoZT6x=$R9Fl^$6L zBX`w_?6;gMXb2{g33GL&s?};U58{0&{~8;=%gf74dUCF6r~9 z2kGu22t!nApF)ALk7#LxFl3z>7iXn(RPysIu>rN`?=afTh&8_Z5;Ofw2M)BW6gkyV zOG#Ze^<~#R9qD{yD{AD>L1;k2+a!E@P(y-{*)I^VWtU9zrz2X^V|$zcjd758(n#WkqpmcRLhfsSfY%ABqf9cZ zM55dNSaX6jL#{+Fs%r=eASy>x6+dwow{vp=#2XSm0fK~5z`0SIZ-Q}45A?D!GYbH9 z663oV~d;>|lCC1Itpt%#-?QkzIFGy(u9s-IFp{uKF zVrJIAG8`qh3a&!Z-6;PeDObYVd1&*143CM8jUCm-^mzRpNUdMKdiXyv{kK)y)q;-o zP_p(oB9L}8zH!ihNOQy?#2=x3EnZomd_OE>C=2&;^VglqRdo4uuIJiWB>f8rDt*NDpe{ydtI6&zF!X+jX%t}d-v|t%lc>+?2JEz ztFveKgA|ozi@VR^K#7O$kPeJoO7Kyf477PAeKYaAlImU_5C)%KXVsyBeDg>8$4R>x zVui|1nkr73>v_(o2{8-v2S+(5${@A@Jj&ZzmkFh*m94>ZUgqm3gu63K7sfh0J2$2s zl=B<`+?|mIR2HYDr2@#r>o;ydzLu4jbFj06X0GYBKvY5S9v>fX17)8enLBD#ra1(p zludA5Ols;ONacY3;^JcHc#kL?0gA3Y&3(druAkyasjN(zV1%J@I2U7w$o}pE(^0$Z z_alkRUE%F`C+E@BKoGxp&>PH-958ei7lASsV)aP8lwoUY8v|k>;-Z18HikQP%mDSG zsTRA;poGhlJ%A;llDPutXb{$B^cWowBlK=ia4>dJ%meH>oyI0afUw2`0h^eZ$XjR@ zwMDHaq@>tGo5#blSP7B0fMaP3)!jdR4A?rU&JO@U|0O>V{Lc{IG8W`qLkfIu83Ipq Oe2*Z#>kgm!%Rd3Ym~jyR literal 0 HcmV?d00001 diff --git a/docs/examples/te_gemma/media/bshd_attention_2.png b/docs/examples/te_gemma/media/bshd_attention_2.png new file mode 100644 index 0000000000000000000000000000000000000000..7e9c471511285a917b8d0a8014ac84e7f7c5f7be GIT binary patch literal 4561 zcmdUzdsvcp8^>+!tu|M#9de~fE7H2GY)YQdbuhyeYcwLkX9dtL9JujAr+o*(?4=XbcD`~KYD z`?-Ds>A(KnE$>=bSgc1JJMx8v#cHmF#VV(DYrsE8JMuSx)7ykYh|}x9Cu3doHSl|L zvfr8HQ;F#0v~&0v3rs>HE@m&`0zM`tfe@RRtX$3Xv9Q>90deHu>C0sbLGU$H8l!iP zJwFs^RNF7tG^}4&^}=xt^u%`^FMbkydvL#H(*1+)KWMoY`SJPP$9CszTcv#6a`H{; zmydASm0_4i_h;Cfz5us z{j>F;!|(hSaJ=#VyCJqc*RQl1!G+=Wj1z;n(XORC5$O!nAUB!aBU0%d~opa&;G@V+tp9bdPfF1ub4HOYrNIIvp;ZwO8wx_xftiUY8N1!!b!o zpK)akxe_mhf-W9bMlEKh&I*gY5}1{WBu{B~A!_Kd_i&9#!QFS-ua0ROi`mLjt8PlgK1~dwR zz^>+$QB}yC?8vcf*3Y&Q^a?s1j+8E>GlZGi@n(}zPfs1M&gpH>HcUVWGcTWdYlVd3 zio`0Z7uMTO${DD7JX8l2@XM>OP#suj$)emtE%9YA$Dni*=eWUi#^rc@K8$A12)8B0 zOiWTKX_6dk)cah>wkSP+Of?u~4(q2N?CXXzY$E-37@uqOr@d) zO}MZ#XWnprigmZOM>HjhXWRCcu+Y4x7yR27s+=Sv{*ZA_V|Z)8jIi0o%GYpd!&cAV zM%Pi2mlkB+3w?)mVJ1vmdGeK$nj`sZ$y1H?VX>s5+a7l(1&H}erR6gMx)U(oMbbSaQ70*+;RH?Is=gh%Ne5pZFs6!+t&_ZZ zc7-_ev={&C-`_O#%F0MehTB>zyP?%q-gm}Z(||E+Y(4PME6=uFwSo*Am-S|oBg@?H zR{G~P)i?Fpt>^$*fib&cmo$=Lb8{F?8$#0pYq+<1CtQ#I@wh)aA3gfm)IEwTXVjK5 zw1h4C&;vo7E)lM0k6d5yAg(0D{~Q*VA+ph08kUWzN+IS&pgkEXC&enKU2ik`qtBae zplNEGy_u)oC`iCsN9{XLx)xL_d_h6MpZ>?gRNyXrH+(iCVyNnP;izvjMgW4bfM$2V zM2V(K!dj*jP%Hw4sAqsSx+r>ge?S?%^l4Ebh8xk0-w83kxhg3Cey?+JptJ3U`O)yi zn1tbfwZ2wK2Ad@X;VP-G$-~+S6;CQd5&>15DM*)3-V;_PbX;;D8W;#D&o=7f1x-%? z(IhN4v6*PKb?qgCg-&MAFN2d@L+em1FJ&7|g4LF6`oOq2H$gPa_0HG$F7F1R#h;dV z;Q&Ym27b8ZGW>)tEV;xc;nzSg+o$*T*1@m#_mP8r%v_l7`Y?F%*;>dD zkBf2NY z@$z>o{iJiu{TDDQ2DU{4lnfE% z?)pr<2IS*=-QC@Psh}&*el$-!RO}LlLJhyD4W(x&Le4KOspZX==gtF=5f7e$f#EcQ zM2sh8%1f+6nnNqc(z-r9^MH6UFA>UUs*x385h8@D0%5E*^E7`_0S(HZ-^=Fsbv_tV zzVpbc==V!Sr)rRr1sPYST@E?h4VK(9kyKDvm}yxX>sneBSJF1uV`FzDz9ZgsFgSaO zYgig%riwO2wj`UGQxjYqp&H$^FXQ=wk5=5D$Z6z*;~7 z8D?rxX~WB_u+<0@aw%#S5C)D5?%c$uVS;_;i`mftKD5#(t&qCp zpR}YYK}3cp1r6KM=y6C2WDm4I-j$YSx|Wn*t@zXvtr4*iA6!MT780}gttDYb<;#iH zajvB?o)n|WbVtiV%N>HafI+brm?hOo6^-u%(Xt2!?D7;*BRx-R@}}*{2rQPqi3=#^ zkurU{JZnkTbo7r0RaI4Ox~Wo--vC~QDPDzWA(Ra*#*#{I6L<7`A&}AVhdO=a;F2-& z`0IuzB${GX>>yB0!(|@ioSd8{f7!j8)=pFWMr;@mPN2i=>(kR0k5-;UQamLGyUeu* zsHj|Y-A-9|w|b=lPybWS)YjJO#+6xxV^P_Rzi`rwmjpNhb(8CVEc_2)GeHNcup+Th zLl0*x_~zv1Jsy)cdLy*)hTfM-{ranqc<}|*aU7*~E6LA>G$vQul1dxWlp@2cwZ45o zLkTsz;_voCvb{knPpofh8rAHgaFtqsVKl8+o^rBoY_Ox_(D$A9~SiE+zWpWE+YSS8eeT?|r*s5?K;W6jr1Ap(Ec**}M)PZrf5vNp{}O&C a%agg84kr%7KLWe3K=>h#P!E0e?SBAShj~N* literal 0 HcmV?d00001 diff --git a/docs/examples/te_gemma/media/thd_attention.png b/docs/examples/te_gemma/media/thd_attention.png new file mode 100644 index 0000000000000000000000000000000000000000..16d707719e9664245fb8772d419bbdb8cbba77be GIT binary patch literal 2487 zcmd^>e^k<07QpGaN89NKJ6oG$l+HG&ZJM+iDv45!rdZM>PNt|?`J=R86af*Fnd$Uv z8ZsqxQNbF+hEf@&R3v2DQe)J#Fwr4oRfG)DK~xkF_RH-4xo6ML-}^n^^X~h8-+lMJ zd(XX}_pX12iQe$R)(=1+&<1o&WIPDu(hKx)Yh8ghlJ1@dh~E(pq7&8vPx0D}YrwZh zLDY!?EFr6aaw8t+t;sWp>QKARevG!WnbSBA75WR z7b}uTB;OvYbDJo<`fMDgR+6pzrAulL8ja?$@#?uYCBwnl&Nwt}wgP>2Rn-F;W@?tx zBbUqBTy7H$!#h-W40r)s5cvg!uzr2*5pRFiOwA|0dm+S zl}ERfqWXVUbEuvvWeM zUk~cd@#}AFOAY<`Uy3seY&<#OD%i{E zRgomk_0m%pWBNbRkeIkQxZQFasuU}e)G0^Bs+S81LbWnYTG?o5ET|7-hX>-!HL;7< z)$@3KR!0y=G%5l-$l2?eNua%%;Il!Eb;_SHe!dy>5N2_e2AngN zk81ESx$8_5SeUkt9YhlFhqT(P2>bFR((=VJNB!kRG3>^+5kN@5THmo?Ll!hM8J zV^&(>_~i~o!u!8pUYBG~N8sopmaRmUyQ`I{ftB0vOGL8>mSsVH)U>1|=oY5s2)g}O=b6$~bVSGUv3Ui0%lkBJ zGd(uhF&JyqMa|PlrTmd8D4qET(;9DU9|z~Sq3sBn+IK}<)%F(b&OAdW`3oFIJ$3O- zqCO4F*DC}%+lq-g&8wX6h7g>EAxVRIl`NkcE7CVQOr(B0_##m0Ux5CXg8$8g_&=Ud zzI1^81F4Rc5*xS(NipqCXxM`GJ-3?em&B*Vbv5}$Nn`XmejbYO`l7C^vf+5Ym87k7 z&$g0RKP`2Woi!qp$((rVg-QF;DYu5{ZRIOK$~Cmn89A!+-tKx&tYjFLJoCAZ^kJlg&=#~shfPK)F(^EGx;OA2D-a?>8M&ir+ z?~x>|0>D?5Ol12pOF?PW+0?S{d!X*BCs?Ve!Q zOK;orCKEfGKHnT=)Z3>}%XMCwZY^VT<4J3@lo(97kTvAN6A$)3j}%I;sYgf9^tM`2 zm*@xv9`=2ChF*KzN9z+P=bW?Alyl}FaZq(nDT@0^^i5}m`uX`KO$+XYrtGyp-T&8s z7GVpM{|w7cs{w#H<$CErmYWbE`C=cDNOZE@Vkuf$S_(v=YT0Zy5U~deedd`npKvg~ zD!ly0E)mIkJi=-b9)V%9Pq&QMhKH@roS(fXH)sa$e(Ic`PPk=Y7>@K~re2I^(O+(f zDzhI#U3z=~K&fR@Z3fcQmnP#+W2u_5XlFnHnhQEDyHxl#W35nT7d>fj{~KkDs)L*) z-Z@d?STzg9Vyk1TIn^lht#z`7oy|RG+tB;^b$oDe@ci>lM>W5UKLZwiyQo(L!QrgQ zp3>~1j?uf#9UWD`5s9?U2$ND$8l0p|lT<0&2)cK#>KIZ4WaHJ9i5sxMkPyYP89oYr zVK5kK5JVqeOwIRz^6>C*OKWRIS67$6zrQnzI2Oik_JSueH;w8=sllcGr4p*?;3lPgpO4((AoCoZ@pan c2XxroMyRD1rR~U7z#<48g^6rBnEv%Y0soto*8l(j literal 0 HcmV?d00001 diff --git a/docs/examples/te_gemma/media/thd_dimensions_1.png b/docs/examples/te_gemma/media/thd_dimensions_1.png new file mode 100644 index 0000000000000000000000000000000000000000..7c639fab31e8d71c619f8c5cf776d8964a5eb514 GIT binary patch literal 19382 zcmeI4cT`i^+wY@{;*0{LC`iqWsEATT>D59RMMMOogA$Q$=nzN<%81ejmEIJjH$$(9 z2m~QO4ANT=2!t9)Ac2G=_u%|~?|s)@cdh&0Kki-kt~Gx^NW$4??{oHk_Orj=&puK2 z?i%v_BKivm1me5(hu(b<=qG*v+7ATcZ{Pnp&;>1e3<4P$-O{^m9_+9-h4gmVMo2RgY&jpru6U^4 zGVD29p0x7u_%rL6v-UWzI+SbbnMlvXu164oC8tEs*gu}c{$}g+##o|l{m{{>c!|`L ziCG`7G+ut7@kmzX_Xh*FqzD;yIHih{^2eUPs2Dt2Js^L#?5-^8kNj>UjgdL1%vknhFX(+qVDZUuV_! zKYRTXFYx=?>zVyGf1D8Cf4TL1|DfM){d3U&y$fMfWoOC@x=Wbv(7UrhOmS$?N_M!n z!~;^xQJVvSgxt$ipEk}?8}qx-}Z1zJ^c5L@Z^zw6yHWjV=o|RJiVxSrE>a)000pWh5 zddIl$6qrbj*CS`JhD^xJ*!?wEB?>u#u1t!eK3lOmrq!fC*Cq%?v`Z=ae~_22D36IW zsnJ^*>NDlioti|>JBq%_e%VaR5sCkvNGK{VjU$z&L)g-)`ndFTE1Ns@mjxB6!9o7h z_0n|q+1Xhm51O-R+kKEFy2;GPe*Tp!aS^s+(cn92C_b=Q!l@Pc`0-Q}pDt0- zYb28y@|HoXXXq0l5sK91KVtR5?3B#uG%Q>fue7-)=G6Ip_{&qMY{(PU4bnLjJtirn z>zdF+L45meEMLd8E+u-3xElS;HMXG1*~kBb!r}NqLz=P|fzw&%su7kLFC!s_;Mvt^ zd96y_CL?_t_B4IA@+1hzl`U$#cB3kNdrKex{{3l@s3IIO@98$yb*x1vofw2oc>V$= zr)!WiDm<02os2{p;xJGT945zWM9FZIWVJ{qii`^F9W$Rj<-%WcGn+ACo zFoIs4`4X(6qm$0J9pECd>k%Lph#$q{odxXFxmC%%lP1!Xgt1r`c2Mo(08)S}zKORh zh7-LB>2k0iX3pH(J*{I(9Zj{FwtQK|TTzKm*QwQ<(A7$hOL^k!uO+s;_Jv{LJdVY$ z<434*Z!z3ydRcKWx2c%exCz&d74N05x%Z_%kl?alHBei0rES%^a11a^jiuB;LRooI ztZMquH=crl@HO%^!{Hiv@M*9v|0VH0{z)6r*yI9eZyXADyUZz)UK#>YRiChi9BhEryy>&j48$-c#IIC`8=We%By zu(nlHyOcx)dCfF>!ry=ID*jSW!0xq*QZ%gmKS;EcT2a&w&*dYmtnra$<dRV4VNL(u=z|m6trDgx|KTb_3 z=boQv2#WvCF2y^upk8Wp4v8cu@p2`-DjwVyCN9Z#V4;33++Bg*^S{S`C3@z(?)n>6 z-N}~nVw??Frd21K8m;di~EuNl*EL^6vDyb^QG~=BlgC$BbXJ-50DE zM6gID<)VJvGW@IcR*t9A%JjYTr|}2IMSPSry=~ooom8UK$!Mkv{9+XGeemJl=tK3Y z$O+css6wO2#K(nUo^H1;R-G}l8g?`C_HTsX#b=&KSnGn-u!xYPc<-?|YTL#igsjU= z_4%oR*%0W(`3p%Byj2lFD$Z-f438DkM8#;oX|IrG8;t>fc$q1Ni#$p(sv(bpb zw*Y-Dw{u=_B99c4lwaIfE@$%~x%8H@Wu1X-)M!ig+MkI_1R^qvCo=SJNX1G>RkGgD@JKiu=wsIvAj8& zAPT({3GX`)?)PE_WVIE-s}Qs8$Di!0T3gZ* zATVQ7`|wBsW#QUa*)E3BV zyq{)+?B`m3|FkKZTl{@cCyoLOt1x+=V(=}Fm4~o*yI?Ea7qk{t*Lb)@SoYGz#+;s(hW<>wsa-^VCk7cqS!uyL_=+Qd{4M;HDs>($ooD7z#Pqf61D(B0GrFj!d^~#$jS?PYs!$_Zuio_|o-~dt z{HX6v-M(1j(laB_1Yu^ripdLn=2F<(BEC5D<=||7`YF`x$so@2g;0(2th^xXIk54F z63MXl#*jhZ!G4VcplP+|JP!%B8xLv5m(>-8+D+*C#QNr3+Il~EZ;#JscerA9=v%da zUmbb68Aa{hdjkv|NJ>{hhLi?4o1($FTDl;p2_bG@jf{Y_6GDj+J~weBAM2@TtGW*= z@|;oQU0=Qqt7XqWU3ejmrgGMsJ@>ZQzY!g1$q`*I4t>KqZ5Sf-wHX9vk!Dic)`0U- z9nD9NuuhM?vH2%4NF{?O_Bn5ixEnVrOj5tM2Yh?fld-mO&0l!vb)*?n!d?_NZDS4$ zxaTvVnLDB9z}Jx51r5tUCR>-QQZe478scapv<r-8fSW|c^dtM6>kTWW1Z+j7J zpFoP}-T6!(ZFpxkga6Rv3fp+NMI3s)qXd(U1WyRod)kTBZm&S_CN3gJFb~q*$H^NH zk8%8ecfPVr85hwsa<-g{Wq{+pOGmqRQ=Jv(KL4fK>BBg#cICv)>I!_f^L%+<(l!pq zWL6{a(N{e@ACnu{O|#|DUf5;OC+quK_6HuCriIzrKb8x9BcPl4UU%g}fg@RJ`U$%0 zhGd^*hy#K#Tri7D$nUzM^Vc>;4M$r)tWmJo{9cny#_qPXD0pP0nR~b7kS#}u%NqBJ{V1y`Vd>AHCD3g2oV2W#6g)q;~7|EaI2ci)Pyj^G(>#X zTfCuAS(F%1r{k6t`-~Nc<&@gyHo|Ayu6Yv2tOz|Y0p#WG!fpXDd7v_l#CjFE`~N|5*5bVn`MW!!2D5db}QUUISC zA7(sQ=@wO2xt6f#u%3`*DZKgV{ThMY5o;tgZGoKZJL1UgW4h-(U6^9!h{O?cnF~cx zO-vEOm!DOAwH)QjHN!CRXm~Wk1E__}el6TK@Dq?wbtNDg9G4Kg%ecA~b$Z9Z@>uUx zY&x78Kle1UVQ)E$P@PB3e1UqY(pkV&3my>U>u*Usy<-=;B`Kfo2MdMNP}ZWnI>W>8 z@4w&9|4{IIdWGzABWKZ(avAAnH#}-H==YKzI~L;o>Bdl{Tb9p*&U>-wZX7MrD!^I^ zf5e+EQ4!hj6R6*DzediNnP}Lm+ctq3Wd>WFj5-gQ-<&?3w&ELM88VmbLGQ94uBZE4 z?Z_tthX(mWi)s#|zY1lUWV6Zd;KOY*2YLS#RxwzKtJWrelpjXAhrH8GhTGiuJP%xL zl(wQ&^eYnhY5bDpLY$t7skT{ig)O*VTuhG=$H~bWM~lE*%_gZkPOJ{} z@98QkC%54kodMOk#J4n&n8ch@7ND8T-_$_C+UE@*^Y={g_VnppNJL^zU$EVbj7U_g zYCV>0KZWw@ixTc1>BP1mA-j2N&ql+HJysFuz0cM|jN0r&0!LEpts3EZ-#^c#Aa<3# zP=3aQU?=Dnr=>{YQ^%Y=ioI6br5m}iCPxS~Asvf-tu3%E8T69dc#Pw^(EdCAZKUgz zsomYzEATSP$^IiEvDu-Kd3Aw4+7|CWd}qeHX4ov2uxKmzxqO z+Q1shZ2=)B6jQI-MYk7=fgmo|QOFLT#?B^Xgaw|G1IU3{_BL~ zv~YRp?KrpyMh!109e>PlM7Z7%UXxF@deyDQbT{pb!)T100QyMfwvcL+_^Y*?wO44% ziuPwp@DskAfG6?sZ3|1Ko`c3n%6hyO4bwT=)ONSw_ux#>h+^yw7 z+}wS8%Ft2EYH4MyKhS}wrj2k|Rsoiotx3BQr?Z(+B9zUCp~G46$0$1wIo zm5`d9#MhPrPjZD7I z5>l0sty}KS>U8Y}IgjQh9UQNH*Mhx6ynS`AZqQaIr;Js%AXt!4_Qc+!&X2=oT*gSpiojexJZMk~bkm{8~Nr6nr z6BDG>HPR(y!91+$b>KQ-fkhn++84+$kIxaTrzv7WneI)IQc^J+Xf0825xN zQ0E+{ELP1MX_AfG@!piwZmu3K4!D#8&p@~b)m65@d8^*1#DG&h(rC`E$c=eo)@xn% zH{GlNF5dEPppsa~N9EE>$v$%<3CNu2y;OW*vExYMHbx^o0GeupO)g_~I-LMkvDF{8 za`uq@pv8t~4Uq2f(d?1sVI}=3+e@uoEh?fnbx+5m$gen_wxl;JFGmd5u-Wjk`(Wb~ z$)Hp4F+w8xbg)Y0Dmh`5K42odrEQbT>O`P68?N%4ciNc<&7F&BhETqvUReO^ji`bm zWY<6{+~+b~Tjzua=A&jEG8e4|d3Yrf<6~x5+gQ5uwn5%zeL23uR*W`s35R06E@;^K zv$H~$wx!XpL}bQu$`m4XCnfR{?qummZdL?k1&6fnadr~8YJrq#a|bMYw{8VjKQ(W{ z(3{xKQnVTomjhSCjOXu$n8ahhm@so01XK~GRw^*;5oAzd{ytkuYCr8FY;IS;?LBhC zrYs0>XT%^{OT0CKlJMR&+BPu2fYnr7n~Lv6RLiENT|GhFwdi;W zAj(;VX_dRjCMT$Lj>8|f{yrA{Q=w^&5>aSr zBg5~TT|-{=Q$hmTTlWYy4Ji&G#YI^U%M|GIE0ZussCesRh$nO2ISB{9P2qY&W!YOi zC>Ia2LA%RBp}G!Y89ga$z4Ye8pcVokF}sY;cl>coAig7*5PvEFMStC76*3!M{@C}- z`P(iH3igsX7xg*Lq_T+AO9M#fa?_lARw_K~GSe`=941@ZswVnis29CXOOK6Od^tB0 zyIWplE5>)p1bNdKkm_bOT+Q!(brh=I<}-D?3JIef;&mBMM4i^rI+Da7M^Dg zqY0_<=eq>x99=^RYSIBa@V*?t-)Nci;3@o)`VCVQ8Iork#l;- z6b;TKheT2-OZ{p|sS4Inp{3vm&Z zSdVK<5zS?1Xy!E*kmq}uOW6&3N3~IK3rJbHt*L+i#_}Iqy6WPBmXuhY9Ip!+2YwW_ z-&)^<|LSp$BXWO<5AL=q1r_zyAlqnVUF;Z*vVWD#@SAvK+;{154J!c?-Bz8 zgV`(!++{M1wRZ5QnHDLvqWf7{S(};|nkMPfObTYtnQ#>JX@x)_9H)dansMO%Bfo~z zI$z6L4yxm|xqGxta$W+q9=%F#!$L{B%Wp9F#K;^A#=Z<8)rCbvgyuZ(nI_ENZ2IH{vQ<+5*6>=dL{t>t)gKz+IRPhq0arSgPEZ}@y3kq z2lGxx?zUsf&FA+YZp?wX4Z)Lyq^B5GZ;nqm~tJ65m0j*gzjr6?#k z90;cU%rBHU@J-VDeje=L;{6ZL@C4Q;bOnygWmFj;+cl~BXFkg`@*YFyLY_3 zv#G$L;i~aniU$|3wp~9(UR#snJ(5q2aoo!#Yp2~j0MxWsSX@a}HH9bQJ`A|iK;vZj zpGRDz?$Pim?GaGKimlJaOj~YH`rrR7B;G>HqQn9>D5?EmLhE56pxC74)zw)0=D2Ni!CvgBZR{ujCA2fDs0jZEHJ&a>gY$a5=0P zArNTjF>RrbFf?RQ(;qLL1HUp{W^GmjDgb&&6r`&^I*ux|{Aydz%(0rX^Iuy4}L2_9n3h3&-`}!`rujTjIYP`a!m>AitJ?>&6r-rxUqze&caSXNI)&HIfWptS>y{?OdqT+~PM$6DYTJ+#Y`6u$P@<|5IJj<=U+#I-e(fwNOzNdUGX$Q~U)fKLN!P>N3}LwfKaW!HrPt-xwM(DfG7rzOzyGE$NqgS^s$3|Cs#$&(jWVXA+&Q zO>yiu_XC_GM;t<#3jA-d>cw+DC8l4NFPm9!^~gAnnPU#k1;v`y}Rhi(~l3 z@xA^Gn}S3h>95) z1b=@)efM4|_+67-=vwXO7pVk2OG``7>6S)vzKz!QpcyJ;dzp$>00hzqASNY*{zIVb zhu?otoPU_4|7|#{W}mp}f3%-(KwMofIC&AS2kgFV+0%ugxTPa8Nz~zbbnQSviMYD& z+1uRau>A!QhVF5N>wZ;TT?E4%s&pDEvDmc5Y}*2P(gzGiSH6jy8_ZlPjLE*ba>;S` zw&zq+t*gGc4!mIxY7#EY$UH^76KXc@|Oj_EIlT30@bEI|`GD6A3FEylw&eO+i&3PJNRTTN$=pnxs==AS1g2 zBy&?cZ*+Ix>?m+an3Jt8$@|d($oodCoQ)=tT=wxS*xbTxu*C;m1x!jxa?9$afO`S* zUvLvtDuX06^u*br;CVz0aMC1gSfk7=bEPtY7YzS6U^o=KTwLe`tUXXL+CtCEKU|*e zmwfab0LA@-#{E|r!I*D}QK_k4aEZBnD_5+?lk=}w`(L}hpN`2Mgb+e$Bp=vZs8jdr zb3ms%g>uq&SeLj^@aX$t2>#SAy`_deO1FD3o(ywi)2$Tyc0KNE_}<=6);^v2sPw>Y zT)aY1Y4H2f&KT_0r_A=#>2}zI;yhl6vEXklD-l z_`=nk?%U0mG-$J>OmtXg$^Zu{Gmb8ytc=j8wp}*plC_l}>V2lzhChLBVei;Xdrjf4 zRASIOmIsFBzrMdoVq<_%le^g4%$N?ngoH%*q65uODFz_{Vs zC(WV!Pw;WzK532gxCGslZOYmgn+K$CLxV2m0{qT%N$R~T93VzJI$cptK;@Bloe>LP13X= z4%(?}EluZ38=IswH)eacGgQV$5|u~pcs0k%>Yjc@`1B-TW0eRG+18fRpTmb0nz z#05^*`V9}D--lth`o?(hCuwQ=ZkXE2((b}b5{|r8gt@>8*T8q9SQ^+CdJLFCepf4^ zDzrUQXuBSO{F%RmW#oNt=qFykRP`d3r;c}Xll;Waw)%}{n>{ntN{GGyfyXy1h1fDu zQ)AflWKRbE3ApIN56CRdWze&?$uN2yw(TNehLe=DX2j?-WcCTQGL%+AI<=`VWy@8C znME(t0h5D{ojp7){S&VtZEXVXtOLpkh>+X|>&h5fJo6)!qgao?C%i}t|NLX?vUOeE zZEcE%UDZ6Aq6y_Ic5Q70O08vq^-#L@Mm21iY$uP_KXvp+mgunUI3zv-x?XeK;|x}^>G$-flVVbx zM$#Fv=*s%8X;gN@(m)J^jfjX^>0fc0%gQ3$FNp^nP1V3YMf_C{YPV)?W5fLWh2eE_ z&j`YelrMNkkFp35s0jgBZ=G#==z^LuvPaL+-%}!iZdj*3S0?7~rc^t5a<`y&8%yV9 z^^XVazR@K|(_dn9-TaJ{?C$v&ls%1!pSWn0RO1g*LgXr^M=#i&caT}@!;bEfc?co5 zs9$>GI;ItBfEMRlOoMZZN|o+z?M=J3J_&_klqb>Bfevp{<%!t1`apxOuH`l7d3f9;q_)2>MzKB0Nz0=*Gn~G)N;~|ldk*Y`gt!N5pU_qf)OTEV zq6H}%t|aGjcIcz5$ELa&=XrDejGQ+_2q&^~JZ}GaxG{=DW-rI8%@Gn|?QtB535c(u zp_09ofA14#HmZ#=Iv24kq^TJxu$>$cPwGCE>E>2uUte5=f3#HP|H(g>PqnjPUmuf?1?oabT*HBE5c-EgMCLma?fM4exo>udCI$x{Sw-!Ob_!N(pVJF^$3s@K82Atjot+1%f1- z?|dkDw0>4uhDWA#TklIxppq}tA_QgYqpvrVcq?K3l66y)99{d3zKX_4 z0ngN+A~sfiP(W5uV`_&B2((GRJYw|H5Jl9`K#)-KQ66Bae12JC(S`{ zd*NkrR_5ubHt?4IzhB^Gd$Im#TA4Zg=R2p$1O)tBDt zvskAcl(tXT!nO~rJ=CeQ>YQoll7-K=sZhvFY?wGkDLqj9Rh@uPpJ1$(KW+=b_W&A` zUFns{x)cAO_$^WwVUIMEAtdxeDFl^RA$x;tpLeGmcA}&}U)56VPv9Pu#Thf`dfdle zO7M>|Hk8%qsT^C|J1}~j4j2342RwT0tvX~5?aJ({6EVK{k;c|CzJT4T%3qvZzX=M<>cP1S(Eio~i z$$W9#Lw8Jatv=gn*Yn%w^%!rVAA#p$vH*+G}(~d6i49HB7A0C@)BD$B=g$YD5#~a5?)R0 z(M*gNVRwG)6+bq25t}^fpygY=(6m0vE*`@lNL8B@Zodkh55BI0Ftv~UyPK$oOg?P+ zefcgt^E)jxZ@b4fw|?(?7{p&|_*6M1lwCiesl5B!_P0y9v_k#rd(FNJw1Wmt{^-JE zvc3@AQNd7wb2q||rWp3IMDs}1K!UpwyhZ$GcdVZm!ct}R@0>d-B~WiT0bC3B@!XYc z;1!!2)I}m(rJ3dcM4)0Kv;(4v|DlfEv%qGjHR)*A=gY`z$_u8xnrf!+a zWGV76;9Un%L|PlC1|n+-HiJey^NGTE`YATFA`#OhGY$sJ+xRs+geg6p64tI4ZtFor zpu%%(8g6Ex+~%fdbq>W5HrKKUnOdyJN(MP@>Kts6G9SU;*%ZI6t7|hT#Jit&-qSa5 zDpDJ2@u%+8@fo99T$7Bq3@M0jIfv_2R9oW>S#g~+i@Y&y3nEl{f$P{6?K4xfyof@7 zZSI$wsvYz$Vaw6Wbu-dOW#h0WbrOtjj)>3W-_}Qc0drqeX}VE*e6gvrt;03;tz;VP zdcL)(@QSXo#I7tn>z6mXSKyxbG`^d6ft4VT+fN#F35VG*-)E1~^h_~TQZ^%5mqG9_ zwSpYmuP1~nWlL_oQB+c!I~IA!TGYr=c<4i(kN@nSCgr$fLs}s;3Y&W>h!Qt_$g$$` zX>Eph?5Wx={d0?GX}X??+i&P!jZ!S0zZ=?Bkq>Hdctfb4+hK|N6NfJ|T96JQTeHs6 z_6Mu?((6VW3tY7lDfKzR+lNwvTQ1dMT2-dYc%6Uwd)7o(kP+x2yxW-&4EicqXVG27 zfH7iYf;P|wXBsp5xJ6jC0dOm^B8`m=Qi$tl%MEE^y6(7|%C&YG=?4^aC!P+=m3M!V zvAJ9~>bGd_|Miq67w&!CMdFTDvWe}4kQ^;iXfeHQw6Pxg$mK!H6&uc3Oo6jkZ;M#R zk|?Y}S!GH)V5$d8o&>^S+uOa+?V8spHHo{jMzSqncu*FV|Kl6;Dd2s(_J! zy<$JS5?LoX6fJw)mAjXHTIat1xmc}1p>h)L*V`jSL7mvK{`0|>W@W^CgS1Ce=&r?( z8sg()lS3j#7=mE^9X_m~K_Ncu0-uxWVBeLmcA_4YlMnvn+YQ|-5_cANq4H`I*kghe z%w=acrO9ClRz+nmZ04vJVY5SP(M}7aI{s8L4#0H&Jl9A-EeS4$MhbhMV0Q$I2PK6q zCnmdm8J()v>vHT9mY8Sd6>1))L2_w8n0cJ(NhNRT^rWy{O|@^K`<|)?Iz~i}(p_g$ z)4q=?s;q?J8I$Q=1a?-PZ}8=0lZkzV>ZdaR4314Edn~~v7~^U@SA<^Rfn|iXTadns z1Kff*l0dsZ&Xv=ju-rzlLma|ZtZtxOoOh=7uycTD-gI(99-@nYl)3eY)iUK*C*NnR1FZYtHR;>k-&NuZSy zJh^Me);jm_PCs_e5N2QenaBn<+Z21E>nvQmeGMSUtbWeW-+rr&Ef0 zF$>GdWsL4O;%GR`D-YSH2nFrn7 z4Sc84kv+^F^P8L74a+LmN|ETTztt7lw6yC`uRv^fo6{Z3sx2#NL4F%v_ayHaKTzTH za>}F}hp`jz~Tm z9DgYQj%Q%FZ2c8+Izar)r*Eol^!B?peH>8Uw9F#ZSpV$dv&T)&ZQGC~DBZC|CRg7EbkH14Ar8s#Y?%Mnw9 zwLb|G7@gawB^)Bjq{t&Zmz_OD*o%5^{m+z@@+BH>FjbY42Ol=8|1FMPjqF5 z1g)Y&*5Aw)-VtAq(hho>FeM<+Wo3*|LNv&T>Y59`X>TQP^WLZKe3e77vgxiG>s2rR z=fKyOMPUg-j#nxC1_#&%F*W+8(v~Bck`1qA@8ez&JNQTa2tB)W0ic^Y%x{+MLg9jz!=?2u@W_srm+aU#wYC=VJ#sN zHO*Ro6<--dnkYZNJYs*U&vMvs4cAU!?Ql*<7c@9qF2+Ju`=kI@1+e$;dWX-vYXVlQ z88h!ne&)Z6pQl9-bgwwnB-N>p@*8Clxt!CYqR~6>eO;XhYt&Vc zgIQZFzgzZSAbG)d>jBw$f!hNp+ne0*_4i=E&Rphv;q99O!lI*fBl$dFUP(P{_Z~U_eAMqEe;v`AhB^a(G!tg(wrm zvY9^b&^(4S!WG5rkrd{}Ju2gnbUeqJU{t%qVPL>`ZH66uZw{OwfFR!lH2(QvTWFjv zGfy|V3yf?bt5*X%gWv!X)R8P_JIi5FDl;qKe)$O(2Xaz~IHRTeFA%`_`k{t<<3zU2 z-kdH60<_J#(bu_U^89^#R>x8|%{l3DUo8rDXXlBy6JXDHfq5Y3# z@&Q{}be_f^0eas*8`A&TWdG082L9__Z}Q{D+ySA9pgJvm!QW9KZCB^0axMJoV(4sW--x`-Hrv@D zOiH2ttww5qX3|@x5yGr8B~z{M?V#3^bg%ScM(P!!l8cih7wRAx|HC-6s`IqYuk4Yv zzEYc5f+up;h*+EgMjW}pJH$!|Ph)XAkWF`6lbO2s^XI>b2ggS9@BiI$9LWFWN6L4UO08?zLp18wj-{r<Kf%9QKr@ zo#|4Z`@fNgw+a@cEigYI3jf=CBq1C%4)-5!J8&5c?A`@J;Xlve_~(Hn|L@w*|H^>= zl@%nAd&N{OnMG&Vu-(>{nt5l(;)oKz>kj{3`H0b|-W@?0))YV*U%Ch^c04LuRU~2a ziR9(wXn@NWZZ(hj*ZRkPnvoBg@&6m!mJS6NVjR-3U@~|vZZkfpCC^oYBEa1cm}4Nh z3;-CE15%SeXWsSv%>S#nxoX>W3qf)7`m3kJ-^$uIDRI7oIqv~V8Q`c#F-Rl;GDaDl z0*K3BR>x~gi?XsMGy_j6zZ>6||K9xqXMW3xXS6J8>#q#Lm(6mb1>iG00K};BH^j*0 z5V@?}KzGg*h83jGS7h)nwy~Q3?so69Z@VHhdw+A0z9p(-E#D)>qDAkO@5_4k^i@1Z z|AJ4>M%^~Kzw`B_j(JVEI}jKFQuF>Vd;n;9L)KqoX76vtvHGuT4@o^q&w1Y6*>}z@ z7No$te>C>ypA>B&$w6got|{+3z1L2qQ#=p4h=I(0oZ!Xnwom(Mq)_>sJ&cs}H2R1* z@OTKIeDkJTPBOox8kc~J?*6FmeIU{L;f>$RzkQ%BD&^8!=7Vw@iyF|6a%ykCZ;Nsf zR+|8+CC&h6FG>9YSnU60C3B$x$mej|`J5And0zgf>qS#BF>MytX;0t}$Ewy#Oi8jBzQ6E?>20&+?!t!2-_!G=m* z{E<_VfD~Fp9#}W4l1SfQ?GyK#Z31EYyv7B#*i53KWE3!0?Eop~+4^)GD--* z9Frh3o@uG6uO^;0%K`HmI+yp~d@yi30{R2J?@`O4h)bM4%&zLc*x^)W&9mh#AUp0X ztwf6&CCf&Ua4A5qe6zW^NoAb6^k5%{ju%%ykzXIbFcv7UXQu5=E3vV8U|Gc1kLhi@$Pm5T=AICVciG9G~ z!o+`W|Co;d*=g~2IPpAn-?u+maJYJQD*2DeVv;6ej8s=>Cxa@=>C=aNZ?&tzb^~u7T#j|7MI_|B!j8t9jm=OM%w>N7gR*CHcO&Joj%&cXDn3*;^XSpi*gV&nch2SNkK zV^&UT6c!_5n@K_>2b-p7VH!$O;M@9%#{KtiTU8j^em52nxSw~jaX+vK_Zxwc0zZEQ zUNZb~yqbAf^vCh2!vDT~|F)VO91hR(HY`DPqt`d|le5aZSnB+HuFm(DTl_YpqPW82 zlNdBPlirYF^PAlKb2hWTMc9o?V=lEUy&HRce53kz0p)b}=o_vFTP3KeUHe;$m?bY!_g&Elw}3rsfFen_GMVI!>ALkh_a)rlfk=&7}C(?%}phlLsTwT3tm9`$ouC?f>rQ$Erom85X4 zOT7u6i~YkX22_^jYa{~Ew~nxsA-^>eam8X#2}1!01a=J)rTaKW{$3fgfP5!d6IO8| zeYA4Mw7xx?6DQ0WJW6g_XjvoJii(mB*=2<{Qa7#&2^;sQG!eEB zX!CdOtYKI02pPc=^-KNM0>W+_t0NR@)o}bF#3X9=Yyz3LW(*6<6f#nmQR*m3N*@@k zv?~kyK2By-6TF=0?4Y%GmPrTASr^oe;!X#RD6ORw! zFE^N;9IZL+gSK$SPiX6JgoG{NdvJ1bFR^h{X}b2NJ3d2HF774hV_f>eD+sq5 zb=j%h71FR>7w&rus}mmd_Sja^vC7p|hwE@j1>$FC z4TaB&ibyGFh1u zdja(r^j=OL(W!ZwJ>^Zjtt~s$?9LwgTox^l+U72NUWBuU;;4}*Y5irnG~JMF6||_% zF-J?U%bvmI0UJsx(SyCTXdI;>TE66aZ}rm$S6|J%S9ilx>uH^2F79QHNSY+e55QHJvj-<6-_ofUx&eCC#s3H9_I@?9kFreE;0~y9PBh=J#)rs ztVtVBU!8>_Ob(7n=c9ZfG>(R#z}bYuqDMnQ!eKe=zs0amjr2QRy+~)~|GF`R(09}g zh$K9}iNRdvwnY&9-+to3g*A|v$H_>c!jb_Uv|!mx-ot2aMW}-y66l{6l}0xZ9tbKI zZO`ovC|9mZ&n$F0g!Y^R@7_V=Xo=p7zwL(-j3T%Xc?EMFg)3!Sl{~^`h#0cxq2`ro zB}lO%x+9I)QD&5}dNV9I2d!I@dGY;f3e;2aN!Iv=E`~RC$!~c>(LcW)cWKwTE#1jJ zmgk#;_KA`y#-Eab5M>JwWuQHNzu_7|{KVPXf|tM?M<)Wqt>8Jk8QtI*Q}8_l@dE9U z$aFW-=A&@qhFZ&o#1aUhnE2#kk$8VRZ`mtuemj z7{u?h=J(%<#qJ~-3(@;hHZJ=Losd>?<*F0cKB3hl$a zfaHoC*T0UhOaM;hUN}^Bx|Uk2d1Y6NlDp#j4w~?v)ui@HWF3jVnV@fxQ~h_ zRvzt$YBy2sDvyh_hZUWB?3S@R&@#3!{&pmwC**@=5=;-tM-CBw3gXN;H2-QPI;YO zh|({~u5Y{=JpoUU|H`GwSj!mMd*Oxt0CjrnuFzUKEVNyogwCvNJvZl1QtqZ$Fx+$< zb(aR*JU<0?_}vcDoHrCVNZ0rco_AOiR+khUI@M;PP&8)5f+w$~A~7EGJIBpKj!F9) zJqaycu~>ZbIDB%Se_wT`JjQTTGMm+(AtVE+9>Ie#{(j2kMw(8kYC|0;ucN)dHV}ul z=nh~mli+RI>V;lg;dh_ze7l*meMmMMH)JtZUgr7kwf7O30Xr|1p=QA{?nhML6C&o3 zWmuw-mG5O85Vu47Y|(+(<)y+PUuV`K7rExy*N%%U7 zlc9Dm(`j=Km59yhFyhDouDCV~vg&z-@jG?QZVI)N6%SreDgR!gv&zSx#n!ogip0^? z1un!2#_5lPIP0|FUqNY47g9G{V4133U-cEre!DA{O|D(;_wl7*i6Sjm?{fNfs?Bci z7RlSR?*F=p?+7gqnCW9->k(C`#}<}~YTjpjv0m5YpR|9UNvt+_B^IB1qKs-Oo%k~O zMbSB54jvBkz+paya;d4Iho=pK3YNMaTpA+aTizK9Ew`F}rb*sE_A>f>8H8f^htrs~ zSGLB5Gf-Pks0D~-c$wtM8E`Ie{p#J7|E=>Qb^q&b9UIUfxD9lNA!`_QcK4Kd^gF%3 z6W*-&b$%k=<^EALe*(sGmoDXzRmq)v0>nxGvq!6Er?UCM9DE_v_1;-!Cn2uPS zE6~$XZT{v{qc=4nolsAQNrAAT=4%k5f#awSj}Nt3$IIg6O>vo1-nvBWAF%s*YXLB8 zKa%FX)q%90BFEI!QttQah28QfoObTL&E<^CiE>%H9XSpU!;8pSZSghsNR+Fb(Xoc6 zjHB&TJsM;E!ggblxCz>=?PX=-=NAS<9CT39FH-sqFG0l$Y3o{jr5Prdb_Rn&tkI`C z2-7_sOOoXq?56&Q17TIj(u=-HvQE?89X7_SDZk(-3yL^PloKzzy_%NfG||Dr&A6+^c=2Gw}McFbl1@=HQK3s*G|HWGrVw-N1&3`r@~Z z4hlRvEIRrxqPFs70q5<_J;A(eqGJhjTXQ zyrs}|R>fvP5)HPH#O;#3X`1V~TESSqF)N+6EV0p`!XOkH)D%1S*O0Y)tum0LZAf6C zZ&mOg&8}Im@YiU@egsbGb6(H4{k5h34vWuf^crHG6|q9cziacm?g&jbJ@e!?7pt24 z_s)3@^y_kJ)BAK}MD0N|7Mb8v+8~u=Re4sm^fY+Fy|B+ec``K?SnNt zc(`|I5$YQbrRB+ar|hkoxLuvp-gUwWRS^*(+I(=?^0UXAM!eaui^1i^^tNPV zba#wAyG`vhw zs6Zx+x*B$aji{IEi;PKLBfmi83v;>8NLBXDA|`uwt3e(kh3O>CnBz)iw$MsX-R+(< zr}g+!=qno45YOxBZGN0w`F)a6H91OA%x!5)! zVscV*XYMXz?}A;(=%{3tjcT=6SNKj8k#KjG?)_#c3PZz!5amR@>h{ zSr*#;AX=1y;RGd6j!=kMiT zR5mt&{cw@yE^>K!xUBSOuy#=IKoxgadTqan>qb`w)5}(L@mzs^`yfDx813J zNOd@j9Z#m3^y0l1#z(6rFRvO*N(daD*jLh&aG8Oe1}wm)76IdK?!1+jJ%f%y^+eur z(R5)RK_k+u(udv`kb~?(Tr?Q_C20D(fh=6N&cydugI@gB9NqAtM(7*P%WD}7# z|K^?l<%O+2&8+sK+c1b5VrR2KedA+2>P7=$Z7h-YzQyD`Hs-QLt=4__Z=iJ0a0kh+riXB0mP0u#Cdqc|1aIhLK`!;wzwfxHVh(WHOo7b`r()P0N-Yru=!2j3+ z4;(gtgw~jx0^9BEZ=;yjz$dy6c87cXQR82yf6l1#+&N0I&g1+|R~AL8?f@Wg4Np1@1qtVmIY? zNE^3-PMjO9 zykKpWknI#z`kUdbu(zG$=C>g_d;#??;*226~o zP9v!|-q}SB(>Wp~hy2G}FH1*!&ybCt>fNLl`JW3}lRYPcI<+^Yu0S#MsBK$y78$Wc zE}#l4!|14((Y)YtwewY#pSHJBF0lp&;cc3s`Y1_I9r+Blj*hI8wWAm&gdhi3=yksO zFL#ZyHb>QWUWq{-iU%OSH>veo72q-KL+`5L{O;OA=0Chhv9`Ccn>KQlnz*rtT{HT6Z& z;NjW>Tw|x%!B`tcc=vbjCj*Q?7oXvknI0v@m1a%X8dktaOuSC`*U^_wNa}b!nL(tS zvz4ihYdVQ8R?uIxKSGmMXl5mi4v8CB)I8+P-a6h3(ll-Jt7U!uTxO*k@h46Q93N@K zCTx;e@In*L$kCxigXNkR2_vr7cB)((h@};AFX?O>E%4zXopOQu(Ljb^R2c^ z>$*ha|%W?A)KDbvVb*Tl6wp(HhU&Y}v{IUf<6#T-%AsS3D!NFdz&QTXS~19ijV z2G{qwdPJhj5z6zARKD5^VJegkc+^ax{C5~m7FAE1;@kA!@TOWA3QlfqVBzc#&;BMI znf-YG1d77yzyOJSVlW~Cj+QHCM_|0(+`I)61}hD@S|;BJ*@2Z6R(-l#diw=x=O$@1 zvazN@#a-S~gnIQ9&mrGx{Q;X1iR^B?p&m=A#v}}h&sXcB@K@f%#3b#T7NyajXmo4i$gl~pMYFgkR zZ8Dg}^<9m&dB1nk&M*AE2$WOWRXM*2D90WaebsPhZA{}ys(gBhuVM5d|77YCpCXa% zS8maFaWrg`QnmJ=xK!G~+PGo5CKt_W9R8g2k%4kF!6%yReKVxSW%eFc!DlX(7Z-XB zm#fT&O)nZr#DL0Y90d85mBug!-O`E6=I&X2*)v$*eJw4k4U7yB$gI~`+HHuycwiK` z6c^3T7@tr?ZjTMmTC%A~IfE!u1ig<}pXV|*mo%b&>&&qdG?L>I7@&x@+;`5{E@HvEb{H* zLz*9`=v-D@dVwNmmPFQb8s}{kL zh))-@U7~qSY-CAfXiC*W7!i&*nE2R|hT!Tg4_12hUV(a)t_ZpLb%E~L>h@5fSyF@qj$-L4%TORT;XZ1FB;+ST z27=tDN&AV+g=O4KRCtt2`%}h5c*0zNr`oy{RO!v^;||07iAAg|6&@)&NH@Zt9a|$^ zDFbdv{ZQw@BuJlb%qGsfqMuJWmsqFi!{-(bTc#TEAs=zXG)8nX?|Bk*R8sof z*byK9LED<5&n@o~Z&Kc)T1>;mLnBd~9o!H>UTnyNRZ|sk-K1`Y-Spc>!h9dxZajgg z+O8ncKFBoCPnpqp3vq-(Q7Y|xueu6k6xIuc`a|CzBE!aGo+sUOBA;+PZEEI$fB(^G zr1KX6SEUnjf8INJIK9(AGx=QiUoqZYA2yi|GLh9q4ZQ38^*`4c;j7F4dhcnuyjZRw zO+h-Xxr~Bpi(8ah5u}^W&2xux!{!u`xx0aFrp%^1q{x+?zV<=X`>oy-FS0To72;Gr z+Fy77cD`!@V}_IyxtMX}udEU>@2`>iq$~pJyt9&>&>uYHjbU#dLzWS->fB}|w;x+| zYOqj0zwW?Yn@Ha}1fu*!vBC=#mU-~7+W2J7OpUi`2i>NPrB9tIUqAJ8Bac6=s)4BDVAM(OA_mwAJx-JR3&_XY{A! zR&rjko*iPTuT9DlE@5d`<~g_QKCeTX}xTB!7!3yUt=T%TX&KfU24v^XF+M2XlL z46TqJ%H4$vTe&P8>H*R#93}ph0Utvq$VDeCML$kV>dhq4CwFsyuLIJ{y~w)&zcKIz z2>*hEZ~LMHChCQCz=;&PR;C1ztS`PF=^vM(;OI7Hm!5liKH#HI|B;6dNT1b_NJcf` zw8+?iOjhp*fx}WVYg1oy#;!p-rO{xSvxg3pO>SKB48h!MUG;3ngJVZkiKau~tne+b zbz>C*n?x>FdQ_U<5^XYraLy)>GKRnnsHj>9qe792j++y7s;agj%r4N>*aji@j_?MhCE!a&+d~HKpgm{Qb-6U04Ln-IAXB~LGvx6(uU!l_ zvDhDnt3?8tnA_qry+vY31QCodJgSl$V%)ZEfk`1Cc+{M(i~7fmfHDJC&1tle_T=q~Zoi>KuvfNX2jhdo++BZ)u z*Q#JkQtFc9+xkH0y8a`~6tZOk4qR8EpbX0ca`Xo%HS%E#dUD1uDiKO_7tnehZ&H~( z^t<28^6q&jX~#$hC$5yam6Lu5Pzm_GLDW^+2~y)Tq5EM}W@P_}J98;hjempdy4P4- zP#p^Mg&Bw&W!!T!3;Z)^IWl65XsQ4N8fg>v5dtU0aY&*(`_J5iJxe=h>haB?|5907Fk~|G5g~Q}<$bw|SQB%Ad24 zqmp~m$vP&qsZrm}j9fhtlU|!h!aRJPFwwqq(fu({>b6`RwMUtTwPWA$zlXVXoh%0_ zNG0$yv>Q{P%6!866GYy+$@bQBdem+)Y^Hf*bHQNrv4|;*J38AAtu2f080_?81Z_4hV|U16dGwZv>B&oe zgXN0+^*TOtb6B~#j9xgHz)G2TgYiYkv_x6H@<;CzC=;8wc>cgj=b8DxKl&vOR$BDP zt>w{Hmg~g)!uW7w5b&MqeK#By;@q=~HMDqJI70^bGI$R6bXx~3{eHOvNzpJdi6^IR*3K1v zc=KU?uIv0_8_@+zWH*=L?B}Wa@J~d=bf~qomd!v!a&qGme3wz_sX14SNXAK9E@qk| zln>Mm${7GjX5Rr~5lz3zJhiBe*a_eks#GIjMG3?VgI40Cb1RRw;#K1jtl4MboOJK6 zPf8J{8Ei7gMr znq5U>$I26mVF2A4)+3$;^G&94KD4I7)uhh7PZGmzVA7bfPSl{_B1F_M7>)lcYLAD7 z`cpQi;?jN#pL}UI(PU(Ig%)Zf_Fz)4fq`aK5ZDjo8R46c3($5p#8Chp^Dt*-s7_VH z6z;oy$uslVSK2r0!GzgFqXFq!|B)jG9O?MUc_;X)afKLjdb>VnCgP4}x_|agkv>H@ zUvzn@GAVy=;Jm7+=~$|N9Je-MpC2KsMl-#SX-vZqnV6x*urY>fur|e|~ zVUbj{s!xkeT&P#KHH;~^XIvC68|4WNnYsCXf6ZY|IMranh=cMJf#y1q6K5oP6GR@c=GTKF@65&$RcZripS2DytUK*m-I-hW80>AL!K z*!je$8ycngINMN0o@%Fdj`YX8p+YAXJsu63*qok5|F*-5noW&TOi^)4p<6Ax$|f^B z53rV&9~$24ppPOEc>W(?uyn%mf-GesG+gXhGH)WV?q(jGzZYe20QNCqYb8_BXDmhz z8nUx1`vcN*?jCHuMkNfbKDlY(Uow4gS$rrw(Mac5Oe^HFZC`Q1mQ4xv;YKca_sju7 zP*n71n-8{41D^OSI&FM%Tq4}pa4u2zzOO?wsgXSlCY+F(2r!=L&o0H1)=FN#X^@Mi6rp7_!?wUY$zBJ_xSjxj=<=nW;u3z zF#^~0^(IVa#{)=GKF8*Ig!WXcykg3K+H@(udu>i zbLshBF&UpVP(MAN%g)Am6?Q1?8+~K~9u{|6lt~;s6ptl+RchP4b7o5pjN7sdHKn>s zT@m<0<-bs2KL7uRF6IyDP?nQO0G9JJ=J@t4d3pIwnG2IgCbCX9iDo98t0~X~(3_tE zIGAYz)?;b^!&Cc?UyH45j2|DFOoh1~h|nkd4|X>d7NxSC-HaadxmlmGdW#r{qW?;idFb?V2sM`Ky^C^y!q91VvIk8b=ry; zMNLmu2q11;BK>yeKp1Rqi)jYKV^&&7VdMOnWd0ddU;t`hI%L}yv)Kmgm=<4TMzNXK z@45zVeG`PfV*Wx3{MWLF8U!MT-Au)zzN~#Yr22gfK<|D_nv3M1VVH0G!Lb!S%#?Sq zohDd&-To)%PK+K?t=kDw{DfX(n3{Ee{(NC)Y;PO8%K1e=V98@S7|ILIJ6w{Jle4y- z3cPfcBk|7P;Pkl?2`Gh9<>;qjdK~fX#^b^T)X?cz4FsjM=xH2iwZqh=y%M z?-&ToicP3=vXY^4qL7U{7`6{A0TDUxgkpc$p^fdBy8@lvfyD1u(ngGC(p7FOPl-1z z73*dNcj%#1K1(MY<{UFz`6{2M#;Mz_+!Lo8@tj=^bTVXsR(HiCkb@I6CH_d=HZt$q z;?63a?z3;S>QQ*Yw;1To-Gza&e9iUs^~nm2f3F|pgzIw1K$5gh$~)kPB>s(dF}Cun z9D0EOe=vF@clF=PsM4r}tH}`J*nV_;2~2GDzm`+328OsIk;vgl-pH~YcHIue{`H?R zygRfUwW@>Fh?+)5DMaluu|_5P!dLa>_gYUXiLm&M3J|NB&8EeXS{4Yq36ezrk`DzoOL{ie z?WFFVFUfnhQ@6+4yAEUTwXSuUsWvHfh0Q>iR4e;hoKA2pZ% zrm*mAN!~5e*4&3>9(IqnECyW`3+!|7oUE{ZHo*~SOHoO9sbn{$W&K70k2&7~*k(Ex6y?`DcVxNa>Jn2;5|U3VT3UTFN&r$@Hy{@!Je@%>2KBb<)u zf-7UYi%57#NXWIzB=lmnLezF8)&RfXUIN%a)Jvoe->I57{j*<@W!^T824kBY;ebUw zJy&qXGaKch*tiWS1U&QRH9(kbwtk~**sbs=JTgk<?@0T{M;_~5~Vjkda-HjK!fyFcaRQ)7JY?%r0(o@1;a$>XOXv3Gkuo^@C>{NukL z;7{-E;PXnF4wr=Qy4;UG`oqY4K8l{niCMYT$fR&9NWZFK|7m$?cXxN0jR3wqA$+M` zaP{ctouA0%JHNk7{-U%srBf%F4B7I62&zH<-Qb5BTkQ(8c>kOD^`8KIv%l`#`9lu7 zvy52;kolj0d>Alr`;*;CKzrEUJk}Tng8}Vzexml3pF%gELtD%ghRQrGU4MC%@l)7S zE~vj%!c}ElTi*Vm41lK%{(1UAwaMV_QVdj6O7zvgUr^c*BT83#4b%fxPyGN>-!~51 z+vQE?4H$9z;r;>p+5xm2FaYrZU!C$-tzL$AApgutEM+NLwLRi;!h8NImQVk=<2PkK zul@TjAZV9H08N1U{Z7O`U+^b(Nf3J&3n+r;bQ(;(QCd!}V_LxC^go@hfcIF~cC$mg zO*$5$x-S#(Xni+)bFh2cfDHZBTWGU9!imxW9^o{i$F@gFDc1v=j@vcv^2Wd}E-rz+Hg>?S zi#I}Z8{e5{T;HaI3*AI@=&}<3Z%G8_=H>=+2Rx@M5_~pSCeFisQ zcPIcD9{9}v&P@X@|DW*oAE*Bh==cBJ{%=VE|CyHmZ%#|w!8Fp1sa*~hD+9skPUx&0 zJWI#^G>k`~s~XVfb~mq?n3zl`U{@62yIbqPuKk`9BNva|h{48=Z3u&Zc{xL%$|ZnY z!|VAaH=811h6fh*TGxNLZppnL>{7S(TW16DIyMnVC{e{M4UhbE9crsmRAF)~7AtS1 z@1qkd8@;Iba^I9$tWFc*hYl=|0F3M*7x_ow7g`^M-}deQbRq(cmWQO$ob|0tf0Jx^ z*H-pk_h0jSc15ozcCI%{s%9P$6qRcStZOPP3gxXi^}};7C@lP8#fB;o$@>LG8rhX& za^S7bm=(vf#(`s?Y5IyToD=F!=}k{Nz(pRCq|cX zNI*r!?*Trk&(1n*FLXGn7WT_otDj2jLvhX7xw$z&t6U?nVSs;*nGnN&^(puhaE}St zJ%0E^oWcZYwWG}k#9WzdgJUox>(QD}pW$G-QQrB_nBSM*>5Bhwvya0EO5JmO06go} zw8Mz_RXP-!4MhC0W33*xfbVwEq|_Usy^GQQsA_vNkn+mh)4kgzPwy=4;XnwCdt}L} z4q#Aj94^qrZslX8=!UB|ypMHQ?lfBd82SW7Rc=3+zZ8>NwBK?tj6HbV6CWA2Glk%L zxL!WBnNYWp0Dw%W(cu@m0CoQ}mSeg&aL5`N)&6gA2ZPiyqE2PR=XZ-Y7@no44XhD; zx1zjp*UR)W785qMn00bLau|2E=uO}M4E zlSel+V>mlY{Rwo+eBQmD2B}^046)08Jiy@$E17Oobw% zfV6xs+HJteu6ZRr0Y3Dw0cDi`PaJpy;EVb(D~-27yrv(BuxECwgGKg3w+6FiVR{mD zk+3U~u6rKZ+|Me6@eil~ok4IGekMNZ@``x$&xlRe0RZ$kKT93^d_BK>RPPm?j5}u! zGDj`bd^f5uqZD=YW#DjB+UnLfxu<}z^?}K|B38KdPdO0k)qS?;Sc|rDa7>#QDn?>Z&Vo|^uP-c z5LK@6bW-G|5NY{mw0Hk60F>KgAE^IQx|L#h%!OIeSj@mg5t63cUVU{zq&5;lWfGR{ zR%@;SP*~|88mxmkJBi^7oBSNJ@@IP?4uG!54<80$|8gnP+nZRJe_X|VS3I-LwdcQz zmPEf(4U1gqvcG*>45W>S6$VO$HG}Sih2hY^buhQ_Gp3~``* zjQ8YsyKh_`7;NvJxv;h~-b+{<_N`7SNqni4(7j?$TreUm7Wy&Pzxr4TQSpGcl~J4| z46fGg%d9;;lntUjwhS|&=jP_?C9?8%g;Es?!yNqz!&_X%JO!xKcZkQ>vVTuTXY!lhB1KV zi`w0~C9S5fpUzaK6c!f7QB1V`fMB=8eB^2l00%p7stKy8vpXY$azp00hx+AgPAz_U zkn{A*)7*guSN&TOl7n{ro6F6YLb?5e>+I-Ac`>?LrF`^$ypM8^YmaTek=m^@kdZ*s zgLJOHs-`UnR`-)Ad7A=ly0A^~Gu(Sm4p7@$PLoxXAyvFD_qQ6c?jJ>axsesf0Bdr% zvaB=6Dl})S6!#@372s^iD3{~cYQ*MYl>~C9(UR}@)n!kZo@02}9b`@@yAke*?PywE zU^*TW1|YKRmOjo)Qt9y1hUi60t?zB%=&5W)15`5`+NRA1sF&3{gb&gHMbk$qWVlbA z-IFZ~r{XJXK9R9R`IAG4N-fa6GODd7)sFY=DJcq|v@K$!>8$?ij}9Xao^Xxc6Pwsn zX#B8%y+n#X-T1-xlHsH9m*%lL&&4j2*vN8msXl+a)@wAeQsUR+;IW(i{LZ+XmKL12 z$msCRg6$s47vE$<{myKt=P@aNh%V=o``dgy7}vKoYEcXIgFf0f7dQR2Z}&4mTAh{A zv+hSZ90M@6plsR)3r(#?H1P+~F*+rC7VHupfy`0FV%FCMQ#Aid?{Dsn?i;P{Vm|#c z>T{Y%y>?>stsT(qSQ+Z%njx`R<7*#*u<>fH`3sk#=ej+Ua|tcyE1o})9PN2%tRV$T@G~m1sVNlHTp?@3D^`xVqOz%rs4lo}CFr8DWJ!DfZIn>7m(U#H)Ndf4>%`L<;H$+^lT@fY-*IHS_+ zG72|bXn0i0;B@8u)(B&_`f)+4r{xPHu4rXrN>`K*Zc&wJyz3w~rq@zPW1vifZ6%Vj zX-pI5Ge6e9_2X2})=QjMx?b(iwauIRs-s5Sh;mvA!>lD^Lf@T zVYOxy^A19E*6uDS`MrYU1fo)(+L*emFXULO&p1^scS$}hxTf`?Z!$0TX=>^j*9erd z$@S>seuEM#oR*c-L2t{EnWdcAw)xz1czj>55o0UP1r^{VQ&Fj;O_yl@$r(#F;tnQj zchrns|34XHW-nfzJMc3b`l;aFy!7hR;HEooJqH=#HmT*!6uM^0my*GCS;9Q)IaHlI zxyg*ZtZvxwzB0kx7i&3uUhh?L#O0z4`Eo}B#b^sXlsow>#f!sEmIcv{%W9c)3r_2Y zu50S$Yl|{~#i~?I@)H%Jt4n=6d(L+e*C(pdwZbe)Op~XN&uo29nQ7q}H_*ro^`QbX z6%DuZw!Uix@1dz5L|h6DBU#9yYu%QHQ0ftA#CDBM9s<9pWwcfObwQ+6W2RBF$ZoaVDSQHrYKSx&d5RH~|`z-lVwLa&YXVz0Ua z-v`-}Ux~XFggkE-!10o6OCc-txm2jsa7C$Z*?x{Z!s1Mo@0PGSAH-(pqhv`F!eW z3#yA$BDl$oE30$x5%-}tEWRv34?FWrS~ji$MQ{T!%R_a2;E~N)nu0VAD1tZlnsXFK zu2u1>NfGZ2nPMu^h!)dc8|qXx&Y}1rayWeRb$x0KYpL=8JV35{$tmLvi-_wg&+#vw zm!w4f<5qU*pKwvsk-N+y_r}x7OXCE zeS{G@Ijo?G|2}NS8k`?`d+KKa%=q`moKHSaL(J`RZS)m@1LlZVUUZ}d^27q z=vGiC@7QNcXQc-veD7VBkj^+dKqg1XF+ZbM0W2pZ9ol!K7OXEp7soX}m8Uo7Pyp&@ z>Egy2G9$Y`=rVmFR|)ofSz;$wg)lj~lLi3zIp*DCcf)q7A4p*MjgfXQ{rSW*|IEW1 zgLJIQJ;Tbd>xVT0yZ<)hWL8Z%toM-_Xdjj`Qqkq)*j(aE{EU%$VqsxYjoC^ND;2X1KMveAfUp4>+HdIx8wrpyUyWZR~|4r)( zHPwS6GWv}?4^W~e(eHi`5-bjkYRH&{Dm`KTQ{u39rVCS-og8{hy9M(T0ak-uxooZX z#oOg$Jv-?LId!g=U5HW5D|IJ7=kRM$m3?m>*#7lEx)k zMP5&OPL)*`CTKSj;ncxj`V@6AoRoedbR;nKA6+BTpuXbfn%Ql?Ho{Cx{gBJnaC+wq z_g6N_&aUEkU5jK7Ghd^p0@kLRHVgGXoK~gz?&QfI;e0Sc9sBm13W1g0GRW^=zc8@U zan~1Xx1j6e-o88MJxeMQijNKd*D|yIk(lhaZK#vK65xBwN zl}-3u?&MS-96GdlXYUvPTjwRnNFihMq}pzv5-xzSf5;dyP^-J}eYtpMW9|xJd}aqUBU6kQ!9+8`!Ljz0lBQPbXcMe27wOAtdl3Uhacq3 zTmD?#_CkKJEMa`SuV0<{{xKX_G_9fQ`K`7HHg1uwbp!3m*fJYr_1jD@vT^)@UrMJV z+=d#Y-*6H>qKr2SKD)iPmId`$h&6unPo&MSu4MCd!@I720U%t#n(5LW5`y6t!68PS za!F13a_qQ1=~}m$$vA|0Nq8BquRqiZxlrYmULj|2G;iU<_nSh{RFJ3odB-pW(IxsiH%ByKgh!gxVCr$|9?ZbC5#;p9z_(ycVxh}D?ip`fB$$Nj&wS=yp5nxJ259(oFJPFKwYOr8as_LJPpl>_xzkfMi?&yPm>Am zEERP%+H;)Ipk%6RAK1x%NM?P$@(SAfMbDiOAZ&V=3s>LzTI_`_Hqy2nRbT(Lf-pJ7 z-5CsXVJzo85KF1mT`3G+$Mpxfx~eA?Rt&+%S}=F-DSu{cOso_=Nl;@qta=KF zcXoPqcsQOOqOcBfhi~K$-zn>(nNnTcudb257cAiNYP~A*ONV@y!ZH$s9Z+=6(c;o1 z0&_@g#T=1Sevb3-bL-V0Ebx81F}mbUS?D0s2?cNsXxE)C0HZj4R^jg9lg2t02$;Cz z){=oR5s=^JXZX%R&qqfCaFtT8z)SrxCYe8AcC#v^#?o2%#nmT zQRa$L5`XKsys68M85ERozp4tVbJrUy8{aw}Gk*vkda@zvR7l=JTb*5aV1ExPvkbzG zw&aHS^Y-#rq>y+P1{N`O`Swq>%qb?-r_Y1zj;pownmo(OdS8!Ez*BOaSUM&`>#>9T zaRENQ-sJj{{Ls-t`8Q!vXA44yv`P~m$pQ6)Mx7`VY`sV*t9nUrt(H@=Rs`+8q2W}keoe7hIVa(k*O2a(mCd+F` zj960%O_DmTPN=?rz^a@?nh0v|t*m}IRbGuB^pkP4plFoDQCL+u(?{sP+SFoJcq8{@ zS$HYMqTUa>kUDOT&#kc0%0yrCcN`rc+EjyPI9~Pdj_=XQOlEot*Y@=tBbytBu}G1( z3h`dT2K;cV4?>%;d?FmS?X3Shfl|``-WwQFI)dQO3gx{8)M|+XvZbl1gpbC+6u*VI zPLmA^cyw$m2@!7rY68$86Lb(8%a+hjUug?N$d$3T6z;DwB_iam{nZg_Pq$1J3U6*a zQF0zro8Iy<bT@ z)D$#m-$oDe(ltX(#dEvAWHh$wc|O_F>vL?s=up|Zv($xwV+Q-`A3~3P*M?J$)2@Fy zVS|x>(B`(9kUBn%2!M2)0yv=Vmf7LRp0oYjTKxV7kS-%m4DIR_{e@` zLWuvy@7-|jc|7nPA{DYhqZ`^rfK5@m;`9?9s(85@$esmv->h3sKyoo;LM^DlfmPI^6^fmdypyDz?@^ny_qyFt0&XMz-yIC&)wYB^F zz6mA4at#ujL0FrdOCiYRCljr3*2bT$-jWklkHuPYgm3xNZO5}u071n8KCflYQE!|7 z*m~ds;Ir$TUivC+YxiJ0qQ)+E1y#PVsNG)lA%6oOa&b}r2aD>h#NiOgjXU+BdA0wD zE;iSrvK6M{p&g+k87t$R#A6h?#&ej8Rft^_I&|M;z|4rbMvB^9kRpxJr$HMw4`~Ig zCw*{9E^u(I1vJrgCcTOP=;JSw(4bTU466BZgN%5^w>iZouD2X0$92avGXFQ#I}`9_dJw`6UZ{G$te zX1jOwc52>homBOWhPQVDob;c7KIP-EC;LcDG8jL?FvQ1fw z+#Wa-7J)>i;tV={y7A zI1vMLF+>nr31IUaG4rl)OX3Z~=m0iTrwC*lOyYE&-&7vEKP?Zg<2|B@+44yN99)8X zwvV*L`i^q#QQrVy0>OY|ZEHeHOG~3PhtB-M8!KSI`$>yQ^TmW^!zvLc znQW~e!eyj@A5qhu4=?p*z)GCd)LBvo1&6xQ^61Mj(nbKLUs6rrcXYLo>N}SXeo)s^S}a#KPxBjK zeiZKjMplVgzaEWm`J^_Y;-@!P0Phtkc5NdDLgb8)7v6*Oy>51_K%nR_JF@7{v%6L<^fmR?sPkd1uX0Y3@qaiRiX<}L(j zv;I{lB)Ie9jSD77;Xwvy<=*<4(hzWDi~~)k+sU{PNg1NSnL*&LA547p>Xo%@yi$}& zj{uM-qsVrqSjGOk8EAp^zmin0Eb&hj+tZkaG23mWjoKq&YFfrGvVN3$M4GZ%S`Oio z$i>C2&e*%@xaeZ>Opbuy&ofIqdqJW13|zR=8M{K(Pys0p8!nl}DC8uCi1z#2DCKO` tHgzr1J7CxURl|1V>;I<%@V=M99s3>u~fyncd0wzXFJ~>P7$n literal 0 HcmV?d00001 From 4677576b601b8b11facb05e3cb7c1b86cb22b0ee Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 15 May 2024 14:21:56 -0700 Subject: [PATCH 120/244] images Signed-off-by: Pawel Gadzinski --- .../tutorial_generation_gemma_with_te.ipynb | 29 ++++++++++++++++--- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb index b023d6ed04..fbabc65d4d 100644 --- a/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb +++ b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb @@ -483,14 +483,29 @@ "id": "8d3945e3", "metadata": {}, "source": [ - "## [Improvement 4] Reducing memory usage with the fp_model_init()" + "## [Improvement 4] Reducing memory usage with the fp8_model_init()" ] }, { "cell_type": "markdown", "id": "2dd0cba9", "metadata": {}, - "source": [] + "source": [ + "
\n", + "\"\"
\n", + "Fig. High precision vs FP8 vs FP8 with fp8_model_init() in TransformerEngine\n", + "
\n", + "\n", + "As we have seen above, generation in FP8 precision results results in considerable speedup. Neverthless, memory usage is no different than without FP8. The reason of that is that TransformerEngine stores parameters in higher precision and only casts them to FP8. It is also true with the optimizer state. It is needed to maintain accucacy during training. However, we can get rid of high precision weights when doing inference. \n", + "\n", + "Transformer Engine supports maintaining only FP8 copy of weights with `fp8_model_init` decorator. Let's see an example\n", + "```\n", + "with te.fp8_model_init(enabled=True):\n", + " linear = te.Linear((1024, 1024)) # this module is initialized only with fp8 weights\n", + "```\n", + "\n", + "Now we can try to use `fp8_model_init` in out code and look at the memory usage." + ] }, { "cell_type": "code", @@ -505,8 +520,6 @@ "\n", "from utils import *\n", "\n", - "from utils import *\n", - "\n", "hyperparams.model_name = \"../../../../gemma-weights\"\n", "hyperparams.fuse_qkv_params = True\n", "hyperparams.qkv_format = \"thd\"\n", @@ -532,6 +545,14 @@ "benchmark_generation(model, 64, 256, 128)" ] }, + { + "cell_type": "markdown", + "id": "3e30ca5a", + "metadata": {}, + "source": [ + "Total memory usage dropped by the **a%**! We can use it to increase batch size to obtain even larger speedup." + ] + }, { "cell_type": "markdown", "id": "c6e87275", From f1e727ab407c09dff8d0c4176f42a9994cdb8328 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 15 May 2024 15:12:12 -0700 Subject: [PATCH 121/244] images Signed-off-by: Pawel Gadzinski --- .../tutorial_generation_gemma_with_te.ipynb | 192 +++++++++++++----- 1 file changed, 145 insertions(+), 47 deletions(-) diff --git a/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb index fbabc65d4d..7973688450 100644 --- a/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb +++ b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb @@ -108,10 +108,43 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "7477e469", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.10/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.\n", + "Gemma's activation function should be approximate GeLU and not exact GeLU.\n", + "Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu_pytorch_tanh`, edit the `model.config` to set `hidden_activation=gelu_pytorch_tanh` instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.\n", + "Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 2.59it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Another string ... \n", + "\n", + "I have a new 2019 15\" MBP with 2.6 GHz i7, 16GB RAM, 512GB SSD.\n", + "\n", + "I have a 2019 27\" iMac with 3.6 GHz i5, 16GB RAM, 1TB SSD.\n", + "\n", + "I have a 2019 13\" MBP with 1.4 GHz i5, 8GB RAM\n", + "====================================================================================================\n", + "I love the idea of a “\n", + "====================================================================================================\n", + "Benchmark with context_length=128 and max_new_tokens=1024 took 8616.48 ms.\n", + "Peak GPU memoty usage: 30.96 GB\n", + "Benchmark with context_length=256 and max_new_tokens=128 took 8430.52 ms.\n", + "Peak GPU memoty usage: 31.83 GB\n" + ] + } + ], "source": [ "# Restart the notebook (to flush the GPU memory)\n", "from utils import restart_jupyter_notebook\n", @@ -126,10 +159,12 @@ "hyperparams.model_name = \"../../../../gemma-weights\" # <== Add model weight location here e.g. \"/path/to/downloaded/gemma/weights\"\n", "hyperparams.mixed_precision = \"bf16\"\n", "\n", - "model = init_te_gemma_model(hyperparams).cuda()\n", + "model = init_baseline_model(hyperparams).cuda()\n", "\n", "print_sample_of_generated_texts(model)\n", - "benchmark_generation(model)" + "\n", + "benchmark_generation(model, 64, 128, 1024)\n", + "benchmark_generation(model, 64, 256, 128)" ] }, { @@ -190,10 +225,79 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "id": "4fc5e1cd", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in TEGemmaForCausalLM is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained(\"openai/whisper-tiny\", attn_implementation=\"flash_attention_2\", torch_dtype=torch.float16)`\n", + "Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in GemmaModel is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained(\"openai/whisper-tiny\", attn_implementation=\"flash_attention_2\", torch_dtype=torch.float16)`\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", + "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n" + ] + }, + { + "ename": "AssertionError", + "evalue": "Data types for parameters must match when outside of autocasted region. Found input dtype: torch.float32 and 'layer_norm_weight' dtype: torch.bfloat16", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[4], line 15\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[38;5;66;03m# Init the model and accelerator wrapper\u001b[39;00m\n\u001b[1;32m 13\u001b[0m model \u001b[38;5;241m=\u001b[39m init_te_gemma_model(hyperparams)\u001b[38;5;241m.\u001b[39mto(torch\u001b[38;5;241m.\u001b[39mbfloat16)\u001b[38;5;241m.\u001b[39mcuda()\n\u001b[0;32m---> 15\u001b[0m \u001b[43mprint_sample_of_generated_texts\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 16\u001b[0m benchmark_generation(model, \u001b[38;5;241m64\u001b[39m, \u001b[38;5;241m128\u001b[39m, \u001b[38;5;241m1024\u001b[39m)\n\u001b[1;32m 17\u001b[0m benchmark_generation(model, \u001b[38;5;241m64\u001b[39m, \u001b[38;5;241m256\u001b[39m, \u001b[38;5;241m128\u001b[39m)\n", + "File \u001b[0;32m/perfhome/tutorial/TransformerEngine/docs/examples/te_gemma/utils.py:228\u001b[0m, in \u001b[0;36mprint_sample_of_generated_texts\u001b[0;34m(model)\u001b[0m\n\u001b[1;32m 225\u001b[0m inputs[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124minput_ids\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;241m=\u001b[39m inputs[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124minput_ids\u001b[39m\u001b[38;5;124m'\u001b[39m]\u001b[38;5;241m.\u001b[39mcuda()\n\u001b[1;32m 226\u001b[0m inputs[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mattention_mask\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;241m=\u001b[39m inputs[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mattention_mask\u001b[39m\u001b[38;5;124m'\u001b[39m]\u001b[38;5;241m.\u001b[39mcuda()\n\u001b[0;32m--> 228\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgenerate\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmax_new_tokens\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m100\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 229\u001b[0m generated_texts \u001b[38;5;241m=\u001b[39m tokenizer\u001b[38;5;241m.\u001b[39mbatch_decode(outputs, skip_special_tokens\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 230\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m text \u001b[38;5;129;01min\u001b[39;00m generated_texts[:\u001b[38;5;241m2\u001b[39m]:\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator..decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 115\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/perfhome/tutorial/TransformerEngine/docs/examples/te_gemma/te_gemma.py:257\u001b[0m, in \u001b[0;36mTEGemmaForCausalLM.generate\u001b[0;34m(self, input_ids, pad_token_id, max_new_tokens, *args, **kwargs)\u001b[0m\n\u001b[1;32m 255\u001b[0m \u001b[38;5;66;03m# Context phase\u001b[39;00m\n\u001b[1;32m 256\u001b[0m TEGemmaForCausalLM\u001b[38;5;241m.\u001b[39m_padding_to_end(input_ids, lengths)\n\u001b[0;32m--> 257\u001b[0m hidden_states, next_tokens \u001b[38;5;241m=\u001b[39m \u001b[43mTEGemmaForCausalLM\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_generate_context_phase\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 258\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 259\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 260\u001b[0m \u001b[43m \u001b[49m\u001b[43minference_params\u001b[49m\n\u001b[1;32m 261\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 263\u001b[0m \u001b[38;5;66;03m# Generation phase.\u001b[39;00m\n\u001b[1;32m 264\u001b[0m inference_params\u001b[38;5;241m.\u001b[39mthd_setup_before_new_input(next_tokens\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m1\u001b[39m))\n", + "File \u001b[0;32m/perfhome/tutorial/TransformerEngine/docs/examples/te_gemma/te_gemma.py:218\u001b[0m, in \u001b[0;36mTEGemmaForCausalLM._generate_context_phase\u001b[0;34m(self, input_ids, inference_params)\u001b[0m\n\u001b[1;32m 216\u001b[0m \u001b[38;5;66;03m#self._model_context_phase = self.record_graph(self._model_context_phase, hidden_states)\u001b[39;00m\n\u001b[1;32m 217\u001b[0m hidden_states\u001b[38;5;241m.\u001b[39mdata[:] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel\u001b[38;5;241m.\u001b[39membed_tokens(input_ids)\n\u001b[0;32m--> 218\u001b[0m logits \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_model_context_phase\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 220\u001b[0m \u001b[38;5;66;03m# We choose logits coresponding with last token in each sequence,\u001b[39;00m\n\u001b[1;32m 221\u001b[0m \u001b[38;5;66;03m# which have various lengths - they are stored in (inference_params.incoming_seq_len - 1) Tensor.\u001b[39;00m\n\u001b[1;32m 222\u001b[0m logits \u001b[38;5;241m=\u001b[39m logits[torch\u001b[38;5;241m.\u001b[39marange(logits\u001b[38;5;241m.\u001b[39msize(\u001b[38;5;241m0\u001b[39m)), inference_params\u001b[38;5;241m.\u001b[39mincoming_seq_len \u001b[38;5;241m-\u001b[39m \u001b[38;5;241m1\u001b[39m, :]\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1518\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1519\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1523\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "File \u001b[0;32m/perfhome/tutorial/TransformerEngine/docs/examples/te_gemma/te_gemma.py:80\u001b[0m, in \u001b[0;36mStaticGemmaModel.forward\u001b[0;34m(self, hidden_states)\u001b[0m\n\u001b[1;32m 78\u001b[0m hidden_states\u001b[38;5;241m.\u001b[39mdata[:] \u001b[38;5;241m=\u001b[39m hidden_states\u001b[38;5;241m.\u001b[39mdata[:] \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnormalizer \u001b[38;5;66;03m# static operation - for CUDA graphs\u001b[39;00m\n\u001b[1;32m 79\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m decoder_layer \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel\u001b[38;5;241m.\u001b[39mlayers:\n\u001b[0;32m---> 80\u001b[0m hidden_states\u001b[38;5;241m.\u001b[39mdata[:] \u001b[38;5;241m=\u001b[39m \u001b[43mdecoder_layer\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 81\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 82\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 83\u001b[0m \u001b[43m \u001b[49m\u001b[43mself_attn_mask_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 84\u001b[0m \u001b[43m \u001b[49m\u001b[43minference_params\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minference_params\u001b[49m\n\u001b[1;32m 85\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# static copy - for CUDA graphs\u001b[39;00m\n\u001b[1;32m 87\u001b[0m hidden_states\u001b[38;5;241m.\u001b[39mcopy_(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel\u001b[38;5;241m.\u001b[39mnorm(hidden_states)) \u001b[38;5;66;03m# static copy - for CUDA graphs\u001b[39;00m\n\u001b[1;32m 88\u001b[0m logits \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlm_head(hidden_states)\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1518\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1519\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1523\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "File \u001b[0;32m/perfhome/tutorial/TransformerEngine/docs/examples/te_gemma/te_gemma.py:54\u001b[0m, in \u001b[0;36mTEGemmaDecoderLayer.forward\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 53\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs): \u001b[38;5;66;03m# We need to pass positional encoding.\u001b[39;00m\n\u001b[0;32m---> 54\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrotary_pos_emb\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mte_rope_emb\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/perfhome/tutorial/TransformerEngine/transformer_engine/pytorch/transformer.py:624\u001b[0m, in \u001b[0;36mTransformerLayer.forward\u001b[0;34m(self, hidden_states, attention_mask, self_attn_mask_type, window_size, encoder_output, enc_dec_attn_mask, is_first_microbatch, checkpoint_core_attention, inference_params, rotary_pos_emb, core_attention_bias_type, core_attention_bias, alibi_slopes, fast_zero_fill)\u001b[0m\n\u001b[1;32m 618\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m cast_if_needed(\n\u001b[1;32m 619\u001b[0m hidden_states, torch\u001b[38;5;241m.\u001b[39mget_autocast_gpu_dtype()\n\u001b[1;32m 620\u001b[0m )\n\u001b[1;32m 623\u001b[0m \u001b[38;5;66;03m# Self attention.\u001b[39;00m\n\u001b[0;32m--> 624\u001b[0m self_attention_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mself_attention\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 625\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 626\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 627\u001b[0m \u001b[43m \u001b[49m\u001b[43mattn_mask_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mself_attn_mask_type\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 628\u001b[0m \u001b[43m \u001b[49m\u001b[43mwindow_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mwindow_size\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 629\u001b[0m \u001b[43m \u001b[49m\u001b[43minference_params\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minference_params\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 630\u001b[0m \u001b[43m \u001b[49m\u001b[43mis_first_microbatch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mis_first_microbatch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 631\u001b[0m \u001b[43m \u001b[49m\u001b[43mcheckpoint_core_attention\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcheckpoint_core_attention\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 632\u001b[0m \u001b[43m \u001b[49m\u001b[43mrotary_pos_emb\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrotary_pos_emb\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 633\u001b[0m \u001b[43m \u001b[49m\u001b[43mcore_attention_bias_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcore_attention_bias_type\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 634\u001b[0m \u001b[43m \u001b[49m\u001b[43mcore_attention_bias\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcore_attention_bias\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 635\u001b[0m \u001b[43m \u001b[49m\u001b[43malibi_slopes\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43malibi_slopes\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 636\u001b[0m \u001b[43m \u001b[49m\u001b[43mfast_zero_fill\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfast_zero_fill\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 637\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 640\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mapply_residual_connection_post_layernorm \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput_layernorm:\n\u001b[1;32m 641\u001b[0m attention_output, attention_bias, residual \u001b[38;5;241m=\u001b[39m self_attention_outputs\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1518\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1519\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1523\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "File \u001b[0;32m/perfhome/tutorial/TransformerEngine/transformer_engine/pytorch/attention.py:4633\u001b[0m, in \u001b[0;36mMultiheadAttention.forward\u001b[0;34m(self, hidden_states, attention_mask, encoder_output, attn_mask_type, window_size, is_first_microbatch, checkpoint_core_attention, inference_params, rotary_pos_emb, core_attention_bias_type, core_attention_bias, alibi_slopes, fast_zero_fill)\u001b[0m\n\u001b[1;32m 4630\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mattention_type \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mself\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m 4631\u001b[0m \u001b[38;5;66;03m# Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn]\u001b[39;00m\n\u001b[1;32m 4632\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minput_layernorm:\n\u001b[0;32m-> 4633\u001b[0m layernorm_qkv_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlayernorm_qkv\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 4634\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4635\u001b[0m \u001b[43m \u001b[49m\u001b[43mis_first_microbatch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mis_first_microbatch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4636\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 4637\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mreturn_layernorm_output:\n\u001b[1;32m 4638\u001b[0m mixed_x_layer, layernorm_output \u001b[38;5;241m=\u001b[39m layernorm_qkv_outputs\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1518\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1519\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1523\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py:417\u001b[0m, in \u001b[0;36m_TorchDynamoContext.__call__.._fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 415\u001b[0m dynamic_ctx\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__enter__\u001b[39m()\n\u001b[1;32m 416\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 417\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 418\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 419\u001b[0m set_eval_frame(prior)\n", + "File \u001b[0;32m/perfhome/tutorial/TransformerEngine/transformer_engine/pytorch/module/layernorm_linear.py:1153\u001b[0m, in \u001b[0;36mLayerNormLinear.forward\u001b[0;34m(self, inp, is_first_microbatch)\u001b[0m\n\u001b[1;32m 1150\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m skip_fp8_weight_update \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 1151\u001b[0m is_first_microbatch \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[0;32m-> 1153\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprepare_forward(inp, is_first_microbatch) \u001b[38;5;28;01mas\u001b[39;00m inp:\n\u001b[1;32m 1154\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfp8 \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprimary_weights_in_fp8, \\\n\u001b[1;32m 1155\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mNeed to run inside fp8_autocast region when weights are stored in FP8.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1157\u001b[0m \u001b[38;5;66;03m# Get concatenated weight and bias tensors\u001b[39;00m\n", + "File \u001b[0;32m/usr/lib/python3.10/contextlib.py:135\u001b[0m, in \u001b[0;36m_GeneratorContextManager.__enter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 133\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mkwds, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfunc\n\u001b[1;32m 134\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 135\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mnext\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgen\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 136\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mStopIteration\u001b[39;00m:\n\u001b[1;32m 137\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mgenerator didn\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt yield\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "File \u001b[0;32m/perfhome/tutorial/TransformerEngine/transformer_engine/pytorch/module/base.py:591\u001b[0m, in \u001b[0;36mTransformerEngineBaseModule.prepare_forward\u001b[0;34m(self, inp, is_first_microbatch, num_gemms, allow_non_contiguous)\u001b[0m\n\u001b[1;32m 588\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtp_size \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[1;32m 589\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtp_group_initialized, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTP group not initialized.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m--> 591\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mset_activation_dtype\u001b[49m\u001b[43m(\u001b[49m\u001b[43minp\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 592\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minit_fp8_metadata(num_gemms\u001b[38;5;241m=\u001b[39mnum_gemms)\n\u001b[1;32m 594\u001b[0m \u001b[38;5;66;03m# Create persistent tensors for fp8 weights and their transposes\u001b[39;00m\n\u001b[1;32m 595\u001b[0m \u001b[38;5;66;03m# only when fp8 weight caching is used and weights are not in fp8\u001b[39;00m\n", + "File \u001b[0;32m/perfhome/tutorial/TransformerEngine/transformer_engine/pytorch/module/base.py:443\u001b[0m, in \u001b[0;36mTransformerEngineBaseModule.set_activation_dtype\u001b[0;34m(self, inp)\u001b[0m\n\u001b[1;32m 441\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m name, param \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnamed_parameters():\n\u001b[1;32m 442\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m param \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 443\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m dtype \u001b[38;5;241m==\u001b[39m param\u001b[38;5;241m.\u001b[39mdtype, (\n\u001b[1;32m 444\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mData types for parameters must match when outside of autocasted region. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 445\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m Found input dtype: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mdtype\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m and \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m!r}\u001b[39;00m\u001b[38;5;124m dtype: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mparam\u001b[38;5;241m.\u001b[39mdtype\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 446\u001b[0m )\n\u001b[1;32m 447\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m name, buf \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnamed_buffers():\n\u001b[1;32m 448\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m buf \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", + "\u001b[0;31mAssertionError\u001b[0m: Data types for parameters must match when outside of autocasted region. Found input dtype: torch.float32 and 'layer_norm_weight' dtype: torch.bfloat16" + ] + } + ], "source": [ "# Import necessary packages and methods\n", "from utils import *\n", @@ -204,12 +308,14 @@ "hyperparams.model_name = \"../../../../gemma-weights\" # <== Add model weight location here e.g. \"/path/to/downloaded/gemma/weights\"\n", "hyperparams.mixed_precision = \"bf16\"\n", "hyperparams.fuse_qkv_params = False\n", + "hyperparams.qkv_format = \"thd\"\n", "\n", "# Init the model and accelerator wrapper\n", "model = init_te_gemma_model(hyperparams).to(torch.bfloat16).cuda()\n", "\n", "print_sample_of_generated_texts(model)\n", - "benchmark_generation(model)" + "benchmark_generation(model, 64, 128, 1024)\n", + "benchmark_generation(model, 64, 256, 128)" ] }, { @@ -290,15 +396,20 @@ "\n", "hyperparams.generation_cuda_graphs = True\n", "hyperparams.cuda_graphs_static_batch_size = 64\n", - "hyperparams.cuda_graphs_static_max_context_len=6\n", - "hyperparams.cuda_graphs_static_max_context_len=100\n", + "hyperparams.cuda_graphs_static_max_seq_len=1024\n", + "hyperparams.cuda_graphs_static_max_context_len=128\n", "model = init_te_gemma_model(hyperparams).cuda()\n", "\n", - "# Load weights of the model with the proper scaling factors.\n", - "model.load_state_dict(torch.load('model_fp8_state_dict.pth'))\n", - "\n", "print_sample_of_generated_texts(model)\n", - "benchmark_generation(model)" + "benchmark_generation(model, 64, 128, 1024)\n", + "\n", + "hyperparams.generation_cuda_graphs = True\n", + "hyperparams.cuda_graphs_static_batch_size = 64\n", + "hyperparams.cuda_graphs_static_max_seq_len=128\n", + "hyperparams.cuda_graphs_static_max_context_len=256\n", + "model = init_te_gemma_model(hyperparams).cuda()\n", + "\n", + "benchmark_generation(model, 64, 256, 128)" ] }, { @@ -374,42 +485,19 @@ "\n", "hyperparams.model_name = \"../../../../gemma-weights\"\n", "hyperparams.fuse_qkv_params = True\n", - "model = init_te_gemma_model(hyperparams, fp8_model_init=False).cuda()\n", - "model = model.to(torch.bfloat16)\n", - "accelerator = Accelerator(\n", - " log_with=\"wandb\",\n", - " gradient_accumulation_steps=hyperparams.gradient_accumulation_steps,\n", - " mixed_precision=hyperparams.mixed_precision\n", - " )\n", - "train_dataloader = get_dataloaders(accelerator, hyperparams)\n", + "hyperparams.qkv_format = \"thd\"\n", "\n", - "tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name)\n", + "model = init_te_gemma_model(hyperparams).cuda()\n", + "model = model.to(torch.bfloat16)\n", "\n", - "print(\"Calibration started\")\n", + "# Calibration\n", "with te.fp8_autocast(enabled=False, calibrating=True):\n", " model.train()\n", - " train_dataloader = enumerate(train_dataloader)\n", - "\n", - " for i in range(100):\n", - " step, batch = next(train_dataloader)\n", - " batch[\"input_ids\"] = batch[\"input_ids\"].cuda()\n", - " outputs = model.generate(\n", - " **batch,\n", - " max_new_tokens=10\n", - " )\n", - " generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n", - "print(\"calibration_finished\")\n", + " run_forward_pass(model, num_iters=100)\n", "\n", - "print(\"scale_fwd computation started\")\n", + "# Compute scale_fwd with enabled fp8 autocast\n", "with te.fp8_autocast(enabled=True):\n", - " for i in range(10):\n", - " step, batch = next(train_dataloader)\n", - " batch[\"input_ids\"] = batch[\"input_ids\"].cuda()\n", - " outputs = model.generate(\n", - " **batch,\n", - " max_new_tokens=1\n", - " )\n", - "print(\"scale_fwd_computation ended\")\n", + " run_forward_pass(model, 10)\n", "\n", "print(\"Casting weights...\")\n", "model_fp8 = init_te_gemma_model(hyperparams, fp8_model_init=True).cuda()\n", @@ -443,9 +531,6 @@ "#restart_jupyter_notebook()\n", "\n", "from utils import *\n", - "\n", - "from utils import *\n", - "\n", "hyperparams.model_name = \"../../../../gemma-weights\"\n", "hyperparams.fuse_qkv_params = True\n", "hyperparams.qkv_format = \"thd\"\n", @@ -458,10 +543,23 @@ "hyperparams.fp = True\n", "hyperparams.fp8_model_weights_filename = \"model_fp8_state_dict.pth\"\n", "hyperparams.fp8_model_init = False\n", + "\n", + "hyperparams.generation_cuda_graphs = True\n", + "hyperparams.cuda_graphs_static_batch_size = 64\n", + "hyperparams.cuda_graphs_static_max_seq_len=1024\n", + "hyperparams.cuda_graphs_static_max_context_len=128\n", "model = init_te_gemma_model(hyperparams).cuda()\n", "\n", "print_sample_of_generated_texts(model)\n", - "benchmark_generation(model)" + "benchmark_generation(model, 64, 128, 1024)\n", + "\n", + "hyperparams.generation_cuda_graphs = True\n", + "hyperparams.cuda_graphs_static_batch_size = 64\n", + "hyperparams.cuda_graphs_static_max_seq_len=128\n", + "hyperparams.cuda_graphs_static_max_context_len=256\n", + "model = init_te_gemma_model(hyperparams).cuda()\n", + "\n", + "benchmark_generation(model, 64, 256, 128)" ] }, { From 20538a54ae06453e4a75059ccab4c771f3c3ab87 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 16 May 2024 11:41:49 -0700 Subject: [PATCH 122/244] Added nice images Signed-off-by: Pawel Gadzinski --- docs/examples/te_gemma/media/calibration.png | Bin 0 -> 109690 bytes docs/examples/te_gemma/media/graphs.png | Bin 0 -> 28406 bytes .../tutorial_generation_gemma_with_te.ipynb | 2 +- 3 files changed, 1 insertion(+), 1 deletion(-) create mode 100644 docs/examples/te_gemma/media/calibration.png create mode 100644 docs/examples/te_gemma/media/graphs.png diff --git a/docs/examples/te_gemma/media/calibration.png b/docs/examples/te_gemma/media/calibration.png new file mode 100644 index 0000000000000000000000000000000000000000..b0da2fd1d348317e0575d0832655b177d64fcd69 GIT binary patch literal 109690 zcmd43Wmr~gv@VQ=0wNNEfS`aNQc9PA(n^PP2uOE#DBUO$lF}j|-7g?5f*{h}AoT*0 zlHZuFwf8=IpYOWP`FZ@Y){^z|&N-iFJY(GBzV9)E5d59gK}^j-$=b-lSz$IKHS`9aj=L9osU@+K zso;jU!;ARj%v#4{QlQ}pc%N5J2P^)4oMqDbvs3q%Mn)?q$+flof=6qm*92+)KGRRq)7?VvwA@@QvT#fc zm$QI?fLYzMLsjDx%D;;<^Ffpf`46V_;$r;r@^UwCV{7ZFMq`n`YsV*LU_i67vch-7 z@%QgRS2%X*|7%f~uaEZn|6Qd_^MAkR@5inzU;lUcmNn=VB>(x@->ZmIzbRXG{^wWb z6w;&qYkSnfieKId|7X$v$8Ygnq{YTUe!YhH84pD75eY&CfGMH(_xy8rl0(A(YaHsc{<|lV`~R-TKZ~l*`X77nKW?*4 zEj=omk``I3|FO_aZE9>$|31Wjzj#pVkN%9mtNY(?HmLPK*6!%P1wipljf}sWoB4nH z^8~kd$bDxX6|Juk-qF^5_6*z1-25B{3kyq7SeV7oR!C^5?e5C+3IV4zH69+GvekIM z>rIZ!KWY2>`!OCKkC*I}hC_%Dlxu|h`CY77UtJBQ5(#h^UR6RFYY!Fbwuh5*4?8I< z6J;viw@y7i_8=lAE?IvXP1o2KFeRtlWs4RkqPhOl3sy}#oKv4B#QDmy*7N7jC8eYc z#?sQ#Y@D5!)|C_#UhV8SWE_oKa**M|Eh*^eT#I>j>E6A22Zx8kEOE|zxo>1egPWW_ z<>$8~^E?0e@dM+%G-kEdwcEFE50CC8N!7aVTg@*loWnRiIT_Pdkd=~Z4!|WJc9NIB zeo4{qdgy4WF`k{B9R`L-z}4QP|9q^NXa2XnkKb4d?fej-NTAPj>8W}9_HB`I59Qw# zNxnEWQfiD(LegGwPdffXG(C}hQZ0V#6+bl+8G+GVNQdJ!jyRm0obvQUHa0dsd3lT| z6v|Ed=MY&jek;G9!{_vL>dF1Xy8X$%r{{@@iGT0?*>5=WRK17Ylgu9|lgO2ZLM82Q z$H(7fQY-op-5ug^Nu|>$=H0vLJ~^TC;jOkXQeWd#&)EIaz8pD&u?oxdNUE4;GxPHo zNk~YpQa-xb9Ec|&D(W{9GgDH+c8k~UqH#}(n3$M<_l6-GDXayn_XK~qZfOKOULcJ`MA152h}fVljCJ*<(CGn9;k{M<4OCxRFh3GN zG^A22j+LT%a#mB{IyxHP&`3ljB!mP6U?xxT^YT(XdUR`l+#c_>ukUw0F3A3$X*oG> z2pw;VR5AuWZ7y#7Rg@#6@Q)GrWn3@z2j^x+Umw(;Bsf>EYoq>X0%KES&b`dm{I_GTE6s(H(q$ z!fGNvePU2nR({M7qyOW>lVndmI8~C7_&0BU4JGMG6|#0pe|Z0Xuyfe_{L4}Te})+B zr%#{4o?&1Z8yhohj0Mli{v;-{Clnq{zgJ;>Kjzu--k3Fy>+j6o6EZR~o4qylBms9~ z!h5nUQV;4JRKwykczmz_BwU}YON_g_v%A~+^$m@lfr0O<3zwZ{-5bw!Sd z$Hc^_)jBb=XxI30PkHp2Y~)SUIvX7Aty^f)MeFr{RDfdd6B>&1?iuZ~XU}YAnl8PI zjXiIoP5yHEXV$@dwA}Q z4}IS9a{BYzu(X@%-n~FttaEoBJsMwOm;L&ZEnl-T!g9Zil!YG6v2N{sM85bvvArU< ziAiuV`{t{tz162Q*tIyt>zc5Ty51+8*4hss9K*=DuU)u&&Cl2O0vS2Eg1o#>X6D`4 z;p5f926;H(c!CEXG+8PQWGzH(F{uP3TwHiIYNsetbz^_@_qSvZUe<_PRxf$^G*ZpU zCuJnIS8pmzAwyhmZ*9c#)E&M)`smSaztT+sx3b15Zcfg$iV9*R|G@cl#kmHVVlGbC zt8wb=)Va>v9WvV5aIvrD{^}6kM%_Los;@mu`D( zzUtWjaUG@9S$ycjiCaS7HYGjfYdy(Ae($Qotw(H-K-SEtotq(S2~YxREh}oo1I^7Z`1|`2ITVw zEZn=|Sp~8`Kkl*U{8dC21)&mfjlyhT$i~reAhC)N8w;z^7f&$g;J_tcy%ZDn1H)#e z|GsoG-wnureSZ>RLpI%F>UReScBfzxwIPu5q_xn(`(aZ6ra zo&k+Kk|Xq8X~e@`1Pgy~w8XS0@Yk@ivikVxlYYC2RJuI%n>*Ch!u?bD&GhKIi|O%+ zQ4s^t5=2f*d(7eDzP^|7N-Xj4@K9TAI8;)^)qYp|2ut@xGRzbj@u4 zuJT0;0w6O7pCw}u0#=rfurtHB!fs4Kyz?iIz%#B@VvhVBd@EA2U&b)eT?MIn>SamuvlwfFflQS!vFoOT^j&b=iuaQY;C<%=e{p)WAl)j zn)*&$rwnB~!H_%?L@g}OOGCpMNO2$Y@_ZF@{gRU@;4x+Z=@b<`RG_$1_36_mo2^-~ zm&V2!pFV{Y{28dU*7fuAdjaX&NM-`Sj6al>`K6_~z4fuR>b3cKKd+NRz2PFgo4j_J zt=VA@_bmD<@%9DJ6)2G-b;Qun5Kh!Eo;UEdyIWjCgAD48ytXzjtbk6P zTjSUmxt^Zhm220y{xB)$4In6CjPjAoLx2$Br!sIiZ=<8%#KpB|XD|2V;JtqR`XKUqFIReEqo~Wo0D=I13ZZFEf#y+Ng^93rDNZ_^B zD%-gqu!OzseUoEhVONQ6+(42!+~&U!D&Q*VYXBGt+LYvP;gV8{a;bt7`MtM!d3hzD zJ!>k^tVC`c4rx_wEe#-#4DoO_jc%+}8hm{G&7B>;XIV`h(f8iGefxT4#jL8jdS-V> zm)j+OuMN{Oyt=jJiZFC!fImgx{*d_5)WiWS2^EHz{L>{0CTud)1A%I(q>RhF$q@}b{nuYp92^|eg$<|rW20#RT6{J z*m6IAz9tcIs}ah>2YLGIH*W^zqdL#3bott? zpv_Gig@w;qpVQRf@S2*MR=MpmSWVRUx3!6vm6d&IY&=IvNqG}C62jgk)s!_PBqRzJ z2C9ErO$`|{Gc!EU=I*Z8%a?Zp@NQcTZeio#AP+xYWqSd@3t*(8@^XJImi8~NF5r`q zg-lPsD2*`v3K)ENr{l4?r}V+e>c@jT=!D_oNK)9Sg#*30zu$haHRrm&&=O3z=1NLL z)St~K?CB{$KuCD*m*81xNy#w1!{f^Dp`9B0c>=r4Qjbbwp3$X+fBm}8k00`({r0%k ztAze+@2^6ji5VHu`3&ASCBnY65rDL|~p}fBJ+CwaI9#ux^jcou{R>wU`QZdU}cq`5nuwxkLB3e3jiS z{1)%e>C5VL(P@yh8qQAnWs~``Dt{sg++wPpVtes>Gfm_p{Y_7kiF!}+L~d(wMa8#` z&%Od=Iaju~=6$%F1NZ%6(L<~gOlb$YX#qDZ$w3U>akCIh@EZA6@ zoSd|jC7P_V^>xmNJ>Nl3ki#B|d3c~ESKr3R2MrIabtUtYlp1|!b6Gn%+Ar;C*;(!b z02~BcwKBeM+Lyg7V`pe&q~Dh<0|Ue)9BjTvT|9Cn>Y}COZ7g*f`j<_-OC?{eQS-7NBNf#-%?gXgCVw1O$ebBV@_=<6R|kv5r9VUL z+z+j>`<@n3e=IDRAnW3y?Kh{Nz+zv9=is4}NqA`^6!WY#^a7>ZFDl;{ zHRzn7^62-Z2*@+U@HGfQX4N+^0ElFLho1h5tE(%@KtfzRL7GT|da_h>{65rl;=w{) zYnzsqKif;imSbFDw}mKs1dsANa@|?9s(g@@g-V0D&+ER&#AQ8+z~q^%I32dya}ra$ zZ4j#K_m#5a^CO3cho|9#2Qz`UZPrIQ>yBF5+cRK`U6!-ru6msgm~n2f|AzbcGBH8X z8TVl5y`fmH@``dS^9}Mc&P8HE8k74(*lUw@{VSr54(7KiIS`N#IQx0`=vNkPeh^1Z=Zqu+{+k!_Iid97ygO15X8o`$A##ouMqL zxa-gwL+H=K>NWvt3a2at)Q_C>l9H0EXnOz!Ya^wXp$#}VnmU_AcW*=MjdI`I@IKSL zr!b7arK+e%0EZway6A#N^a*5uw&OqB_=JSMfE15^cRpC#hy0?ENZ^0t#tkwl|X0Sy@@*1>7qdohJACJ)x(* zPEH;Q!K^c@avNHqt^`gTJfY*F=h7fDS!6$xsHuacDC8L&pYgl@+#&MO?hnN%BDQ(z;@*4=TH8kl%=S* zI#__3Y&g`abGs2r%-mFMzvS}dwHAxSRK4fS%E~u>Co3zKSQhPXtG0uL!y_YG6;MI9 zn(>6D;eeo4sZrgK0D{89TOo^zQq)4;9Qv%?3^+|Rc?cjxk;Mp`&Rh90%}nd5dQ^Zo zfD?A(?iao3V&Z{s?0a(3(#};^S0BuVK5)4dP#+>KD|-zt76J$~W61-0StgCLpt!s5 zB<~`C`*3f)Q>qDCMLoC%@;hOeZtK+n`V_#DK-XF>`xXx53iI94($Zp=E&Yt1P}mg@ zNXo!UG)&&5QhomXhS$N&<&2HI>}*VEz}s?_a$Im}-~}?dkXSErHG# zxlmg|eM&uHk*=MLah(n29XY`AKPtRdBIsZ^f12?Z$YJ^*1J@r}XoRfpY%NtSOzCB~othfqifG-?R zKSd%oVn2Vm>pi<#g~j`hc+HI7JJ3+TF0zqHA$bwGRauIr$w^APYOS0tPF`M8_@A>f zxI+pMVMQ-H2w-dIqL;^W;EO#VgQ6x!w3-@Bt*qJrt%)+X{unU0@EPznlXeY}t*tEz zQyxl#i@SUC#6+^QcV}m3vD>a$zHYB8TxhMZj){xguJ}%>(9>!;^{u;`i?>Eu%?eO=qV#K< zu@(otjnrn?uNI@4pG)wYYGo#G z8t+mKm75Q3cHB$7S#aCX*mw>~*x*w>QwcNd!smK=RDyrz2{WEDNEJIBJN}-3>T!0m z&-e5`oIV{S;#Y?X+u@rW(HWb^e6Kk-;<8pX7v*!650-OMH#;9RY@!2|8N^ZrJ)tVK znlCBQ6Ac4-3GxHG?}D7>svw|Ogc-7dZ-xy4SUowZbB}cR#}6ONv5FaZ!t!rxSDjCS zgfT2!>SOCCi2xrr^=3$%D>Fhtn=iBXdx48(WO%rY_Yd){TUX&|s}va-8j`N3rQwF$ zx^+t}9axWMlSpT)bems4($YydPBxqIdVl_WU25E;55c&$v{8S=V*O3hU}y!hB0yMR4CM7TczgTey2SS9$ldnXY9+z&0eGVl^=-ipIpFIb{Tm z7teSFaHB1PB6PCOT_jEZmo#|-z)slGk|Ah8zCcb#`lm_bOvv^S$ZPegDk`&ih4oE< zzIW*it2WB(>nQ;q1}Nr=LPu!?OlqU=WrTpk#l?l?Xoj4x?*9=w2I%$K1rMe#)D0Cw zZ>Q_Idjsd{OEywUuLC_Gc6?>D(~`v~802dHjqGf?ag-||)Jq1JkLz<$Zf}@ zrKRP?=5(WtgM+A{A+4gKVi6#TJ%MqkhQJE?6dSa!+}eQlLfF*w{@671K}WFdXbGiZE40uE(_3R29Ag*v=;El0K(xWhL?8XnzNpAJC{ECU2iUX%D29c z{uaY1o^$OvJ<-?i-wnq`bCsdJRC_(u&*y*X(j|*S;V+wI8TaH;T3}U2Dq!apdr~8T zE-2nsVu+EHk!gjCiR&cF75e1>;LQNQ*#Iwy zbj9lG>PW6cH=KD*vbH>mWcsIPrRC-($x|uNF9w)_W7Lxp4)~-Iu#s6g8NxI||7Hu& z9-}}Ys^JLD&!k==x}*VxFh-SG7YN{7y`ak$|fZ;sF?}q8?>Hj)R}I&FwP% zY2nBr`w11|#r}P0H0-Kx!9n@+=g)AtIZ58!M`Z@2noN1_Cr{#MRa8_kd6}7*e1WT; zE8`lCxXF56R#w)6>;a2Z8|GkGH76NvJ=}%PQssvacOLBx8{FiwxM|J=4I7X5=_ANs zU2hu#@ooo1M&hCC_H~lyu*okXQ2+!e4&Yr~Mzftd3NWqCYp+m%SYt6!aov z0%?|k#qqU#2!$};D$8=}r<31p;6NIOrYo%9_gxzr7_IEmE8xP{`x74*7YHR)7#Kw+ zrmN@9opYOeY##Gw&|>Ml4;(~lUf%bXcocj@+eS65b|qN03YRX+dNA;w1#mEf79`B#p>FNXHI|??bQ2@a_rd1_gWuuQ}%mfBsa|(flY%-6*g>?h6^} z;NakKM(>BRGNd-HxnUbWH(&i#4Ed=6 zOPxWk(PZd?Gytoy6#fYKnXlizH35Vge*9`}&EoiYJFea;_dmP}V2e$pR&dE-ef+UqZA5vKOK*wd~(hkT2 zP>qd+rYbDSi67`#>|C)*@5e z&5e(Zo!xPWt*!Rks~5?D2Sg}ZgenUPy27Mee+%dm=prPgrCa7ZqFs;9fb>FP z&jJDQ8p=viWh7V_jJo1qefjbjDWZ-x!&SC4z)82__~78+xYFPsA0Ka`Yd2ONscLI! z*+3df({1nqs7UHF=w9l5=5;)G`^;pOmcak$PqoJ%HmEzcAg#<+FQv?#$kIhq@$lTd z?R7vk>2+i|$atTYR&2vTsI|2G=_9OKu_sX23Mwi_MM{AW-+bIC2fEJ1un>~G;OT)u z^5h1@UWP;@!s^B{slNsiAHyWs&OgzCEl{Jt$eiVT!B)r7Mng|x#KhnrF$f}mk@rBJ zDh+g5&)v$T1tOVlY?O1OIm8N7w)Jm7d4sI(N74+ta`EED>4M71CV(R&rH_P!ZUZyk z1c%7AeZ128cHxs_Sl=Oa>+!ZWOdu2vK-CS0$0|{y1K14wxTRGXXPicXI22M$TrxkP z2I5tno@OZ$ujZ>2zb9%*$43TPP7US4YJOz`<;va`dAo=4=-k9cAAl`&efuvy z=TIlmmTy7I5&@kNo_x(YI5-$Y_~h=tW`sei3Qu#m_QC1V#<{XsU4TWL2~3XwQdnpH z0VFW$5dpt6=V;c`)x{A{-__l1+IPRqD z5=HWcQiIpYW(%R*m-hC+@Q>ye7A+N)W1EMoh1+`qjw?jg6=9btT^?^FgaL)7R`lZB zEeZgnOJ`Hivg%hc5EhU4Ej zg@n2F+k{ed(WGInuvC_t&Vx0&vH8~J3-7g+m0J@1fmykH_NT|&fo}j8(=;Ttufb^OJbg6%#<=QQt7h9q; z+CsY~c6MwJid!s?XMXSOv@jN;nyf6qyn(C>!s4J%w}9`3GKTQI2nTd_yd(rE^E}2? zGL9=fkG7?2yQ7cJq3jky8owmeUe9f45K6pP`{Bcv>G4s7xtXb@;9dU+c@x(06mFbM zdrX`svZ}s(9n=n&<`du7uch~jZnNrNv=(@aLOqIZPN^-0Zl8?e)eSLq0)cH3u2S|44Ur(_YK-Nt3pdVh%wjZCvLT z6mk)u2a-uruT!TyRA^s6zSH6#Q10-YKZtD!aHH?(sC=Feng7GzEi zfzcvg|FsV%V%zPT9l5E`Ae)b-imRL3%3Z8#V9gLlak7W6CO=y8u3D8%TJLUtit8e& zP6lb^>(18oTkk+lj*7b8l`1r4A|l}&f?LI!5*uvn3G)bXz!9laxk z1w%h>jXQO*8nngV^5in^zIBz9?IM_l;`tmubvuMOIy#~8Pq2+qhv`-O1AS*#XA-hY1G&yII`(gZtdz(4vdxjwQZv!_dx6Y{q?{Qf(Pn ztJb5E(BuH~(K1sMXH%I#t~(hJ847jQ=@}U#2@-$-fGtP01PZbnb^TDsc@jiEZyGp~ zNFDPGpMv81U=_5$PsGJBH;z8^fG5e^?i3BgKyz>Jn@bN4vm@uYP~Tp%MPH)%aX6hK z=*h>w1oc??`SbU;mZ=uW5`1gDK0JKPK3_hh!+uZ;Z02vbH{uc!iAj_OAQwr>$|7wR zY^HTG)ra!e9MJS5x^qp9Mn6Y7Abh5SPkDJF1_l|eZ}arKl6d>(kB9|Db_EKnGH4+w z?+PYXBT5bhMR@c*xey_B;h<~Y-rlg1MdBRAPX+ZvGy#>ZzDiC=_*5wktO%;@H3WNI zU!hGW86e2yl<~uV9Z|*t<#Tg$Vlw1(cctH9OT5K_!V6J%8Jat`qjD(u>V(`X#igZl zj&yzPz@|b*fr!Thg<*PS1-o41JDIVCg+-CkcQSZ*(8)CV%M?dty)^WPB1+!DHKS?P z$H&L7laYPU?&Vyqupm)X|PDwnPF#*+m* zUP{j$1qDZaB@pt4`vXTEv3EV`$}FEFQdhv$SE;GDfByVwgNWz@1Iw1<3l-M>0Rc`2 zcDUNKlf}x3K~d6lBGS@fy`Fp}VP(qvs?hd7uuA2!Oz53|H3+z-*l|^bD_T)l_`G>2 zr)U=^2MPDco$M0{y673OqH!J^!P)v9!3&!j-40|8g8yb_XU~Id3Tw1ATMQsJvekKa zcbCq1!rsmfX{QqDsxu7qrCN2tI)gIj1xh8_a2mh`^l^90KKPZZt$L}~)Eiz47jIs@ zj&ME-j@IDk1LY}fpzeWoO{VkWLn7Q1uqp6bzOMxIHl^$%nSu(Stx0viLr5eVv4YCTT*RA zl+$Nx)++O`bJMgUZf_wH6jEEE{hqAYqp}W*Pmv#YTrb<63k9L{+gq#0t0$~6^c;hl zx}5+L0Zti&X98^E!*fF`lfrpOCv=62tR+Dw|NHmvhrroncpG#~zQk~ei3W0M7Ni2Y zXoa>b?rp*Mr4g;SN9b()mmtsvAd6N8d}M7>K-L>J7M?1T%j)41Ohyuu;hS`V6)T7!S$@O z0qXsGo;;bRoCZBL_3IcVZtF==kO)R7WUFXWC~n;v-uVP2gB?>bw|^}pFKbr7d{$O- zZe!zt*+fm8j$6mC4PbNO$*A10&VxfUShPuetbb@I^pIZ)pI>YY9Dc6JI-o;`afOYJ zjfI9#_@MF0ZRRu{ABQo)$pA;(i%ww>dE+hCu4@X->0B;OZn;vNo$@HXD5uPN!hNF- zVfPV4I4Crj19uBb2p2$L>o7V`By&0p)wfkFl7pX)#gcxVX9& zZ)}34fYox81FClA_x8Dpv*FRv1Zl7H1swyRt&$$sIbn<0$4X7NaZykZYBHP>bU&m$ zW?GqBUdDptR7a!CA5xBv<6P0y2gDv77shTZ-U^EKUJ@|BkbDAqH~iykBIDvntldE> zyn>BgvZ9fc?Vy(akegez|3Y6Bd^q}5KZGxm9DEyqDt!VvX(GY9&Yqs0UnXQ9%Hw*V zadfG(01o$|%jOHnoR4#r5fPT2nORIki(*wPo|9S=sVmXfzAoX6&;p@LIeK z4O~g2E4p``|JhsqWHOA58<@;W0zHfj48!NYVuW3rAR*wIOsElB%;%-?ppSd^t|(;{ z?6QbquHC3_2F(khjj!Go9A8U`8Q&oJ z;I5nhv_M59OYup8%4*oNwXfosh@0`r+lRlF=&*eTms6?*0$X%Cq1|a598_6y@w>ha z01fiSO6Xmfw*8SYuMprr@Q|-@%e#Trsz~?j*~WlaC*p1+kD$$f_~!-&a>x zC-yvGeSLjqW23N?*zbO9$bLRha{_+dzognlVNb_X8ckeRd zw1p;jqz4Mh^y+G-`1>2g#OctTWMyTk)wyvZIyf*;lyhz2>|58ux53^5nbq9RjTr%qAu`TB<1CmGS(thT}uv2Orr5wy)O=6iyJ zyOY2l=y?-UQ&v99nqjrhTfM!#786@=SP&Mb3dsG3wzDl<-{afr0-+m4HcwcIb7rnWd<*qv{EZrz=0**uXemYzsyzqk8mI>k;*bhIy2 z1w2sjTAe9n|2wA$%&We$?i~k1WwsrcORFQEn&oA5b!eCSFg4J zGgbry2}Xhc&Ur;4OA!1EKTUrCe2C-K_MifafKUW@ z8tg}*v9Yl%eNiuhbsBuo?BWN!94+15*T6Ke7;Z=H7ykgVuc-BxjE zw@^AM&{*;X3!}RT0gV3OAAP1a2(}i<9XFXXAbj?iO;vluNY|BPo{3x z8ht`jah-n1z4g@j=g8Cy`b)yPCY1C85U|yk#DENY1fwx-{I?e7=U+gaBF_fx0wSHF zK;%x(rskl7N1!#|ar9fNur?jyn0N9Us+6VLHwJeq2V+I?yZ?;Up)#6INvU#stR6r@ z-o%bt7#6(eKX95VK7VewgUa8k_SkH^45r35@kfwGsKAf_WGYe)goTCGjw0iSmfBhu z!{`a~Ky6v-NtN*O6118`1w&{<1#-@fkc|y#05Nm;3=X&0?aGX3lQy*#FZJ4b!)NZJ z+5SDON|t^4sqUb*vH>61)m;FMHlo*ob(9szZ*cwH0baAQv5`Wt9&&RLIMdcwIQ2Tj z9e}w+9GrLWUIZh2q%0npr`Pe8MC@G$JG&-mEA@*4gQ0EI@!Wk5VC!hxc|DepbvytST@spmRZpn|>D41t^|;ZWZt_S=&mZr@-2|EIrYu4eYIs9G`vQ zvfciW2}k-O8QQAreFP9R5trn%vm3A57x*`CjM;gY;ihM2E3)l-+Sv+;G+_>d-mH+p zxXbF;JxHc<3i@GFXD3c^aWU%XqdYxW!kV(Lu7a}+>MM0jUTc%5x@`Ns0_H`A=lB6t|2tA9BRF z0%DgU0?)J_s3U962x-LD{6msaPrWZiCu32T^P@_hYeu3>fkE?t_MDi3-+oils@^vs z(u%Am_;f@RRcR~EGoS=xgGkD2L@d6Xpn$(FbsNPg^v~U}ayQPR$}HWPpLcY&nu6Y< z8H6PDs}Em(z0RbA8lJZ21F`^-?MZY!8b@Qij@G#^UcMY^CdGaiOoO0}V7@@npq#Rq zK9tj-fFA*7RzXWE3_OE6msLRBFb<vQ?wXYD@%#fyQ zK0Dnw(}$Sn@jB)yWETgejevka&A0_fY{ZBQN&;uqHJW_^wUW9O;PwxIG{o(88H!P3 zIGyzx?5d8{AN*RzB{lWHe6Vu#_u--k$#+k#;JTrI2~#A5l2^5qmHi+tkm(L!*dX{o z(P{(nCZFqDWix?4_%>Y@VpQ`F$>aof-}Nq^08;?{)C^>`K@vgWS%(srLGHQ1c~!gi z;A4dXw3eT>gN31XBFKo68VEgjg20!?)4cKX(GERRby#>v-?dct!)8+ zR%yl??#ZqjcZ>MB3rO$Uk!))d=*Zb`;O&pgOpx10FptNL-8Ae7xUUp$JZOPYvy|fg z5E|wrK~FN^9jlzy6J+rw!LTq5x(Z_ELYQa}0yT0jGgPvUC@yJ2=p*#${pfNyYJqT) zn!bO>1(F+a9~KlagE9l#GeU+v5wpo|(d}HCapAtO+o!KzC1dF=z-SDfDRSQU=;&s$ zr|{9$Ru&>k%I@rojP&%=-$%bty!q%U;K`&u7@tU=X^fw0!|Ey(Wj29Ghl$K3Mk+V=t-1tdJh={uc>cj}AKWe7H2}yvXP1P)Aq{M?2J=ryf`3e7uxD6RVqD^S5DWTE~lOIlYTqO@CA?+4aZ1HNtqPB z%KQc$mo$k|{HPX8M3AFFwJN&#hJXDdxQiNhar_LwBdLN+d-coBVk}DZ>3I2OS1@6f z(ytw&CbiZi<5DWnN)WDKq(YgYb?S7Fd))i%L_E3x_yu&CXc`>+8<)vL(?&CDTQJAX zU!sl0s=Vd$ov-U-gFc&3K3d75gZg^9v9RC|hP9BxS`ZfXK?gwGoFB&Se?vHjo8ZI@ zeGd(V4{w(I*F$!0||wYJdBO_KGG@=z|}v--{ePV$HEErTL8s2Luf00ic@b=ujS^ z7Z(<`-tHqf7(^&=L1~9B*|{mF|Inzyi3IOe0%t>oP5szdrFA8v!c9ub2=IlUpO)+S zVi`q%sW_TTR#I9Dwt&~5Z6j`3Q1Fezy8#*!P~35P4kHHn#!1z9%2so=6%wNEjyo56 zh>B@&`_0PLgLzr~`kutwdjsLQ;UWEf$+9~LTTkwRRT}Lh<1EDyn(PcmdwUTF2M!?J z&XuWrL^M*wZSsMZot+R-gMn)rnX)ac1rNN4h)C_=S{}H<02&TEAzDS_<=gS7w%BKi zPIUG;Ow(-6$ppedUeG-RLie{?@OjKpTZL08oL9g#Hi+rqtU(pNZidswKQNGd zO@W@tT$$l{U<@Zg@I~0!wQ95O(y#pqS0~Zs{|Ux87a-wcij!wU_g(T?@IXNEbS z)gh&Q)$AF+u6NZi*rk}i6?6QiWPwWJ4H7K=6E!p3m%=p(NSDn&CmU{sCR2kV#H3&n zyj(^@O|7PhAQ8|91u>3A(`+VMU@j-NWjf-{e48~CpyWyQN_O>6GMNSc&`WVebr*`d ziP9(vyw6U5O*IJV>+28y#+HbrT#zZKf$_>PN`Zu-`<){r?+fl32m3t-qhy}*xi6Vd zqc=y*3*R5}D`l|PUa&ff6i)l`-e@EWS`2ldi!y*Nn%Qniia2QU{QiD#SPRT}SFT
Al@)uYb<E2W{OY)A0|IQ$*@b%M$+HpfvQZUvNc_nl0|}%%H6q)QI_I)Fgy9rSZwXPF33n z?pJdf+p3B{xo~+^!)tyrQR5iuESQa!Gc?TlrNcL>P^B@xzcMq^=ssNG=-Kqy3MLW} z;RykhIOCgb~kcpZ)PMia&qjWr)p|&CJq|DDd3uer8Ytz3ioS` z|L_7>edd(SQF3*0+0?%wwWG7>G>?G1Y}sV9-`0G!NrXp-ZTcT8R%DA`*^pw}*Vi`n z#8%_(`t4lgNH5whGp;c}dn)A(tk*}5^Dy%BlfzKe_n#PLbY^9%gl4uwOiawohV1h{ z4*kwjvqnWnKm7yV9i5~n@u3ViH8pK^^KVR#7^fNv0a)c*7aCvH(a;!0Z}4%`VQR#! z5cnqMuY9he*#sjK3{HrhXuA)7wtiqf@xZ*lva2`6>3etB6wHy~-?-rqec`*xB$yqJ zX*GVHCYp8y81l8VXIVe3^09%Dk#TN9$#W<&@v%bQK)M@?JvmwJ`QjxdAQZ=lF{Lf3qB{_uONA$xg?eAL z0d?;>mMvay#z8jN+(d+ScXV)==ZuGmDNx#iESDA*9L?~&&9}F==Qv*wlaOqq-)y?u zE_}P8&^`f1V$dQ3kQuxbUT5_9%zAs>U;sT)h?AUc-wl{d*>Zar;;Dk=%PA~*WJBlz zb|4@r5Em51YfzA+V|m99ax*h8!Ym-_S+XcKE13ZrO z-yFgkQ*HG~v6`w62W`&)96Jacw007-3WHZ9JYeW}o18pB!}kow;3ez2P!p*cUjoo( zh(nT$&w;*h+-$Yz`@qJ>}%b;8kGBzlHaK--&DDV%}V}FYN~a~g%Gg0fYq?T zQ^RlkIoRY}lp0|yZJWRmaxfhU)&9=TyZ}GY?=I~?qfR^W5SCO{F2$IGf=3`wCkNdy zv?rs_UoGzaRE6MY$b|0M(o4X+SoydmycjH2qS}#QStJCu`3poY5XqRy-G&MG} z-IDEUGYrgFS&6>rCVm6SNOjpdFCq+g)Lq-Z2d{WH0*3!}1YYwqypKw08+MKtCvu`h zZ`V{+spZ($Qb7=+UcRPHIl^c-p}lvnecnPHLr5VW1WWj}!4Xr)k>W!?%tFxzuI0*! zl9~=LNlkN|J_DDxHF^SbgEEv#tDuloIV>yYFbjd?;1Avym`xbVpC{n{LvtL4GbYwr zcBWP1fCZgh8_Wk4d!4xHw4Xzb{M>xXi=kEhf-?L$0W=3Wk)$c4oW3~Exkhm#; z!+!BkIU*vtxVe?AL(Ifmc&H*`);^5kA*N8iBP@6F_hci;bit&@pEjG@HNf(#;0VKj z6ygtXNQEi~ja7nWZEfwU^lq;TbeRivL*R^y_b3Dy3Zp4{WAMd9H+F7A{Ja!d1+%9M zuA)JE#K4I0w|DnZE$$W;`5*6t-DCI_oWi;F%BeQsT*qe4%|S0y3~OvLTy&0{%i=o4 z3mbPXqLO%x3YjKx>Xw0W)Vgnj>Xg!NfCCanqQoBzJl^O8K z$<2RTLpeaqde8GzWl=ih)Ol7D-mZ8+e62=HZn4Tvg2oCd+R7^ zfoI;185>reoRsv?YMiG6k_YprGy@nKK&=B6QugQ8mF1P1qn~sFH!9Z<&V}3=oQW-r zg#N6enzrQO{)SO$dj)`pROEU_MFtl8?La_ zlt@Z=GKfNKE2i(ayA}uQZeU;9+nBPp(Jg?`*h0Ujoppv`C9zOqz1EN$L7}1Se{S_b zRi2~e_Me*AONWw+K#d!A5Myf9c3Hj4?-8FX3@esc28+IisW#1d54YVF#0!1F2%vc3xtGp0Im;Q%rEM`gfVJjsc}MK4O{YRJ9~R% z2DhTMyW)!sQHp?jEQxARa4|U9JM$z{AGs31^$Kz&<}bY zF5jE%K^pwfq|kDM3sUXl;A3A1k4K`XaY!X^U>MA6g_B%Bdnokf5&R@{Y{%CSXa#SE z3ziS8I1LL+_b-9q$jGIE)QRzNm+*phL|PqvQP%h6LP3v(Wo=v(_;~V2kQb&vby#y2 z$nwk+U>c$OsX_rWvn9!^KzuFWrXMWC>4rsQGBPo>z;KRQl8ITzIrEve&mj36Ku;6K z5|`O)Qe@aktn*u9_!L~Jo5Kc?#p{r0!C9$g-2{`jFBZPVpq8~a;AwQm@BzO1fY!3= zxyS*oB6t%AZ~^|6mE7p+CCWX70w^6+-e3R|flE&(GF1fAGHyc*FqYm@V%Q09!)RR9 zn(WPYZD7)#I^X4PgtMvvm7kT5W*Qm!gKAR&{B$^?D~{_jym?nBF<>${<(9(F5{9Zyd`H3J%NLP3Mr;Ne}CrC5boxQf&xDHF%*=Pd|?hA z-XT+dkbq=z_-x^DW`okz?ZFrF@M zt@8Z&3@{tWL@+|4gZ&&4xIn&uN%R0PL&3;HF!X7`hKIPGBmP5Z@M5*WXQ|>k!TT8i zrn41k?=fQg#jZx6P?2GWBtA!Mc%#NE=$Cy4|9q88;%Q728HW*e7>ir?(*6v5I{}<( zCLi!M(*tjacuK(f^B5*|VI~8l3vkJb!ol9K2B2f|o3m{7)?GTf*$L+OfgV1@XsM(ig#8Hk6Be`+V3T zF22hA{Hw@h4$`F~)@JB=K>2M2;z8}sPJ|n{`YtQHj-J2Kn?1={V_42t$ZO6HW-d3GGRv?w>DOZ2kl2Z%p4!?Z{Wf6 z!%KUR_j167%!ez3*_9fW^v)jhste5FA!crs#Zy>&% zo`IeB(|AHYsk7jA7dym~5vw=YZ9x+Sdif^Iorx5v(139lk`#)EmX0o4OWYik;wrlZY#6S^fVWP;3v&=bYq-pmuyzt83xu>sm|1wPP#LVk z*6=`399f;wS}kpDv+(kjFCbL|^1}r8FeC#S7$N~?d5M!xTue;e;u9MJE+VSS_g8_f z4FE_9-4ijh22w_-L56@Whmj3n>6&1hT$J`)-QDN>NpwuM%c(PKGeq^4-gLYKs~?r% z*|9yboW&%H6q$qp&hqa`;^ln;??Qo*0l-LO+4qoE{KDrr4 zMJ#M=woy}Fdio&doDA(*VV(jlO~;tMf{GFc1Zpdj{(pSE2UJw)5;clC<}^A4MZr-K z0YOxf5>0>&Dmf<=$w`TlF^@!Pf(i;qh9*kR7ytnQi7i=0at;jw0&gFjd+-0R_txvR z=ITZIgztP+wQJX|IwqC&$8q_EQ&quGvq%vJ)ql6x+K>(Cmp<~e`#mH<+bC~ z=?T%ixiFi4k-tva^t;u~hrMD(A4yxb1qB-nf=Dy~UWEhkeDb6o^SiDqB*!~W_sz4% z2-^)sL*lSEMqduS_~=F%%x{j5699$Fba$7wRV~aJ)QoE> zR(54vVGjs%8saRPZ6h?GL9OO3?#t5$EIM)(a=P6L5|tAou}i@W4?;km7SAy-GLp3G zu$20?d%3%y`l`pqWjJtvH>1_ug|XYd-6ByT3gccn0s{h~Fnb9_eiKNc7`xwt^;aA? zbu{K&WZHAMWdTUBmY~-?*}{Sr+7`g;t*`~wTJwH1_}6xa>NqGK}i zvmh=y7b4pmE_AWc(Z-`~S&k5uv<~`9vnrsol-rQZz|Z#;g%T=eW9fJ7zbq@8@m#}p z)H{2hLXnf57zP0?3`F_S@7kqcJg*F8n9l);t|7i}(CP65^w&H;`TlvG0Eoi|WjNzn021q%?F@pjv)>@Tq%bH`;G@1%_aM|_~2|qJ7 zAF1A_P35XwRKPXS0MJaK#VTh?B7ha=TiYTa@*y}(-r1fXaY!hgW?PN2uDxB`^XNL5 zm|ki|pP)Nstk3+`vmEDY+kiwg}6#Y2ijXH@l``LO5DjkVSK;;<%81Fs)XjaYq- zxVr&SQlZ5H?P2r2SJBg-NH}Q>da@g8A|XpW;j6P$tPppV+6tq`*zu65*!JI^MaOe_3bTwL4*W>7^g zxPOzo3ieeJ`E>lHxg44IZF3i8D5BzrlU>6pXpMFH_K=AmbqM{>D=3k8FP6^VwzX}# zuv6@2mPU5WN#v|GP$oa`cp4ZOjfoKqt*Hjp!OQ}t%6BWmfv}*LbuRa&eERfhZK^>` z$6PRE$v-$rN@pcK%5<}vaQDc}7G$cKpk%C~87I>?tc4;R5d zYYOD)1Y^XI0D{WWu-?k4mPlDSxw<(!xc59vmeIMndHIcUO9`);TQ^x|ISdZX0@5fA z2x<6>ZRq=#^byEalcx}yVvw)H^fbyxs&N5~^I;Z7MscW~_KCdhk@Dx@UITPe1QAAC zv1=ZB99_Ho z!YHgB8v7}>Z6hNiRRS6}a0YY~X@hTQ0;IHqq~1l=7(`nF@G3LEv3$Ce(Z&rMOd(NT zKus_UYBjHVHMIjd@RrP?chV*ca%UZx_-WlWHCG5q0_HeT#mNCQXwG}v|x-HKe5yMHkOI-0&sAUr| zJ8|esdQ^?Hkk~}&I#X)OD`e*>I6H&p@jTUJ&7qFG3?~c&{IQaUzAp+wj%MMVbWErV z7Nns&lZh}e&OoL(tK#M-`3un8d3Vsd#FlT6Hzi!z!i zd}(jjqFKch67|0P(d_maAUX&y-+gy|S63&8yCUp$Qs~xbcT)0-`x+rHTu3j8p3%Ye zomqSLrAYazrF(q(<=c`qGNfQsJQLNgQE&?pvP!}DM!xP{K{53yIwH*pO41-7qcHk0 z2D3i)Rltchy)lr*0?E;angaqI*#F__@~fK4Z>*@0Mx#O_O5<6k&0JoJGF0)aXpltE zi#TufoEWJF(Nhy1(6ml;*ZcPuAoa{2PATx_t-BJ!Uk^YOk0hz?Ci(!{cf+~!nv&L<*EITzR_GwkwV; zbyUsn{fnMH8c5I-`8&`~XxU{ooC5%vg3?)jkFSWgrx5kctbl{Ue?dfd1|UWh*cUV6 zJGr;@?;2Q8R4B2)1{6dVgDIUPG^jfbC+Ezd3kx=S#gTo7c6C{qq(xgM@c@!e4IY!Q zH#8Lw)HF1Vh4c>WWJFVw@{7ytc?AW>PMrAMcVKrm{l0w^1er=~W7Gu`b&K+wV|#<8 zoSgCru^zGmOc$#H;_vF}N`%8ou-Zf@)&Pz3UFP3>GR(k$1_pZio9SQ8NYIJxSwRB% z^2g|CGGq&i?t4bMOS;tyB&$T)xlW!`KpH2dV-8^%j52u9$q5n?1QJzDPeL-+J2=P* zuZB@wOpEv3(NIyoy$#w}%u=m6Z73Lun8c}*Z%YCwnVOi8mzRIKcHKHz78^~?5aQOJ zKaa<);W<=MscI-V8w6h~Qo#Ncb{Km}w3mTZ253WFQbSi)H&G+QFrlZms!H+Z&1f*^ zfn@R_NE_B~*bqE0ceFQUuhXxew~<*?nG5dAcm)rzWh4M=2m2xpfVRM(Qwn!@GFp=W z#(G@2O1dl2n(1R#37Kvz@PDww^M_7ApILZJE`IC(_+t;BC>=e|!Gq^8bWZ^Yl(<`j z7f_41+h_Aeb}}-a7*Fdx zCx#CIRV1aQhcVGWMo#V-ViG__EWjzlb7MqaMbN+)iu(gA3Uoyg4JU)fXa)trBW^|D z*5g7(RYPs1ZPX4bkPalY-(dAYe?M%LwLUSi*%c)-@d zfmb0xg45v-GLL0QFoTFl!J(DosV%pjDIZ*RTNB8lRa}87ISi-wX_L$H*;{m(UPo8; zId%-mEz{w)UxIO827pqj&WLcv6@`upLBHla`Vql0TUCa z{^3&tRhLelb=j9}LknBFoa3+`RBQmnuYyaF?7E2DujSkO$37NaUUYQx3FZdh=;%0( z5jzexjQrFyt9H1V#(rdPhWgz%QaHYcwf(^7g8-+U>Ov8)we`n8ximMQa#sto4PH;D z%_}CR15)NhOW}^qo6|Y^1H%!g_Z+z%-C_5TjxLw=$$2`uXrZn<7cgFsY)$zKj#Uj7 zwmw_Jrq3I+nVETxas9nI=(DS^aB}Y~H%+yHlm!xW=j%*z-gKanW?G2<(+ur}b-Nxc;LBATmoL z1`eqjb^w=b+7m9=q$w^*0}wxu!$~L1+0o0C{f|t-yg6IB4`q+JEpb7 zxJ18}?%Fzl9w!VLn;{15M3o_|MiS+$ zFBNdvpr2Fz^x|HZdqz1yQIfa#XC%>Z4j(L2#ZsAnvy!}a5#s9sG!z#AAE@bGUyJuB zoQ=^x^@1M*D~oXd!fwBE*@*sh7$#Sy9!f@@+L0U5*fIBw?#m^p%$w7$U*N%T4`_5% zot&JCM-jT3aV*+syBbE!^zJ|65vqgx6M(1;&-ZeuphW|UeA5>F@YVmqhd_x9bo8FI zKsvgQy*)iq(9W7-h`<8aciX;CFI?IYTLnxSF5)Ok$XKe*$aDNf?I(lUIq=&~hZBxF8faL^=IS{ESybE{%P5}WG_5czt3lUkEZ_}Ej;iGcw|)L`+mAx`W#|g)TznVaz9k_>mG*sRo+xY z4%UPWmk6L}3Xr*cSsA;OaJ#q(FfYOO9yL!n{;*U8wKP;8^*H%G2NO_4y4|MI?KZZy z%5VbbFI=!-+2Wyv((WsGuD9Oj=t{R=_cX}M0vxM%k^ zJfPp7j(f!8Cq5C{C!*jNB5oYShFbklIFn{L@HZ|-;yDqYH0IU61#KkkllZ7#x&Gw5 zj317m%#FNx^=H=v9v)doB98ih{w$DnJsn-V=zq>uYP$M2W!C?FD~UL-yyNLF!gLXD z5StXxWr~jHL-3fm_ymMG@8^FIXNcIe`X%DB&YnbJdl<%C5B!?~e%wnR0});$y1b;PxLcs{&%P<|VN61t=tUD=DG6gVIpQIhYRtSh7{0aob)l1nzi%{1s9;TC6jUeb=6?InD8iZq zJFB<)zbBlLM^#VXM|(Ph zkv85_Lb7D}8&U`V{XSjJsTPhuw*0wk*K4>RQUlHo#8dHMNVCr%K6Pz9zr z*gCYUCPM!LJ4ks%sD*@xr zXOY0k2XU!*x5tMD#QTDZ)00emtTd%NF`P#9I>-EfEN78YSSTL6tNjHD>^5 z0Ng~weI@WVtPev66u?;Yj*jZwpz`5=d6TyfK9@&Ik@spzAn7BBz$?s*FD6T={IkmqVW9(E@ zb$I~o#2|-VXzyM&yzK&)(}uh+%ngKP`j-`y1%8Bbh@!4u{PX+PSE7z@O7{p^w6P&0 z1}xRZVm%8F*Sl+hVYWt5f0)8sHbtCU$D>=Q4rP6|F%v7Rr9a0fbZZl%u~5zAI&vhY zbXt}SUjl`|%-8!ntKPql`kejyUX|2|a8Jju^~w3tDr{h4g2QM>9$(!_fG04j!zfXf zb5x-TlW?r?K;VPSU&mZgQdFG5d>lqIi??_SMGXQlQUvBxH!lr_e-?mJfC5&93#yVY zsE!;3G9|Xa$&>Y$+JJv8THi>b%<~x)B^4N)OC3S|{d(!|{~BcA05VMu{vFr^NajuO z)Es@08^~0->ZB6klNi$m=s?i3Xwl?a4%k54hx&Ft!o958{2HPszteTDlN^o6&WUYA zCbAl!d(BKJg9O}`QF`PiKfshy!XqOsKaQ4JQeMW|x$*8)#IN(v{{*m&b{43i%H4!% zm>SnEW6V#^rQJw7Rf!xK+#ftJc-~ZuwH!bW2eBlXZ4iQGwOhB+5ao`Nh(m*fA!b9Z+)I&*O%h_^rfR3rja!C1a5uv9k3@Lm>{WMs=Gm{UuR2ExKO zWsz4u&B(9`>$eq1cc9tLgGw|I3?;lj0$U~GFASB@;4*f|O@B7?j@)$OtQc^zr^E}Z z`s(!U^Oy?|ir+oLne$)-EO0b{LUAc|53z3ecxbK1WO+KIMpAD^+Ac;^-sHMv4{dTb z$RHTKFr1lr%xMvgt8G7+S|84 za;l7NEouRukYE(%D20n#zQa(h$62aFF0 z0^e%>_~Ao_kf&3#v#Jnkd5>a?ze4wYcQ!_kcum?McM`B2_`J4GE5q;#_?o&e#mMwd z9)81*K`}dZaΝG6q{OCcQJ;&LZ6OJKnxR3FdzH6sVhL>O2dL+wC*{m zq=K=A{;mKjM(M8T>JjBzBw)t(51aK1fLa|5|Ii&LD1A6`BCkaL;AOX7cF1(E0_WUT zJwH}4&k=wzB$r3^i{5_ta1l*)dZw6;^m-}ZVf?s*!`%6H?SkcN=#2AEN2S;lbyCS} zghV>JK%U=GYSnpQv=Wjn0Z`=0>Z4*}Vm#T?p`-RFl6#F}1TVc~7sBNr9s#%NHpx7D zoeuynUILr9HYREC9ai8V9liRJ1G#y-MGk{w7a-9PGGuZz7WbH_s5Wf1%yTnw%4;J3 zMt8sp!E1t{P}0O;Bm(?-?-~0dA22^|$U!NRs zCmB_wy%G?y5E~o3Wh%D}c{NU&N2N7HCFv0`BOPFe<1iWtV9U=C>Z2hA!34Wwn6rjj z^$}KL<5~SOH!@Ni+3VUuRARkGJv=;a%+Y3NgDY>IYGE%F{Pz3+(rjdnxOI@ZCeV=x z+YUTs<3w_UfYS)s0n%Iw!c`(DAX5;{=y&&I;U(_!Yk4{mozM23i~KVUqf7QDYv(D# zy&rE`a4fq$>WJD|wcQ_mw7R#gt$7!h0PK9M6D5mS%fxcPG@|pGNNP;rJ{a}*F?S&g z`UQQRGTaOWbn{SK76pUCWaveRQpRa^rV3%bzl z2)Y}1XDZVRoCTm8K?N_4708J}58yUITnHWI&?>|ufL4+*z5<{{1WFW^!Tajm&7j+P zKwSG~@&DqaWc-Lz-RMdU^c*4!z&Mr&lG8NPI8&D)og)7!0C>vyO5LBh2 zgrxcD+GP@{9))khZ#$3WRSex;J*{>6j{h!oB)l(x*6Gv64bdc&?xbG3dIsFVFmgUO z!O$8R5-I(QS}>|1JlS}Ugjc5!A7A4k1@)Gb)OdvpAVK9ZA`}x}k^%IL;UkF6Q*VgZ z#t``(%vvcnoj7A@vggIsGb$7dY`YtZxd|^Ze%fagair)Y+iNz&f|QSfr~)vygVp45 z^oUW$a(hj^KH@k6eCGA$6wE=@C;0@@3{*p@=;WRGc-U{xVvS)2)B^IBzwc}LK#cUi zG0`Ik8vObj;Bwo_hH^mhA8OZ@SR6&)f0^(c!a)j@6RO0#X~-w#VFFUqx=TX@vl!$57jHGsS0JpR4AnqoY;df#%g-lbOWFjRASVqi7;FuCc zoCwA`r5Ay(KU4?NcrwE=Rb#Bj0>--7Vie}?3I`SNyguQ0g^s6d3T~WP0?qyu*h7r( zw?uf|(bf!U*0V%7OJIsZ&?O(LCQEMYz$7Q_kDcP!cUW!a? zUF5}yAhrRUgV1vi{Y3<*<3YSki5v~|lL+FUK(nx9wqEu<2vNx(dZ9+$V>S%s*5KWZ zte+wEImGGE_idk5;cK8ZqF!k2;GhPc65B5Z90~JDd|}|g%1`MqbuhG;&Ju_GhD6@M zDjH9m9?p`WMj@ z21|CsH)-?=SfW%!R9rBMkCfMY1wB9sxPC4+wu{J7(1)Z1NjRe94C-bBP|--OK*&Tv zf};9lb&EY^GF#2uhy?W5P$*7Bq8{k?FsP`(pslC9C0|U5BCdjsyKfz6Y@D5Yg&)GxXO<4MK>5qF8#w#gOyeM6xWFFn}4=fqi*WY(9r zm$LMy5guNmv zav{9~5eb)d{FZ%qpw?Ala?`07m#1hn9{$Rsy|v=r-rmaR&)cHGE9L-^uu(BfL8NhOm0SO zq~E{)rPVtI4_eWDUnp`zz2XXx0Q~4JvGygvTQ<>MMz2mQLzzYo4XJ0O!BTkKTLl>K zv&dV|Jy}5GNsuv*1Y#_zTB*^mdx`Z02ni#{@;N`YElsFR+>h?OdfD@Uv1=PniuNHHL;PYtJ4*D@+_bD<>qJEw3Zinl(f} z!K^#;;S=)PK5?v4Cv+aoPyhJCBUDRUdq3$sJz;3sJJytaAo5}2qrM$%Nv^-970J-1cbVcFuwW&k7jy-sHdwDuMBWz^=HKYm7cw0U9b1*4gH79I{VNXg(V^i;NMiI zkU_Qy{}GrGx%=*KIJ)CE2j;PbGuq9tzOPTJeA_};JRRMw zV|x~-rm+Fz|3uP*I2)2>ns{csF4Jz0hswWCIby%LdJ$=d?~b9*y*u=~TlxAkoys(0 zv`GrX**wiD7oH2Iv~n{s8D$OsX_$#(TVv3VW191gT(87zt_ql{sDKz)o;T5-* zJ~4jpePX$TI`f{ZkABHl8ZGE}?ZLH7wwCB_Np6d{;LZ{pUJ|~rZ6skPbDKr?n+rPm z4&Qz#yuo)^YO_7k2}s9-MmTP>aY+k z)Z5b^HSWG7vqoiI;H#kII94asIJs8}0ewC^ZOaqQyIm(}&RYtieAout>q@BA`nt`1 zc3MqQRzky~U}Zs6x0=_^E`k!fXo(9 zR2jK+pk%xyakLus&uEOB>1@6JRD5MZcK5@y%`l>Ck~)jFikFHsVYCdNk4~4Q*})Y= zY`eUl_J-ZpDeuG5aoybLb-OmwK4kcDMwoiARsOiv@VqiA>!hgF`gSe>Y2z;j__o9||1IPodjV5g7eczIyR7G(}aP ze&p}S{?{XWD4I5BKV0Z+7OGngJjVHueGA{IQ%6X?SA89Q^Lhg~_+MOhoL$4fuzIG$ zLJm1Dsl(A7NZ)1YF)*~M{^Ypbt3|sf?-WtY}|O+8((^=Fx8-N!QEGmYKbOKMQ@jQAydd zz_AuCkz?&r3Rs~C{vm-RO7-M31{S-vFVC>9R;>yVk^dGY&YrQcvr(Dr`leZI`>=Rn zM(foR-})g2GupRbe!Zi^7Q+)B``#gHdT3oKDOjlL>e9s$N0%&%Mttm1aGGp=g3|jN zJL9wdt0^(ahIOg6P8zS?i44#t8=DU&rHZ@FDHSbsX<6UC9b6=)M^$hn^7OB&jcKHH zMQ5x+BCkrBU8-}o8-KMk(Q-+?bME5_()YF<3?cXbjnNTrHjg(;KCY9dztd6E$x}CZ zfa+?MK04e_nPaR`vscZAz2e3Rxu`5tj|xv(BC#c_=5DX0s2!V`yz5DXB@fTO$KA!8 zGe2dODu!YbkUFxu7WC#Z?z63<+n#Uap{Uddm~)8ap z@}P}w^URsluFp1fGG6?Gt4~RAqITc(%+Iu=L9YUMHdb`pE~L2WCV%q2EMV37sl0xT z!DBp_uZA{L>d~xwe*P28vs`v*x1oWdNv|A5Ip(`jNA2w6 zZFS}*4|=Un+G!j`wBsJw6aw&vj~ z-2i{%(`4~y$hfdTX^r>lZDM>*^xO9vOefOrX75;jf8msAoLo|zno;3f)S$bAD|1}4D?CGOJJ(G}Nq^d-Q+ z#NESFv}u)VG*hy*bPn%)pLB9Ze$quhf!V=l$nBzt*eNrDn+PiVdVQV1oiC-Q#sM8Y zFp+9V@w>BD%;37f^>CXWqV|^UF>$Oes(eg^ukT)1>^~D091{DsbFALguOmv$sQ6IA zOjLW_#c&mS)!FS8s29p>o2jJ)AQ*<^SjOr-OIH2y5|hS>|$%kFfh#AyVe1L<`xh+~mC5q9x*> z?lk9yomfO%&fCx0n>J;>&bJ+B3%FltoN=AT?szAkqSCIanPXbCV^o4$&;5YiFiT1V z^W7HLyKpj9l`jVNjJ6ePbvijT0}LeV6)Bi-nwT=QD6$k}3$=}QFFw0b`%)!GS0`(x zJv=ZZB=&JwZujY8?6c91Y$6(N*}PrBPo$u-xw-6ndPAd4>Snq9BAj{>dUj)73R&k( zWyo(%Wz2NmdfV_HEkL#3v1WJ?0rQ5lieuXrMnfHl=LreA67B2X59fxsviO;KY`RA_uO)KWlc z*3<9!h0-!znkP2)sVyjODsRbH5l{n)JL|V&>9M%mxYB4>v2=!Rc$A394&o$7db>6; z&wcuXXW9D6qd9l=10{2i)1BrbwPm^H zc{N^NcRld=Lw)KVi-c6AAfHdHS5Q44D=+1MH4@~n7dZ4Tj-Xk zNbh8!&XBBzY5^ha4pj}fA7@uxSJ0KLR$d(Ys@>kCURp2F>hWF8U%k6YiZU{ZBm&UF zjc;m0uF~$V)2(w-q56;gh{c+DPn8%dcH3#v*v6xoVf`x1ZQ#t6y>ptaMPG7FPNF1# zjbYx$Qo_0KrE01=1)p8_0&3IZVTVzgZThWs9bvj=sx1W3VO{;K`l(~L9g7(I)dYD{uyoQ)uIJI0(?N`hl9$i}2 z`Jt9zHoVX+)p=1Y-l6MxY62g-Y?aPBqsDw5W?Dmr=0vMDsi9Kef2h&Tqj+h=*(Y@B zUbC)VL_7f`T5g=`Rh3tts`1k-c1&<_aYd7afaQljT<0gL%*%7vO+V)rqiA7!WZu)4 zMLd?}u~2pAdmZvj^U-eOc``%DQup2=2Q@vCBCclpmIH}ucC)al+wnxrL zRtx@$7vJ7%GKjq_QuEr7AA%rniI67Ns6YXA*=^xquqX4Y80#F z&5ke0ju)Pk+bL=vuIr{B$%ejk=_sgbs7IG&33G<~ezVeTvpf`FHS^x8n}dL*gAb4D zHRXi!wjkm&Q{&IxCOFrmk5_ZXn7$87=d4UFTGA{e&JvJbr`^Hq`+yq1>y%cA`1}D06S1+0nko!S~R|e~R#r<;o`{gt{>HhXC zgZF}@Q@ic#hMG^3ACNOG8tu!%#8TbG=bo2se_j#hb`cO#bNM|lIl5OO=_%V%PhN$$ zAZ=fB<2||lM2D3XiS8`&?f~0f8PA5s1|Qd@#~lfF1!}#aY?rUfv^;7KZCXhY@d4Wn zh;iA$wos?!xRybES96VupV+j6egQH|mF-E1jrB%b?gV)>5Fs_9uygur=hP(4KXCDk zElt>JIzO3gojOW*(z!TRLEhfJLe`&EESs+9Di97aja4;)3Vu# zcA!D6`Tl+HiaXc8bJQ%4d%N_I@?WxX<|O@6Nl7`rVBH$WyfT#QvdFEMe;$I6!{y=@ zBYZDLXkY#*Iffu+=B{H@@Al>X1>)M5lJ!N9d2sa2v>wvsS>@SEJwN5=6;I&ZQK<8X z)h&a0I<$B-d?wcAm&?pjsvePLyvcY&Yr0T3G?Tz*+EAD-*5FySyRHkLwJf_t_U8c?}x7B?|0W=4a+mD;5XLpj;u+4UvY$F zrkXRC#OTvS%;kt21fny_Qn=og2bAC zxc+@*_kA=J)RU6mYnuGB37~aGmJRLcI)|%MY-9|c`;cPyeSkDSzgSh*O}*j%YGwN- zFVX87$BXTadAuy1tt2OW%wmxt$|n(L|B^0BdJ?63{C0^K)A7NxAI^zQj_8+vWDmGc z(r5Q;DW7=l*jA8NFnFL&%+maEQ*ttZF#TUF0dYLo@B-RJ{?VAJ^E?OuW7((VEy~#PYob`$lN)gwZ@(0A0q_}HD&GoxQ!|}!SC5M{Kx*P3z za@kA0p~1&2*G#9VrTS9F{!ilrmz*}Xyj@|VJA0S)j(KidN0%_M?Y8;2TUy^;s(o3- zqsL&~`7u`N%DyUE{#GH&96|Gd^I=8^`}&a`iKy5e>{k=@5lGMsnmv=#SqasJF=iYV1j zEtLkVl9~B=_BN4*MVw7;%a%uC*EdfsaIZVRXmIm~Z%cmh%V%sv3Tu1fq$Rlhz*Hsw z`usnQ95eb;T{kLIuJYvNo@V+oo^7AZM$!_pthO;9t+FG`i`P0TsEi|6?B^=fykv*yy zW$5FrK|bq9=&d3Ofo7;T7IEeyr0t>n+?m}L*OQ4f7rv5b_5ctsh$LC^`JqQ2+o@Vw zLh8CJ7P>i-t|Hr-iaU2X)-Fm|e!t z0^*xid8L|BXFEN`oOQPvuBh!(wNk7W84{ zX3fG$DqZ62&SDEQ@7szUAB$^FJlV~Al=HaY)kgN>gbbtF!6&=?H2cfeCJ{+lXm}BGlMXSl(Wal2c)6zT1VV6}rC)t@$>!&# znmaTL>@#h@Up?mtWDXo6|7mc~Ql4`-aW9BW&1ok0Mf*E(2~y(A_lj4(Z#?qhD2Koi z6=(agW+F_@T0yX8Fd?>%UcO8(wV40P<7MHb%uHL0?!pVxhQ$;(%mka9d(`|5mVYRr zD5%P`TT=^!Wcf!5r6oz7GLmGzPZIlXVue}IA|l%X4IWqT2f^Ucz`1>Yo8C&HVdsx~ zpB&v|bzM#J*}dH~UaU}}<;8OI4T~0o6RoYj^Bg-c@eef6FT%r5E6#~o z%Z1rM!#;dJwX6A9k8Y#K@eAHd0;a0ePwTDXkweaxPu2PR9fRC}zy*7iY^)5|D^F6* zGGYg6-@3}oMq(2vD?%;sS$eB@{eb%qRsDEobrl;47cJH>)`zoZK$6STl_Cq&YQCmL zUl5@iZ_x(Yby91bT8-yswyr3nh9(}elHOP6B9vpLeym38v@}qGj9hiv>GAGxr%Oz= zdY|{Derk&DkNXiuy_p)e4EA zphPeIWdDzzVg1{$ZdjifoR}{(h-RyN-*|z@VyRtWkzX+8!Z{Za_-vT9*A*@P7OTT|ZWZ!fzSN3S@O0%D79nhTjVD@**E^?_O14Zf3foWHYo_)o^B3NXEg_4dCu z3UgOHX;fSF_HA?o^Et;G^F#f0Q?AELMt{cqqqc?P3_%$L<_x`UrV$|@n>r!EBvB#b zQJEceU~`?x%!B@^(e`Ay+%eM}@OIU9Se*Pf+hK3JWhu{q@V$%W{+Y8saio zzZK=vpT4JC#S?0`GzlQPxG+CdBsm@QDhNF~+#97g0$mL>l-*I;nq-=(Z%YVPZrN~} zN*%^F+q?^|8)o0w-9hP%arS%DNiep#39E6ZkRKS#NR;F#9qwO_VjjnWQf zaJ@%^s;G3l7FHyEys?7$U+hQT3fDv zWh7XCdG`OKN_1)z2b1sGuybGh??TMNY{odPaZaLLc-!AQiCwSE2oxPMy?mL*`nfvx zw)Q^-@znAaSw!|XCi`;*&JgTXX$_hVF5qg&p#@qmX2G~{Sh+7bo~`f;`&{~Z=?Doi zg#0+JKr_oC#*Wf5{CWBP6m3o+r|a96#9hWSS_Rgx;QN~k?f91R&9e^4U96(kso~)$ zo(<|y7yIWCyh)x>vz=sA2qoicg53 za!!|ykDF1Xex})?rQL9+O8<=;S#C6`n3(7U6?Iv=k+v@59IzW#C8a1UPepiG9V%KR z0`Htr5L>WstiMn()8kOhMlw@NrcOp3ZJMqf>z(?w=E=o_6Ye9Oj_QHE-^(XMS?K?i zzqtP5`r9=}v`Ynea+*#l$;Fzq-smMIXEvM?$muv$!y`bg>b%Uil#tgKoY&{McyHhL zwOfw=QzCKMbjRQE?+f|f&)Mkc^cRjVe)=S)`--opV|-ea?hUDRivjDx)kSGXY#xso zo|gscjYnO=HR+KG>$=VV`hnR4&y|dnBji1@SJE%MUvN!+Na^E>Hoi`3P+KRRaQ5|D z8JWwb_R>?~SJ0iBm%`GzUApuS)Akbfr|$LP1DleVn2vKZ4(^n?SP*i?E?ef=v6#lA`ZG13 z)1BtCo;5j7#F3-t`P>`sy^x6G+!E&-p0eb;W9WsHU6#X~_)_Am_n2z0e z(*=pEFJf9!vlZ_oeAjTQF-#6q!rq__cFDaj->#beT{5G;X3gGfe&^{<9NxIMc$6^`vKN%}l(r)hj=CH0D~(ewgOqFS$p&t7p|k>xK5sT;hT_=5JC ztc0y2FXv?p8P6z?=4472RQ6HSa(HJd%9GpDoTBmzm+xu{nkPT_X*MX%K5ze~%W_dK z%W65lCABN!{YGEoFHDD(Q>!u-%qqf~xT_g?)#{Fin^olw4lPIt5mjP|t6?$0+x1$$ z&zBf>A3!=)+AV7B82`Q?HI|iW`?1Jh>!*fi67~e#wj8rGT`_EL>X}e1V^EcVa|?NyulDtULbnMP zNo9C2Bym5(Ysj?+-{(lE&L0i`oFm~&^mwM|6{PghmZMl-Ih%Y-%rmiuacZgY>NU5$ zIu=wRI@%(TvlrYkYLdihDJ+uH{MYQs$` z7w|HGGm}n{1O?n1`i$7S5OgcPkBjYhFB_a-6+?3)AM7wxD4<+ElZ7-DdUkhLQ(mLFy*yTuOh(UNOcxMh@&N(7982 zBA{nc_HFnXSN464Yx>F{RY|*`k*C!8wqAkss@eWO(B3VPThvOGEQ$TDBWmFdW)x zg>gboZ)Db&z4Rtc6BVSszL62~{NGE{hZ{=(_+-c!ylT;M*&{2GaDgIB&ys4=t?PQP zTaOHt&k28FMNQ3vrE26VZIN$b?beMU^!Mhxa=L?)>uW6vuCR`WRqIIa=+xRh^RZj@ zJ|idNjZ?Aa4^NS+qC*FRgmhb`jvQ;wY}uNq5SMS-ln~oi*QES9;{4uXzo^H)3A20x zjW2n3JB{0+IFXRH`}NApmGFWhRP%ynBpjR&m-h@##Z;Tr9%0iCKqsSUQIRdrw=VL? zidJ-oULs4{M)*rQ=i6s$X9{%AyWC@|tZtWWD6;Tu(7MZg@b0*M8Y|(~MFN{enn8wL zDeowBRB9;V6-Z|~9`Z|Rc#t`fe(yn2{_P&!`M!bYG(Z0vWnQNBA@f5`qbxd}1xapV z8xt%qUa)hX3)og&s3l|->%8%q*L0DMkW~l&nemrp1w~jnmfpZC4!o|VeWHze< zR-MImT3~+?^a)E?=e#p^(Zwa_{;r>r40I#%h7O5F9{@7QGfr_-WzMO}jot-0B%6-1 z+@@An4S%8V@UEd>`NV&_zo(ka&YWRx88>@K&yzHlUlM8=mitB4q_#q*(4n?}d_S%@ z8*loM4KM1}D%wB`Obz!lyb?+Q2bfTObS21kye}nfLVjss`jl9AV!{^sZJ)=@1WkKW zOJZD8M0Gv8+%LsleHko$d}r@;o!MZ{I?>&3f2^nb;)Ok-aIyI)O}R)@7!{O9bBmge zF4OT>mo&-Ir+Q^bYrR*oXf8%vGND^wD$g1 zv8L?{r9EG}`C&*eIupYDk5ZJMvRP>7aMon0RIA2ywd#sTzGQnb;>bb5E^6E=mm~A) z#ZM~~0I6f6?lG*+j?P0qEB5k)p*nA0)+(!F07TL=G2Cz2I`R0W&P$cPCyINEQ8cA9 zQ;xix%Oxopf4jyoQ9U#9k=tZYm|6WrtD-rlq33laBMCo#h&YWUFsZ~`+?PXaCV9et%_^uBBx5Sf+P{aj3;jpPFJUWr6KyHl5);lC2Jd5EvZ^K*p0KSP0ykh zh7r{5`mcJnq^QR!9lc5!9p;S-?ay*(8_Y4F&`hXxG==e5yca zb*Rw=qSk&*{(2jpkmxgd9Pv;4mbx#Oh54x<$gkh9v8r+~d|m@`MrOfTS=W;@Hf6fG z@gDcqBnVBq3#L6L`5J_})t7B$3K9xG&^}xm`7>ub-N@&Xab-D~%jBpi{Sc96zEd~e z$?4tCE>E;eD~2hZ_!4k^74zOpAa%m2!GjY!d1c#hQ}-x z-j|{^8ja5>JYKPFdi6tu%qGWMSK`iL_G(Q3P@;H3>e20u5xN(Y3*PaO*&G~%vMjr0 z4$KYC%Db~;52}6}zxM&*=U~|>-FKb!^$HK~Z82Q1?Q%S&tIC?DIzsRe``q1grgu)b ze@)ZQv*AP|;6`8Mw#sMU{xbC1aFYRB?#t_>j)buKSL0#2Va&qzPmA|-Y4z8=RI{wp zu6ib+7gF+R@#li^$c%!_KP9n{cF@bmdsM!Ot>aHLv)&k-8lGByd0ztNcXQ3UBoW=c z2cFj5ETXNrdb0@_E6QyTX=;4YXk>>MlrU`L{u>i70#52U7sjQCb;(ZaZtxzUtU0`3 zI28w;iLR7PsO9KZpJ6Gh!Q@ri#>t6WYW|-;9T@H%<>a-hXBC!H9g7Ev zGHXO@u#=KL5&HGrTWdR;EXv#;(T|A?b@q6~SvTSeisxN?eAkK!dLr@0wo?&y8+hP* zK=``twv5Vt8!J*`R<)bZ>CI1z)HUhpu8oJqtE)z_yBpYUE9J&sDy>ycWjQjZJtkFw zZ{N9zDTbzvu|C8w3SY8?{diY0+8QY`L0nt5IQ?v9G?zphUHz_WbgP1rFFSD3G`<70 z`OgD{lGCpYx_$s=L1C$E$cGv*d4OT>-URgiMD@@aO70?VUSz+|YIg#Deevz8e6{@? zu`0wT<2q0%x(NRJ1Fu!Vh6gPI!Rq->`yXro(M}wgc&-;y70fUE z=s3X63nim{EJ%mxN-x|XB%Z$yc%nY3*@pS#_4xXenI@f#KCHsEzmLSad-X`8w*-j{ zIDe!wOHO(ROEY0mtl{DeQKr|g)Bd9c*l=aY1e9u9rlps*wvO_VL^9S>kaoKud%Pu& zufI8hBxa#rKq|-{4F-=t1Uc^aBkvU}k{3z@s)DL7_zO@%3db74SLD{O-5P%*W~wdI z&5M!urwZO1*%awm!f~4I62mdIv`+_!Tl6q5TNT^3+$R*Te{JnQt~%%79L>rL6`2_7 zI-EVRl;M4ynfk+5mrgqdYYMt(c=fM@2wT>7=*?6@)CpkOT+q%$$PZ$#{M4!=l>sqOk%5?9KS6c>1R40d;}d8%IALxEcLevGEWRXGyRfTjQr}&Oo-T` zYxI}t^acI*sx`k1oa~Fmh?6mvH$>x))jRgM45384K2@mqJ0w%YFoggWE{?vNmqXI6 zimB&(edok=3k$^vwxLX?FdE1}lb??`i2iEdOQ}8a$I~N{o(zI#>}?#>6dQc_dzOn1 zyk=uKzMGrmdB#=Bpz2v|u?O)T^R;+jK@nm)c4n3`F0?I3p2`mrjUe1zWNxH|F0PXnh*FX+QuJpLcL-U2GiwTT)=QB)9v6c9wE z8$pm35b&Y9Q>06}Lj+Nj?ogz=OBz8yN~8tpF6qvHZq9k%_h0M#KGyMg4hx?9zG|+S zJu`c&S&Ucq&Nz+h98sVqQyVBz&YMoq8=Ot0?{He7;264=;>@(ljx^(%-yEL?*U|h6 z47DCnzGPpFqaCg}VFr~=g=LkKJTA65`Yc1G%Xl5kC_1gYlHIdXo$Sb5WmrHYpb|d3 z9tHNW)z*1_qX&5Zb_N6a5QpZr6&NbC`iHvz9JX)BrY&9?-r}=&Da~ezuOdgeB)w3> zChiy5acN<)GKK@>lCEWh3ib4=dmBia@;U)qdGpgupLOTyRy9GVMiK$GdiHVHEW6Rr z%}Do$@Y!f&*N_kc?cCsay_Y%i%z;1gE+eHau1mwyHSE*>TH4t7D%ZWdtbYH#SWSs5Y3dr2@V-loa+&m10RYE9EF^$KA|P0N^8x?pRdzNIS}U> zt$-A9!i#B?%2J5gGmY#JED#L?Su?Om|4fqcY0s7^{5dziXuAKD>HWU(Z?zs;{R`x_ zN6SGL>_1<@=yH&j!Ag^Q6+KIQrD<%=C&|DB4kp_>`VbQUN9s|YZ|v~BMD5nvHa9j}x3}_MH2H6CZ57MXO?3bFhhF?{gi@mA)h;}@ zUq;03fep(4UELSwzi%JID6lj)Cn|R}?!WhDu^s|l)X-~Jf5jwZil5wpE?L;cYAmxSG3gaOcXKcpKK^H^ z4UvQ6B_>D5{(Ipn*YN4X^ziWTmv-0h{CBrN(7xn**iVCnU8H*;`d7fFZQ%~fyX5YF zj0T#n!2kU?(>EH~1~#96xC0Q>0Cgc%=RtJz9aInWgaWTl2m2tj(1X0+kCnhHVUZop*dn&+@5fFF*&FLyKa<2fr$6j{>;HB4pz`!$rE65nypBJ?SY-Zj8 zdC&_RCg_hsi2Nw0F8uHY7ENQ&HvscQhFA!52?Xm21KJ^jukWeFko~|L3EmyZJj(#z zmRcp-6FR!Ki`U+$Z$nRkr^?7~Ad`UZRm78p zrd}iNGTm?9bXt`qm*tU_W5Gm@iuurofCRbmpfS{5Y*2YrtMvz}t_;thvotcDN;jpt z+%w@5G0DD!sy(x#%@`X8haN`(roMOA#?6j{yS7E+2q{qAVlo-()4MLyZ)(+9jofx? z5_2Ysx&E2svW$!j27F%z9)$bmx4CA8l~ykno~h?FTL^lNH(@s%m|6L3Yml~P}* zQBscnss~i@B=T>TLQTr#symG1zz1fKWn;Y(X9LX2wwq+J0qhY=Wfn6pfBxc*nAXsr`X?I|2GmJe|6(9qC4 zPnSvH;PN~ZLyIK63X5=;!!=6>4udn%OmDnx0WUs4k)W z*?5;msG(-fa7gU7k`))%sD`{(1>N%4Eog}b9gL_8?~y47+nrIc!v2DIy`|zE$ekA^4_XAhy$og zmv2QTa*r=&?s#%KodR|0?9NEKnD~j?UL1H}IJq(YIs7@Nj=Q?RcPNNLe)Z1%;w0p@ z@`|4;9q24ETN8BWULg5Y^Ga?IH*DK8n{z39gH5qk@|s9Ey@B2LE5kLe<-qxz24qWS zU>ikSUdJFlNqr7*Vr~N^ZOmPC_|$s$^XoIsX#Pgq>|~yve!9u}LS`rhkm{7WM1_OyTkt~oC?kjRjLMMAiFG4bhASZtchTuQO>EAh=Yli*Lt za1P08dK$H51SMr7pglutwLJY9~rbQ0OJVK zJKMb&+zD?`k2xfaGQJ7ad4_i9k<}dH=jggZn1Y9k`w)c?EvuIp>%DrQiV5~fwUDUM z+*R~5BgqrXvfX=JfM>un>*ro7;~6?YRe|b2#fK<`vm*)SniQZAHGJX#yeXBStqaZ! zHL!n3d>7{UKm8fis6VZaMZ6BcY=`Uuo7uKh@FgOcb+myrqP^Mfhh?BB#5&O}xoqdYU05eoUm0DvpnkJR|JkphE*JzZl4M3_q#>jquwm#a&siCL1tn(0MVDghfDQ zrCV-b*;-a(?s$J{MIDX9CWBV|=Z#VR^U>P)S&X!Q7t6a?H-W|T;f0n>}vv4|Xx%u@&3`*oYd3BsS{w7k#9Lr0PEMXd;VMHY;IuKFG&l9|6T8X>o5x7+>ghfA5& z502gXN<1ts6JkCmEq6LlKRB3?OT|nr)orCC<#rKrK5^?yLh)VRKSG6Cj#Q@})a58P z6;UcyJ~Lj@XVwWA%%}HMQB2(gmaL2z0mj~}@w4CrE3y-oGZhhKy7XVc|Q&97Of6uM(j{JO_^a?&}~Buu=p9F~w|U#$44xA<$qS1TJBODh4>xBq=2#p|gjV z0jy5J1;X@>fr`m-x`WMr$(H8);V)asRPIu=<;-0PY}#gu6;T-5PF*Q9yv~O?!;Y(|Yogn7g~3my zj`dy#CU7UWt4#0w{ph^eI&EMq=cEHPN8y}S+Q2jBc(^l|)9G*>3rhj6eh_EM1Vgrr zP|<-N6vF6@C>adJ-wE`nO zs?0#us4r^M<4suE@Xl=Q?P$Ik4o!he>=TctT<;u%^0?l#{jO$C^Nm$n)Y7n7h6nT_ zdorA8&$fpLQ*>%o3+0ZVJ`=&bvQ(UB6-{~0^0V9kpqkfXAm!=XgqGs;qd)2_&R;1O zCqi)P6tXL*-0Pm|cnr*rJrSeMRjpD(cdbpgTq^+K!N>yP zky4Iop}FKoKN0XObg-wkK$A|D<)b-ZO%oTd%sf6gzAk=_&{dtc`TX_AW1^s@77Mc_gRWW0 zNJ+LH(A<+szcKW=Hxy`5@PU)!p#D195pbm~0l&`xaA`dTB@0qs zCs~9HO;`6zOxTp(&@(xw#U4tST&J(!P+lbfQ2;%n8N58p4JIbK{Z`M2bZy2@q^G$! z1E;o8Jpqj^_`R-QHQ!*-sx%d0q_LOoP28@is-W}a*Fp98S=A*CR_+e`&a~~DYch(6 zcM}tCkyPplh@*0Rfy6Kcj0j-$o11VxM~ND|!rpjW+ZvU-W0m%Qa$;7#frpU~WHO-m zAO=K;-5{CJm!(7xEcq=9#$2&dc)okSzP_0tAOmrtJWy!UpGSK?zS)R`hv~GZRoZWU zL`<;D>8|ew?vbKo|IRy#yX{`gsPsJ&!Kk0ej3%O1isgs;1_&yq`?13iPab*urV!k8d;ZE_oR&z0wU2R3a+DF8A9omD!hS|LXo?%rO2X8Tel2odrhQsx!}cp;gEvKGsx zc%{u4+e-U6x3|+Bas+`S$C>KIH%tGf_kAt z7y*;kr@0EZ>Y9@IdOcFU+sPvpypnBVuE6K>*w@F$c;u@!;*%jPIB*c#fsG1wpTGW` zFr*9UtdV1ufTs^wQVzKjl5I;mm6P<$^u*CUx9j4;WVab2T`DKUEwkk>r4{3=gV& z?Q-t=Lg=D99q(3+|2w~GlRYaLHY6yk{^|W_jlDit8sOQ+$o`D!V2dcYCmfsxl>Bzc zv`HQ*zlZqe3ftDk(awfUvK1IaEQ_8u(0|gS#GgNBd4fdXx}VO&1bS~w3>vQH?YTpH zEb4Tg-HGOe-g6kevZh!1llD=NbOK*~buItq7V9T$6&^YvY9R1fEGC>;SxFJPa0=Bv z_T#Q6lr=p?g{vbdxv@G=o_1Q0OGC(!v56i%<8fS71itwGbeVhj4646i)h;3qu(XqM zcFNQiHC;iMpUvhbdj-KG>vx!Lzq{$gg-{TvJh#tM8Th_NXgnv=ZufXHmPzaVU^H#< z4id2SpYcY4{6Ywb-984f;#{pNW#GgUZ8={eVh9h&)4!g*pPO7wn-alFr@gj@5H3&F zU(BS`6NMdSqlwzLy63(f0z6L$mzn+Am=;Kdv2)7@>hje!Ttw?=IO~kUuYHlWS!NpR zvg*_#lj6x}po={xc{gb{H+HG><`JD$>oBsiF4$zrGF8=i_Y)i+b8BY-Foz9fcz?>R z5VQe2Wh*-hnCkkhLU4)DoD_kIsarvYFBO~y_sv#nZh+moLsrNjxb-5fC^L!siG9Y= z>ETn}^?7hZnonuw&$pX9ZUG6IF@{8!Y(GX&jPs5W^JVynV(+W6&cB$2CGb{$G`%qn za(8(kIYA37V<6w7guoji@LaU)W`OU#;=gF;W%9 z^&jPG!kWQC>_>CJqqg7|eMh>hkTN-p`Hvpic1yoD+!g1_*4-7@w4{Kct`DgD0s>z;A`$J+ z*(_D%M1jC=h_e<(f(ya&R!e(6k`n7@jFMOFmatq>W8h0sEJ4T%mP* zj>hCnPwT)SwL2t7`jw!k)WWzLAWw*{3kb3$?eFiiPh;ZId;y|pgf<*B&vp>*VB}Ll zz!L!`fVa~0sSCVOz0-N_0R=U-NUb!OpF5;YtB6<h!IpRVEp7hx#cU0i?&${BwhuOrDm4jXqC%zPpCp3Q19C_D&OJR4D3hb1o(U>0 z1~RJT#cR9%eT>NBE%l^+KvoT*6^>>V24cm22*(+T0E<8?aN#;wBL;K}2rQn-8D9Vy z2;|2Ru37Lvay45iuG)fY>~)txX4+_mO!@naW?eoNP10$5rSozbYvQTFN*`CH9pzth zehGL+49KgnoBw8kgU;A!4IhTUWC%SpLUEU+Sy6=W^?}xoF;KQ~yKHA%RzU9N=2mmj z`v^2XCX1&o4UT4|-QojOEU;K}07LPs2YoX(HmHp%trLb*@X}>stS7CN* z?JY3#Z0?e+p_DtIg0KtFVL6fk^ujk2!gd;ub?g^@qnfBS2WMiicA8Xh*)4v)0NDOHX_Vcx;nG(u zLS~bZl8PlB^XggvMLDfT#Yr6zycWV@t!lfb8T%V|`Sb40on-+cHtag_PE`6>Fcb;N z)xX-68jP7^5ObJk0E--|M>yA$3?Dxgp^!$@u0W&!$ow84fC0U7&X2qdm>&2lV1dT^ z(OQe+B~32W8w_lObPSj;15YsT;ljjCyoU0ozFJ#mskoZw{jnvz+ZA=jCv$(7e%B-< z#g-`8Y}Qnz;)Gcu5i&8Gb6FO1Eh<|Z%0H+Z(a?9>?K(?en`bUuiIWxs#i{;ev757t2u zn-iRmCI0RwAfKV??DNq}FzfV2LDHcbxm8?m-tG3ck6R0{t=tRQ{JZTO%Vzyfxq~Qw zoKUJeT9B{Aeb2*vZ(0TdP8BFSyRuddn~2rLYfg&TaL#eZ__Nd0HTV5VKhR?gg|eX? z*`S?`4LNX2`+=lVNHX6lDYeEjN(NAqTKBtAe*z-cHE zO2DL<0aV#ff$2{@DUfI52#0^nG7lTg6n(}AKemWvkN0laU2pbdvQ3>R)x^@ikoQXo@`U~zINs#;Tkwc%|YD^R4 zQ4lx>LOduGu|=?0HCWmBRO^2OFTBf5 zWKJ;TS+(7Flt2y&1Fva$gxbsm5PX{v!$g&;8d zW1zmmIec8Y(WOPzqB4bRx1#v5ErN*&wb?n%hytTE_HqP>7gCPgxMszCVD6VepjQFF zd4%3EzxL3d`7u19Ky5`B_5f;?(6@tzT^YFNDNqPI`Ztt07Gkq&H` zOy)aO?>9}YHT3`qWM74?QK|Wmd?*4B073un2}O7b_;R}CFSIOMVl3C6r=p-01k5_wdhYzPJ9PLIDBp|6j{y!Ub(l`Td~ci*{7#hH zZY9OVahV1setLwZ^oF}1j~FNAXTlBT%-GIwyXiE&wbmYOr^(ATpduT3E@B226oVS^ zayqEfbBHgZ{ota(%oPJihJrPFUGE7A+#|YLRM(eNuQj@)l0iE-EtjE2thL}}^d7!7 zpw1*pzOoSaI)Mi4$VIFf6{eexVXSL4r=i?su5#%=-)(b2qpAkvJMS zOR+5^j1Wjj(2Tbo5K~ST4FGp34Y)l`5TxAWq!Q8OOV+=D=H&;B>i5B`_gt zi10hm!KPL4gG>w;w~ZCEdDNRX!f$=gD}2Yke*OCVMUar$Es*##?awlqZVmwc&WNRf zTr)e0?3QtE$Z_oK=^VOO*Cp^r4DyZ zJ22@YdjVu&4G?E~J=h=kwt>$lGyqbcOfzMQ{Eu1!)PARKE zMUMbU|8_L=H;HtT2;|C7Hhm0y;Rq!aIUwQzlzn30=MRE9?#IAOa7u)LNj0(ssEjjJQV z6^jS}!Oi@_j$uX@$d7Bl7JLl6K!N~W=b1zFf{dftalkJhk+z$kml6rf8N1UdW_*D( zk}(J*3|H8u!~g2STqA&s1k#`CLK9>{lb%=4p3}Z zxJrZ&#rFYMG&UKpDmun>Sr}Wu`Et~Q;tG;s1Bs%y`e3UJzdd`p3mX$+l!l(pG>AmKD!|dohs3uTBtyWFB@xlXIIk`$5MU_R zB?3TG7GxXjrkk;$mLyQX3eW{a*J^X7b?et4n(KKXmMbi7&Fb0(;A(ex6Q;_HY^aXa z)2N63*(-yfU-*g3Rupm`1AzDM{Nw?D9&kzCr=Uey0>fdSUpw=<@cY9RMm}*lV{K5M z{RPkvvYQlBp5Kn3RcMCt46Rqz)*;}O25ioCW<%8mB7p~=2I4EM+@bvlcoa?QD5s_j ziw|;4M7cV3+HhkbHNr@M`4j_APY<3i94crQE|3xrLGNiT5F@$$we-D1JD1l}=7 zo+&^?T4p&aE*Fi|40{d9;i*Rga;5D&i;pZEFp0~7{NQ=57u=Jw`f;+DOOsL(tKFa7KXAz=bJ0$o;qVUf0=LVQNdR zTG`lj&q}n_WSSDoJuCkwps)1}m4mo(^((7uP!wm(1%aCaEC9R5u@j?e!BYsLB9Zcj zLKqb{j7riUSxNL>w4rIiVvHEMXuc)(ECahszu!&!?I+%4;Y3CTeDA>J(h@{M@?D&Z z7_rfwG?)Q$kcsKfQBx=^EL1Phecf%1Ow?&jhw;;=UTu+#X1c`Wt_6@%CZ|T_(^Nyj zBG6FU!#>ce^BB!A-q(PcfMX9@R#tX+*@#4t*U;#hjUCjOzX_Y+g>=^-oc|Jszv>YT zYN3?RAyX>^$!Zmd`$~;Tez%KPHo}W5$&(gxvqqKs1QOy za7UFo37TUOki1WRA-rdZ|C!%i9XRD77d#6$yAom%S$HBfSft^_5jhqVG8~}9$Cd_r z*8>U0$0($v6|LYt2!TN3ZhvbSHF8K1B;p^T!qC-GcF(Ojsj%_yh#5i73O5xW*t@OaA;piBYJ0Pt~|4-E(Qr5ulUI_(NoAuUv0&C#WzIdwOtZvDVEQYS8Y0Ves<9Q7lQhC@g-5^x>b8 zT6+biJj=|M8=+)n{La7hA64$s1UZMBX$&04e>0xBgxp>QMo7++x$Zc0@vA>ReH&n% z7vBfR06R9S2g3jB(l6Z4oi0aDQ9ayv^F#PlTowMKZp1IYkEgr%-+CqekLDdqFg!l?4nmYkoT z%cpQ2$G$8EX{$rb;}X+*-Lwm+k0v}wnUS(oolnGf=I(%m{L@~dSm(V_brd~Ykhh` zklSuu2m>mhI-fIi(7L>CgwGAt$`CXWNMGV51~d7z?3IIqVUQ4ggSPL}U0RLAkD`)y zvI=l+?oA7L#iuKjpSqmXB@mOM7reXD3GNO|{)dnL4dA~+?z+Fhs7`uv97?mTQ|&Q0b|;_Z0d|v!S5AjX6Jon1AnV zWz>Eq;~wfa2_Am%{fkL7On8^}$n5*;3Ok3k+V}+_f>QyCrlO?-Lfe@E z0qU~gEuW-@&Q5G#dV$*0;P%Ihe8N}p-JSs1h>Z@KY#;zwkM|%ScUa zn(OyuSzo@dS>Y5RmFiex%9c9BesSp_!V6!Ww( zb(f;~J7PHED-`^S?R(`TDmJVO-TE_`PI(4w52j4g0v09mZ;Vk;Q+h1*-txJ7@GDEO zFiG~-B|#eTu{=7rIEudhetZKlUm=kk8oWz_w>EvB;}P(}j!uScbqsKHCrh7Wce8csVwmF`-h7m;eyzc%u`m2qRakt*on`rXy6%CG4x1zZ2 ze!U4&A>e8o&-Mml({uJNe(2hokc8#2S~ST_p`gwz1ueFhJdQv6^rBh5u;VJtadFkotlg}~9LjE10B-KNpWKLc2{^AU7 z$oU|v9Hb~PFHfR>_TfzxABw&AcrSn2C`6oc`>@upsI#{U!=B{e6=aGc3T-OlPa6+u z_VRvpcPD10Ur7%<-r4wu<-LxQB0Awy5cN!0cJ7vAzIvllyld^Po>~zentbNR)5i2$ zzjlL%ow~RK!cjtAv8BuCx5y6nA3h}XIi|!p-$T}VhK8a5ozJ-D$qk$$=6PWKe0E&Zeix&P%;v+ z*}qJ4cP(F={ywRo_7mG3NBmkg2%BWBk|^gYE0sRG?Gkv0!0a zL?>$?Cwbs=XlTPs$haXFGsE_6Gy~GM>=GW84DZ_7#9=2M*M4j6N`Icj-X3{dk|J_9 z9Pp1P>wtfT{K>)%!#Tt7LfTJ@(PDeTP+t5czW5I=lpzTaEk8Sx%2Ls-h~|wxdnh=hn`}R|GQdEe{irsn5Q5dbOPaO z!@?=cQlSM84oby6o{q)v6#+n_pWuF_l@N#kK{r{Z(*|S zWtEx?SAI};eIYb*1=lyzN<7TeY&m0Rktm~1=efbua+Rzd9AhDA3QDT(S~(+b_YJPz zIq5dLj$gF`A-58~NM4Vm)Ax;Y_0<#@Y|k(z8*_d3_PCL94y;0vJX{jukqjvC4Ih*G z2VX8sVlkVFdwafCV^o+;{JxhF3HWQ6m{XvNM(|!p@W++6Z!x|%H)odm`SA@tW%nk> zQ^a^HiuS#!R^CKLT71S{vV=+q$45q#0@L@nr&dB^*zxvJ838-EGnX!3HsQX00Vz~} zE%6dWt#3h;x5%>TlbTAInVD(+r%b9hapt}EzRLmgSnbbeNwQiYl6Yg#UDSL3sxMMJ zmnMA6=YcQ2w?0FD3n?ptEozByzTQKy*e8gF;A!;3T?xGF8cSV#YB+eUrOAk zTb3n@k&OGY9HlsX#$6V*Z71iUN^|_3%-5T0q?+LSN!3l`KbZkVjQ9Cpzi|z^MMe;q z9EL8X&N21#ud5I5)kLKu0=lizl9I0;vlvD)pM*8U;)d;ct@G6v6%s^4K$rbHIB@Kw zZ`VZyC64b3cxY0DVikDr!Z4s8s=K!rH&u*HZU)B}E~Jh)PMKc4tG*zW_~u!W?5)io zfA}y(k_sk3=CGx+GXOH?8#ArA;5KJ3246~iZZeukNl%bt*jHcpG3F7D@IzsOpW5Dm zfiBL1g17vc-Q9(kI4m1)pXDlM&ILHu7>-of2Vm3B>Ao;cTwK32DS&%)3p2LBY*~AN zc|1cBPgJCK#xvLD7TIq6QfapJ@h{W>gIbWj^A%}QzT8lnzMh^JTT23}Rp#h<_*!3G zkH|d2=2+~1K7(CVktUt_eDPPGM(_O@Y(~c5_VE)d2lon#Bj%u}p%G38GOnmA)Cr!Z zoW|*kxAh+Ytv-`s!OtmTZ^+&9x%i@Q#y2I6Vw%&L<;SdI=4$5**YbX%ibTnGS6}k0 z7tx3o7omQ|Xe_34sc&{mWnB_GhTqGI){p8D6>it9y6wH^&ZetJq0}K4W2H zU#^FQPGbn3qzIGTeq@b;8gW@p`XY;u(A4+Wb+ET@ZAR1ZsTT|l)3x2_ciDbk{Q9cU zg_7wXkvDgP;-AEa;QG-A#g&*pd-dHug7q)Ar)bsBgplZHpW0foFJEac>nc@JpuG{h z`Eq^XwP}!m#VSQw=cnz2pK#Y7&;KdFf9)7rv@*!aBELNTb>6$JFs6dCIO;DsA0Kvw zW}qa1Afgzo%nxbR+KACI6(bk7R3pj`CbqDI1gXS>Fe?hDgGN{OA1+5Q2F{Wet?;-X zUkB%eHp~ZY{O0l7op#QwfQa_l?Rx`aYL|t?#Qd^8lf`py4z^rj&cLZV@|LB+YZ~t> zKf?8YrZ?wretu>D$|-*@RG3ttTU1zhft=#m2*v%Y&VMIs#6FT7vYjpWM|5$QT_k6G z9Pq~V#`KK*4X};aVl-EU1Pkp}pLyr$Obw>~0ol*X6Wm9WpRb?JTFy>A(Z@w5;MNctfW=Pw znf&VI=Z|G{IC7$;*1DEu|D<1&3+du=(B&z0BkuibASUHUef2;ToEV$roEpYlo|29d ztNy7{itp4fxi~2)@kuQ#N_TTCM}SY`9$^aDXIWj4!wLXAP613fr2g1RrC8mR(BL6KUP$PI{Gak`al6_x6Q2dH-o*OWk#tGmtSB@sx9VNLs!-xy;$XpZ zJj?|D#E=Vx`O|&Fjg6np$ce$55ybuZnGeeg3i01Qr2UF)0IZEv@gKm%a@!|JCF$8U zME)7zzE37~k&%(ZzGyNsT{5?0g9D(sYIEWnIvKwk4Xj6Si=CaF@pvlZ->Q;+I~ura zUfqz!U*LI}l-795hnG5ZR7g;Etm!?;pf}>w4!-=|YouA{f)08~ur-VuK3%y_L1kFe zk`HIt@@F96m9I_)v`Rg2^&hU7w~FQ-2m1WLH9I19OMbA@W^5p|RIyH~}8AOSTf61xAoz-yp0TV4+*LdTmjR9)I)CL~pETnSSQoZP#`XQ6mk`d;m-`MJsY7+na%t{o+9!jVQn+b> zr-IJMj}Gx$q;+(cM({Nc=Lw>2o&`!$%?^zExv{;(;o~lQF_b^u9og!HPa>ARD)4Kz zD#6Ka5v^G?W-WjgPck*P{gG9vXPW!)FgIOG0Az2@ezmqutFSUxIITX(CXk*V2xvDS zvAvL~*!(%y?02;xS#mm4ntN3Kg{GyZe4c)&?q5?YIn_P=B+Ck250X!xqT#&v;kI8g z=8nIZ`-2J-iNIvqTUz`8v@zh34ku-N0z13t9tAS^ity;@cd$8NPhS#j1h1sQ(?Hkw z=_?->RNda+Cq*K&PAh#OA3t%+J2Qklzw_AS(nIPSpfPw^OG^ulBWFZ1^O16!_5X1d z=tv9Gc$FADq6>HY(X9$wzKti`eQ-iHQoVl!B`Nvh@K`DSMzy5yJsIgM+7MH`xmSC7 zZ}?FYrK5uWLl^s#kI>U+JA25uk@Ry0 z=)-*UWD+(UC-CW1><3@j3F#^KYlS?;bSJpr6_I)wEq{7Sskw=o;-YMkPZRY=X&M?$ zdKa|y2fWwOu-ZexnE+kLoJUNwtYiwd5+M{F`jB`)2WUhB?svBrBIm&kHL%G-qcr#{ zIc6sx`hDQl3E}dzDNPc85HU5U{W?@R^59Mc7y=N}oP|Qz=$M$vzkdUwq6mSjH1Ra2~Nm0zInHPwX*aebS~ks@aK?klkQ@y#VoUalkSg6Kg+%w zt=>YD4QIV<-2a`B(@0EIDy}PQNML$KHPX|2wcu6%y%N)$1zi7xD~@l_Xbf^yBMzO* z*M1jUkQP%NKdnm)V`b-Ij$R$k*3lVYrV?-}P4y-3F)s z)}{Mh`-g}ANECOG#qBs)GhlFVP=D2?&dp!+7CSW+)mu;+eKk>k0p?(1bF<;k9};@{ z>=kN4^xR=j!R%d~nWMXczkF^+5aOqb`O2i0*wZ>B{B8~uaIh8F=YM2D4oUQyFBi%bqMp5#YG6+OqI~_ zTuI1IZD~nsiT%Z*kBuRf{P5)_J0}2@2`HzXw^{ zrr}?jSD*ZRi*NfXPu4fN7;3`UwuI81!MQClnX6!5vHdPW+b7)Prr*9nUM;X>wZ_ zK7M@jGu?I1#h)iZ#9fO$W})dZy(HBjl+c}_hVF1u;RkesE*^ihhCJaKX8L>-2SosP z+qizEC_UWX*-`vV&&3s!d=#`bF(%x>*pWH-o3Ltcmx;@Mn;mA3f2Za`)8+1ZHbMy6 zgRZv~zO3G>^O#GNg7-$Sr$!bVA)K9+J zFiajIe&MV8^rM?=S9*)IIA`}4G6>`OoV^Hd>p|Q-sh!IakyTwaL|4p3OWPXHRU72D zXPHPDGNGA>X}`oNM(4g4Au!P=cs)QeL$j14U1nu^HP@4+m_~3^)PhE-nZwGn-rMVU zvMkLLnB9XkPb(?E`*`G&x8>}P4OP?JwT3zLw?c{9rUU-$7Fz$|FSV9hb}h8|+4Vft zUp|C2U;iLyJ}4z6rO4obVa1~D)z(jRszb)aMit?qI+Od4V%R;UpZ_3~WB$CQaK;p~ z``DsCCbP~$mh9f|(irqAj_iQ!zkhr-*dpl#jxLW^H@vGQ&6FRFW8ynXw$lARih2K^B0zBm_XN zRtV@T4Zr>F;}jZ-!^+A!eHigHvMa8^pON|%&Ahm{4GylDn@B#tK%qsiJO$1*ha0oB z^O_-Mo>y3Mti-mrw>L86{N}qpvv|(Yxa<%Kp6RW%iRNm|wS;ovzIv7Q^nGZmT6xFK zE9saF+Mc%_>MQK7jdeuYV@06qw$g)Gl60tZ#Q#jFjs+}4F-P;bU$Wr#X#Kj_j@CB^ zmSNIwjM15hNDzX9M?JmB=kQ!HBFN+M%8Yp-lcMw?P2W-uhw;6Sk&JKwokgQSR9E3I ztGgtQS6-N{ZMS-y3T==>yyh({%QZaoYU5X$T@T)a|;m z>(S+IxPz8slVb-pzyCyD_+r>E4u2=2=J}5cm&y6KOMgzY-QEl<%|%~!{M1zo*OHkMD?WV%yJ|jtVa< z?3!ky)&9)u>|IwU08Y?AZqu5k^!tDCzx2~jG;UR)Pa&2E5VeQX5oD5!;@piK^b9TE zYBCvp_v$v+-ns&6N8g${iL|sxr>3VW2;G9Pz8SHwKe{xrSE}*5d<)Gs@j z4_a`RI5sbnv$DRos@&Gw8>snOp(4u@1t$0H9_f1t-_4w%G9fi$v(X}K$$Mv5*s%l2 z232!yJPMvVRlcNRWwH%fm(YZdj=v|6l(E@NtzT48re|Ubga~oYiJ>0ws4X^%`$0|3 z&4vOC5NhWYCFns_%lhVKTgw6-1rF_NqrgYH`-(2r6&bc?Ke1grR~vh{1lVjR?=39Q zDdVgSRmN_xMfDb_xjrrId9^Y@_B)cXB{abA7pH8_b?x$wpr+!MX{~&&eqQ%$H!9m; z4=(i2Qy*e~+MeUG?8$z^NlzeY@Z-*oO=C8JRAZ^UQskM(#rMw&g(eHm2nbHN1C0{z z$7nyH!tq}0B~%-OaeXfgd-$@>gScGbRjmU+qY})v0tcb(ztNd z*~yQoGrPdHoetxaqAdyW)SBKrlfYuboxU5usE|dmQ&T)XzOuRVC%lARh3eXP*5({l z{c`#09KWn!G90RADnl+}jfO1tezU=LhmpT=DjO!hgi@E#y=;BkATV2>unc~VcPa*n zRSGJq9;mTsi(<}Se1(bqFOsV0i-})QUQSIxaYZn>UhlHDBBjC&*=!v5ahEz`UgJf5vFd$rHP z^_g4MAN}X%)ut73+;!YW11i`EUcFk=Q!g{iGhgdT4g-&OVpijWJI~EqyW2#wT37>g ziICw}&QT4?=7MB#XS(qvsj&9hAG0=~d;ISGd%d&MlNHd8uJF@yQ8%sSE6>a0*K$4i z!s}pTou1|>e6foRU6yKT<*pV2zkisQUY~NlBiwSDC5dy8a#>^t;{uX5UW0N4bmO+R|^SdR&R=vg8`(s$or+o_!gSDhhv z5{sd^l%bRziIIG}C^S4g1+c>VI$srKYs}XFj=k2CX;f^f;L5%`j#Vr3J~Xl;c0XJE zsL?eD&FR(N{kNiTQ__uS+JkEiEHp%8o*b7SN-fEE7SknT1nMt0|J~=RRLq)+ zoaN&Ex_9d-V=HxUq@J0P@Q*hwbboheR}=$x##S<;!&TQkx;i^LK=%4a_iqp$zm1D4 z3(qSz_1ronzdF0BEgJlm3+sc{o|c+5UwYEhjq1T~y5MAY#xWO384&OsGV(8kB3Z9BwBWenfdzKci zm>=JDo4^eUQf6ja!&Egi6lz@T)^KBG!|uQa4O~CX z8mCC9?769_xz};7yV)uNz76|vbW!a<=W^u>@EP49nK$n*_#*of7s)hv&-LYMM0LiE zTTC&u1rkW!B;$o5Z|(5XTmA%m=k_$Zd9JXEOu`yQVDh|UV_$PR1^pQP%i*PK{pA+4t*z^G-0vlo zoyO>W8e~@FAU}wciO0Ez8GbzIVRP?=xw#KyeJ-lhYV|fWG+f1Ub(|iBs&*tkt0yNE zNPplK_LOgpctxEjij|E(Hgjy=Ys&psl_7qhh1T0oQYP49(RDxs(+D-;s&yWG&|Li; z;yPw#=8f_Ro!TW%Og?vX05m%*4g4Q%LzXZtmxt}u%yk11EGhbrb(vtv0kA?}Bpf-)p&wB)fBrRdV>A!!g>%TBe8b}BU5q$NF6ufU*i(Hxw&M$-H*8{E} z7|8wHi+e89^~HYF!S3%+O#Y#JjL`<+OOXvQ2(d3sP5mKsp6^O%8672+Xpi@##pVBW zw1=O{2o5ogX;N~MQlP#Ly|^c_QSr_96~0^!fd&Wm8yyTSBX!Bo3&cc3XX&UQ!OQJn zCevd17bPJzy+0ZDk%T*>@C|i=-XQUg>R33HKbZb97goAQm~llM6)qlvG{!F<5ov2{ z_hl;hA(>^0x6#Hs+ zoWA=DO)7Mshub@RdIH~TOj9+xXqn*R2(POH8v2)zBxzH63Cq*>%8tu+vcCCl0NW#W za>ng7O+JTZDtnXB6sAJ!YtNq5e6LhQZSU9#P<05+o(x{5$&%0ng9Vl8w774fDXK@; zr=r|a>}g>0BFUZG{~ud#8C7N5g#qG=h=PEU($d}Cf^>I>bayvONeI&2CEeZK-QC^Y zoVl5AX04f7n{pSx#x0Rb!m-mkZ!| z7OO3|-=cAzMPZ19Xtw&g?Y$=f04a_4Of~tE-9@~>b}$rKKNxKZrXXd`Rf7&5v0HWk zevCeZ{{i;#>E-n`OqW^5ucLE)p?v{x%?Bq^VU!l&tUigzk&abo9>VS_-4+d8Q?dQ_ z#Lb)VV!OiwbP1@o_Wm~t_6q+=rzl|B?{ulDKU8g}4bD#E-R|(!vU}}igKCu~OYM-f z>aK_aG!Dj?@NM^#1%?RU@%+x7)ZQwYE8nlx16)A~SV=tG91jBi$UY#g_eAxd6s`}vT#py=2^x0$OJXrMU;tL4$^9oJ z36_&n2=ZT!EPap?ara{NAjGB*@*WD{i~8ruhSX`T!Hu=2`Js)r*Q9~BxFi4=u}pF* z)7@?yVBR;~$KkO_7HXw6>|feW=D_@G8VUT)5!ei=C%JfPU36l zcaF9ME#ze$v|23?FnMdVr&_*Hra?M7Q!ByOBj^-MLwM>Q{{mcx(RxOJwd5g%j6g=| zVRa-VSoKeK@u(-G{4G)+m#Q1SYA{7jNkQsw0oE>w(w2`CIkr#H~>?7+!; zXlRI@iRqDeyYoVW-?Q1H2!l@h6E5x_;H3kA384|Nlk?>oIYALmbhOW8By*1dRt^?x z;cdIC4$|K6ICL-O%q>}uB&L{Y(!|r{Eng@OJE8`cIqQ+SuOel}tTRtMWRUQ+Vb-s` zY5M)!pyw@LWK^VW@mA(d-A`j$dgBLf(h)3?DCIw6)6?5G!#@CyDkW7qbenEdy!Xr6 zm$OK%0kL)f?W|;1XfBGr2O$MiV0g=|4?3w&MQ_&?Y;0^~OYo?HFI#S>93vyE?!sKl z$258jFhFs7Dvw-bqn$K98^urF-P(`_*Z+Le0oxJIfI_XDIWh{$T%8?aYiYu0Btl@m zb|YRf=hctn3ALdI*RT3{blnRxKbse4U#pBPQaVL?gMQROczJHB=23tipZ%S zZY-{id+{6V!NP=f!tW?=g65yKzFZ8wzoSVNd+p2q|Az_@pCde`@7+^*5xX0$EI~Gw zDziHUi&rOO zY%pzhUe2Xft(PXdbI?I`K}Kn5Zva|SJf{;W=mNWjoBv}gSvCN~B#XKFL}U^)s@k}h znqRI=7#o@y+|sn$>~*{Jlf9!Gy9t|n4~Aq=Corq8G^WrL*5DX;mCYmd)cAOnp``l1 zkn4jX;?@s|_q}(5fG^k$B=a0DXJ22%^s#7kN(Irm-%;Rv0EgJ=WT_u8xosUCl}ZFz z82TH6eZOvX0hPMz(;Wk#YG#5eF%dCu3^}NS@MLOyvaibUZt7OKh6f&DckkyW0QqE# zb78AJxbF|=Ft4s0^0zOow!=Xea&CK$vt%D{oUuMPr#6{sX%EJM%(j$51K_7|04Xg% z417!{bJv06$UeOywR)9ky|bnqJ&>PC1^zXL+c@0TZ}$8%Kfi4^UH=vJKVM;l9{^PR z_+rHlrJPe7-)hGy+#`&itO9;|$jnR6)p_gKQX>MCd_puS4K~yL?D>H0bAXqOo*rQ# zdWmkvDss2xXaTh<4~CZ3bz^}0?XlQalEKCtLGW8Yuv<(INhL3^OIcrv>#4j9pUQ$h z?(zJ-s&aVk!_=;K838F{XM>FJeOf9gkr+AdoL1X^<$CxJfUYG+qws*3w1+4{)MUN{ zTgQu!d$C@@d(me#n_m)lIdEzU^;`La^~F*3z(snl5&DqM@n}`~^Ic1lUK+p$h(Px% z7zpt!you**a2wL9Pf9>}H_<%YcwuR1ag=a%h^4v2&G=_~@Koe4m{kPBOvdw>Gee=1 z6n)__{4Ay6II@f+^y%WSs+?+*ShhUK_O0{YeG>?MOAsLnt{#zDWgU9UBT03SnUDFB z!6hY#;C)$wZ|U`4WBP(caq~Bo+wEP3Ovjw#k&Gec%^*g}nb6`&k#5KZfJt^GpQqmr50R}B3#?4d2QKdAG)|Ig;f~l!Y+CGT71itjGdx#sPuA#sI%D#& zy}kV{ENrFKGAB0=46B zIlU9C8*lhkWRSwKV&)rcxe$`r)&L4_TfA2gNgf(5k!~My60Z?Y5Yu1Bd}|#$s?$B9 zuMy*H^#PxZnHd#|Ew2JAEZe2g5%M7@C@2x}`|#r1u;Q!-jOY$h781;3Liv+}{mw<0 zi|dQ>l~`P~Ab||>ikfPp#f(+g^CuW!Av_md%}?p6V(Sga5Fhb)=qKy{s558Fg#>^@ zg^Ys(vv5M$j7<#S-2vhF9pOZ; zK71h0%>``VL_{JDaNWKZfcD{Df_{N|W>!Cg^8G7xC}RwP3?(Dt1IqIZz1amJVQQe> z50g;CAbGQ(&$Hsws;!6IU{VYwYNGp$hA|kCZvw)eV>0x;mGBztXMZe*L2c;id@x>A;pwb!$iIO%g|WH*Tt7Q64+S3|GJaWIT@4Cj zl@^ir0#?cFw!fO)Zvk8XJ1~;~bchTxWPg8IGczo}`7}EQ6rTB!6jmovBXlDRt}8B* zBJsbevC^xld}o0FgGSYuAG&=0D6RZdjw)jGw1Jcgl(5wpLlgBblV@$ro!|4faF=T9K^o_v+c;rfc=P!^=#2*|8LaRnan z z*&HwYB=B!)YWwn<`|J2);3+z%TFpMM%)h6=zDz1lI^|-k z956fjfc(pzrxdX12WA6nYFL4<+HjPM-v$vm`EN8TWhi=!xiX}cv#(pBfXpiZ7&XR2 z=|zKzh9s01B%snbKSt7_I<){yn2*?O@vT?|ke_EIr^R3`P2Cj3e49fluU*Nzu#L%a)N85)uJg4*UxzXxsnI6JN=~zTk~YSCzVHsYL@x;_^xhe?k#rAdutX zbCACRsy#)_-6Lf@AH;RnT7Htx%vTMN-Tm6=H&81TA{OA>S!q;{@?l0t2LTTZ$7{z z&gABpSnA01$V|2u{seqUxSxP(2$Wtw!IAsAV8!z%l$ot3Q7@5q=8@3fW~df!Pk!O? zwB0TLx|R1o;5HJ=AUP4(+y2iNfNDTBF&@j9qjyQ*T&+ZR@?oI)J(IqX z)w^$s28tzm6M$$8if4zTnbfK&<}iHT{<;hlOfHxI0Ow9FymulwZWkg<8&^o&wi-*v zQb|8JeTji)SC0b?4Grdn0A+glSKvVIiTU@}{SnelX zDeCL6Q0R7~x`3r{tl_&v^^M=19}+2?!pc%{PH~nOMl?TZkifpXyD_RG%=&et1E1fS zTFCEtXX?q9I(u&vR#(~?6g!;@e)2?O^PQ|}Pu!*Y4?YgtniE2qvVA{+5(Nr83}A8& z8dZHo#fvGdXY#(r0XX4%0fMv5KFmZ!K*nrBDket_>e$%4=q_Y7Ulw%5anLHKD+o!F zh_weq7F}Ia={B9+DP(^Dw(44Hyt1rtl5yx+*hm0X>;cJ>?PMsZQWBdjL=8T7eQhh9 z5ds=omzN_BfWtz4`(ID{dL9v}X@=+TsKo$y?gB6x%-obzYG704q9?BY$&ozb3w~w> z_Tb=gS+F_><7c@)cn0{UAmsK4n?18xulBZmLlWV_u5QU2O+~q8e8(1Nw`WkHceSdw ziF{@@?fatRXxOHrWsjrp=Jrkr%V}c~FO6yI5sxEPCB6xgN*mSxPazKxNa9>vTukT- ziu$+hPFCyNqy9c-=j7;_s=*o8QhrJLNA^ralzPrcx#1gbY+w;wUF>NgEx;NfSIVUZg04$mR(>@b1sH$;juW=*a?WxuMn2f}%tRjprsSaI#c&Dgq@}(y}KGHpAt-4Y~OyAQ;ncZ}(`KL|U} zsO(mO z$hGd|luTjzkY8j{x@S6HK>}29y$cIiKA>qeWzTfxE>zCU&fy4@e*2FE4k_jRHxZ2_ zSb~ft&$}qfK!L0w=k3(j#` z)JH@sZKiNnCgBaaEW}MC>dxgXbVf_dR9>&PRKD4>$Ax>q9f>nc3NIwNi?!hXB@{Gd zZ9`8lt2Jz9p(osRNl%-?#ZQygW44KOG#~mif}y;0DykP88#sJbA?qk!3MI@QX^=34 zFG-2UoVDcGj0W#V)A+!6IECQ5V0H7;{|O*)f(sLi&AP8e3-+IKA>}pJz1?Atr82Fi zZ+`s)cKDg8s^={3U`g7!$A}%Fvc-s#hnwbq1bm^L`{CV?`+5}QsDI-{kFdQS%f&o{ z-ngUtZG%;ZEbNmE;Dr5nGU*$^ccmO_{@Vlddntae{Z;8@U34WG2Z_IkOzOGeE-Q5+ z2i7thc04=>(R>wZ@x8)V&c|kq?AA~M9wZchK%CQ6O+#AEXBxLgpjmW$TvG*nf(Gwv zBApw|pEWfbt9+!K!9heVgl9(#M>xr1Qer!5ChuP-KU>)_dwY93^RDmV^jKaAY2f3G z0ubR~Kpmaprk&hm=XX|SlPyn?Q!J4miplNv6;K8hlL~==4CtROpWCe8!+QoOQnnXF zqcK`eOC$jFetk0#TyJl<;SqY5mOi2YSg$oGpP~KptWpnz+4f6k55kU+ATiKxL<5P* zi2E#gVwFWUBWzpua6&OG+a)7qqRMzFxmv+nVF;z8T3%}00=G|1%EDNULK4nX7aSfK zzb{l#)RjvRBdarC>0XXJ3%0ubso!t(e)R2yHkoNl@(EGl@VH182O`u2`0U)A@q+Wl z>2weuf5hdM=xwzu1kSL1xggSOGb(EfQ7MVd(dEE=`SFAStJ52y$4B)~y(606NTbFZ z#I%x+fF<+er|H7F^_?dx9bOt5p zwAx1pgeOzZmOtvo3n~-hjyZebSw`}$=f+S5eN4YY-y5%nT~nv=`D1(@{lel(09Bdy zar>aVvS{(MPe@Xx)ARWA;uo$uFf~E+VEQ4Mn)=-qAZKpQd}kaN$E%HvTMljwd4n%~jtdG3>uMW|XJmC^udgfo zGYD=2?Xm}$4ev^42Qa$OF>LZx@Ta!^8|;bOSvmFzG-Tz@ibeD;8r}S)lex+E-PCDR zgxt~bd!j=$;_?;Ij?&U9pDselc_^DmV#LcR9&Jzg!Qiaore;2Da(ujfg<&_wG@J1{ z+fZ6nHBn#JH1*I`;Kg*hD-L1tiQ_@#YYLw^967HJJ^)?!_*ef%M6_q1yrF!A!{c>D z3=$N$3~+>cAO1U2sP={14Oy6IRUO;{jE*C4fb?iM`rx{ExR0FnR}>sFOuBow1A^L}7n%d}i?SI_L|!2?j7DR%9~0JkpG=9Ako^GyG-Y~85{qusmo1+=s?kpeKMM#$ zMyaj8^~MuvK%Xw1);a_j^UP=YYHKL|{{3LF_KC{u27aDYiu5~u+(;q-R0D6FV!+rJ z-;BjgVACgl#jm@rw3O;J*j{#CQj#1XZZ^67yq37Q2UFJVE0%|#+J$wp?ilOy8<=hm z7RsatR8in1f(&N#K-Q=#7C-uWp@(Fu^knK*J|wIpT*dy(_^A##L0MD`8>SgjYp^8% zXa{&Yg9S$L#|0BjDt{t)ds04ud2$zwzkoZcPCQ!lY={?$3XaO~&U8E5Bl?QsYVBB+ z5H$G2$mA{y3qz)WYarmMD0Te0Ks1c$rQJTt-_-(AuhcT9K|CWFtbmqOS~6(biF9cx z!^ltRm&Q_rY0xzoG<{S1ivPatw zdwu<(s|KgxPPSd1;>$E`KD{m0=lN7yYTWmQ4=kj6t}|kmrUO7ihLTcu)ElYN)#UV+ zCx*LJ{qf@~Xh8oFDPO}0wAfIe`n0roosjl#b2r`o{?~V;@&}*K<-fy<8O49mmACr* zpLn$a3b!MP+cOTm*e9y8T+lwV612X#yP;K@P(i~%@iX<<%=od++&Q^>xH(|pu6#0f zqQ?GIvX}1B5?OGnXUPspsv*?sMhb_2nrxhi0Vd|}fkpuxH?vmeF*l=Kdw%zv{lRNo@L zdao|fxX;zvzHw|WfG)F}lzw<&j_D*D9WN1rW~ELJaDNW_csv`$%fq!jid6We(TnAi zJ_-}1a(=?;XiP+^^yG}zo#lfJ&eA2UmfMZW&Aqjb499M8U7b?DgB3VCi@C0TxU7aP zwc4YQ#fhKpoXmZzWYzjB9_&=F?*quJMnm!Kz(D-``Z~uO@&ek-mngzUKglDx`1(~N z>NPNl6&)2-i8W_G@{XR0R-{x>Xr-#DBR46uAg<)bl(un+ z^*frCU`a~-z1w3z`;oP@%0f@8++M4uhS=McWHYWz4gCU;kr9@NnHgN*bL|fckbsNU zy>y*$v)p2v5P18Gad}nv6EBd7mbuv5mwvlsm14mSko_2sg}uJ1_WHh2EFz}7PY3lP zUW}%R`e?|VpWw(_?Kf*DoxR*1IO2t*!4+BHBs$`BNlr~!_!mu9yt&Wsd5ZN^e)%nT z)P3VfNM2sye!Z~7WDs>$RYF7UqH}hc5qGuD%JOw%{tPG7MZ36~_hxq|$ji*q24uX@ zh1J~buD6-m+$#5UH;$nTPmt;Y;Z3X@#^=+Q6l}9isbT7Wvs+9zRvLL05@KRIWf0{B zhbJ2}%(Mp@a=8j-_nKm*8W^qCF5gCekJ#a*%r#&<1V~8!v!yzdMQThD#7~03N5KLF z_AJfjWqRvoyw($Fs{xJxsM=~v?F2|puF#^ATMjNX*$kF}JG)rb}%{}~J|X4kJ(oULhp0w71cM(o@4UbgS?%oo@-mg?dm zHOrzhGRDsrI>2Ng`!J5rgnc1%8*mgr^p||2lO^Egv*3n(H5sa({FzOhC)+BfWvFpA zXd;cz7XmuUR}N&&YO`nfBHJ(1=?oIkzyja~#Fz=44ReWRBZF?FH?O63AZ}PeM22WG zc3lCo%$p7pIXK+b*gu?qP%PZ%r@@zb;dFY(!=q((??TxN#B-nVF-F^Y0$tIu!v3VN z&@8Zu=;(yU$&G^%a&X`+p6&C6<3|JplZlGY?=o>BzHWuwGl^VKGyAP7pR^aAG{P?f z(%_W&1ty?6?$1_o%)ao7q&(0l?qFdiZPmd+`@AYj79L?in^ic><54|A{C9j#~~ zZuY#VpX#gd5!C(=Njm)CnlNdFg?=JNJo`im;UXk zEV-V=14O@i*W-%#qtC^*qQuHfe!{&tgK=*KWv^R9@HC)3nc~~_2w3g!V}8OeOodu% zamBne{-bPqX*w3kG3Z=Sb$!6}pFT>ZBz~ix(oRkc-{D+H`)Uoz-42>2mC|NumU|HB zH3w^!gFGS}f2k72OcZ=ay0OseL?X(~01YPO<{H#A3hm9DDs(BsCl2t%fEH2)A|gPn zDTuq`c{Sm65bX{C<5Z>=goCo(E~gu-Rgt|Y`XL&c`05{6KyQd@$Pf1yl!x!rDxZAy zks`euHgo<6=$s)31wIZ-tl+5CH&BzUbZ)1Kger zYlS5PjEI0rOv=jpGXuqIySKa&s#{B2BX+c-F*oo5CYLIXsq-5?$1iyf}?B~xFd4fk;il12uh4W)u$knryDT?LX!~da?C?i*VV$efL zQs@7~#N`TMcVy=wBle5x0v3VK+g_=Ut|d{crpcF@x1L&jp=9ahe3vSt5h+}^wfDp+ z95RK7G@|u!0;6G`_4lJR+@8z*BO^v9)|dIG+Tk$fJDFKT&yB458~9%^u?LX-MG#3R z@+N0Z4)@j+S0fM#**=eDzbSX$_@`aj0QAxKl^6~d3or^yD8yJ;I0>%N{*|$x8jPVS zZsk=g4QO0lU2($r2BrsRZEu}WfK=xWYC@Y^|F)*}?97rlj@Tn5qddRp)ir z+32n?pa2S%zYO4gvVM|u62eW-j3EgEcXX?}$ozSL*NI;*j={l~-@iG2D{;W96((PW zN`8B2qPlPKpz>KjRLVb=jTxfuHqaI}LWfdMym&sjy1e%&$soY%ER4)iBB4Fu5fgU9 zTHq{I7kbP3{Bd!%fJYMhws_(mjF&Qs4E+U`Czr{7SHz;tHQ2)NP1zYe?kKsim;|dO zR9H)*{#SJ{mJ$P>hf+)SF|7^V&wBOS=^LU|_PlH^k2`u07)L-w6#@YBNnqly6nG`} z**k-mU+LHqlhN*PLFcEHHoLvqc+ZvK#>%*hA5tbQyM~&L?b~$kWL?`8KWbCy$f~+; zAkU%1rC-f0Q4npLW!>WQM!a;A4Wwk{4C&DIy+@;aFj_d@*07!K<-f-s+U{SK|$z z>(DpSa7=RNu-sg9T;Ipbchp5VX z@Ta8I)S;{(sa`C*T`R;7hT{$Qb#61TKb{Nzgko}rdCEP0FhlRVa#b+h zSx(f3b|0&6R8^yRk3Ud9+1ocV($TrQ88c+QP{71hQ)JXOb}mXhQM!+GvP!~Bn;gWX zQf_bld%~4f3HP7&U_vo=?4YMr@SmlgO5<=-@P8zBrC-QP`9Hl#Pz$5KUDvlQtIROl z4(`H}>{_W|tgo+UzVNQTtBXj3HD^6uJ9!4;D*OzEB$Zcho6H}Wxm;GcGQ*`&l)Ax9#dk|>0~f&Q#BPdWanA`vv}Z= zQ&GtiYP7OYZ~S|r18Hl2j*v`gF8O9(Y+10bII3`iNq>CLlvuMbZ%~fL1~Q5CSDIY6 zwztQDH9UBG6mciWM1b}oBYvX5KqyB?M1Q6P(;1@K9@RWh8)KZqy=AxiQ&Ei_Elb$z z)S+|@k><{`14%r5fB&HGncL2W?$=W`e$CVikM?p2T)?|hWl0{-n_t!<5OJG8bV)oG zRo0cZLI1&7>x<*Pd}gk)>l8gS?f&Bz1ebT|yvBrPjjU8WdE`<4L}Wubx|KKJ8)JXz ze_qO3wOMM$e!=#`;%sP`n^TT^Y`fmLSI7;Agc(^jo(&_FMvh+dJk?*1qP>1@*j+I5 z%rk|@Jr*d?buBC;K^&tZs4e~_hq&`NTsv@7LFh+lrNOTqQFcndnH#RF+WHyb76F{s zIQ0%Mj7ai$8~9s;m&T((OP~8Ombv-?@$Pyl+4j=OaB=5Y9K|2}3%eTtiFizgFd7a* zeI5Tx?{M`=51Nxa`E>u3LIU}}y2rXVHjk?n!a!Gz0}Ood&(xw7sEMxk+CBX*^hu3& zDz@<2Mrh3E8zt`b_)gyqe}OF8`Jd0=x|cc4;~#}D1nK+*joC=+><<%@Q0vavds4VJ z>dt%ZLqVM`SM=bx-Cs>C0|tQ?=D+HN>e*UV`DH(sXUX1;0XmfQV~@_KrAGMnX84tt zt+@!4P#o^WAf{fq)#qgl19*I$%dhy#B}knsGl}YmO68Qs1Y9>495ZQyFB>36fZ2MF zh)cOxoP>HvY95>_f5}7t@QLTZ#VVF7FNrBnuAuvL_KwOfwo9|NxIU)dy5Vk&ZU&B* z%^OpvGDXa{hGk3h$WW>;QYK&0Z{vmIt#kXrV zodFB%!~z8~e}ZK_zuPoH%*B-injXMc3Iez*$c$}Y)UcqH^OsK|SNvY2i2_%`-l}uO zbEto)mImu>x&5;7X<}4cmc4gu7Qynd;M4<&MoLOD6s|_Er@aL3rIe>jac3juklQ;# z*fTE?t~c3i0mYdJa^Tl0Dj3nis^J~J(j3<>&0P#XZu5>#ITWRx23r>fvx@f?&P}JsMp1$0I*tUc zleIKctdQ=4hU-B-N;&0Df567+r8$6>rLu*|ygcEZW@O=QH-VC&3*Bk!Xn3a4T`jcQ)#8?-|_tIhC!Kq;E$x263s3 zb%iZqxpsX-dw~zueTxp8uyk^Fe7-&A+$H>#gX3HkPb5E|^tv~+`&m?^ z&`+<_>wiDgV$-}+0>MR%Q=p{VS@S+Dnfp=tyw!si!+Y!P@PEw=EUX=UvfM}%lnrk0 zB%eH4zCToD^crhB9K0aV;vtMpvaPj08-#U<0K+c03p@M!rbQ_K`*QaVdRB(v+lMz& z3m`f;pI&=>0o_;y-^p657h@Tgf`^Qce|Yh7k?wm?&|Q)aWJz|-m82U0-KvdFG-6_6 zbMy0`ad1MLo0|(d9}Hp^1a0qDHt5SzI8OWqhGw zicXCkd3L9;n=Mf83^QXn>Khn{B)QQZc&fDy%8lKK5(pBMSEK{BHMc)Zl|J|! zs-W^^J#kt?b3jN>WpHJF^d|?w?JAb2r;KuXEA2hjoQ(PE^m+~YK01rFTq1#bUw`S! z&^pD~#6ZAT*mADM7Dmg>YY-fL^o)${Y_f{TpA-+JfbYud+2!Eq7+-I*0NdPK^?~WW z!Wx)&pre3AK|eOTse6D@Uy((}a&x;@dtYd>b(&FTmL{~>S|GNS>rb`Hytl`B-)X|_1CEUF*Q9EGcp)% zsAkV=KK_>lFqn-d5o=0QrJ|rp1+>1qRH{pt6fM~O?Y1XFuCPE4F@_LwDQC!|EjS#B z*_@uzv((DQC+VDa#1V!Q?>5%8I0mK?;{|6KKGMOXm6oKw0jsK#&y7s`MQVjN<89FI z->)iVz@Y@;5Z{2)mYdtaqdOeFNWK&*+(36P3BA=F-4mxVvf}{3!qV|pqvSV8@6rNlt9xjl=f(O^^M7p%f%UsyQSVA7mo z1!3D_j8Gxlrg|;wu&+m7e?JMa16gAhB7K=LnZFiBPWL~#tg`Cr?#@nV;BONCfY)DH zPGtPRIK=}ko5b_RY8C3+u_|}HVX%&M;sqMz=Ek@WjW3FHZu_$Tj28I3n8e9 z!rGrDjvu%5T`snQmaITAu6>ShrTg?-ESb|b;BM|Mw?{TruvW9CDyuaNknZB}xqQ-Tb(l)_#@pB^7h~TOyLzv0NR01OxhD@>iB4|fcie*6B8~L= zVonER+3+2}#{Gq>F*8;r6(>-vit(0T*8i^VEh7`-InqVUA0D^6-*wH8`8z*E{wyF? zoQH_TiE?_T?i?o)MEr}2O8KrR_;Ghs^P5&?SeCtI(=(6Zhl8Uc3#X^%kOP+EeC^9= zy;DHi{e&OqBS=U9I<6Vkn7I~1vg(7W?1JY)c|S8VGtgHD{yvW2_I~*t}rq-moy(^tMtx8O!F+Alz;K~T-n*dB1lh(0nZ-Fl;NQa{ra`3(xb7Qc zA|M2HaQg&r4^o7d>fBNF-{So+_*XgxCrXYLi&limWBMkGbYyqj_mSKC<}HeX{l`Bs zr|3A`dT_oXk`!mxF)6D>mKcfEF{P}!I-Nm&@E2($Di*xyHqGNv=;WcEqQ{7#LgysGXA0QH4Y+CApLtG3|@RZHF1I)DKSMt;!kR4he!v;8;1$+EGfxs%T~yroCdn;=g1kRoiVZbV;`n+PJr_uEi+K{d zn<&lCU@_S%Aa?ao5yK(jvj0-*zED)0sO{>hVZ=(rO6GJ@xAGVrY2Gr}ZF+?N)JX3q zzXx6kK@kxYT-@QwJel+JbJxdgBta>uKWn_J43Zx>k3h9Co;iIjg>IYs2wy-E0GX&)JnD8K}zYTZi*n z+f<^RNzD5P2RCh=y7y{}3Fge3(h2fWN25|#Gy)q2!n*00uz9}d~ z0m_pcc%ffyn=9%TobMAAdXZHtubgQ~Q?iG5oBt`+Q-T8@EX9E3#hC~f@-Uws%Y6Cj zfTbl3wi4uWPtTUwni9RVQ|b1ZV|V&8cLic(gv6Q?%{7}Hf`HT>qdR<>QK6VfBrbUd#$<#3b6r1=jdPG`uI?{k>nh6JHpo)wKd z>5yhu0(_jUL)qq5eTZNWDg900i((n-2XcNFQxJ^l=O?N`0-~K!Ks{M*yWWZdk1q|b z{{p$DM5EJXH2?B^hOm+9`DU6(bCnWZGb8cInLiZz=Vm#g4uaIeJppry)w-7rG>XMW z65sT_p;EJX3G{Zg!SCq*^XD7w7d-L9^t>hSZ(Cu_ORdiMg>c)ZOBJ63geL+;`)g0_ z_S$rto_Gvzhk>UkCwG8%$+Do776R8}!QDF-qVd08qMY^{)Nyp$8IG{`FUo<4IO}&2 zD1gwckZ&i1#flKk)Arx*pwP}X*|Y{)#~`?H()}w}$ziGY7RcZ3;kv0l%POA6bp|5Y zgPfLFYBfTMFCa*xQnchbQn~vg8lU6+X6~CMot9J)?ApOFl2J{+Fkf6!8sFfn)YI=; zS}Een+lsgW1BOOM$k^Bt78cZo1JM!9e8CNAy4p0x7dzI!aIt06cs^GTidELH!!R{K4I*KCh-Og4M~!Q z#*1p>P1zR?T3_Z(WPMK)h*O~j^CsZJsdv~zUt{%A|GdHV?p=Y^Y|PiDQsA?I0sYY-i^ni6=uqb7KL zM38!Hu6{_i%$`g=CwZXwk65*+Dph5uBanqk4Z)$=7kZ=k7sJg%0M-@FMgq6$AbICd zS2qi!xN+Yfkv5Mu1i=sf@QBD{_3$O&8oNCK#rU7XXoSQr-R0M90%pN^<6bT~&**Wlq(T~&T43X~dy+N?W+5jl zug1x~T(w}x+#V;(`OIB}&wDar5gtl({{nq5o8P`iptHVtQgko|Xu6_eVg%qwzPodQ z0vUCZAPra%CnjAItoR+dCgr^0cxHGBDv1g0rmr+Kxcd4;VYnRaMLN3`v=(SJ)7kyM z_;QCbG96(ePnWj!CP;W6Gil!ZdFDv`E81%13p8yRsYy)zwUG(n0`i2tHQL)#xvoa) zAsK@VwtxA8C0p#KG?St=TDoae<}pcVHs!Q-6~y#J|AV+tAWpugBYA7ld4H;j3h;H< zr|pTksFHvOhy3(PELu{Wo{2$QEW)B{y-hW>TzpE*FdI+cClt)U5MxiE_PaChewX{i zoYH-`v%W{Y%=olxeaS$j9{}FkTC1g=i2|IUz%T~Sj#G>I#-9U9!K0mrnO?1+_DcpS z282XJ6gn;3Kn3XH<|dyZ&^~~H>X&M7Ny#ZEEid8EEfZci;It*b&Ztk0{N|{nuzL$xiVR;;SMUv1f0+fjM90HzH(4GH+#0o z66WOOlv~U*zA6}(>KH&0xr(FcFIfv1=9}H8W(By-?;=PMj~3UF;IW@u4<%Olo9jIf z5s`y94F_Xlcud^(PyhX~<4a0RU<7<(y9)ch!NKp}6?~VLH1T*{1na%=dBvEC#r4ap zDD&uf&G$dbpX8)loA{LMX%I%+fCz8U90Tk_2Bde9tCZZ)(3*oeka`UyGW3f1WVO*3 z%#nA{x1(<$lhaJ(iiR+TyM_7rD{)FiwxG`fLSVC za<$*8*FA1)Knr+YGA5>|jOTz&#nA{D^WDti;$bk1yf2a?6Gp`AjBUn7-GEbj!3f9a zA>DDwwh)@)o3%PKY}7sq=lurm3bdUf&8mQZXJ3qH?*EV$XZPascys|bT1LGylbT-3 zH`#{R8<{tk$7j}e_S@7@CE5+b3Z=?LI`t9wUN73~3k5Z%lcZNy7o`-CtNlp0t07JIa zZ%*X=z^MJliNjAGLyUz?sYH8+xVK6+x-(TiIi$gZ4MrDG9i4>z14EHCC6q;sX+J2e zGaN3?&oe+R5}$+t(dJ&@Y*4I1hd+Ot@elZVpG8`yvrUpcNlQB+o|m6b*%bW*cAzDx zw8TeTf~)P#9ADym@=Cb?$rgB1KmpCo|92!JUngS};2^1>;FTfb zPt|TsV&VnF6JN*{fdr4wv;Ei|>7TAK&4a)9H?|TI+e>wMhQ${Jdb4jU=8#7U2`MR{ z&7|=Q41`|W*%1cazu~~|)AH^FQB*I%1SJCaGgH^W>u*F`cO7XMM);?wK)sDc3@Rmv5@ z)xG74;t9zr5t)ds&~#L<3}{Yv&Ns;_1T(6oNHW8PwjMb<1e|}UG#7v?U}E|Lx529? zG|1a#|H<&EKtclh;9!5k{2S7|tJ0ALxw%?w5~O5{(#mS3NzGt(0HiLE=K+%w0N=l( z(T`=KL{HLvf2yw6l`Q@_Kf$!lMovzrO&CxRMHYu(A6+`9VZg{E>uAM1gEV zGDV22`7-9U8j7sO1DDyfdDJZJ-LbttvOWMxySlsOs?C&u{JrqN<}LqS95{8+ht&p* z`BTn5oDqWlgsV1S^Y}F1fj{>)xaEfr`ekKhP+yq|UucZct;JM=PwF!YfW2F%n=j07|IM^Q_(qK`od=IJ*z+ z;AVy{Gge{l4Nr|uD$UC4+zwHJs`KD~Xk~>^gv&Ns(WJ*t#5Y{;3*6m^s^3?4lC@!{ zLD4<=dXyKK#GD@1*_~5@YaMQrRVnAT+3Ol#Ewe1r5k{>;0lhMeb>9?*uK$w36ZGQN zN0`naHV9ag$Df|hHoxTbEiK`ImksDC%d4xm<8-{%PnI(w^pp(DZFZ`BrX4oxW|7_+ zKHJ#5?C@hCh3@F+D0oX97=<4JTl2Qv)+akn5G$!R+5z_OB$$^H28pbDY5&}Efyu0K zA~=}zBA#tQcdhPh$vzF>Eg3X|BeU>-E~4S_;jt zB%8?8Vj52qBcBE)y%K_gZ$Eta0L2X)s7*7=zFKKZxjl*j?@F$XRgR3WMV6GhmI^`8 z4CgpfqS+#?Rz6k4Cdy8%hbopEVh1Sm_WuTIX5tf)?7p&U6Ue#jwS6w4jLple)?17# zfbiyVIvz3KuoU$7bar0okn4OU+}HYQWo1=ipex1%V&H)o0Qj*Hu&~5{@(mcrD&o#? z6#}_I!FSoQ5-qN!Wug0-jeqH6BgD0$ijgF`Ar!he75`InUowT{O}{fB6%6< zBxIYl(c8gyIZ2M=3Qgs86}L)#D*4tLz^hmh8`~-_6Gyr_HNA~uZ>^_M45yLBR&HuZ zF1OAgd1uSG86cBw__!l^XIV6{!j^!EX<|m%;>gUBt&h21>Lf^UxGa~PkeDbBj1N7+ zcWXCg=TG<%5D?T(>l++nW}6_g4o4RMt`1F7aoGjTREa^k&3o`HfoyMY4=k*7h-Vp@ znW@j}g@uJHwrqi!e#q7fOw4Q!9Y=BNH&0Mr7t97A->l|;Jq12HK*Igr2k;|f=je2P z$tEgQ*BbWq@fRR||HK;%auCZv+!07hf-0-1=vZA%^YQUb5YTett>&e>JYLYH73p zOL>igJ)nm>&P($G0T{R{Dk{K}TD8fA!Tt868%VZ@Yl8m<1$`hTB?Z5XiwIcXegOd& zcjx-l)YK-Exd=cSQGh5xP7KBfzwvlHvgvQG_L-)prs(Mh_}b4ZiUh!O+n)cs4Y1jOO#~(9Co2OC*OSU2Py49e2K{`|&~Z zdAEVj*TfmX4T$>96(d-r@3d*5vUTcyg*ETx%i}b8Sl2Z)6tI4RCbAQe7CIhcC@(B~ z+vxeos*o%CCTUAiFP8wy|+>$|%muV%ku>yO+Z0V~RCx#9fj z{xTC$;jRxL4%r8v{PaYl@zAt)vG;}(a|qZQXM|Kvl@hpA=rGVn)CEf4EN_pT@|2OO zsm)VA+jCrc$I&J-(=(@oo%ZprC2hhTn949%Z(DfbT%dTi?7f--3@s)+RVEl4ZuC5eK^=G`TNG{2*8LDv${tu?U0xGL5>h=SbMvzVsk?u}G z0YO0!q>&JjQX1)$l12ms2@#NxM!LJDr5mKX8{Ycv|G)9xy+g;i;~MAev-jF-%{kXx z9no$>g=O>)=_2=I^4=I}kzm`Y_*FcM$%}}Om+wYpAUfU>ZT>-%QQ$UwQ)iup(<1v% zP75uYX2nNC*DFc+{lU==n7gT<<*fu3u0wO}9odV|jg8Tk6&YNjJ^K>6P-0XfRz}~4=|2HRO@^rCNqj6OF6x}Z&!JI??bKZP$QGRHpw4|jaM@&x= zYJGav4Q=htYZ-C(FP(|6M*4^Q(`<>)QR}t_?tvJ^pvfNvAJ7TdjX`t-(1|K;4{?DG z9Gz)sk{mgB&nk_(~w;HYF} zIaBEaijFZ`qkgB{v*?JzO!a8H!yUa z%AfchuGVJfgaP_n1sp0PAB6!_=QA?Hwj43VM5O>mj5{L|Xwd}F^ zHLWDVzzj_>NAiyf6il01#J?{%zw{)Sbmr(=f5=k+YM1*?;bgf-maQG|PR|uB%7gOfDY%{!i*IA2_nHKfDpsnrODN-ZR7M!P!?6f{SaNrAF;u z)5u?<;Jg#CACn&$e}2U2@5!OxOKK7S#`ukDrYjLT%sc!X{hd8)ApHqDBV#FlRwr-Og&Q`lIjh=bHm&7;nm&t@4+b%!%h3j= z7l7$dQD^`(13?P`4%71}AMauCq{t@kBtuQ9nx#|tXnE=1zdeOsT0q+0{5zp}yD%ZJ z`3TBSiW2H+W`7Mt~u9WORh$y7I$+=a0OG|GqCN99h#8Gcr;fay+#iwO4Ls%Yl% ztK}XYp-`J9MjqRCvfn&umpHjB+T8Bbs*~0F(Od~7{{%fQju`z9X&%Gb0>Y#CB1hQJ zlREdHGSWoh^?{?|>d@+YD4Y>y75?HkTnQ$tiUx!ar`C7R=U1QK-DkbBPpjGf@S%8h z^3u*foc;1zsIsa`C139$80dEUVUxa{$PV@)c`P0hCl)`|k4{A3SCxwQBkk8Srw`SQ z)uSh=K=urXWzHJ$tcec2-K_4|rnd9gn)Ym;RaU_~uwhjWYy`RR0`&CxG*!8GoM`Ut z?U8Y3@)kx+m`)XpeA8Y{ z>9-2Ea4}G}A8k!H79@Fq54H{M4kKe@+aUv8 z?N4oM!+wdc3Sd$w=eI}9 zMtrurc&vT@_NG&l)h{Qz4x#vfARP4Z2K;+bTb14;b%8`&;B}o z5hxXp{P2xGK5aX)AVBjAhhCaRkr&<$C98Il`_Z>7DW-zWaHn+6_w%P&*8f8_H*3r>1Bt)*qn***X4rLcrCQGzyt9u-G4*{H>yQ(32 zCp`4FzRv^8!o&B)$bC1)FaK8y0LJ`J6%|`YMq*%C<3HE$FoQ$i{g4>!XBfo9z ztaNrpiL*FITSo__nyJ+w>OI;nRU=PgbEo5dR@T$9-R#*&Ll$TB^cz~%}h5@9M3m(fON?1HT)kryIdr^ncjYFsW5x{ZA#>P(uQnm1dg6U+# zV8S)WE2hz(&j=y5`+#o*WQgK7!B4~YW%?63iKYC>r%uns^c|uE#;1wh5imvq!0L}l zvhEJD)UQqJED;xWtlxghkC4G4BY28_uMBuTC>06WVwqeBc#G(Y{NRvUndP;rz z?9ImeZ^lUOHl57fII$@xPOzM8kdCvD*fLj*dHSzsCdNAC=jE$DP@&NA61JgH#}*xM zy4YLFm)&M}j`-SBe34_E7eqJ~v_YqGjpOa|{1Ytyr9=-qp>Ipwuo6Ve0wL1-$ zXWqjA0+2Izy-Z>P)c14lTO2H`6c_>czgt%o9AkG6R#hVI5#5|i3>biI8}%SnHhc+& z*J8o^VH!rzJVSxi2fNp8StCdDM9#unDM6>sBN6=AV45+)tScVw<=}0UBn+r2p91%r znFUq!9g%TuSXKgZ*Ux}Xe}DLbj&_%*h%pRtuf5kNbztL8HEUU2)hnX(K=1;*^^{Ljx!Hw&^~u1eUvhuZs6 zF<*;`idt6}@tcR2I1K0;#`9Fi&QRRD;@|RW>LM@AA^)4P@>F-RU}v{TTL2kg_r$Ft z`zD^+Q-k|(rhP56pQ)XX3;oSIZk0A5u7X;~)@Z$T@wc=Aqy&6shfjWSik`WADFElr z4_lL`{~Ne>eis#0K*>5hJZ!rF15|`WM6#Gz9^=bZ0*3_?Xo4uQv9Y1vCTcz27`x#W z9*zf|p%h@Tk`O`7kuG)L8b zjF8N4@HQPv7G)(>WQ5yG{3AQYOzp?%8++RkCzZq0E<5X^R#4`1bU0qkt)P$hf#5XfcK{<>@vHty#Aflu&km>fy^JlRU1)dCSASD_fE1hUYrbIjZWSX3PUS zC-YKP$Yqc1*Ng9+onfd^+>ru5sfVxS-6pZGXxsZ9CAGcy&Jr?HX?OfkFK{RMHVM1)LhXfG-$(DhEmqT+Hk`W-Bdkrr?S?f*-(?yZl-$H9B(p$^k z)H?_{y)#OZxGScqP=^1Ao@zmT@8OtbV~qF3q*m#cRb!(eB$j3bv-F}%i}N?@zn9BV&>afq?u^Ivl_%WM_@ZYay9;leT&EmXi$<4B?H}DEkQx8vcF>flFa$}ivWkk~RVQxwh+UNY#gIcm zs@V@gUXq;jI13|ffxQK1ZHl_fiZ3&yxdy$hekZQO8s+*K0tzSgTUajc$Gp6)w_EBT zh3gx0ywnaWzV|!?2PI(KlMLd8IZL#xuI^=G9zE0V{JG{nZeD=%Fs8=K2P>9eXN_b9 z|3{GO7|l#mOPkL;g5bet%BH`!*Ai?ZZLXiRU4k3>e0Rs<=Ygiu{?Wxzm(rX2S4$%j zfdzDvkqgAB}p?+XcOx{=?J^IfzX&J=+|g|z~=;Kn5L=97(u z1u0-4rSqvFM6VIyS=#u`IjbYb$9LaWv}N5l{GjgW18VcRL#|iXh==JACbAw#Z2P&mF`61e&g(|7#&_6n$x6W;dM9LH{zg5sqGbwTY^3ZNY%_qKCZM8XldHpQ`3@Nv z*rm;mx2$K~e!^t%r@LE?x{;9;OwnM<-vW@lCGUoieHIvLNP{_Vr7vme>9$uF&HyrV zlBi?Szf@EV%b{_8`H}zuKrb>VRA=0}ghPv1*{5pyuNLihP|98SIJAI~Bi*=V)&mqz63GWPnd@ zf7<)N;F|V=@~jhCGyW9<&g9&1z8@MLrn;iX?cYB$d|uf+Ls^PaASf2sMVjB2>D+{O z984fi0bgRnf|Yb-5b0YgJs_T*)%HZcxXMLh8%P)>OpjY;+r zV!e=(B9zaT%euV~X$#QsaGtI_>=+o~Kqn*=dv+MTJxYwcFulv|*Ga$~PFh-ec6mAD zGjjIyN$meb&MbhK{ud>L^Nf^hV*P^|STGcrfRq$k$64ay;`q!*@jd@zvs7VD*&n1~ z$#3&3xscFM%}Rr0bMxsCKXbgPj-A4$75*Pp=5CtTcW&}y*8_i%iWzvYE&zU5)N2*O zk#4Ntei>#AI$B!+?ulBzN@#LoQ2$ijk0UL_PZ+8)oRuqfSH~yHz~F&Z^r>B|n+RH& zNl&q}Z;Kx^SAr@yN4GL;5nElkCjB94BnPNGaXh|pf4!|Mfkd)5rb$rh!}=0I{R8H_ zFO*7HqWeBSUD>}Ba(J?t5%LW)!Uj{P*2Rat%TvSN$y+H0)&n2Jx8N)QwHA1#s+8tD z4Z?;@0;~n5s(`H}ASZtYobA%mlBTBSXGkQVOv-8^yg5gRoSWgT^s=S78BA!O!^8-n zJe|%#C!PHB=L0QoE=ERlxO%}p)GsNC2GBz#dnZXu;P2zN%K@HSl$SrW!rOAnnm^w5q!Z)0fiMi9U!59tUZm zt}+g}aKP2f;jLNF^SO9l@goTA11C?R4VB<~!GE0#yu4;xQ~qkfBLm6nCdHT4{Gi?m zximl>Tvli))sDtHeNPa?3~I2_{BPZ7xpU26>6iMK@q_tLY+e$7YEa3E>@$1>(x)5h zM`dMYMDhEPFqn^=B?P(X&E%ghCp*b(k54wnc-h$5F)=Xi-MfbXvp+Q$IEqZllEiN=(CX!VZqSz;5##qtgaFYy;K-` zW;f=W{p?7exkt(M=2vl>*FObZ@U!XDr=yN!NI8r9RA9LL>{g;{$eB?&aPocskKuktr>Sfn?Jdd+Umxx?5X+ru@r{}~{vZZG70kCBn_6GSGkHHF*`wSIV*1di3k z<)u;YHcUMSAc4##Yuh)XkoFTERl`^dFtzA5nVVjYUZfy0L-q$$9zUj!SvsOmqqfzREbIy0T4sreqLBze%5%^ge3lM6z)1XRY3=QDD(^*Y-zUQnDVbyTJdxPtOA8SY)Kyz8XqWHtfZ-Pt+tP}t^F`?RA6 zCrf%J5YKdB?^o9C$4)PRGNa?;ivn~7r7z_PHmIGm<};DO@yilGpwKyiOtyUMs+OTk_MO!2-$K)Si1kji<+v|fsKrjgs{3R z;S!=RMS}o}6J-UsL%`y&+-RV=0WQ9DuJ3N=mUpVL<}6Te+^_+6d3!QHhw~$g+jqs7 z&b2f&VwSq&v$;}^4Gjz)N3*o#XBdN9t`i(d0E4blxwd0K>Vc%%A)29pIsFfy+ToMB zx8oy6n7&*Ba#Q-r6DXsC;dl-}U4S9^?12Lj2|LVVSNjCHiT9`i6A?Nr@$Y}LU3)1o zlkVoto3q>7oL@#8(UH+C1O(QcwKVeWzw5w}?=$_i4m7Tad(xk$;wdG6A%@0ml+0yy zQZM+pSb>c-RVkU_3M2u@m0;@iv%FM?p>n{e|-J0v4uIl_gnC zNE%M=Gjs(Yn*%=HKu2)FJod&$Z>XKX%N8~uylF!KHPnq!4}kx808}X{De2nZU#V-> zzAyR!EH%PxH0k|G=>`f8YdgE9rY0XAqs~kr$z(plz{WR7#o|!bM4|Z=*4p1(_H|m# z-VzSTRC9Do!_Y$Ya^Dr^E8~J(0a)N$YUUOgrt>exgZW5$b$j+sS+$5Oi9kT=cV28d z67r-dhUZNmlScv<6sMKMCaNpX4D@)W;B z9t6;q00M}3ZD#fi2yYY=6d1jjDI@}$qS09;MTw0EEMtW#JT66GMI~Rn2nS?2TZC$k z5IG?GHnm7InCDkm&!OJSogIc?FgZ2##E4}s^$x5|ds(ue{I$;GDn*O8h9D>u5*eC@ z@N|NH+L!i1{E4)5fl(I;h~=c=)&XHTdu?s4RT6TPGe&;@@bFT5 z1fAxqSIER176e8;fqnrA0^UqkohUXVL?)PEu=}Pke@8jtLc;Z!THWaK+UBCepttGn z{U7Ei#5zAy<^{M_cvW~=OA>Rf7DoPl-87IoALmkF&$fya9wf4R@#qF&0?wT=>rtwn zZM^xFJ37phA;$}qJuSk=BHQOWQJ>E3J*ShTJ{_c@gE=k=1`f{O@Vn5-pjs~3Km1f_ z6z8Gfx!L#l%MUe+T_fxk)~=MVD5v9(N%Gyff^&GaG$%CYHrC9okEcJ!quU%z8u(mV z)?AVxf54EV!s4771ow67Ekt=`6=J7(pVxW*^xNU|es{RuBK#PtD^;2J?J@Gswr6qn z-$p>-uZwpjB%la?1En~qDR`Xs-ul=>4%-7^JDK-Pf?aw&f!1 zyg|hIkUu!kai)R2EcE7&%B;Azh1K8=~Z18p+OZ0mjNnx#wW{SxDD3IX#`L_Nqe z0VByR>xPE|509nU1|K%qJtp6sWh4pu8FI8s;Q7^+l$9UD^OjA!WV`($l*F)kqNb*1 zK2;f=m6e6YDT8}gO!H%wFaCUMq%iUg5{s#b6==o9^(0&hs;jAxdJB*%SFEPTzrAhS z&jw#j=JdqH(ELp!c6i;NBv13EhK`SkuN)&>ctTI*#nmds`f*@*@TWBS_p*hMawNh@ z)3=>~02CP?-v#rXPM6}XoaO@=P1sduA*T-wXsA-sL9D%*!P=kv@gOmF_>(vi8JGmX2tvgm$S8il9QDz~Op znvmlL=`Em+Iu1#|9zr!ZR^`Zv)YEiEF)Ju%;VUpOGY10h_8Wvw2;^*@Tv*mId|eD$ zUnUL?26$B9v?HUVZ~dKqECw#en3#U}x#NrD!E)=$w7vsB7cFsFPA3=U4s%*5~lI9YH13KlF4 z&aSVo@8rsWucKSO}S%qL25GZQ7=o$SmVY>cVQb^7ANDgn^Q3q$Q_B_wFP0s=75 zaL69Pz;>d%=NPmUGmA1ra!O3TN(0GKpEOIMWrB)$B-c_T9Wpuyia+(Ilcmt_-@lK6 zhsSvD-lGOAzj6@*A}rc9a<13)^drMcntZ5Q?RYb&X~+eh{6 z#jUFCuU}K=O< zu!rljt-U=!bgNwI%q&=)88}zu_p(`#f$a|3JOfxA97+KkLc&uM6%1PGCdjI+KYZZR zT$-C}SkJEuOG-+TRD=Cz%Rx%idEqCCL2t@OA%1zFf`x?@D`L+A{bDMW5BepeqM}wN zf?C_!RMgec{(8*hQwo!jkfd8jOJ2Czwy0p(48it!rl8Q$)1xRU*}q0aM6_~4GE~N= zq=Y|JCiD$F2MDy7b-&|F1{h~q-mLufi`j9?UZduafB=hD+MOpw2_xATPvYV;bPEGj z&=zoqE{a2z&kj!SS!oJ12cDV5>FDahsi1-7_^>&yHpk*u&W$`9Iwh5Q6aPgS0$R%i zz);;E#m2;-*B2PT7zuMS03TG!MIexMs~H=o*FJ$-i5~%iiI^ib2oJ>novsq(-&}Wn z6%Qv%?|1N5d@pfOZSLFy%p3v0J&)7QYxnoc(Q;|gwDI+A=Wl@H8rkXs=4v*4~v`PIbZ_6mmb~L%<5bSAVW811GgS2h5t0a&{FZB}4)q zhN#0vrnQkIM*l^-m@hoS?D=*bqksSjRJ$<1J@YLHU|d> zh%>9Qw?Q|~qVo%1K?=5Evzjnqkz-?HAJ&HzmfjPxf00S0g_r>8%9n@;lt|*?0L)vA zH!rI&?Z5wD>^Dhr_lGRDkr*=#&J8f6PXQ+JygZD#s>FW;&MuhP*vR4l1XBRi5$C94 z7(Io9Ud{sO`=iZvR$6Tswn>FtAwWi;rvc?2;CMh@@j}q@%O*^*f2OF&Kui00mB+W7 z_WwU3@m96WD9QMeI=9Roj#potE+@&w+#`xgNSIvmZY-Peh?n|w8)av&ixYYUepo+A z`IVND0pg6OFrx>#Wu044Bqb#e8eTto6gu{@Z!q*v-xDf#w}{8Ex%2 z@q28{%qZ`Cw&0-%$;p3N+e;Q((mvL!KKOextqWhleD4MglWrq!#J6u9@HsnxKXHcs z`1RbF&q{Et<!Ol~UDU__zghqf zDlf-rXLOhy=*LQ+@Hl&j;-(K6#GToLXyoiQK)^&GGq}Y{oO1f#ZhPW+v5SuA@9)o4 zW%0|tGrtGBwG|c;uB#^yr1~;m&>-V13>_oRQ-r6nv!8FqgvK2i=x8G*CVsS8$Nb-! zVDQDy4G&>5k5{BOw*GL9iDk#9srhYgxZzvgrB7jj<%NK!sK>+lEY7Rhsl#4@@~eMy z9~3X5hi652u~p7OHy>YB+`oMlVOr4B*h!_qDSJUDVm)_nJZM1fC`Y2XqF5)#oE>A6 z+UrGzgW!gnbQ6Ryr~Ph0Xx^*`3JtC`$q2i9%@ zp`k{w6X9)X;0`dHt^as(a>AsO2-kp4%=?Ed<6q%HYu)W{J*x`8GdV~ab99?%$#k^n z3ZGSCqWM6uZGk3lqpm{shghcny4gc$L}7`c=B3J zi$S+2(tggWdKV3XiChdlMMHn8EZj7Vw6q@~u^`XE#=;s1r^2~?TRr^|t@aOfYNGFQ z!nq$XC5xTxQ!&wRQ!(imNEa1b&4?q95?oXN)uiCm30fx_kk!Lcf^(c;F;ko9yuXA} zvKer8yxj{lgXxIk)(NR(z&gDhQINdmTvb(7X-P>`IKZI2f%`69-vmxa-^;&)MHZzR zB=asM<3Fs|&o<4Css@ZO`csBN89)s~54ZlRtuNPqU2B$8a!BMKTHau3l2C4-y*3%y zq4^?-LEG#Hd^mCttGKv05Q7VAYc7k4kSjn?x9$07w-YYi+CycGPk(_-SfXWK8L$W>9K;oU48nZg?06;_df8#&bFP?Okm3tv!2n`4mY-~6%h1D;MDCN-9 z{&8|y8PZW-E{0Y*+j`Daa&>_{&Id@nLdR`A;IebCQ&zUOd%aL@tjyF&q^BQl-zSFE zKmZW(w9Im9jXjCm(EpDvoiDz^a~Yyi>b3=y{?&V8l`1~#-MKiTq?*=)>@E8oj@n%iogJ0um!}5fA9=rvo2 zOTw!XIHu$OO{KFG??IFoJ3bH-iFL1|sWs14BOWr>cm6SLTT;ZHnA(FU*Dm%=K`o8M zYwrsdLFwJO=492or>0nl%?q7x2j)2PGiRanR(YuRGG%3DB!pJ9#YJq$;kIUbRT={K zqk)A*_s?gMWPDG>XTHuuv2^x`3w%zX%JDe_0qcj=fm;$15|>iB>a41{_!d&IcCdu4 zS!yNv#;|Be)j`%Lu4=G8A!VB3R($AT~w|eE^asAS6WDIXtZS zmnjL!p5S+1FVuI-mxDupl~hujH0`e(=YvgZWnbyuc%=4(*wEy*;jHgd+g>7k{5cD2 z97a4>#9j}{OJ#M}g-kowPr;=&_gLA{y}UQ>qgL5-X?pVS6RB&Mjg=*P#Ey#hP2FR< zn5!wt40+$%@~08!_;!!j-Y$`MrQ^-neCiIj>n(~cXM3tdms91ajJXWFVeV^*E~krq z*SD&3_Mn5H1j%-4LBVIpSAg7>hKeSdLte*QJv&!WyP&*$^fe_qIyzX?!Rxfl%+P=& z9`Z{L8);f;!_ScaDfhy*M6$SW;6f+`6<2MgDeY$KQIMJuz;}OTWue}KZ0LXLthD25 z$ho_M;Ch0;720YM_+(`5Ad`Zs#;Y+Jt}lFxa94EF4Oipht;LqJk8iAgf*tiizK^vR%@wi}_Oufk|^;fwyZ*sp1EQtwUe`mEk{#RYK zyHj&e-5T#Xs&L|^_+gU<$IDI>y;Y3IS6Jm*er?8NXZ7z^&^9fFgF$xxE6WPMXORrd zIuYfXz?U|dHUuw|-h^4+_kiA((C&jN;~(Yx@JYYGOB>u16pVoxO{iQB8=9&tL&Qh; z^Jj^|b}~@8=PmmB`1puZfo$nzIr8hXH-M~5^Y>d!=m_@p^ScA$TL4z0IOHKcx$*PR zbYga-1slG(fLm24`b^T~tZ5~-&*PY*FgPpxnn{V@GxRk`PZH(3TYCDJye({P4sB82 zz4BW!BBr3CTTd|*NMC!$F4IbyxFW%rmmP5`d>xAYDzp+{7PEl|03g8U^z!)mT30AG zDM{Zd9ug)3A|h!(rXUe2FE3|O`h`Dli>y4zG%IZBA%_NWAgHG_09?rh)yFwNNG0Rr z5+~3@o`Z)04ba4SkmtqeDJWP_=2gbqXJwbf#*(pWy$qgPTa(+Ku2~&O!+z*`q*_S@ z>BwN7ahPQP8Tqe}*uDAOKd$eX^~LKF;xi62bY(SIShIxPANa6KY{*3Fdu#kaN4Q%C z&h9hD>>N)M-aOt*OUd~XD(}$8o%68gd&#EEWu;T0LBE5Wh%?O1YnF_tG=yFH*@2HR{~jLXbf3$4(>i*D9}FaQOKjX<2zx4aIoYQ^Qebp?7eZ9=e$~NK z>HY#Zsb4wu6%7n%BO)T6$MIomXq-|-E5Secp(`yRAt4x8SofKjo=UvyYi&*cn2Jdo zM~Dd#!3C151rS8mx3zuZHf&{rQYj5O)*vhl4i2I~QTs1F-V!NVyNQOT2O^q)QV#Zq z!Habl9vz3risYe@sd2{=r#-Ksq$3yCyNg57cZ2dhtB<{?17pElxGoWKG(N#(1T&%R z8zZxRrbGRX_KCY0H5sh>?ngWmjxB>$CpZIPIcYT}>(%$Z#K#9mMh^0Cs;H_qZ`WP- zoa`>F=f;b>pM|n0z#VV@x!*e_%(GnZ#pSV{I_GHPIdAQ7_`TxaB|q%y?p4~pjyF*7Zyc9(g-A~&b>1v#mM6%dJDpRyFL@egG@2WW(tD>vbUEbz@ zXR?db%VWXiv+WMiw&FDw_Io=jolP>nU4~)Ls@!Kijw?~MrdaZ)vsYFH#f3krf1+ww zF>blOM>Qr%S9mG)-g>2v7C7cEz&{YNFxsqa(l)qJt%My|20l}|xWq(VE2JvTorw(k zVe0B6fd7h2#b*z_3BUk?g@xtha6})4jMq3MeoyR0*YPzJ$QFwu*XcZzv*AySI|BPYlEq^Uzy zRYb2c>~;`Yo3{A z7ZWYLXm6waG-?Iw? zgLcF5OeX&r5G@fI8X119h<4Mj&<1}Pa0u~-Lx6z35Ew5SGO0QlASW*mdN(>y@dNq+ zn;Oc|^vcTd*P^VftkB{MnP>y-dcde)%o@5;5Q*fZ5)y+zAisw&0D&-*Gceesx&y5yXZ=3^)nClMrXAwfv%-y%7aUdH2fqxR4z4iIMxvwDLMksY z(;r??$R1rkw=elRGIJS(w=CA;GiY^VcDhPtVZZ=G;^IOSuRBxXO|GJumhkfLed`ZO zh9gbs zq?(n*^1m>fkn0hfl55d3p#6aaP*D2~_uKn>xw22Bko|5SH1vlYdN!#zszP&{7Okvn z2hLj4Uy`p%$=N_sWrSRPP_-+otINQGLRuH~N6`I-n!5U$Pd&Cz1JIi;6O|Ame7&cq zr+LbL;OE97cR$pM$Vd&DZwR%hI$XHeF`qv_0yq~P?=GZ(=mC+ze_~ zfT!>+qQ8A(CTzuBto>%C^itcu^XU_3XL^pFWlm`=Ep4EmpDEis^J#-Bbbk7ip3&l( z=A;gFhe(WUo%kKL^Cbl;>W#jrC?2EMuAEDEbj7fDot-QvE1MbRj-956qr0HIUgeC4 z(GN&VyY=J=g}BSP#PcMV_ch1kg`y2V4l`>s_-N`b`L1-e|Ixwjsd>r^hfu(;ZK!tL|zp$SUMH&20{PiLHv>@ic%youi+6F zD{J<2X?lA4$@%%ZR6}JF2}$SkrhC3h{;;^kHV}W_UR|6bCT-fN0d8Z_i7*XBmel_P z9P^s=AYL0A%K(rIM&b9MU7_Tx(=ilaLzSlr62ap=Xz#Pf1=J(0Xc}abIMZG^aarbt zx$RG}bat`Ea#@=CoO7O=J&b>I{Ke%hE-YdYf7+zTF7wx`^004slMR&H9^L&>gKw#H zKLr=vnBlk#T@9~RoQZG5JUBKUZ^|D?eZ~0I;}DG{I>=f7s#)qqZOcl;B`#{0KUsFj z>FgkV&grQk>P7V!O^kFeii+0x`E8l-jo#{fk2ja;iByxNzIll^#tKD0zG4sy+`Y9~ z9+f`yPOIQG3J9l1I7M9!m;h@7beqilm73bEN7~vHa53Ugixf`_AS(jKZ>fOz(f{x5 z*Vij01fUggG=VU&r7^mM3JmD@P06*7t7)Ne1Rw{%jMMOLO#pV@OpHw%M+bMODG~}f z*(mDmGv%A~yuLixntHYRMH<@{YC2>A2HO~{#Lyn4z3i`Xbp-Jx=ubb-YGL$edgv5J zY*x*$v**93#rhoQ@b%K>83J}jp~%jIu7U^}6}9 z_=im$(eb=(1zqxoza>~T+VpnI6&CBRFFD~Q9cvcBY&4G)b$IH@)udf-KdcDDuNRO2 z4t~A^n@*vnrDbEXA`)C;)~%?pq2CsU;ecG-FZ+G1t(_gN3`Tld8hnZ5y233O*(!s^ zF17A4K03;zRDqvD4{3OSz(#7g_CNMrR}=55($oF#7@Ev) zyzTwfJY#LETIA4ld&}cB2>`1hWG_7xqg*p48^>`)G1eR?r<#1oMOGNL;)K4;*54+2 zFW1q*Uf0QeDKYycw zg5(f9J#X#y`w%{dQ*j=+?T1-p4aZ+))UKRfJ!G5qEt!MAA(HjXVm$anV;2(v9Uduu+WZrSiV%o!kA$(e5?#p`FW{AmgZ4;HmLAJP8JeMVtMBZq7hCZ7W7gKRz z>6U*w%98Y^Wtc-$8UDMHEbO1 zQZmbZZzq@OY?|rsHmL`V?Vo8w$a!Vpw&DE%&bLtDBFD^MrEq|kYJtbUACt)p zzNWZqGW)t>pOPXX`stg*6BT8RjED_rJ$(I6b~gxBzPVhp{<%IgX$`&Wn+;ttuBg3s+eX>!8TE}olog-;rV`dzJ&_7 zklzzzmg zZ|uDH&AWGv$#pL^>ugA^s`r$9&L_|NvJ<*H7c;h}8-`nhAlu+6vNc`aLah&}YTH>% ziGKc_%NrqKReNel+*{2Tb4TiIcI5}!>}Vk0F7wVJlj*j9G#dDB)bXQXnlE!1n9fsjvpW;iJ7W$WCb2Y z38-+eR8*}DVK}}#o+*M^@<1&+@3xRebI3An0Hsa^dOjFhiS&CiGTOXQCm|vEusxj! zdR|!jpz!eU#%Q=`7Ya2de(#;*Ii36}E3~2s)wYl_NcNZAN~-^ZzGZ5iRAwyEnaHcx zSkVMVQ5zjyoIf4&IjZ-^4H4Bl2Q@#bUeK4HJLRykbMf-dO>i0$QjMZ`i84<9sgFgs zwzHSCFlh=8o?f~(cJB<4)z2u-xu^tp)^kSde=E<>!d)<;I|G#I&UOzBRh%|A^Yhlr zuZKbNb!d@3yGI&oFG4Mn7zU1OAYm9|sLRHrm41I8h`FBe)5^V-HkAUK0`8SsXd+;$53RH%Al_mxunb#}?1FI`AT2)e++|F98&5ZjKe4 zIY+}gPEV@)=Xrfj<9T(u3`MP@vNJ$M7F$(uKwb!FU}2?2e=m>f++*Q8rPmzOdShUr zXH}E_BWhrthiYmQa76*ljKRkS4tLa|hlDDBao3uv-9)bH8nNj7}Yx_Q&M?paNM#4$cct4 z(j2Y4--hK1EML-is#f|k2)-X?gS$p4d+1M+ggyjSLJ_;A5OFf-c7<_|3l= zXJMCL%9-ygO&t}{R7ueH*4MDw`o>SfDk{&?$)a^JE=S;)_S(YaiJB4dg7NKz^jK?z zB4eUK|9e&@W|pQulfJ*dU|G7Flhw#R`Rkfb_zCw_IET{~`lMR|^B?A=lVrY5d6okE zf~Ecn73;l%jI7^rPl`0GP0yA8Dhj~5-{I!+Ez!4HybSH-Ka?GM0NQLAeYNoC8!t#H zA?Wx^a=y_O739rc@TtE(i z0J+NG_y;A7;=mtDuT*?G+n;U?-v~=jW7d_7WMO4BJlYt8KYng?wZnD0Cga1ilnApD zZRBypL*u04Ujd&m2w1eVv~XzNSXt%2PJ=6!qAByvCFJo;578$5iXtD537<|<@Oo+J zyZz{Rr&;5|54`<3aT%|mplq%>(z8uFS0-285PrMo=var_W4Iu`JX=?`$vv{V2u?kJ zu{VBww@usP*%LGR>&v>ujq^)Rw2H=(sLQ)}&@&QXal1)I2+~I&s)-9O?>8 zv&wB09rdZwCl8M346tB5D!RBV`4EQH*gC!Td)wk|&YO`qr}SDef1qVTc@%{0(ywB2 zCAN*A`Zqr}l6dACwmT_!Jh!^t8rWoZaXdz&L|!)e-Xc0W8ggkRC`FBo=%Gpft&vg0 z*RL|5*$1VhX^FOvdqQF&Gbg8PYX8o~elL_(e(Kq2fEz=3395sJ)XK^@VE%nZ?0eT6 zYlCHvqNAde)YK01d!D7eG%zy@OiQB&gdIdzYjuWl@Nw{ChL?OAm+n1e4Zha3=584s zB_l;gul(dczrE=`)rLpL?R`CIhZp|*`{e9d#f-aShda8A%m6*t=j+QW(N!%iG=3K$ zQ9BhUq3?NlXy*wccK%6%=CkICx4&CjJ}2W&nHEUdO%$^62K4m{sE|Pd0EE6e2uPlS zJmeGvCjsH%f@J)rZ{PmJyCrJBCJl)U@`CX-HD_RG{d2wIEFQc60EsR*+P5|~=5$O; zaB_zK`}Y=PoChEErD%V^QUB9_ADHBXn||}!(J|^Tb=4HiY{-Y-$^XgyW{6VHtKW;DCc7NuD7I6@m+&YEE6T9l!K>RlEODsKgKv4W zexMb#j$``Y+kZDc7oU^SMqB|qi84De7Tlx!m}>m!IP@GrF&6nk9AqxIP}S-ooJt_K zJLsKpSqUPUtd#Tt}O8Bm3Mr%?R#N!D0q@Io1bE=4guP8T<2-6*xg5&BNmq zkf^)e8tMWQV9T3 z)PqQvnS~_<_-nDL_ByGQRr*t9hTZ?G1)!&o`21PP>hBLoK|Q`Vz`Z{=KcDg%iJ3(K z68ac@`)p>DB0e#ZI=^}ktF_g_B=0T&?6@GGYbmK@PG@~q_0q|GPpsNU3}bCoK7TIJ zX5S_~rI8(*DBU`UxRYH~(EYa3$yv4Ltl{6hycB4Ua0E%D8tXm;Vxz2P=jEt0Wb`}P{9vtG= zMn*PA`kERW=YtHAE+ihqGt)gb;N3e?d2wfcIm7c3^0Dfj-8M09`>ck5>dq_{o}R>} zzsocCvQ6lNM>e5L#3M_zBR#m|H{dE3$h)t zokxQu6ojFS5)w&JnY*N-p@0nhko}32q~zKq@^O7lS?oz8;9>}1jX>TCOJ9w}4jSI_ zyhxz48R5PfkN88= zI_@YVLRh8rm$xa&$#?st@PK#=g=$0mpo$BJPe7nHYaJ5^S$bjNZ=gFX#AG3b$Vdb` zv$YcK7K&0%f4#)xxiVIby#dCg*G2}9wVZ|DQBsohd`qQSE%P{cNUu2{ATc(EbxRMW z9LY>>XpXls?(x1wrT8p4F`G4R;3cY?8&0W>gY}Q{pFjKG;V)l8&;s!ZCnAJO_zSa2 zI+ISsNf}Zn25K8%03mKqy#D(Ge3u>zt<23O{}=6%Rh7-A_RIGBasG-LqG!ELBgA<6 z8vc`GpSJa`!F%Fp(Ebru>Fr@8Q!{GP74C`=ct8i7?w3X$OV>9-4q+wxJ!$kN(*@&y z6-)makx;OSM#tJTh0Tn95TQ;USOp5GnI3eVArN4*O0?nL6LMd|+2+)F`T@9u^D|Uy z4$knl+^z!FE~_#QUEe*L+Xsz%J6Z$bbmW8A-qCAS z95>kW{Z{?d7`fN4vgyu^UWw(fc&l`ICH(X2vtZ~l;l56CkzRnABX7oj8*NOK;X%^K zpqTLVKLKdqdnCOM)4&eHy}eN{MSvtmf{DNvLGK$tX8KD{;XERjCRxmkjwx?wKscMJ zbCB!ua5whVH)S;~1p`a&cgy3JS~XUIua~XRSGn16fX&*_Bkb+$6}&^R`YBp>bFyu0 z^>kEt{Ik8tC*0st!?!o-pBRnkzItlqIqXK<_3I~&h_BP$NDomMV?yoL0z_r}?vFS9 zx}R=$MB?f>V&eRp5VkwW*QKWJ2ncOyoTNe+MhC^89~!o#D?q25H5YSW!iC%2@`c-% z;=Kwmd93aZbqubeJ;+q53raHB_lLG%s(wHD)#79?f6m-R|e1pOw$ zx;K`aN)g&dYmQM%JjujtMs#5;`iZtUmJKc-2yK8t!s;Q7_5ukRpJJtet#4NROL_@~@%dkU zPU$X`Y>W<=JgiNY`5Yb6I2SmjYiRBTC3U+BQAQEmKxD5Y{q*er3aOzeXx+x7-6J5W&0pq+_zB z359}H*`I`659F%K7GDs_Sxyh9(V@zIj}a$BF`_XcGM{DKlhP^86sUx6USQDBK;_$s zOO@mXR!9`~^r$&40prKcK3|t0&m$JvMK6DBYiqdlLR|whKtmaXK46m+@I@ldGl5Ed z`Hmv4yxNbz2%X~K5L~FI;}_;14fbTI^iz@*8n01JfEvBQ@4Al=CR}|UT&xa(pXBLB zz+*iZf7kMz#utHN9()$B!qp8o1vkv%04IP!-abj4La)@c7ks11=HzKk|wOD1x zBvIBxC626BA5Pld=Juxdjglqa02_Gd)tvnVt!r+Ogt0KxZO-qtNO+`tJ3I;5{60%v ztHRn<3A`V~G2J%MVBgHNMIm*Tt1W5Q^RAQXKv4^|ZL{(lzrZ)Q<-9!nabq@o_4mtt zD;j0JD~ANmD-!436_f88q`#<*UjTbJ(npP$*2iwJc0VL&`bGxQpv)B&LAP_4mUTyV zlSyZKoOkKBw#Ql>wsl8{9j*!9ue!J(ln z7I}bY2F#_UrV*X`1tu$(j}9L&boQiKVGR|-Yq)XF4K36Xmt8rBF47e;7UbpWfj|MM zz^xU60$?}nSorC$sXuw5o2{UAo&Ikr81{w~D%?odkMhZoI^#P&sk>;MG+~||;1~#mt8>0COoCvyuLo(9KGmHDaJ>zD&{KfpDM^q4j z#ZLsrW{0Qs7;|x5Bh}g8T-bI5=0_f?cHbHx>_GaU0`w>5IfewOToqH zU6_PWhWqf9#u^@*Ha7k-WBm{92ns|7yJ~J44#m1D+~=?HGf=020t|qzHVRWtl0j+z z1SMr%ZS7qI3~;Bbe|5^GLD7ZLjkV}jx`o6}cYN@z-Bk{C4GjpT{RZ0KFVVHL=E~Be zA_AT!wY~{2s9LISdZM7ALCR9GPpiz~TQlrRf4`pHCvuPSF4^Xbi7l2-Z^n<&EIG$0 zCkt9y7IFx3dvdUp1d8P`ANdvSUpdh$jGkfc{7A&?w|W^Y;ji5aTTzZ-C#qjPwP_pG z(mrYVs~>7-ZNUo&CT(1wApXEuaXh0mEG;e zz0>Sg$K?R8*$8G7soql3WL~FuHl%(ty8R+8D0&C}pN$dToSGK`hb zgb;&l@96LO>LPhpy&bkJ{8rK{KjuABf49-}gBiW}T^reCyyajJ_W1Ie;bG4bq^k7h zHr*+G?A<=_-X4;)J)V?wCF+**`(@tSQF7Y;?iClPCHcu^$%Vy|52_(w0Xa)bmNm{m z4xdjUNnmF*&oVsW)chwn$4fkhxv5XqGro!MTTc_2vS3L%fP>;~KDl;Ae75wLNOCnd zF;j5!r(VO@dvFyE9j#us98=FRqPfk_ua4IIr!$4-Fj72}6`y2@|00(jBZcELZ2MiQ zk+(ceU$vU$vE^7N~ep*QAhYAO&lK$3#K)NU!$+9*R)Mdd8)?O+c)Gev(3fGiaYfb3cRQ{N>A6cnI^ zUlFm|O0bNsc}AwO5sh-xHMWEQaqOcq*p+Jl)hlJJo;7 zDDNW}YrUh%fBV0JnIN=ACR!-nyqJ34*=o#|Ko`iCrl$Nv#Ursq47M?1Il3eaQ|+U(zuXFf&mNXc#k4SQ z8IL&bza>?(;e`$49>C*}0GwBZPJ|=*%K6SdMzf*f&rqj9xGVz$X`mC;ho1k=Os-~h zsq}IU4oY1aMX&(~l%!qX+=ONXas;P*71Exa^tSVUQ&oh$nj9J@$1+g zwmKhICXuSR^|=e*KEyeP-Q;S}IS&$mvJcWUge*I<6BOu+{@ya&DbKF*v~?)`HZ?SU zD67^OJ6Qw?3e}B^SFVsj|F8$QNk78$3>76XaAe6W3QaVc`*8tm29mPv=OkrDHbpHeSFf{Ft+6o8FQ^yTa@!4WQJ}7>EYvGzY|fSn zscZJQh?KqgqO0KB9(%!Y_+{sHGLm;>eB;Ynh0vw*|RD6#+qe zA>R3P{03yEKNybJeWXdpRB205>jK{lG{>_7@M^SK6RBG`)4$QZt`7Yd7U%-a4CoHP~09rs1}d?C&BQrp^LiCWcs&1csh)sW&+(bqx~LEl_hI?xRtSe>)g1{ry_ zwl;pU)DAD~`ATv%0!8CiiT(u@m2TKXfGyxC#n@>(uH3%4LEo_@W%jEco21pj^(XH1 zg7^Ws(U;6fMTDWR$btp>JoOQxUm&A8Lql^H+S<-$rrUG#^4_3lP^Y5|2?`P_t_TfP zt9bl4BSh64zFpmUTS43943YG{Xy0JY@1kqz&o7*m)^=gdF+YB_PFcYT)GD=|?lt>>U=gyduS!RlSA z9w=BIpS|G+n4#5Xt1aCZFCN({O4dg_9k3S_>6ZPYGN*q!JljK8IqisxtJROPjMlm@ zDgFUzosRDggxY`ywA*cpIlU07W4&RU%O*(GcYZ#d3)-S*ztATsGIqbiy@xerxt~c(t8G;A zGE&GR(Nv?-6YQ~^f1Iw!co?B7FG4b*Y!sg|OeF>lK*u`=pbz36jX1QbAd?F4M>wAL z2*Tt51VPX;RS($EQHn@yXw(6;7Mx}A2`GwO;onz4<_;;B*#$il6M5JjF*7qG?b?fr znXcw8-=d;H9{29PB@vJ19>{WCnAca@^D~HQYHet3rG+#)TL~K~&sQY~ma&9E}6c`_2^rI101sw+{*m0cgbgj~{bfc9U7kyy2%M zHWQY_<;MLxFxl*h25*YRdLtX5$^)f1{d;l6VVN?rx{Ca$1(IQYVR6^~r|3j>sj}mG zG7UK%S`n${$0q#`anvg(CZd_D+6Ui(^9VlMt*#kq)sZ$fefswa>H8Y&-Wasp?@ZeCh4inf=`SEd zx?Wns&F6FQ33dsUPNTAK9d}+6&@=!!ZmZF#`V^|+=39Xj56DK&G+m>nJgS2axWF{v zLc4U4x-dhu+Sx*(n)^p|7G@{h%BVzrKobIhxTR@xs3TJ{FhO`7p9oR2`f46%-U!R#yd%9|e+CK&e)bj8vZJxQEzNSb256m!P}B|^DB+cwWa`ucab20y^mi#ZZa_3+lXaa z5y0F&3O-E|%F>Rmb|ZTmay(|&)DK29)8~2Y9oz6s!mMc2!H&n~sG=3Dy(9P7-2cQC~0ko<4KtHk?j? zs3*YeMz3Zmf;1|gxY2q3%g)ZOW|QtMd#p#o&eqm^JnyfM0Y&f9KXu2Ft?#Dh+TsqD zUpTAcUCL92Gd&z`#cgua#@6&HK79o@V_21Ubo4AprL)dxL9wpkXG6U9#b{fhK&zCJ zNA@oHl@ckg#nhG26(0tl56087adPoYK#PNJMrMVJvbHAdBBMZuAp85uXzBYr&hW-4 z2Qni<%qQUF92$$t@aY(dVBUm!05HCQ6TlI`;_B*a6m*Jx^zym{jUNCo8~YXOw3{6e z@Q)Cq?*Vo=;memg9gQbX&H$;+ju4|g8%nJ92tEwuuUyRzx0JE%uXbsaUfsGz#o_2kwmIwZF^9jO zNZecRlvUUf8-CE;-A#6T&?64OY)LNpzv@9wZ>HlskMACCE=sh|mbSBO&qakFF{Jlo zFeD?T{U@s}!KPx5lF$C*J>}B$Uk7zTs`t?rO{JwT4t|tln?g!Wj0Wlj7M>*SUzb;) z5y2Z5@X*<;Q1C93-2XI(vLf#PtN4ljB?-1+%}J7Z?N?OCmQ#m4j&H5ezICpgfD_3y z!(y=mgM;{ife2_N&o3%UuBdRX7_>M5lw=b39y;emJsC-*jeAv5ny&IQRn^Si6*UZ| z>YI5rK;7;38Br)yHvii@H|l+4aSCg;zGBuj2U;d;Sibqv;be?Tq46}1bS|B}$z5t& zK~rY&(`?2Lw2}nm2l74rp3Xpjy^eRDz)Xl?rw0_fo#U1FP zO#&1vMAT3pcy#~%eV}8>h4vR7joiy>K*sE#t78XT7?{+wG<;uQ7|^RjR`pA1sYAu! z`CU6JYtf@1Z_`u4XccUIK(IzmzOXQQqlM*QEPQVz^f3CeMy0+1B64S`N?GQTIXfc8 z)8%5@;HlQEC6`N5m5T@YK0~Z3&OHY+Ya42zNqK_(4sVog`^?IQ9M@_Y4i4f{X3^ms zEHU94N^Nb8iaX0>Ce!necqJxF=bifq@=_Bn2%4Sn4sCh z0)x;ioykYI7#E9UchAo^cx4%)}Xf){KImf3R1xFEDEcd-{NGK;NG>VYa zzsL#m^4e&8{7gk9(SlH^#riwa-+ql?V{vHzqlk-@khF(8fe zZ7JOMp4!NO%Vmoamk1^6jU`KERn0^Kv60E~faCRAy`ACes9GY%jF6vU-qI&X-ImG1 zl=F>|!b87?)HQTTi~m~H_CXdt%axG+yeER%#k$C8dAVvhAe01kG`5YW`+tlLai%|pFvbdT026JLr%bS`B!G62sKlf|9qZO@4hVV%_TzlnBT3# z)fSChQ;m>}=i!@ve%98uhRTEzg}-OK#dIpc8YA_xV1NNL#RT z&VRFIl)R-hWGTx`++VUfOz@waliJfwa+yQhFAowH9UeK#Hj;HXPfHyN>TxJEPbrCC z!`#}P7lNaT7Co7p<><7Fo3g4psW)6*4$X{bK0-JAYk&ZSmQun4qi(#A#;eW^^U`^}gPf$2!3!8y~@)py{s(4+k z?P$=9;MYvOR3U=fTOKHx>1dA#&N$y|UUe^5+vE|2Cu1& zSSwdb*%gs15@yZlqIh(^?sVP#25N&!3$`x`>k zOxu|zX`JT-jmN(`t@()y*sf0lW9azmTvo?yVWGLh7E0RU*;d0yJv7R+uxGw|Iy1PH zLHJD`dy7G0AWw+W^n1LR_d_i*_Mh7_fk#lLA7ZmTuhLSKtOUrvZe|Z1Csb}#n08mW zYGX=;aEyXia65gC0hdLdTR-swCJ;)A?O*FDsm-?T!q2Wr_cNBaAM~SwihM~uOf|6kcnyu>SGp-JvGn#esI6R0 zTMn0+D%8`A9)=|_ClH(MPAq@@Jp6Ak_ zP$yr>-MjrvM)wRV`*!_Py6OgDVIh5FJdbyC1Eo2Z6`x*dxkpBp-Kn_@3pQo({pZKZ znH4v7SsYDO)1QZf$skA)^_dTbmjqy{t9DAjrn-%o=_TCvWhO;UGg#BEW;$*5>8)-1 zF!YaR*CI=$HO2Ylzb@&Zc8tium?%?+0 ze+1<~4mPzYzCc({Q9@W@cOjII;AEB30{0ECUlMm7_tg$> zrC}zdm>qfW7L{VjOwQ|=eJ4b-epoiz+%-Txcizzd?^XJM=cEXGP~(x3k<#Q>>zS;i zqSBi)!L2&puaUPp8-L%z@nGJ5O~7Pv2neAr%K<^v>TWWwou=bfU{7}%{-D7j_lC5S z%jSGJuu05ia5OP7tb1@;wg&I70aZ6I?si=>!ElhG#QO*Wm6? zk;>0mLrzwznVP|8Jeec%{G}wZ&O3Oq-PE+#E%`2yd4LV!14w7W>0^`N^VhYTb~ zb7efS#LDV-5SQ3~nBMAd=dB5%<5C-!EK77CZ~1krb36|MfJ9)@4|jNr#{cG3z7adj ztg*>je+ys73BATitYJDNBQM+Uc+jm5a^PCOjQPnM+`sQt{{o4A_I87~j7%(CTh{go zN=UtQ{Gbr5}rRNk-PxGl#YVryZ7+MS?MZBn4|F*5k{Fv< zZK|kEWCtp>!_u+a_Jd2fZAJL{o)dPUMa5!I({2h#?N0CsE`GZEdYsi$AM8X|PuKVi zo%Zxi$rpoWftSH&b{E_LfBD>QE<>Z#iHMG|sIwh3Xx}4hlo+e@4*Z;G$Uf2 z&ngdqmOP*H5;uVwC48sN;)=%Z?fsJ?BVn^H`>Q$brv`MW_}aq*zmYSC;iOg{^bW{d zO@4XIsXv#L)aHap+%4v>$iaSxb0tRwaKzYOS%UGznQYZzGiSB$L1)+#B}V+!Da#>e zUOXMoqn&&CF14iS-BUL?{SQW|OU!1wKq(P*VHg^U3h{Tng>_!>Hu{>@Qj}GEsc6VX zt?+{I)w`!8Me!~x-b1$CDn{v91RbhD>I)HR?P44>e8m?BeHnwde3_$_aCX|2uSe|r zU8a9rvlm0|kNKtE?KlZa(2)0tal+XCEVdLg8Y<4*MFNSAVJoFcUiD86J3HWk6$4sf z!DmwUBk_zUY^T=fcprhOE7VAC$;zpeR057D#JMX_QOO0=ZW5$AhMesVum^ka0F)Pm zURE%gSMAIeb=j;E4xJU5M0nRQ&cbrJ3d(o~M@D47RDR)t2dK#1esT|d4+D-EBWyJ( z?yx=9A+gnt85@p5{7w-2ZB~?Y$K#&cqgwuwnzMhjYQ$>(39_<-3?3Kb6~!k$u=#P? zFa$k@M!yt!w|JLH~5b$LL45v489}26!wcxZc>P} z3HLr%hR1T-5``y)*Kjkl$C5Tsc1A=-3R+FJJ`Fv3z6;WN@R@0#7(+f%8Z1FUsF<8I zjD80WR1Ze~^Wp^RC*(~b=O4NL_v-(BD)3BbGak;NK1=uz<=@MpgAi{V^Vb_o7YQJ$ zYMnBl#})+L`V>60GK)`6P$eM zaZpJVYFZVb#CFJ(=+MwhP#|A}r+@Ir@&BLRea-CxO0%_G zfzr`~FhROVbUM}@iId>9eCq#Z+;U|*0M7ndBJf=|?JPBbp|rKh1>g?Wli4$m-7ua0 z5tqpOpFraRmo=mXLP|gQY5*6gw$n=e)6gRpCeZnl0!PN#Fd)}Qp(4{^kiT!Gx2}ei u{(XzmX`%TeL)gDE{PX|c{XZmwokOaoqxsGYo``=#$=z4Hmw89$?f(FdJj)>f literal 0 HcmV?d00001 diff --git a/docs/examples/te_gemma/media/graphs.png b/docs/examples/te_gemma/media/graphs.png new file mode 100644 index 0000000000000000000000000000000000000000..3d978a698d4050a0aad76cae110760c856cd1dd4 GIT binary patch literal 28406 zcmeFZXH-?$+BJw}EEO#SDu|#ZsAS1VR8jOGL2?F_oFz*(VggYlNLI<|9LX7sAX&0x z1<5&Q=()YO?l=0!+dq2r7~S9K%a1C+!#R7e^@RD%IoE!CRqhh?cDn6UR8-WY%NG=> zsJ8S_QT-vXZ3})Qtu~v7|8256Pg357FSl(5kMJ|C{Y7&d{Sb z&e?3>P`xmm%9`c4XZT0Br=?pyWqrqCu_;J?0JcP{eJ{rS&N zR8(rLABFz&)BC^g+x+K;+ZX?!TmIJ{)>Q`c284zVPa1P`bGuA_Ie2)%`}gZ+@0{L6 ze8ait!Gi}3*1nG)oBrV4K>ViRV5{;U#5cdc4jl8|_|NxL?B@?{{m&2T_y6?{zkl&= z+jaW?{L8gH|MQh?!*=}h_4Kcj-=FyV=l^X#=s#&|31&c!r~yT zD<}7C7lWvY#r$}e-mzoH9E4BYzOp0Ijw*=fpEc4}6kq(c&(P42#>8}?wzk%+DOp25vXgn-b~$WgP97FojhmHv!~5j;_{&|M>xHQuF;@!v2x4jh)7eMKD%C#wx0WJK&qI-kn-$sb4{Fn!P?SrYr3&- z>gqy|Z<6H0UE)W2q-}q^-(cC1r{X#{qTuM5RqAdb5UQDFk>KyYpPH8E0`6vd=zERm zw0Cvbnd2n0){F$Bnh2fYY3+B{?g&`5n^;IJkEB2G@kzMrRPj4L3jJVp%v7V(9gL4I5N;WD z?CtGMh>!oet?k~udyN^UO8CI?8sFXI7kV=$~}X(&nR&z`ijw3*sV zjHb~MeUOMxubMt8Im_JJs6leZ z&8FxT>$=X5sYpqAm6w-`j`fd^Q?e}E5{+vj=macY?cKYVZb1goa+AfRBQNuYzP`gy z0}DwJFRzZ2pq!8mx#h6__U+$CgzdlDw~ck=y9+<`_TG8?(aY!vG1rjor4i13wEAAp ztBjJSgtW4)_@5-K&1DxS7*+&|hhNgyO1<%*({tV3_m1l2%a1|?EEBQ&0sj7{1O?Tw zF0t;8Bj4Y+FV79){#3e(+z$_B9RB>`_$5BiH8=N#pATJdOU90^qd8}+~^>L9Ni5VGut2OH@ zzWChMt8u%i;+a1Rz5iR8?);MV?3?0Z&lFu(YXy?2W2UlVgt)tAq}$ZvCbG_Oqw9_A zHiw0Y>qo@h)GGoxl-Ac)`H?OcFJ9z---0UoLlHUsmQuW?srbdA8|zTOU_I8FnWV^8 zm~LFdDk7p~*_tuAbZD%jV89V0d>NsS_KH`mW^Z&j$EeDV6~cw?M=6VC9Uf z!xXSdb7LJz`%cOolBlVwYE0Hhx9<;k8Vc01N|^15i;cw!DEsZ-w=Zg7z@%s1xG_cb<;9B^ArhWN%)aNC4I9b7q|3!z=LQ1QbQSOw6S;ZQ zrg+o(*e8!3J^%Fc^p-5L-YX=CWlP#cL}Mf4`obk$XuQm6H^+h%J#*Oa&lXZ#Q=>=Ns(kG#E;Q#Y@mZv46_npML{gx+}Zb3Y8L#IxC?LTiEr;?;fyFS)cT+HI#mSOrbRwlT6)=409cD7khZF_f# zFC#@{EJyOghYu>9Ts%CAsFsc2-biM8thx}(Ds{;}ER242xof>C!?d1V%!%r32deae zw{?G!Y0On$`=xzsR3+tFbqE(umK=_y5;%SOYLt{u9dA~xsi~=MN6pWhuS!Dn zOT4)?E5^}kdP;pMlZQ5M-YjhU>rrY)g3?I@-NMZ|8fj_iM2i|v77?57`)jM?o}pIx zvrSqK2pbo7@9*E$tHRFwUbVULF0O* zOt(eb9%G`V5q$wkuw-9}OlCQGvh+KH(9PFxqoQUb*H

dMJrj*81-HJt;4^7@p|`d zZa;|UE6Agvp&4A5>^8`rIru$>##zw5B`}*pGZT;~KOj96EW*2RXG* zXFc4_IA-@(e?y9vplq0s!qWpBipt9V=vAyV21MUX(JgHKK8AXLMqrrTmT6Q)H~KwQ za&d9-$&)8_QNH4uhWQl*1)`Q+MLKw|CR~e;oXndyZ_HZKV)E<<--^30HRNQFNTkP2 z6g_K;`H>bS4$X|`%6m>(x92EmXlRsIR4AF6CJ?Q_akMo-uh>J^ld;dBv47=zs(x#x zc~DwWzD8quisc5M>Py~kX|D~kJ0 z$i&Z2V}5ua!@|t`1;xh{Af4aKhOFNx7cR^qEUXD^U5=uP;@}?-Gns0B|Nd9Vd@a0K)RaAC`ZHx8t z5~c_}I$`^GwRx^nr@kTmy)B!UeGYMd+}87qK`b#jnf;39LL_1fP>DNBQc@C)szLaO zU3-qrP)mATYqquCXx;IHT)KSjo)g1*>zaU(2p6lE-mBONwg)!o0ZXWqQ4YgRD!HA@ zWS#ts8(+4gSPVqCjwQ?$vT$(;hSv~@m1$CYUe9AO<^ZR*x~VC3LU($OLD^pFBf`q_ z9Zr$W!#$1j{i5i7;|0^t>gwyw-9~b9gk&Q`HIO&+L)sg(7lBT8D43e)#QdD>DJ6Q1 zp^*_QFmQiOq#EGC-+%wzP~`49J5c*6De%tV;Na3=TqKi=6{b*T1iJpc=aKc_LK0ama{oGiY%UR|aCvt{#c|(J;zCb8brdgx( zZn2r?^5?>o5bz$W!+OJr3bmUz-;s0sLd=YhSOM&!T>3?0jp;_^>+0$zuQFJ_Jdg`1R(F^QWtZ93F-QTZg|Tp+Y+p^y_Pp{D1q!t!W3 z`t|Fx=w;e5WC;%skL}ddM230O+WIN0tH&=+_YaS>q^of5`Pu0cgVLwxHhD4TKm@s; zU1eqTvaGC%l@&R{ZK{mU_y=ir`Akh^Wkc8*JDD1_#9JIF>B^d#X2N}s_nqXk?iB5C zpYPD6Y227IwFf;os_vZ!8?}6dXaX|tETC$NZBt8MRVcGcf#ZC;oiFN(a_^*~l2SlO zh*=m#@VYjNSi%$Yx@c0#=9^f$7+Q11iN}$e|SrkW70| zTOOD*$*Id-O~3S6BL<*Eo#rWIek1V7sf6y;fx4KdNGAf$)c168l9r|_xQI6tTbcOu z#MhTHweE4*D+Z|FcQ`n+K4(hqR(aCpY5fD%X{8WwhfZ`NJqxgPOrk4cLdmx()9s9gH6mGRV>Gr0;+c2Q|-p32_B-j`7K?u3d8or|mYm6W=h zRU!ORHxKOJuWYLNB>j&vhrSmlG1W9+*HbvOvR>WILE?WYEoI~3sk0xctdzm)uN)S% zN-i~STu4rg4&yVeqoMswT2G0#?``ERKEbb|m{WFGVN!$En zd2#>2gKWIKN}h`qdNSw2;%INu($a?Tn+KYDU3WmX&6J&t9O!WB;=h}7`!5xb1P`gFx6jCQ`!1wXjLrQFdHzG zU0-D_5I7B@FX0A#8(Bw0HdLTW`JE=J3DDn?8+lAyk`^^EGLKDrH0&pZNhjy#xvjhC z<0r*|oMXE#$#2dfX9olZhI|ZO>UJR`)2Y{v_?Vaaj2A7-pZqJ2#HlcAh6i?3E)Z{q~ z8(iP3`CtJ4w4RLbH33j=m#Do~GxZ4WR6SW{oOpQuytI@wJ94qyX?c4c)o3S9$ zi8=9fXKbShj-X!V6Gl@r1d?W+eF9p)!_SYKt0E=zQZzHooA~^Q$wydOhD&ChvM^7g?Ioaop;t5%dwODPWYG3vi+3x2>Dth& z@nCXtGDOrl&F|$)pF6q$dUGQ!QR^%H5>k!5XjYlP7ONx164STD$$H6Qm}5bMsX066 z<~fa%0Sc0US)X`&E8Ezl0v*O6V+Vhgz4T*UlUo6JsZ){{Q|XzmMuQ~A`{&R9l6`gh z9D$-`WV~@ra*^UGz%fLO3l#ac7C5I{fnSlPd9AGY2i5H>D$>$V0y%ZeEXKh*m2*d% zUynCm9!E*5&vew()C@@J)Gkf0jyXaX^6BUE$jC?r*O62+N?WNfqjg@OnA^fTm)Rke z;nK{uQDN7)*)y&VIFDNO*r1P6OXg~Eh4iJ-V-6g$jKk~hHrm!keC21&g z))>76w{L8rc5uJem)S{7R}=;@mw|qXb?3}^8?H@^yhbvaypojMK01&A=YXvPnJBH- zgkz$k2bP8@l&O9IPGU2Q)|SjkG?YQiPI{K5Z}{9 z85%vSI&9C4DqmgCo3Kea%gVvQ&*>6Ib2(y*AOY(3d?wX3Bq)jL(hFkOvpXC^0mEJw zugy{rib&%n&-L;3!3PI*C&tQ}&z(Jc8|iwS6cCX@#9RK#A+FzU4v0N%ZU?G#El_AXRAfV4@tBkM`KP` zD>28IZ;mw}N7G}n1u#AnF6P=4AMFN;ipW2z+pU+Sr4=t+xX<00NZ$8!SiRoCd#Ihv z;?2gw!mU|)h31Z~>=zqVS{{ST0s3Yq=WoO(w!U$#a1I~-Ri#nJK-qh#{@Rw7~NVJz|F0nB80~ZRya3HmmA<6NY!>_M<^c{nkKr&NW?_4!7h&GKi z*42$j<Zr}88o>Zxn1ISJr6tLF(jtOo!sCdONlSBKG51pY|fA`?Hx+NQn z2*S)`bxDB)p@dyOy+y{{I}*Sgh&m~wD^AFG2MyZ&)hIBl9%tIE5!1zU=;N zu)S-3Y3yxx(vKfM=$`Hk|M=+>cr!MK@y?Etq_vMzZt zJA20INU1RWF3h^|B|U%K7C}9xzAgPJ(OKzEhE0!D{n_fa^0xqgbSYxYpv7c^+(Wak zq`n=qwaryklm$@7o_sOB3wwa7TVN zHqzK4E+$5Wb&GAGqgn0S2bvis*O#1q8N_~#k8|aBSb1N`cNjKmEp=zt4&~&WWfm9j zOj^~Ztcz{-F1o3+tgV=7&J@Kn-&~P-?OS+Q0F93YMcpR30(xFC#miU4VyiP(vDj_K0~-0wR6l=AlInhYnTquw+pdR zcWf2xRHhR)9wff!dd=A^Eot2QB5pfVGyjZ*fL2U%469C_y~4F?ud%MoB<$YQkVg;? z(b|Tj%Qx>haH!ZCSD4s*x_Y+IdPG_};l}B1A23NktHDNds-s)68DRIXlq4MYq8^daeeiZIzctvMP4i_{@C@K#VIR1cRoHTR1`Z1NSIdbBq0 zx!!!shnI)P|HX@AKv2YZ3-T8HpK6?3IGvm9S*oC|ABA4uxoWDO8N&Mf8U2|XD=3l$#`>10~T-@PhHI<{mVPO{cZSoyQ>o7Rk^%&>o=EerI`7^;i8;$yx>%mY}MlI1kB#zdDtdjNK1+Q`SQ>rcc~ zNCTT3U{`+){-Vj*$VA1Bg~#!T6N%=Iwj@o`+-RGc=TdznpD?Y93 z-ZObq-wIWXWGT$Qi3772etu{t2pW(=k=&eBU%#Fvv#hSJCTRy1Kn>Cai52~X-wycE zTXk}b<0GZyG$-bFND+cf_ryI`a?o8^9VXu0HP3#Sr|X%2aH~r_>8-d>+eZ=5E9E^N zC*{Jv5fyk89~KryCt@!|Ai;W!dr2#c zyCqg$1KQTdO(&2~HL9SY6Ck(pf*nCDg7!{|k3R*vo9*b)v)HDo`FS;rXQl~TtrvLa zZ1e=)zS-fAy-p0$^Uy%N*P{Jp;GU%Q2Y@dLQib53!Sv9H&3u#04Ds3aqg?XPR=&Sj zsm)6>y{eo4@dAK2xw}cPxy-hs>*vi|?jKwe09>q4Nu(3rhSWK0=XZ2oyXPqHwn(pA za7%$rHTjz#^A*u7>*U=5}k_E5bTa7^w88or*Ylm-geY=d-v^2 z3VYnWqJbC!UM0kvbN9B8cn&`Edhnp}%#R#x2cX7n1V4aCibF-!Sj!*ckUPu^iEK%1(ZQs6qK9hu=6!7B3&(v*gIx;dcnLVmW zMoFYPk?{ik?cESb>_Ky@wP{2b6twrBuFh?^CkVnxrp+fc&;^jqYbWj7W0GF+0Qw-{ zS_@sX2>_3wzBXeV9RJKry}UV8=uNg%+j?$$U5I%z-oEUcuXe8Sv%9CY(W-VihK10I zOh3iMcyUoc=G5?jfZoQC=K#y~JGct2j*1gHhCI!KPE|$pOJV{!EMRfEgMo%722Bs- zeG&$7DH=mu3*h`{{U5!4)}9C3=QC|5hW`8)14guCf{Z*VAK}k5+Ev`O5Ime|zexOYFO}-sHL%&Puc^y%iDJ&E|U`j$$NiF-Hp;vzO)M1RmIc@J=vy zonU)&Gb0OcL;+3J{E9lmb>0dPJ*TxrivNuLz;i+vD-LRi6|#`McGA7ZsVUp2Q!S|1mLMyxyrSv-`)8y+h!!4L+ZN6_kUdB~SMM=#wV%*;}6LXXtMD z?C|~|J-s&9@ae^=FU2(NTCd1NLOsg|BWL!h`q86pJzdq(<*j+`tStNX?sZ$LnRRM% zKdr6>BLa^lMJt>9!PsHZY;mf$qyxvH9CpuhjCti$=1Dd~Pw zov!1J?YVZdnAHZy^2b#Hic@{l{q{nJFeCr;?yyorqsE%~3|ABF4cct0evWQ7gNUOH zXq4E$z8a;gT}Ucd))K) zhl{(nX6UFKOG-lRkUkPx@FF~XwDJ4g%Hniv%<;q_dpnkGyLUH-J9U~`h#=V>QU4hF zTnSL1U^lsL_EU9duwf~hB_Zv3|VI2Sb zc;D}2wFD-jW)yE9US2MTAa`%K_%}%DK5Z*2D@2iXZI=L&C8YE$*Z95^Fl0JT9S#I# z!>qzC@y>JI!(p_wv0pzoM9AiC!ex873|t_?^5L^JIXbREeS;<;9Im8dmY(q$y-_KN zQ<$Jysufg7^74Y=YLX832=M{Hj+nK=g>25_{e#pfm$bFROqXpJmp8o^Rx8P%{lRk{ zby{1`7eV5y;t>HUvV4D-^Y$d|;)3E^lW*11ZME;!!!r^yPQEWKO$s>w0nnAGM6HZ*!@@wu*Teb{^n3K7);EA%{F4&%L zul&WKc;nVxuRi81vC}*3-lDTZnadEOS{@Wj*=)n=!14c*zwaJpj-88h6drZ<}#Dve}|@LY4Q zSFblpZ)~fh9~&NNoDq|O_(m9(26+dUmKDR{K506ndS*vkINgohT?yT}HBiXcmzs++ zELwP=FgWRU)dZ+S(ohBMCa~}sHtKM&JJxis78Fef>|@mE@^&{BtfqB3*cg;edrQt| zF18z2-RjQU5|}-->=25*ADsWHK^in~``D_iC=5Qtz4lOmKWr@~8Pd`xjhPK~sOg2i zDX*$2D@Oy>JBz)aWa8(~RBY6DkFOec!%c9vH8=NIbQKkrxC5ifT+fT^vSZUmKob%z zLNfE_&-0W>bPQsnz*UwmI2jqm!dB!x8e&8i%Js6r)FzM_m{fB^bbeAB<4EC=HZx3_ z_e#8J_46!YJX-q1xb_-!oQj48x+XW~+{<9eb4eljS2eCl8{-!SrTyQ`S!>?{PYGarU`J4 zUv&OhUd4=^^^+4&QZtdEm}Fg5JmOLbZuUw#6BARzt&yc~Pyz{E@4Lt9Y;z_iBC~Nj zVoh`lzJIssDq3Z34%mG5Q*G zZZWm>Ia>Dxh+Pj-JFuX>DoSIP(~i()EZGXt!iwr{zh)&{?f{NH`S6oMQq#T&AvTsS zmtl^cEaXNKgfgg`S7TTOoQuvm{)>o!oMp8@NhJqtQLz8^P zmoCG^5Zl&K{4JhizU#xt$cWiI7dv|c8>QfH%*0AU1D4+wkhUfepc}~&#J8`Wb$}?Q zJkwDl2EjFuZM1Eedv#T86bYD2Cpq3(2sRw99GP%Qu#khZL1khtGtY~Qi$@&uSspz_ zoypU`E7*STA#!o3l+jZe1#bKH?F6LS!Y2m{EGH# zE6#u&pzGIN0uW2&6z?H>>N6r)POx4dc<{u3RQqTO$ClXATMIdDG8ZSij`3) z_E_~E-^pYOf(6S&I6NSv$e8Pkpj&AZMp4yd^*4lxaNye;GaxlcIsVvPpTmM7cXQsq zKTEjH;$djtrh|u zuH%K&E=5$FvLKz057fu`?kRy96yox683oZIuK8Pl7jv|xFR9l@u{ke>0HWSu1_(&VS z@93D#Om7>NmY1g}8AjwEw#qqo=FFKTFy}sPVFCVxvikncMzLM?zDGo2u_Gz!sUP`r z=Cb^3F+fXcDH}8QqImI{HJpMULM2ZgDs$25oYwD`2_;@(juAnXfG;onjGYvCqdMT= zI9P$Cc^}h7`PH4d)dqkND;KV5ss76cpl20;xUqt(I5{niheUI(`K7#g0@P3Q@9(Yy zSb(F>g8om8;UKtg0d~9A}yfwOI(j zyP-(P>ow=JEjMl6Yzo#o56q_6Gk%?1PQf4j{c2oC)QKWaj24GVe*^E2LL|ce1$`wJ zEkO$vmVb2wqJeHEo9yMw-%wMM2`yRcCnX6i@>Qa|bH_rwX|Cd{xU!h^)TS>(RV zq_Tg5NO6U8_v4eIYt%R$@J-)WES8Wg2;}hb!Hy)+lR$B+eS*c@Rm$euh(CjBFypyC z?-|R_Z$5Mzc>nDs2IuJN2f0+_okU2#uKVT$T+;xB+55|5&|`#6>(B>Db@f*;2@}o` zpy|-W&C9dHaSBmV&Z1UWRm`Q=wO9Ijdri?&=^!F#qJ4q9J;lW(4;hl+aA2$W-MPX@EZpg3C}s9AS6Rzh*z#eGv{W zP1s0PO-$muySw-A*+VdQ5Cap@V}Yut7M7vPD0O(QuhN~qb{1u&5uX>2&Q40)rWgrj z!-qjk^N5I}KZa>Uo9xdqf@VhFF#Ga^rewUT#OH7$>JnxXg`YovGI(Ah`Y`5TbCB90 z;_i8vClUz*6^vV$L|Sw0;s_(y(%YZp{4ulVTz`fAi@tjaqDXCfPr+z+cV2Ph^)M#H ztt_Qz$t>pCq1jxIz(&Hr7NqM^mSZJ*23T3bs zJjg7BXq9{eNl1a@!Z}UcMr+z_UhqY-A$){Bh}({W^8kX{t;XErY-|F3P5Q$I6a(Z?t9UF!d?JY6XW^B5THHFzwO#IoWpI1>v z_*~~K2qZ;#w|uI?=f`WK9ztD?*s{Ajj7SNr;~;E)^{5hw)G|>~QT#W5oP}asRC1K| zedDLTYa8~yR;Ua+lV1132(zNN`_H9zyBZa`c5gcV#JaohRcuDNBmDiNfq_2sc%aQC zq%qDBbZCcd1&3cWf~$@%suy0|ho!VeV;-|fdbEYjTk*kND}90t^L7xJVap$_|5!%g z?Xlu@58w3euxE@CWO!S$TAxTJbf33FL8Id}^7dF??mBFF-a3CYOIoZe=83>Q&xKEn zqF#D_Gaq+wy>j{1{AF!*Sxi;2DOw=F-w#wnU#ef)i6T`+0!0yHX`y(Lh#o78G33Xc zKQKeh3?)&ToT=2%-jINw4_yZQ8t4tuBiD255DsV{XI)pE@<~0 zq#=@qJ4*1xiIRA$Gp)OHkzXlO;+W|O%@7`z5J79J`!-Nk0T{UF z>=+~AfFP6?u%;^i@@H!H=<{#sH_pJ}5|G9OkU)i#e=tWy=B)JINcU%eS5bP4h0BjS zTkYZAH65&rLHnKmo^~4URX~Jq_y62iV3t73$Vib`U#uZ|1L9bfMH}2ZkvW8kCaSe! z)>&=#iwmk%EbK)Da2V`>{me*(irr8D*1~wHzX_^g0f0yCiW1ZW3fZ&npEV71* zWFQ0nHin{e64r=V+D~#gozsLlJqFFA=~Mdj5n^bfYI^+isS<2o`jJwF?^-ONoL#&} zrfn~9p4RlCRw#yXk+Uq&3%yZ1Gcka*B%;&Xu)UOm4V_iZv4{cGr)>YW;r-)a8>#}w z1Qoi#z`!Ad7UB5>tU_^)B{s>tncQ@a_yutqptrvtCSO8?kiL95Ug@+QN;~1+!UhsM z0pbXdE-^es4N9zTAyl7$moKBhBo2U=ff`@p+68Zqlq(nE^uS35usE7vH_@MzAPpw1 zevHP*07--((qSi&`H(nVnL}7*Q617Jc@Z|Stezvz7=Vd~0vhoA`D>g_8jiTe0dD*K z5ndeldLK~jZWLy|bw5>(A2;u=+s&Mr(x0G~QH+Wf34S0}w}19SO%0`P$woEGZ)%7J zj3+)K5n#;6o%IBavf`e)n6nxY7IgE|8II?`;14;h%CP;W&#~+`9VX02_Tig9nDu4Tl4!F$?Mcu8#4jOD&}u*7^p>N@TGUZvhjLGiWZgf4@Z;bC9;; z7!MVdX7mA?5}c3aCDp;<&7*c+$%mHr1ZWi&CtAKc9u{Y2^n2@lIUag!Qs+Lf*``QD zL&F;{z=2Q>Ct4~h4%MgIdk(bR{6w8zFDDX!Td0K(mj`xh+n^7&ZyDJ_JuEb4^bPBM zsn|NXOX-!y_Fj7+G!!3IuJZ@g*{4m3?VEqEC)L?2%>N%OI9AM8wY0Rx_TFBPVSKar z&6nzG!P3%FLvwRbzPI5Gy$u+(E*@&$hM^Hl0-`Qt_TEsT^={QnH(lOWTUKlhz50DZH#SmPN1W~3kLLgg@}xG0Qm z#o$@Y1}gOk_^c;D^7P}1uvrfd!H!=v9X|f+{yw>{mf0PuW46~GYnn+wUE^)DNS6i6 z8G-`}hG6JhY)X&VL5m2?=UY(b$Yktto35G}?~(v>?W(F}nMs5|n`O3+*KS7Lu{tVs z62V|{%-tFdv#Tz{3j&}yvN$LKZ2ro;pEO{u8FNyD}W-|@VX z@bu`!dxM@n0h;;>_)xkkb9kXNv)78>sK2{*&Mz7ykEw~tS8xFJILKKUD#(Syq;g0B zP)!-%9icsSfG*?(J98PcHeFRxVxngE6l03%!y}Gg6XGI52&Vh-I73K~^l&M$xX^NI zaB}wLt|P)sBv7&;@Nk`$Sop9zWJcp`u8O9Fu49YgAPxmc`OpP`3u7Y5iB5jjwy(W- zX5jU`kSPVKvu|*K@Fmm(HgI;;a0qCG=<4Z3!fng9WF__DsuW>xD-b7bLMa9ppZ;BD84JUTe@-V?hYKUY1 z6sc+0*Zi=-#(0p%VqH*?zroLNCJ?_b;m0z7^-=={i7Vs1DcPOnI2!i~CorO67V{ZO zp-|*;A-u@t$`F1wtk0>&rP|gbS2j|8Qzm<}eJ6qtZal-daK~~%9L=`K?G0Qkl~_{< zDTW5p=dUDZno)_u6vntV@sQ*PdQ<>=RDtyl;>d32!9kKsF7spI(=`&sGfTw96tB)z z94ub=GLlg**J2Xo141^NPqV)ZRx0AXTkw|$wbpuY|GZ`m@$^`aik{=QjiUB0+YoK}^t&ZHQgv!`I;SZQI1p9-@7(^lgzWfp>G51|jmHzSKt7{6Y0Z5z8J4qZQ7U z%FiLr-EgX_+NpTCLnc&!gEaMB-!}&IbxiTgUpT^XIoqmZ3Iyy4ja#8*C|4X0>1{7( z;Vhg5eCaJ_Nhv%w-Jk1th|~1xP(zKu-`2po-#e04!Rd^L0!U9IwrsJBf|?#UK^dtz zk2igF1Oi6tQBIPa(9ip#kd|ZcJ;AnVyU#YZSfp2_b$Ki^YvfSgHL7YE4yR;S;|xX^ z9A&3b7fx~5)m}W<+t#(_=z%LYN=nHp$GMI=ie527Zl6z9zd9Xi%(+8quM?Ew`rx0d zAtD+F6>%WTyA0372<~|%VwKRGI>34d>&8uZ($`QC`n-9ozdk=E$pgC&(~g>*P~t3f z-6p>jtV~s$=?}BBt>iBK1rdZTr62ySDrkAZgmL95$GS~pVk z3pLITJ9dKPfz33Yt`g)L-4W*4YPvWm?->GBPO$7F+i0mb?aS8IR$@(PB|O~V1&oy3 zEyjns72_ZhxP|JP^Is)2@v*Tdz;xGa*h>H0;ODbNCzQ|v9Gy#`P}a4Fg}WaemsTsW zK>Q{oB~{>D*Gnd%-mfi>>*aNP#oC6qS>BiTqRy?@RQd9!j7i! z9plm?PL+2z@1WyQ)6xpvCnxmA-h2d%v?%y?)0st~-)lBenRAygh|-F?-2&nW0^Mtf z7K7!v;_mJqW!%F~!r2VzbzOVcWNf`LJSxEgmYPw#8?`dc0-$BIk^CUdtnvmu3?QK= zf{>ni8M(rQ6NSoB&*+1(wblv45BlQ$mTf{Uv&!IoibSzw4$^Vd=&On3dcN_s21-l? zQZN9nbR+2Ti$Ct|CNe=P`<6BxWAWY;6v&vQPk4I4E0|{ZUAAZwkLf8~i|Wc|)LZLO z#lxhSo_aDyI`;)T!zdy>4WC24RCf3XN>5a?{HxQ)a7?BOY&S1_ZFNfr#O?`bEU@=o zY+D{!4&yPD9yDMmt;4aN-6*r)^wy?=Mv!VPH=c6XqNMc0MgA2}N67x%D?XTbTR#6t z$6H2r375k@^4*S($s??}jQ!{?xO{$`Lav0{WuQq=xj=fX)ZphRD-LqCtxP0X`U7{23kt;l?1HblKQ&yadtPg5LbY&w|D_N70{fr}=V{2!K zkc5QpqKmFQ0^R*72i#Ekb;WVAp@HF!tt@C6Wq78{6p9}+&Vj$bQQ@{iTlute`O>fh zJ2ln&siU~B)%mVS2%h~pJcy9ijxcLDG^!{mDZPa1 zBM>5vM+kh0{ZO8)xSaat&9ScawfTDdLp$6RaNuaf&YrbXOjx~zB{zbL!zY$Q*xUo5 zQ;{J@b*hf&JGglcKnruCYomSt0;FK6vPPHlD^YWSCOT+m8Kp+ib%u{+3dd3dW5rfy z>WN|yl+Xu;+Zd}?fvpII)nI&|UV8MzI2<+A`1l~y5H1)9WT7+hTD9kV(Vd5vRbXkC zC*I1B-GiACE`xy6dFT6ScDbs zIa9E(gBcml>6*+{6nin0k=!45cPOf<1)(GK6R;8Pp1jKOL3EeP(Dc&KM0SZV0Z#}YBcP+39B&7JtTLrF!YabWHh-kCTy==Wfmvsv3t z78zY=kx?sEp`8629QqiS@Yh17t}pgn{}L;KRze`JspYP9aV&=^fj}*QDsxfK(%|nk zCc3KFeEUnNzNOBnvT3zahhwmtErvllX$bZJkJs9C@L0L&x{hWNI5sIcO`3Q#8tg2h zs0(FwF!1Mg)e|AcxbkDGPpmt1W7)I`Gp~#5UVX95b3Me{1I-$t;%HcOVi`iJ1lmUI zDEn0J2Y&ZG8FA}IC+r7Ao)d-6($Z1@khp^xqe+n6ayRwnYhuFn?%ktBN8tR zPkd;>)9i@33t+Jqcjv*GKaw_5v^=&5D`o`I?atsaMRqu5(|>md=Vge5cAd`sy%f(? zb2D@EAhTs0{VId+GaL)S39A(Gxka-n*8GFciKq=m3@}WGp@#W7m%rH`Z;^_*hOXZ`Yii}PC*&|4zTq8QVhH-cRAQ897Xoga51&cE(9$XZm z$*CoDqywD~J=RthV;IMXSjDazHmax*DHr4{0>!SJ*m4NkQoC!&d<4G1km`pAH63>y z%QeoexK2BGY4!5-8^$1OiKeEeK6Feuz~?j+9zkdKD_11WR|wvk_~Fecv9_-t57=Tu zL7dx8HtEfTne7CLvv_5SN3G62Gc(f}0h3ljXKp}zeA+~oKQe$g!07it&7QC!$u5zT zDU=!jsu!>;#9&D(Qr0l|p&N@3+l%32uqosUE|HeVR-6IW4H0k}YfsJ_O3Yp8e#(2} z^CQveO5R{GS4Z4?phJ_!$rLS0SU)?XLxTb-tup)IW~v|QDPgX`Y~O+LI|jz_z!84o zdqrrywI4pD4Si4Tu@xJ*6L?wT{Io*C>UY>fYy396zCk>{1-0CQZ?xx56A)`4KxGQ4 z&9pJG@kW~l1J1o?ei#2#hldbO-#9tLzMbjigOh7a#IZo)EQBE!%u<2x389ntdVQ4q zd*)VU-O6z58XgIeZ5xD#YDoX<*_p}8M0vv#yg*D)QK{~=Qi8FO_}cuB9C-HJKhovZ zzy5is=Kc%B^9TPgUp5hpbpsKA7sX%0+HyGkY505i7q3T+3wS9f%mQCyx&PLbXvou>2$aKd$BPrGI0&EG9z*kcrPyD3d3h0Vop5B;!RS|~ z;e(>W;)FSuO8F6iO<=;8#;hEXlGU2HzcQ1FkZNC(kOVx4AsAa&aYVKp&ND6HTd3Tb zmbr?e0R;9%EMs_kHD;5mt`o#2Fs@d-!VK@kPE!p$IV=YilKT&Jd{ z4(2megnP>$4}&{Sl-WVmKs>WeZ$o39B4LCJ<}oba-Vr-4t9k-2K7+&a41V^QiS4y>tA&U6 zYZzTmq!V8n+iv@XM(x*RDI=E)FLrJZR!4e6Q~6uGc}zN;nBLe+Y6ejuK=Jj@&N!n| zjD3tcgdkX5=(;_ z;`M~30v@u{X8>c&V$)pLDJK>VDV4;!=dt}`SJERcQjq60J-bC;goMy#$0c4{vRSP6 zWoKK>SqX?@eEF+Z!PhvwN}vM$py`(-Pe$FtY1e&&lgoF_Bd2b6f5iI#(Y+Br^mk(v5Dj^WQo%d!C6E+TaQ1vrt zJ5r?4XQr;YuSBkrxCxL0!v86v{|SXxryXxFRP_*L{o?5HnpfOCb`NY29$euG-%*DD zVAD5D`*J|b;3-GeDJR-*e-q>rX+&Q+@1n%o;nH#=dM&SuJ?f<8m8%&*aUSH;0FjO{ z`AKerSM*rlrP1g=AnL=}3gT&XI}e>citCn!tEF6rc%CoK0Gr;T^x&^*1g&jQdOM*q z+>TKt@WGo_z2(+sA{)`Gi@jBdL{0$Jpp&$d9;<^PxMtcLtG6uQ}hN|g=+4g zQv8?(J6d*o!$#t7lZ)4udO72661jHmF_HUuDqn?%lav#TU?Uzj{YSrpyy>2*7`ChRNr?e|TT;l4J@np1K&Ns=L z0r~im_%ccOzf}#1Km8wy=GE&%X2T##CyHl|^D!>%p*nRQ6V%Y~a0L?2fL?1$&o*VA z>-27)?c^mQwXAd8QwOb!)>O&)H*V?DipB0!S&i&ppRBQ{bDu6)?enj3)6uci8O!T+ zS{z?y(BOV>ZiD7|iE0pP#k7-cVeUhVxie}({kg_##+y~QRmfi%C~X<9seUwURiP0V z7FxlX^Ea31diVC^zo-`YnsvBExe6zyZq#4uTmXU3N}m6!|8KM-z-LYuCN`R8`X&|59~D8FEOMX zDUPg{7iODpx3SKxxMQd~xGH4cx7^axwa(kTJD(f7Topa}a%wev_g9YQF4eKRA9hJI z9&XE3EG<8W^ZUj`JR()pwD@~Up4I4zjD)sXugK_f|w++hrs4^wg)Y#Pft#DUO?z+P{KZCKb-Bam|;bD=h1O0~No!8zp^>R{$q|e97 z)|_)?2rN20U-m#PB66sy(WlRI;b%bs_n=R5XR2f#xg)c-0Hbpp4ipB97+V=tzY1`0 zapaN(JyrheF(uA@+3_hycJ!=A;O6Or29daRn>hOLVzH%PX@#aHv8*YKm$aSCjbtK& zmm^p7Rfpec`WaM_&HKuqQmW>6w;0okZ!i(Zt0?&^UiKpyE@Owz{GHOq-o7d}-LGTY z$lPp%XN_m?j6`1g?`iG`d?QJ&Y|Peu!Le*zYod!7Jt6-pEKxP3&aRLF1So4&lFrpo zKpPNxzh+6+28Gm3vIs@|MiHmMl;UKn9x_&6nP=!w(&;cuXwGMnIP)vV;@x!$z3Y^` zZBuYiLy7ghbs@jKZe5K*GGEIXmg^we*oUPy1^FygJgqv;a zxU+hbPO&j(?p$*W}g5dtc*FPQ`|HDsY$p&eN?5Tb#I6xl@fJs48Ok}5k0VGU(V z2#JIxEZO_UGv~~onIAJfJ!gLSFF%s^eBb-t_ul(F&%HOTx|Yft?V7U)0B}l~;~WK6 z`VQT6gfgC!y4sQ2GJCm0!+H{&CXbD9>G>d+7)VLe??V94aUSksrVYL9R&tgjW`;)g z6*e2g4Iy_cHFww$!@Jgwc`}Gk7z{M(llPp-iCROKmjWXMPB8{-D?bz$u$PeJMqXSa zkn*@jTm2(GuDR6EZ(MwG2#45E^aZEo5lityrMJe)Z!1-z2lYGyJ4Gv}jLi^QjpC?_ z&l^d1U`nUn?}gJ^;Ws^u*uWm1lFwKP;5Rauir(UnQgdKbOspbxl9BxiN90%`=XM)~Y4wYpWw_$d1y z{!f*=LsvHt?^^J)3y8VMZAGFV{2jQ(K5wJIrZ25p?o4w|lJy#so(wb$ ze4DzoPEul3vd3N090GP=D8rZhBO8^33HeXVjVH_QjX1WR<}qfuIP4{9 z4t--B4t-WK0KWyqfkj>%3c0dn^Yq8T^lL}a#)0DT;xFH#TH!yNG#FbZdpWD;F|vzC zJrQ&Ax`~DwKc6H=s?dl zB}`qa=S&k9=u)0~>ZFpj-fV^dh!OK-D1#?pDNX&jxTv|w)#3>VMDL>JKLW1bzrV4C z((Mg+z4Z%r|A(GEt8GH8N0fg8Lu%mH7~o~$b?c!-c#ZP8=L-8^RA}Ci+L#_z*pWP^ z<=|{x3>CL;a^pv%d|RNqGx$w4!IX7kXE8!%1^8;=oE+MRaM$m@UsX*DWg4*Fj9ENQ z@+KRjMdwRn;I7+sz4E^kO$clCW(bgdcGi7H6%l^jY3Qbf9ic-w*aeuh6}0p=tg&Ua zdgBdEIrR?02tjim8Xx_b6wj5tKC)1oFN7=;!ATP zsT6=LZwafKddwexpx$@NdL}hFPJ#tghaE|ErSvSSpd}%{>(qD8hTf5(LDjS&UPM}g z`kXEbUS-xN%XKRpz{Y5UOJ~D~{JcBMkQ!&pI=IR2Gq~}(oBc8L^pojrc7fd@)quZ7 z$a(g{6y=>UdMc0SLp$cutP|TfTcd5=A_Cz`!;X3~a?Y0X4~xQk3#9F1J;O2k7QR{V zcjnJ;rt>&~^sqOx0-PwYY+&QG>ouJpp;(UQts$gTy;*+6^fx;xK@pT~Po`elA>iS{ zPhtgyx!&-D2+Vv8ypN<5^-8|R(uXH@5t{dkZ;=`_s?X1c#Yik}uwZ7T%~rF%ON7>Y$dJ&Z({>2bFjd zG4nYe1aFrNd>cEp`mXn&G?S27inR4CA7-zM9`ToaSS(v-qwda=&v(Xg) zNor!XlrAQodfizXu-d+`Gr_;B4p~Woy!Xig?V`E6P?HD+m27Upat*01%vcjRdhlV6u?eEh{;}s`S`wgAu8x(v!`4aVTQSJySB^HngwcrJsd6!7@zoEQ!*s?LC@cj4~p+16SN=XeUw( zH~9x3%>n!xFqSLa?0>zDu6rX&v0@--97=b=&q5=*C58yVjZ9ekz*IY~YvmbaGhjjX zi@n!3m@js`&YbD8LseA@!FaYri|nGj9~KS+Zc!gV5baWy!A!r-(IE4!1-i4^sYLB^ z5=6vgE-}fY(%>-o?jCij@N~sR%JUpB72~EQk>Gy^rx0s2*wm3S{(O>+XWv;zK_j{- zqDBYuwMrJ>be>39lD03D%`SFVC3=%Zbv>65(I4qIDW2(sKAc1xw~yjkg2X0zEs zV{jP^*)U8c*(odUoKvgrIBqH;!Kk@i4C}b?em<#h_T+=hHYDE<~I7i4uRc|-f4=NvN)3wR;b?H z=1PaO*cW$EwNG272$qM7q>Y7mNu~7X^{9)0h!Sv1g8QL3@7#U`Y6S=FR1X1f|3g+; z0*6#?_VhzA2fhF$EsdzrrCs${VSl;9)%g5@RW+*iXWq=_EbD1^!FDij@~%t8J)&ZL zpL?Vx#*M?^9tp?*n}7LK+$CfqLP&?bU=APOIdom65>&EzYbE+f z-sIUU9)@gfmBcsJ;5o0Yi_|jy(2scef}ZOIbwo|Ic|{)DSkkBE0)p-m8TgFpFvU@j z8U&R@W)-Fif~50Qr%)aPl25&#URGGPf6sx_xMYCGAK>HUsFwh86$A4FSsE9&`Qia% z^1k%(pOG{rF7V#I@WYY2gTW1KlhliVH?_jLSeYs-pfNB>_y{LBFqrx=ij|v94?s!Y zlD&_BZv^etHz@?e-AMBfH3Ao)r$ZmTuVr~1u&vGrjLf_|ubph33KDXJ=AAtT>SFta z$eZ49Ug0%0fWqiKG1&{_gM&9>{L0P!{ehg6wG=SIZ;mzwWiew7qYww%n#`*f6^Tf? ziLl1+xbRy37l1Vz3X#_{Qc@T@%Pq#1uR#KunmcO!zZ5Rz8nW|?U9!VL+!+%wc$f!! z&q*tK{XW&;FhI(@_WxY%rl}WNi?Qf*a_P$8g7TfTW%p=ri7CYQJ^2CatLS*cOjEa> zLzyKLKutmLZRqyrMNcr)$CyDN0$`|a?`oQC5t4A7`pMWTO2->Cx0l61X}o(tGXVzv8InT1EdoQY0k3+4wRL2h@)_i(Pyg8>_OB0Sj;oaOaRtR1E)|0KtM{ih Ie>r#MHv|CrfdBvi literal 0 HcmV?d00001 diff --git a/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb index 7973688450..55e552d3ec 100644 --- a/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb +++ b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb @@ -44,7 +44,7 @@ "\n", "We highly recommend familiarizing yourself with the [tutorial](../../examples/fp8_primer.ipynb) on FP8 precision to understand the importance of proper scaling factors.\n", "\n", - "### 4. FP8 Model Weights.\n", + "##### 4. FP8 Model Weights.\n", "\n", "The typical approach is to store weights in higher precision and then cast them to fp8 before operations. This is especially useful during training, as it allows us to store some values in high precision to avoid performance drops. However, for inference, this level of precision is not necessary.\n", "\n", From 5fa76f4d072c0778cb09a24da6c875cfd311c19a Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 16 May 2024 11:44:12 -0700 Subject: [PATCH 123/244] Added nice images Signed-off-by: Pawel Gadzinski --- .../tutorial_generation_gemma_with_te.ipynb | 27 +++++++++++++++---- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb index 55e552d3ec..c542a32793 100644 --- a/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb +++ b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb @@ -38,18 +38,37 @@ "\n", "Transformer Engine supports cuda graphs from version 1.5.\n", "\n", - "##### 3. FP8 Weight Calibration.\n", + "\n", + "

\n", + "\"\"
\n", + "Fig. CUDA Graphs speedup.

\n", + "
\n", + "\n", + "\n", + "##### 3. FP8 Weights Calibration.\n", "\n", "Assuming that we have a model trained in FP32/BF16 precision and we wish to execute it in FP8 precision, the process isn't straightforward due to the absence of appropriate FP8 scaling factors. In this scenario, FP8 calibration becomes essential. By conducting several forward passes on sample data, we can compute the FP8 saling parameters. This calibration allows the model to operate correctly in FP8 precision.\n", "\n", "We highly recommend familiarizing yourself with the [tutorial](../../examples/fp8_primer.ipynb) on FP8 precision to understand the importance of proper scaling factors.\n", "\n", + "\n", + "
\n", + "\"\"
\n", + "Fig. The weights calibration.

\n", + "
\n", + "\n", "##### 4. FP8 Model Weights.\n", "\n", "The typical approach is to store weights in higher precision and then cast them to fp8 before operations. This is especially useful during training, as it allows us to store some values in high precision to avoid performance drops. However, for inference, this level of precision is not necessary.\n", "\n", "The TransformerEngine offers a feature called `fp8_model_init`, which enables the creation of models that store only the fp8 copy of the weights. This helps reduce memory consumption, which can then be utilized to increase the batch size, leading to a speedup in generation.\n", "\n", + "\n", + "
\n", + "\"\"
\n", + "Fig. Saving memory with fp8_model_init().

\n", + "
\n", + "\n", "#### Benchmarking\n", "\n", "We'll evaluate the generation time across three benchmarks:\n", @@ -546,16 +565,14 @@ "\n", "hyperparams.generation_cuda_graphs = True\n", "hyperparams.cuda_graphs_static_batch_size = 64\n", - "hyperparams.cuda_graphs_static_max_seq_len=1024\n", + "hyperparams.cuda_graphs_static_max_seq_len = 1024\n", "hyperparams.cuda_graphs_static_max_context_len=128\n", "model = init_te_gemma_model(hyperparams).cuda()\n", "\n", "print_sample_of_generated_texts(model)\n", "benchmark_generation(model, 64, 128, 1024)\n", "\n", - "hyperparams.generation_cuda_graphs = True\n", - "hyperparams.cuda_graphs_static_batch_size = 64\n", - "hyperparams.cuda_graphs_static_max_seq_len=128\n", + "hyperparams.cuda_graphs_static_max_seq_len = 128\n", "hyperparams.cuda_graphs_static_max_context_len=256\n", "model = init_te_gemma_model(hyperparams).cuda()\n", "\n", From 62ec2f4886c71c922d5d0695471b348cc2f9ace7 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 16 May 2024 12:08:25 -0700 Subject: [PATCH 124/244] Small code refactors Signed-off-by: Pawel Gadzinski --- .../tutorial_generation_gemma_with_te.ipynb | 142 ++++++++---------- docs/examples/te_gemma/utils.py | 4 +- 2 files changed, 64 insertions(+), 82 deletions(-) diff --git a/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb index c542a32793..dd535bf738 100644 --- a/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb +++ b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb @@ -363,37 +363,40 @@ "id": "e2d53e7b", "metadata": {}, "source": [ - "TransformerEngine includes a function `transformer_engine.pytorch.make_graphed_callables`, which functions similarly to the corresponding feature in PyTorch. It is capable of recording any modules from the Transformer Engine. Below is a code excerpt from `te_gemma.py`:\n", + "TransformerEngine includes a function `transformer_engine.pytorch.make_graphed_callables`, which functions similarly to the corresponding feature in PyTorch. It is capable of recording any modules from the Transformer Engine. Below is a code excerpt from `te_gemma.py` from class `TEGemmaForCausalLMCudaGraphs`:\n", "```\n", - " generator = GemmaGenerator(\n", - " lm_head=self.lm_head,\n", - " model=self.model, \n", - " inference_params=inference_params, \n", - " generation_config=generation_config, \n", - " dtype=hidden_states.dtype,\n", - " )\n", - "\n", - " (...)\n", - " if use_cuda_graphs:\n", - " fp8_format = Format.HYBRID\n", - " fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo=\"max\")\n", - " graphed_generator = te.pytorch.make_graphed_callables(\n", - " generator, \n", - " args, \n", + " def __init__(self, config : GemmaConfig):\n", + " (...)\n", + " \n", + " # Here \"the trick\" happens. We override methods from TEGemmaForCausalLM\n", + " # with their recorded version. After invocation of each of them,\n", + " # captured graph will be replayed with minimal usage of CPU,\n", + " # what will lead to huge speedup.\n", + " (...)\n", + " self._model_context_phase = self.record_graph(self._model_context_phase, self.hidden_states_buffer) # CUDA Graphs recording\n", + "\n", + " (...) \n", + " self._model_generation_phase = self.record_graph(self._model_generation_phase, self.generation_buffer) # CUDA Graphs recording\n", + "\n", + " @torch.no_grad()\n", + " def record_graph(self, function, input_tensor):\n", + " # function is invoked on argument (self.hidden_states,) and all kernels are recorded.\n", + " # record_graph() returns captured function, which can be run later with minimal use of th CPU.\n", + " fp8_format = Format.HYBRID\n", + " fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo=\"max\")\n", + " with autocast(dtype=torch.bfloat16, cache_enabled=False):\n", + " graphed_function = te.pytorch.make_graphed_callables(\n", + " function, \n", + " (input_tensor,), \n", " fp8_enabled=True, \n", " fp8_recipe=fp8_recipe, \n", " allow_unused_input=True,\n", - " num_warmup_iters=10\n", + " num_warmup_iters=3\n", " )\n", - " \n", - " (...)\n", - "\n", - " for i in range(max_new_tokens):\n", - " next_tokens = graphed_generator(*args) if use_cuda_graphs else generator(*args)\n", - " output_tokens.append(next_tokens.clone())\n", + " return graphed_function\n", "```\n", "\n", - "Let us now proceed to evaluate the performance improvement offered by CUDA Graphs." + "We strongly recommend reviewing the entire code of the class `TEGemmaForCausalLMCudaGraphs`. Let us now proceed to evaluate the performance improvement offered by CUDA Graphs." ] }, { @@ -405,7 +408,7 @@ "source": [ "#Restart the notebook (to flush the GPU memory)\n", "from utils import restart_jupyter_notebook\n", - "#restart_jupyter_notebook()\n", + "restart_jupyter_notebook()\n", "\n", "from utils import *\n", "\n", @@ -414,21 +417,24 @@ "hyperparams.qkv_format = \"thd\"\n", "\n", "hyperparams.generation_cuda_graphs = True\n", + "\n", + "# CUDA Graphs needs all kernels argument to be static - not to change between\n", + "# the time of recording and the time of generation.\n", + "# We need to allocate buffer large enough to fit all sequences.\n", "hyperparams.cuda_graphs_static_batch_size = 64\n", - "hyperparams.cuda_graphs_static_max_seq_len=1024\n", - "hyperparams.cuda_graphs_static_max_context_len=128\n", + "hyperparams.cuda_graphs_static_max_seq_len = 1024\n", + "hyperparams.cuda_graphs_static_max_context_len = 128\n", "model = init_te_gemma_model(hyperparams).cuda()\n", "\n", "print_sample_of_generated_texts(model)\n", - "benchmark_generation(model, 64, 128, 1024)\n", + "benchmark_generation(model, batch_size=64, context_len=128, max_new_tokens=1024)\n", "\n", - "hyperparams.generation_cuda_graphs = True\n", "hyperparams.cuda_graphs_static_batch_size = 64\n", - "hyperparams.cuda_graphs_static_max_seq_len=128\n", - "hyperparams.cuda_graphs_static_max_context_len=256\n", + "hyperparams.cuda_graphs_static_max_seq_len = 128\n", + "hyperparams.cuda_graphs_static_max_context_len = 256\n", "model = init_te_gemma_model(hyperparams).cuda()\n", "\n", - "benchmark_generation(model, 64, 256, 128)" + "benchmark_generation(model, batch_size=64, context_len=256, max_new_tokens=128)" ] }, { @@ -446,25 +452,6 @@ "| THD attention + FP8 + Cuda Graphs with TE | - | - | " ] }, - { - "cell_type": "markdown", - "id": "a2bd87e6", - "metadata": {}, - "source": [ - "We can also see how use of graphs reduced CPU overhead. Here are two screenshots from the profiler:\n", - "\n", - "
\n", - "\"Logo\n", - "
\n", - "Generation without CUDA Graphs\n", - "
\n", - "\n", - "\"Logo\n", - "
\n", - "Generation with CUDA Graphs\n", - "
" - ] - }, { "cell_type": "markdown", "id": "e6b171a0", @@ -496,18 +483,18 @@ "metadata": {}, "outputs": [], "source": [ - "# Import necessary packages and methods\n", - "import transformer_engine.pytorch as te\n", - "from utils import *\n", - "import torch\n", + "#Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "restart_jupyter_notebook()\n", "\n", + "from utils import *\n", + "import transformer_engine.pytorch as te\n", "\n", "hyperparams.model_name = \"../../../../gemma-weights\"\n", "hyperparams.fuse_qkv_params = True\n", "hyperparams.qkv_format = \"thd\"\n", "\n", "model = init_te_gemma_model(hyperparams).cuda()\n", - "model = model.to(torch.bfloat16)\n", "\n", "# Calibration\n", "with te.fp8_autocast(enabled=False, calibrating=True):\n", @@ -518,14 +505,14 @@ "with te.fp8_autocast(enabled=True):\n", " run_forward_pass(model, 10)\n", "\n", - "print(\"Casting weights...\")\n", - "model_fp8 = init_te_gemma_model(hyperparams, fp8_model_init=True).cuda()\n", - "model_fp8.load_state_dict(model.state_dict())\n", - "print(\"Weights casted\")\n", "\n", - "print(\"Saving model...\")\n", - "torch.save(model_fp8.state_dict(), 'model_fp8_state_dict.pth')\n", - "print(\"Model saved!\")" + "model_fp8 = init_te_gemma_model(hyperparams, fp8_model_init=True).cuda()\n", + "# model_fp8 contains only fp8 copies of the weights,\n", + "# model contains bf16 copies and scaling factors. \n", + "# Both of these are copied into fp8 parameters of model_fp8.\n", + "model_fp8.load_state_dict(model.state_dict()) \n", + "# saving only fp8 weights and fp8 metadata like scaling factors\n", + "torch.save(model_fp8.state_dict(), 'model_fp8_state_dict.pth') " ] }, { @@ -554,29 +541,23 @@ "hyperparams.fuse_qkv_params = True\n", "hyperparams.qkv_format = \"thd\"\n", "\n", - "hyperparams.generation_cuda_graphs = True\n", - "hyperparams.cuda_graphs_static_batch_size = 64\n", - "hyperparams.cuda_graphs_static_max_context_len=6\n", - "hyperparams.cuda_graphs_static_max_context_len=100\n", - "\n", - "hyperparams.fp = True\n", + "hyperparams.fp8 = True\n", + "# We load calibrated fp8 weights directly from the file.\n", "hyperparams.fp8_model_weights_filename = \"model_fp8_state_dict.pth\"\n", - "hyperparams.fp8_model_init = False\n", "\n", - "hyperparams.generation_cuda_graphs = True\n", "hyperparams.cuda_graphs_static_batch_size = 64\n", "hyperparams.cuda_graphs_static_max_seq_len = 1024\n", "hyperparams.cuda_graphs_static_max_context_len=128\n", "model = init_te_gemma_model(hyperparams).cuda()\n", "\n", "print_sample_of_generated_texts(model)\n", - "benchmark_generation(model, 64, 128, 1024)\n", + "benchmark_generation(model, batch_size=64, context_len=128, max_new_tokens=1024)\n", "\n", "hyperparams.cuda_graphs_static_max_seq_len = 128\n", "hyperparams.cuda_graphs_static_max_context_len=256\n", "model = init_te_gemma_model(hyperparams).cuda()\n", "\n", - "benchmark_generation(model, 64, 256, 128)" + "benchmark_generation(model, batch_size=64, context_len=256, max_new_tokens=128)" ] }, { @@ -644,20 +625,21 @@ "hyperparams.cuda_graphs_static_max_context_len=128\n", "hyperparams.cuda_graphs_static_max_context_len=1024\n", "\n", - "hyperparams.fp = True\n", + "hyperparams.fp8 = True\n", "hyperparams.fp8_model_weights_filename = \"model_fp8_state_dict.pth\"\n", - "hyperparams.fp8_model_init = True\n", + "# It impacts the behaviour of the load_te_model() function in te_gemma_loading_weights.py file.\n", + "hyperparams.fp8_model_init = True \n", + "\n", "model = init_te_gemma_model(hyperparams).cuda()\n", "\n", "print_sample_of_generated_texts(model)\n", - "benchmark_generation(model, 64, 128, 1024)\n", + "benchmark_generation(model, batch_size=64, context_len=128, max_new_tokens=1024)\n", "\n", - "hyperparams.generation_cuda_graphs = True\n", - "hyperparams.cuda_graphs_static_batch_size = 64\n", + "hyperparams.cuda_graphs_static_max_seq_len = 128\n", "hyperparams.cuda_graphs_static_max_context_len=256\n", - "hyperparams.cuda_graphs_static_max_context_len=128\n", + "model = init_te_gemma_model(hyperparams).cuda()\n", "\n", - "benchmark_generation(model, 64, 256, 128)" + "benchmark_generation(model, batch_size=64, context_len=256, max_new_tokens=128)" ] }, { diff --git a/docs/examples/te_gemma/utils.py b/docs/examples/te_gemma/utils.py index b316247640..2781633ed5 100644 --- a/docs/examples/te_gemma/utils.py +++ b/docs/examples/te_gemma/utils.py @@ -231,8 +231,8 @@ def print_sample_of_generated_texts(model): print(text) print("=" * 100) -def benchmark_generation(model, tokenizer, context_length, max_new_tokens): - inputs = tokenizer(["a" * context_length] * context_length, return_tensors="pt", padding=True) +def benchmark_generation(model, tokenizer, batch_size, context_length, max_new_tokens): + inputs = tokenizer(["a" * context_length] * batch_size, return_tensors="pt", padding=True) start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) From 2c0ea1fcb4e9524be4743f348e6cc3dd8ad6ef88 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 16 May 2024 12:16:00 -0700 Subject: [PATCH 125/244] Small code refactors Signed-off-by: Pawel Gadzinski --- .../te_gemma/tutorial_generation_gemma_with_te.ipynb | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb index dd535bf738..952c6397a8 100644 --- a/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb +++ b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb @@ -167,7 +167,7 @@ "source": [ "# Restart the notebook (to flush the GPU memory)\n", "from utils import restart_jupyter_notebook\n", - "#restart_jupyter_notebook()\n", + "restart_jupyter_notebook()\n", "\n", "# Import necessary packages and methods\n", "from utils import *\n", @@ -318,6 +318,10 @@ } ], "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "restart_jupyter_notebook()\n", + "\n", "# Import necessary packages and methods\n", "from utils import *\n", "\n", From 0f16bf82c7cf095be8c390223de9a0e4b647b503 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 16 May 2024 13:33:44 -0700 Subject: [PATCH 126/244] Small code refactors Signed-off-by: Pawel Gadzinski --- .../tutorial_generation_gemma_with_te.ipynb | 114 +----------------- 1 file changed, 4 insertions(+), 110 deletions(-) diff --git a/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb index 952c6397a8..5abf8f07c2 100644 --- a/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb +++ b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb @@ -127,43 +127,10 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "7477e469", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.10/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n", - "You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.\n", - "Gemma's activation function should be approximate GeLU and not exact GeLU.\n", - "Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu_pytorch_tanh`, edit the `model.config` to set `hidden_activation=gelu_pytorch_tanh` instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.\n", - "Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 2.59it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Another string ... \n", - "\n", - "I have a new 2019 15\" MBP with 2.6 GHz i7, 16GB RAM, 512GB SSD.\n", - "\n", - "I have a 2019 27\" iMac with 3.6 GHz i5, 16GB RAM, 1TB SSD.\n", - "\n", - "I have a 2019 13\" MBP with 1.4 GHz i5, 8GB RAM\n", - "====================================================================================================\n", - "I love the idea of a “\n", - "====================================================================================================\n", - "Benchmark with context_length=128 and max_new_tokens=1024 took 8616.48 ms.\n", - "Peak GPU memoty usage: 30.96 GB\n", - "Benchmark with context_length=256 and max_new_tokens=128 took 8430.52 ms.\n", - "Peak GPU memoty usage: 31.83 GB\n" - ] - } - ], + "outputs": [], "source": [ "# Restart the notebook (to flush the GPU memory)\n", "from utils import restart_jupyter_notebook\n", @@ -244,79 +211,10 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "4fc5e1cd", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in TEGemmaForCausalLM is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained(\"openai/whisper-tiny\", attn_implementation=\"flash_attention_2\", torch_dtype=torch.float16)`\n", - "Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in GemmaModel is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained(\"openai/whisper-tiny\", attn_implementation=\"flash_attention_2\", torch_dtype=torch.float16)`\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n", - "[W init.cpp:767] Warning: nvfuser is no longer supported in torch script, use _jit_set_nvfuser_enabled is deprecated and a no-op (function operator())\n" - ] - }, - { - "ename": "AssertionError", - "evalue": "Data types for parameters must match when outside of autocasted region. Found input dtype: torch.float32 and 'layer_norm_weight' dtype: torch.bfloat16", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[4], line 15\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[38;5;66;03m# Init the model and accelerator wrapper\u001b[39;00m\n\u001b[1;32m 13\u001b[0m model \u001b[38;5;241m=\u001b[39m init_te_gemma_model(hyperparams)\u001b[38;5;241m.\u001b[39mto(torch\u001b[38;5;241m.\u001b[39mbfloat16)\u001b[38;5;241m.\u001b[39mcuda()\n\u001b[0;32m---> 15\u001b[0m \u001b[43mprint_sample_of_generated_texts\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 16\u001b[0m benchmark_generation(model, \u001b[38;5;241m64\u001b[39m, \u001b[38;5;241m128\u001b[39m, \u001b[38;5;241m1024\u001b[39m)\n\u001b[1;32m 17\u001b[0m benchmark_generation(model, \u001b[38;5;241m64\u001b[39m, \u001b[38;5;241m256\u001b[39m, \u001b[38;5;241m128\u001b[39m)\n", - "File \u001b[0;32m/perfhome/tutorial/TransformerEngine/docs/examples/te_gemma/utils.py:228\u001b[0m, in \u001b[0;36mprint_sample_of_generated_texts\u001b[0;34m(model)\u001b[0m\n\u001b[1;32m 225\u001b[0m inputs[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124minput_ids\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;241m=\u001b[39m inputs[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124minput_ids\u001b[39m\u001b[38;5;124m'\u001b[39m]\u001b[38;5;241m.\u001b[39mcuda()\n\u001b[1;32m 226\u001b[0m inputs[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mattention_mask\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;241m=\u001b[39m inputs[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mattention_mask\u001b[39m\u001b[38;5;124m'\u001b[39m]\u001b[38;5;241m.\u001b[39mcuda()\n\u001b[0;32m--> 228\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgenerate\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmax_new_tokens\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m100\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 229\u001b[0m generated_texts \u001b[38;5;241m=\u001b[39m tokenizer\u001b[38;5;241m.\u001b[39mbatch_decode(outputs, skip_special_tokens\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 230\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m text \u001b[38;5;129;01min\u001b[39;00m generated_texts[:\u001b[38;5;241m2\u001b[39m]:\n", - "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator..decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 115\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m/perfhome/tutorial/TransformerEngine/docs/examples/te_gemma/te_gemma.py:257\u001b[0m, in \u001b[0;36mTEGemmaForCausalLM.generate\u001b[0;34m(self, input_ids, pad_token_id, max_new_tokens, *args, **kwargs)\u001b[0m\n\u001b[1;32m 255\u001b[0m \u001b[38;5;66;03m# Context phase\u001b[39;00m\n\u001b[1;32m 256\u001b[0m TEGemmaForCausalLM\u001b[38;5;241m.\u001b[39m_padding_to_end(input_ids, lengths)\n\u001b[0;32m--> 257\u001b[0m hidden_states, next_tokens \u001b[38;5;241m=\u001b[39m \u001b[43mTEGemmaForCausalLM\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_generate_context_phase\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 258\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 259\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 260\u001b[0m \u001b[43m \u001b[49m\u001b[43minference_params\u001b[49m\n\u001b[1;32m 261\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 263\u001b[0m \u001b[38;5;66;03m# Generation phase.\u001b[39;00m\n\u001b[1;32m 264\u001b[0m inference_params\u001b[38;5;241m.\u001b[39mthd_setup_before_new_input(next_tokens\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m1\u001b[39m))\n", - "File \u001b[0;32m/perfhome/tutorial/TransformerEngine/docs/examples/te_gemma/te_gemma.py:218\u001b[0m, in \u001b[0;36mTEGemmaForCausalLM._generate_context_phase\u001b[0;34m(self, input_ids, inference_params)\u001b[0m\n\u001b[1;32m 216\u001b[0m \u001b[38;5;66;03m#self._model_context_phase = self.record_graph(self._model_context_phase, hidden_states)\u001b[39;00m\n\u001b[1;32m 217\u001b[0m hidden_states\u001b[38;5;241m.\u001b[39mdata[:] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel\u001b[38;5;241m.\u001b[39membed_tokens(input_ids)\n\u001b[0;32m--> 218\u001b[0m logits \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_model_context_phase\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 220\u001b[0m \u001b[38;5;66;03m# We choose logits coresponding with last token in each sequence,\u001b[39;00m\n\u001b[1;32m 221\u001b[0m \u001b[38;5;66;03m# which have various lengths - they are stored in (inference_params.incoming_seq_len - 1) Tensor.\u001b[39;00m\n\u001b[1;32m 222\u001b[0m logits \u001b[38;5;241m=\u001b[39m logits[torch\u001b[38;5;241m.\u001b[39marange(logits\u001b[38;5;241m.\u001b[39msize(\u001b[38;5;241m0\u001b[39m)), inference_params\u001b[38;5;241m.\u001b[39mincoming_seq_len \u001b[38;5;241m-\u001b[39m \u001b[38;5;241m1\u001b[39m, :]\n", - "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1518\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1519\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1523\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", - "File \u001b[0;32m/perfhome/tutorial/TransformerEngine/docs/examples/te_gemma/te_gemma.py:80\u001b[0m, in \u001b[0;36mStaticGemmaModel.forward\u001b[0;34m(self, hidden_states)\u001b[0m\n\u001b[1;32m 78\u001b[0m hidden_states\u001b[38;5;241m.\u001b[39mdata[:] \u001b[38;5;241m=\u001b[39m hidden_states\u001b[38;5;241m.\u001b[39mdata[:] \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnormalizer \u001b[38;5;66;03m# static operation - for CUDA graphs\u001b[39;00m\n\u001b[1;32m 79\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m decoder_layer \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel\u001b[38;5;241m.\u001b[39mlayers:\n\u001b[0;32m---> 80\u001b[0m hidden_states\u001b[38;5;241m.\u001b[39mdata[:] \u001b[38;5;241m=\u001b[39m \u001b[43mdecoder_layer\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 81\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 82\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 83\u001b[0m \u001b[43m \u001b[49m\u001b[43mself_attn_mask_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 84\u001b[0m \u001b[43m \u001b[49m\u001b[43minference_params\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minference_params\u001b[49m\n\u001b[1;32m 85\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# static copy - for CUDA graphs\u001b[39;00m\n\u001b[1;32m 87\u001b[0m hidden_states\u001b[38;5;241m.\u001b[39mcopy_(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel\u001b[38;5;241m.\u001b[39mnorm(hidden_states)) \u001b[38;5;66;03m# static copy - for CUDA graphs\u001b[39;00m\n\u001b[1;32m 88\u001b[0m logits \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlm_head(hidden_states)\n", - "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1518\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1519\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1523\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", - "File \u001b[0;32m/perfhome/tutorial/TransformerEngine/docs/examples/te_gemma/te_gemma.py:54\u001b[0m, in \u001b[0;36mTEGemmaDecoderLayer.forward\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 53\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs): \u001b[38;5;66;03m# We need to pass positional encoding.\u001b[39;00m\n\u001b[0;32m---> 54\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrotary_pos_emb\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mte_rope_emb\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m/perfhome/tutorial/TransformerEngine/transformer_engine/pytorch/transformer.py:624\u001b[0m, in \u001b[0;36mTransformerLayer.forward\u001b[0;34m(self, hidden_states, attention_mask, self_attn_mask_type, window_size, encoder_output, enc_dec_attn_mask, is_first_microbatch, checkpoint_core_attention, inference_params, rotary_pos_emb, core_attention_bias_type, core_attention_bias, alibi_slopes, fast_zero_fill)\u001b[0m\n\u001b[1;32m 618\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m cast_if_needed(\n\u001b[1;32m 619\u001b[0m hidden_states, torch\u001b[38;5;241m.\u001b[39mget_autocast_gpu_dtype()\n\u001b[1;32m 620\u001b[0m )\n\u001b[1;32m 623\u001b[0m \u001b[38;5;66;03m# Self attention.\u001b[39;00m\n\u001b[0;32m--> 624\u001b[0m self_attention_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mself_attention\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 625\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 626\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 627\u001b[0m \u001b[43m \u001b[49m\u001b[43mattn_mask_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mself_attn_mask_type\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 628\u001b[0m \u001b[43m \u001b[49m\u001b[43mwindow_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mwindow_size\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 629\u001b[0m \u001b[43m \u001b[49m\u001b[43minference_params\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minference_params\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 630\u001b[0m \u001b[43m \u001b[49m\u001b[43mis_first_microbatch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mis_first_microbatch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 631\u001b[0m \u001b[43m \u001b[49m\u001b[43mcheckpoint_core_attention\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcheckpoint_core_attention\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 632\u001b[0m \u001b[43m \u001b[49m\u001b[43mrotary_pos_emb\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrotary_pos_emb\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 633\u001b[0m \u001b[43m \u001b[49m\u001b[43mcore_attention_bias_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcore_attention_bias_type\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 634\u001b[0m \u001b[43m \u001b[49m\u001b[43mcore_attention_bias\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcore_attention_bias\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 635\u001b[0m \u001b[43m \u001b[49m\u001b[43malibi_slopes\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43malibi_slopes\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 636\u001b[0m \u001b[43m \u001b[49m\u001b[43mfast_zero_fill\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfast_zero_fill\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 637\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 640\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mapply_residual_connection_post_layernorm \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput_layernorm:\n\u001b[1;32m 641\u001b[0m attention_output, attention_bias, residual \u001b[38;5;241m=\u001b[39m self_attention_outputs\n", - "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1518\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1519\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1523\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", - "File \u001b[0;32m/perfhome/tutorial/TransformerEngine/transformer_engine/pytorch/attention.py:4633\u001b[0m, in \u001b[0;36mMultiheadAttention.forward\u001b[0;34m(self, hidden_states, attention_mask, encoder_output, attn_mask_type, window_size, is_first_microbatch, checkpoint_core_attention, inference_params, rotary_pos_emb, core_attention_bias_type, core_attention_bias, alibi_slopes, fast_zero_fill)\u001b[0m\n\u001b[1;32m 4630\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mattention_type \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mself\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m 4631\u001b[0m \u001b[38;5;66;03m# Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn]\u001b[39;00m\n\u001b[1;32m 4632\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minput_layernorm:\n\u001b[0;32m-> 4633\u001b[0m layernorm_qkv_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlayernorm_qkv\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 4634\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4635\u001b[0m \u001b[43m \u001b[49m\u001b[43mis_first_microbatch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mis_first_microbatch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4636\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 4637\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mreturn_layernorm_output:\n\u001b[1;32m 4638\u001b[0m mixed_x_layer, layernorm_output \u001b[38;5;241m=\u001b[39m layernorm_qkv_outputs\n", - "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1511\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1509\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1510\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1511\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1520\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1515\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1518\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1519\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1520\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1523\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", - "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py:417\u001b[0m, in \u001b[0;36m_TorchDynamoContext.__call__.._fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 415\u001b[0m dynamic_ctx\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__enter__\u001b[39m()\n\u001b[1;32m 416\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 417\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 418\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 419\u001b[0m set_eval_frame(prior)\n", - "File \u001b[0;32m/perfhome/tutorial/TransformerEngine/transformer_engine/pytorch/module/layernorm_linear.py:1153\u001b[0m, in \u001b[0;36mLayerNormLinear.forward\u001b[0;34m(self, inp, is_first_microbatch)\u001b[0m\n\u001b[1;32m 1150\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m skip_fp8_weight_update \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 1151\u001b[0m is_first_microbatch \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[0;32m-> 1153\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprepare_forward(inp, is_first_microbatch) \u001b[38;5;28;01mas\u001b[39;00m inp:\n\u001b[1;32m 1154\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfp8 \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprimary_weights_in_fp8, \\\n\u001b[1;32m 1155\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mNeed to run inside fp8_autocast region when weights are stored in FP8.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1157\u001b[0m \u001b[38;5;66;03m# Get concatenated weight and bias tensors\u001b[39;00m\n", - "File \u001b[0;32m/usr/lib/python3.10/contextlib.py:135\u001b[0m, in \u001b[0;36m_GeneratorContextManager.__enter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 133\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mkwds, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfunc\n\u001b[1;32m 134\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 135\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mnext\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgen\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 136\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mStopIteration\u001b[39;00m:\n\u001b[1;32m 137\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mgenerator didn\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt yield\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n", - "File \u001b[0;32m/perfhome/tutorial/TransformerEngine/transformer_engine/pytorch/module/base.py:591\u001b[0m, in \u001b[0;36mTransformerEngineBaseModule.prepare_forward\u001b[0;34m(self, inp, is_first_microbatch, num_gemms, allow_non_contiguous)\u001b[0m\n\u001b[1;32m 588\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtp_size \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[1;32m 589\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtp_group_initialized, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTP group not initialized.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m--> 591\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mset_activation_dtype\u001b[49m\u001b[43m(\u001b[49m\u001b[43minp\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 592\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minit_fp8_metadata(num_gemms\u001b[38;5;241m=\u001b[39mnum_gemms)\n\u001b[1;32m 594\u001b[0m \u001b[38;5;66;03m# Create persistent tensors for fp8 weights and their transposes\u001b[39;00m\n\u001b[1;32m 595\u001b[0m \u001b[38;5;66;03m# only when fp8 weight caching is used and weights are not in fp8\u001b[39;00m\n", - "File \u001b[0;32m/perfhome/tutorial/TransformerEngine/transformer_engine/pytorch/module/base.py:443\u001b[0m, in \u001b[0;36mTransformerEngineBaseModule.set_activation_dtype\u001b[0;34m(self, inp)\u001b[0m\n\u001b[1;32m 441\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m name, param \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnamed_parameters():\n\u001b[1;32m 442\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m param \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 443\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m dtype \u001b[38;5;241m==\u001b[39m param\u001b[38;5;241m.\u001b[39mdtype, (\n\u001b[1;32m 444\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mData types for parameters must match when outside of autocasted region. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 445\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m Found input dtype: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mdtype\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m and \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m!r}\u001b[39;00m\u001b[38;5;124m dtype: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mparam\u001b[38;5;241m.\u001b[39mdtype\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 446\u001b[0m )\n\u001b[1;32m 447\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m name, buf \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnamed_buffers():\n\u001b[1;32m 448\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m buf \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", - "\u001b[0;31mAssertionError\u001b[0m: Data types for parameters must match when outside of autocasted region. Found input dtype: torch.float32 and 'layer_norm_weight' dtype: torch.bfloat16" - ] - } - ], + "outputs": [], "source": [ "# Restart the notebook (to flush the GPU memory)\n", "from utils import restart_jupyter_notebook\n", @@ -591,10 +489,6 @@ "id": "2dd0cba9", "metadata": {}, "source": [ - "
\n", - "\"\"
\n", - "Fig. High precision vs FP8 vs FP8 with fp8_model_init() in TransformerEngine\n", - "
\n", "\n", "As we have seen above, generation in FP8 precision results results in considerable speedup. Neverthless, memory usage is no different than without FP8. The reason of that is that TransformerEngine stores parameters in higher precision and only casts them to FP8. It is also true with the optimizer state. It is needed to maintain accucacy during training. However, we can get rid of high precision weights when doing inference. \n", "\n", From cd2566fd8b8bb89ff15d845ee839926ed86f7f10 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 16 May 2024 14:52:03 -0700 Subject: [PATCH 127/244] Small code refactors Signed-off-by: Pawel Gadzinski --- docs/examples/te_gemma/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/examples/te_gemma/utils.py b/docs/examples/te_gemma/utils.py index 2781633ed5..d00f108d20 100644 --- a/docs/examples/te_gemma/utils.py +++ b/docs/examples/te_gemma/utils.py @@ -26,6 +26,8 @@ def __init__(self): self.mixed_precision = "bf16" self.model_name = None + self.fp8 = False + # Weights in fp8 self.fp8_model_weights_filename = None self.fp8_model_init = False From 3501548c42da7271db0c95e278ac859a754d706b Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 16 May 2024 17:06:09 -0700 Subject: [PATCH 128/244] Cosmetic change Signed-off-by: Pawel Gadzinski --- .../te_gemma/media/thd_dimensions_2.png | Bin 25116 -> 15653 bytes .../tutorial_generation_gemma_with_te.ipynb | 13 +++---------- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/docs/examples/te_gemma/media/thd_dimensions_2.png b/docs/examples/te_gemma/media/thd_dimensions_2.png index f8001c669e2f980f4f03667042f7d5bd249ebda3..223859b741e5c727a397678671e19f7302a1da0f 100644 GIT binary patch literal 15653 zcmeHu2{4xb-!4)rA}yAPY)Ol?veP0|qL3~78p)pZK~gB9s3cpZNJz+*-IFa_B}>_} zCHoe#o$K!RpLgcWnKSdAdCz;6nfEvIo2TM_?)&@wewOQVUDtj4ysFZc&AT>JP*7|+ zt*oFkrn*g0RecbQmMB1b{7hvKxtaZUG_u}%*c_JP>7B{%!<_t_ImmNj~|g&%2i zDok_UZYWpHIaw_c)j`K~FQO(U%4>s@&*C{%J5gO_<4ng*eiU5yuKT&K_rDr4MU}&M zDf-Zr?BCwj*HQ{bR(s?B=#{lmb-df>YdG|ig}Hf-l=8lRbyh=7);lqd2 zUP0myykB>?6MvumcYKWEl73q{^3$ZGz|~bB84ia0{Cr+MKJOpb7FK*by}VdiSW1Ux z?#ITSGd2CeZx|TE7GNSQEPPB&je#R9aCX+^OH&is;8D=cAWwJm7Fuemug%R7(lv_4IZ{L_`!WjTgAMxZs~8-L`Gt zemf-O#p~C1Dl01;B6+yEUX=1=#WmZ!eE06b!vop@1L`dgf`T@FtgY33?~*KN5p^9u zd;9httx(~;y1Kf*mn4?*#P4DXULRbsZCv!NURu_3dP#B1S`R+%XD^S}nwYa}z+FSa z!unjKZU+UG_Du=P$%WPqaEpqHE-*|jEJ*m2a3(d~?)3SSd)2|=-m_<1`hv`xHf@q% z553nSFDIw!>nrbf?D~TTJLt%zy$)L$85#AfCg#P2g|D}c@NscbCN+6{XJ94^v#_y| zEwdPR?z~lDYGP_iRnWQWx8(Ku*q&qf2E+8og#~wgLG8D@$-ZKIL{U{&k(afd-C$z09rynIcRqi=@=lkZDaEbzmPn|R^PfyQRZ{7^HWtfG8hK4_SRQWQ} z@QjVk{!6)b13y0}NU+o22@Xzv^-7|TFB)rkb~2EKm9=cx<;VO|a$4HJhuhmVwY1dr z_328d`l=*J3&QvB-)|dpegFRbzJmw*Z_o$g zV0!uT#@zOEZf@et`VN!>0|T6rdw$Q&WtdgbKZ%ZZ%-64rmycgh%e0k^ZRvi_ykm1q z%k=Md`=++GYQ9MQ6RN7J!Y(=4+4mzN$jgf}{%$cIG8|IU(gK2lmc8Y-6ciPy_{ryy zYAY)%I1phOnK~TE%ur)eULFe)2w`;O$dSW`4+|eYtbpUb(vr$0FCTAE;OdBdwmr(W zxVU)G;QfP;kPwDc>8%^j8yTtLSp_2xEX+4Fs17#9H-___rMPTwpY;0mW{RYy(~%6uOJxKTU1zEaxFr1g370=MhOs$uj9;obin%QcjN;7IL3QNy4drBbiL* z<>htgy~V7qrKMp2U%zUieu&Wd-(?g3f*sZ_`E$|9$%&#Z%ZkBwbtXQ4 zG_zJG--(Zc_zMxdI0^m(2WU&W!>amv5o60^_P%BVA0HB1hA(rS=&B#@EFeM@k55hI zC)U#W_3MY^8lP4&0>^eVI-es%#`0@Gk=coq_h}CZeZbHbM4yW2fR&+8|VVpxz zR&iy*bZ_#0Y~HqQn}p}=zDBvFgW}yJvD!!>l!pr%8f2tj#y8tk{rvE@Hl2e^w9(@B ziHV5;*G5|IM@G)y?@~%{tqs`1CgHOzfuPvL5g|4_JpAmI!GSob4YNxPUYE z@sSIp-~D4C(#W!_(Bnl?Qv6(@l;aO-X}57+Y}t`3Er>QErF+W!C1O<{9n?K>=@Oez z;UpEZK;LI^z}(#Y)A#Qm@x&jht2f3k)>)y*7q0wnPc zyS(1z&XZ zgA_;YI;xAkm-7DnO7Ztg%E@{3{kxuB;gn)#R#r`Yy~RXVk!@FD)SWwm8+AM?C>*DH~XBlq#aGWSVc%N_p7QqUD49?Q_b{`}sWhM$vlWEXV;{P^GBX zHEJ?+mfxbSsi`6AAog%|I|rtST; z1GLP7ANr=SIm?t`Rb;o~ncYl6h9MQE&k$HfE5Edn9~!#4i+%+1%DLKJIanW^KbF&u zTYRw0y8q6y5q5(P5=e*_iHWLgocA6)xY<6o zS6tlfJI!Tv^=%0W33;94J()WXXc=edptV^!I0#FuK70A{R#{n@_1M7R;B<4k&qL0& zg{s2Ia#~dRh^VNVJ39|3pF0r{@apmYklRl#&wMYUp)0K`E>_ojK|0)9a{N$ifC4C3|}z1kVT44C8FQ z?$(CS2};91f9~PoSr=U&fX5aSi|Fs~@9~_Pno?0w0Zfxi@RkQ`8N`VspqO1u(aN~| z`FYR4(=2o@tFiCdKvAr@Rv9`$!NJj=+qU9s>MeOtRnW~+TO#K9KU7r#b(}tZn(NS^ zLwP<(3IYp;hHi*7a41{f;j=V`6^n_AmY0@N04uq;xixim%BB^CgwTz@9})CftTmec za#rYbSC=ZVNs4wly^nkRxf3Ud+O#~Z=KCBQ`TRLIpo+kO1H_Gy3`8GkYhwZ6`2G8L z&y`qJexhj~Ir90-mvRJBUSFv6J|Q87RaFZMi~BJ#q!z<)KR?PYkD=WGfq`hd6IYUo zi{la#_io<2`RzAb;6Svwq|V`;Y;3_MEJ*&w%iX=^=MCOnX1sRo+SjjNiEb)-`0&k6 zJKzKz6QzDL^`0oc@bK`&Egxfcz4CZJ5MZLR^r!?$h{HlP0?+{vbrAkFEXZ?EKkr__ z`1{~MPM^m)IY(E=x>v7On$Q@dR_l9>WfMWMJnmM)vv+SmzVnb%=e5@Rzyw}i(%#a-B_foU5epn8bqz4-iBEFVE7_?*9CF83P0$@`s7}lJvCj@NGm3=I~iCA*$Cp z1{vYLx9iSpYJO1vZosf7fUfS7drDTu z>`$wz(uaoKzI%6zp=oq@e7u^b<_=u4u(E1AQ8#QO6P?^nHZwJ?ZfJ;WIqPC0on$wo z-=#k8AVE}B_tnKAkHvw=Zf$bZ?C;-BqirGq0ReKF(P-)YLqm-V4VlS3GED;-bY$bk zIN9a7bl(cQ+$_&I8=@jCEG*E_SH?ZrlbM+nlBr_oEs^K^V`(#Adw5 zfPz(|Hw;Y|>Y*C|l@T4DUEU{VeB5r-wSdEe%ZYy9B~SrKEiXUeI`+K+b;GB`_dCslsy@k) zmw_=c>Nis7qc3{`kF&3iwFKr~!2>xO;*Cd?tOBcRTQqO`Z?sg4)Lt zv$N+;-tCS0HMJyMw3qXukrAgaYR{{bls@V1MfylXPXY8o97>PJwdDmzZf+ zDkPk;=g;ZsX&s>Dk=FDf;K5Dw6s({T>ckjK0Rxs($ zxqMxIZLtxA2C!rRkTwQCj5jekgTH>+bl8alr|#IflaZB`=vn(vgSokHeE6UsSTHIY z=ez3hYh*+K-QB0>$kC%;T3Qr!bywjPl~-C* zd;Lo5n*l@;(Q~lI9=;G@1-(m`@Kt<*f_C3CWFDHZG==gXv$7hp_A*S>x-w~(kZeDP zut-TSezJ~+o;qv<^Q>nJ!@|3Fhk<&c9zT}o^GwRrImo+rZ^ORQb7LdPRUCljey7tKajR~OB8^81x^Pz8BF1xTQwA^9~)0wr5p zTVvoi&ieLqe0;T3x{01st`7CqK`NFQ>p-3`3F@;H(n(#L*E)G}9S#w{P3hJ|h7bM@Gyl?(77HzZV-@kE;OaK_Q7r zN}j%QU6kBt$m5Ha-u0A+y>K<7CLrUct-ik5^~H9R8XHkOAu05uxl z3Jg&uz-x2milbC4(&{q;g*2vlIm&a-aN@aW_lcr;EFRqfTn*du%1YHr+^&hEM+EmB zJ^Bd9^$QBCMWg{!;T2T{Iga&l!Zy~{&p@#j7RrEQK*M| zOvmFc?Et8V8it#m8`(7NS#(f0^Cn8`vXA_&Teo=m`DwhaqWJ2fqp`BG5;R|BWp!O0 zFW|0wvh^=EZ+-oaB{`R@%*@9a#{ihHIieyWWj#H{lP5t%+(9!LY=~ux>1`FJ!=Yk6 ze2P&F{|a>HJk(Gz;Iy7br=Zgg8ADX~kvR;+r_t^(mlI{hs)6WH2#-Covau*n_%1Xh zD_h%x5-Coeo>G|f6_k{kKz#s-Ra8`%TUv(SyLU`qpPhg{ePY>ZxJXe+NrbM?dnq;| zV*9dUd}O2&t^^t*0Jm$w>2v4!jvUzug2~6nr!6kG`SWMSLg_;|8{`1S8CDF7B_;AK z%6&h6D66QL8yjE0eR~t|F4nR~QZnks4Jv>u4C7_BcLCTRJbDDCPi4>6J<`(Ab#*H5 z3krrOCio8>x_c^^2?R(W9v|%s9Rxv;(LRRGL6DjFux?BN!wqW=%nW$H z01I<-0F4{%5_9}QLSZ*<+%OjU^7UZ-%a<>Kdr9AWoER5>9<2PfckZ6r`uN7Go}N8O z?Y8a0b3gASUwY8EGpuSm`>a&6rLFBK&2Kqk#1miEn0O-D+~3cOMFQZX0xWdR7gqrt zEadtwiGw8(NqBGhC6oldJO0_TXU57%j4~x{or8~@|LB0@XB)lG4vMvJYAZn(hw(&t z(jjO2Y3M)AFOm>A)HI*vmnX@#o$nul5spNs^HjW6|034;s?G9~Zya9S+f@Ht6}_EY zv!-g1p>twdCSIc|sm{D75wLkAtyCkT?%d+j3`Wk}-HQ&kN86XW%)r=(`}HKnrHHXE z$X+`=rZVcf^e!$Z%*l}W+H=R1`a7DMzF24x*`Jo%_59g0zrjIEf|Lr~EzLygNvKv5 z5f(0ull20XoQOHVt;-`2C6wXFR6-Ex{qU8Wb^D=lV7el}7Z_J!#UjVJIf-q;tdF=N z2u{q&C>BYv$1W(;oi=kf*NS@IFMQ;Fd(>Y39Z|`?GUgf2e3_F-1j1fh-OvP>32hsJJn_ z67d=v8yf?%7`+SWW3+>9b8bO?7aTu#6R%!f}Y+PuyL5z z2>u^8LGqRBB-dGT1($*w>FCgRa(2!K2?>!T0c2mSv|FafjAfViya9lWva0GJR$)2y zqvF~AQwj=l6)3JAm>`$34x-NI84yttrR*&#;A(l6WwZO&(h)=Ae?{#_QeOAuk=70VO-?>S)LRGH7brO?0@=Us65#F zEys+FUrKfBM(Pt8i@F$)Zru7=eR=e#vzmn+%Lc=A=7LTiLz$$?kvT91;44vo93r=? zqs)U6nE6q(g_;Tt&ot{>LvR%X{zTQ|AW^Pg!w*mym6<7u_-K6prmBx?>&6}zDS5w- zyd5W|+gGb6I8071Z73AzQgg5&PYNq;dA{p;p3~|>^CYXlf|z?xVq*Hoa=R0sKJ6)d zENj0OEii#Lg_hMr17`XWrj2uu9bP5o-9e6k=z#HTU(;+?cKRsQo ztxBK@0a52es0eulpyt8Dhu&1pr}J(d3|&Z{>?`P^nR;ep^5|7SwGvN0_+;H|(#5=# zP03L1RAVH{K%O>&iP$~vbYh%Fet{BTfsg|d*ETGC)lhT9fM-`Urv=fIY@n*VTh{YI zQk>_V>1ijBm!ZVtl&sFB13L8j(QpWRJ9KAXA@G-&VF{oa@{Pbw!NIg97NTr7bJqBF zE7Mxj_PUMsbD}*bPR#xM^$W-jGqLF(?^_XJCn>>qh^~yj(O}S5QC1dS@oMkv+!DjN zw0SEhXISG9>JxdktSB+jvU23licct~e8ZLWB1I1m$(=iQa;2(RuG^x=#CU(E0cs;Q zZL~c%?9ZP+5N6{k)^jj2N=^V-UjT)Ho*Br%X$jS;gYR^Zz+_NRwMWIoJTw$Aja$HK zYhJuaV6^c}-VcQA-Cy{ny}cF?^|Gxk^-~aHMPBpIt*oP?iSUcjCZ-%vteCWERJ+o$ zp*%rHBkM=q49CwYezvE@FoR9|R^q$%uD!L>PQ1Wv!g8n~)+ANs#5q4D1%(DS;%UJ0 z0#2%3>VlH<#-M#|PWhwNSQC3r{8g>Z> z<~e?x5{{nFBC#o*Q4FbgW`?_odhfN6_krRD>VbpSNIwyWBUux44gD=~VdQBRCGw{DY{#SM)E6_ASE}nsW z^_a9U#~mii`;?PjE9_Tf6wjPEL!@GHv0Q*s0HKm4Ckr6yI{?eM2qw*=&){*Ty8z3l9X%Z&=umy0u5H}DXj zgFmz0(#+hPI;`rXzP^55?IwYO&d~0OD&dfsA2K<5X)FPLopx0=N9OkhKvu`iv0c4_ zr!1e{qmFyLz%n}Ex@(a0_RE)&a4cb-luLliMb4ZmxuX z0m=X2gMuZ1aeSMC!U@H{D^XGV|MH`ep8$F7@`0Gg)5gY!y-iR@;3_}+)IJJP*4y( zK0f}h)DE-y;=jf!UwCs-=*gDrz2ZF^2N~Raj44)SX076`_zHG3-;5D4JztRWXl3d5 zb|f@WNs^Liz;s-@c7%~}Q0{(JRTXexALxvsQyW=!?JA#`cZRr@s^z@qKRV@35W9U- zHw$(|-NHJz^=-SY!|XSd!l6uyl~(f@f3q%Ic!S(RB44BIrsxaaP+6}?>}KuS)!wYW z_yMoAWm%}a>}jTzRyaW7R|3OA_oPD$`|G{%Y&Z1usUB7 z)-+6qY|FN8+Uq*nb{x|?QBSHn9F34v)I>vPlw{b4<{cC?O!V;oypN3(Ou>8|d_h4s zL;icfDci9gw|#sZq_Qv_MXJ_FYQBmsi{d1WrS{c$Ug8n&hgpMQgBJsn2j;eedhdeZ zMstWK5y0%)`g_T(7qz+r;i>j<8hPiOLvnii!@NjIeX~ zERZ>!dP@DAy8=vD08$xNq|oBQu7}C5O2{wPs7YY2B-6(K_PN!weR~qm^=m_gVRvYZK^fU*9#Qu7C>DMFl}2OdnGdYpW9- z7|y{)1Qr%WsWSj8DF?+!=`7fe#8?-B#m3Ch>D)4yJmRCN5!sfQAJl z@X$gp`^=CHW9hJiu7cJcI1*i4>ejYp&EVe%qk zMBj6m800g(=IxH!ww1!3weY94qyaO3d6@6A*@g`p@+OLx^436IMJ_v>e-S<%e=R%GG?Pw^Kmz@R+XQhlD_AS2AnMGmxIA>g|0Zz?FIT&z~aN zNk2P5XJW?o8qW_ToL^m}nR6T-{fq$79zA|EgmV+V82&i-8@{VUzJukoLgqE$e6aWU ztc(~D4gn;O0DN`@-7A71HD&(%Oz3cKlqF3C8^yF3f{8Bq)0~b0`mZ$<8UtK8nZW3A zK7Xv44(gPyXV#}66giv>317$$<{PoOx>T|jofPl$1jGWOt1-`>Ag00YLZc@rUkVU} z7#3clSc68CgDeyQww_MdVOlT;f;|Um@_6y=OCieh+S*D`D!^ls=2DC*Of%qvYa}zs zhd_EGR1@G9LU8I{8FyQ5^<7~m^i%(;S@IVaqIfv(!_z=CQ(ZDP%yZy^`;HK|OmJQl zphb}sWO(u@Pzb^Q!Tm!J$(9{mm>%56vRiesEVp2}6bLUu9>A0Y1+DDbn5#69CWurj z3{`6I!D4EF2R0`A5Wp73ADCyKgC|54f`*S{d3iQc^@`PNep+J)zThQ6U(pjLQla7n zZj9~yBSvddAYsh#Q+O2LUTGE7O4B<9f<3?l4v-I+Q6RTlVE4r21lN>h36g5))A0D$ zGb$>);2UAwd!3R(psVQzt`hgS$e8BNy;Ww{Ki6HwB(N8%|gWME3JN64y|EY{oWz*Yf~ zSqLUsIC?SmrKg9%I0{1&81s$AJ}{jSj0xwJP@ccNMvX85gLVQf4P@ayno;Er>Sp10 zZb|)aw+_mf(Tv3^>Bp+#ZmK9ga2XKn8z^aT{&FqsvSQ6JUxO`c)pyHKLp+^<97M4p4(|H}h{?FR1E(rT&N^^_#F ziwtD%L?3H6{0Cz}E zC>`>X*&y5qTOQ$@Ra=`?g9Is&SD2L*f&7FR2ZsjC-oLz;zI1fd!P6-3H1-#04ixf(s!OqUjtskn1+AK+d%mPL8sk3u;$Drmy za%yV-`;j+hz&AiV`xe&NcI_gZpkU}I${eMVLYt#(y5ZV}G?}mhSZ4wcS<^50YnnEk zq@kjOA&n!9M#4jS39Q0GALm*%nmF_5XD0+d2S-ZtM~K#WumrjH-4#aOZ(Ka3-Urv_;XcAv7;pFyq~jpGCqN_p_KI5`vaH3E?kenw2K<0Wf8B9+qZ9our#x{@V&J- zpwkY&K_wxnv%uE9upzh8l?_p6ZehWQcG6)8@r{5XMoZsYQ3T;rA({j`!|bP%3{&`j zeQUg)+!4p#<#)xqB*xYx+38`kf^Y@->~nWFLF|F4dGYe)`uPb0u^BCt zGk3MNXdaV}PL{4i^-==vfnJQa5Wu&egGK;ius+jH#6Z&?PzG`C-;RG`QaZrx@>S-P%`Qt9!PMPBzkAA`zS zdp}r;%}0;8765^u3!3J(6SMjvNe(;%whQ=-U@$=glZEkTV7Z$Zx~3ONb+OWszALT1 z1IQrRzW1OVmyt1j70gEc@@q5l`NN6Pa?kej>7SSux}>kd9=cZRH*YD#d+Yfv-TpA+ zZ}w$6R;BBE1^0fa^qVO2HCpV}_a12OeY3aUXLfBl=<5T`@xmEZkjKQ+$WDFO4jtM7 zB6^lXZ#gj*!51Mj_f>aw9 z=6=)(bEiD%V zImMRbmtcH+t5@%nHq8qqZCR2thG9u6+jodu9QTXV);(jwIglVNFPDwct-la&$p8f; zCjHmP|LHBTVyI9IR1BOPn>4Y`yw0=ZP#s6*g>_^rTmW?&pmeU`tRoJgB zv=bcJT-7hfw%xWJ`S&*}DS(JeZpqP_B!*XvjRbuN{aX^=9e_v_EEQ#{ux9)9%q*@vJEG$aWQ>QijbH2= zBVPqyZ*e=q#q2<{henSn`#K+J46UaJj(zVsjI{MT_BT)g^!lb^n3p#5&$kMP^!P?c zRrQtO?WLK;a)Uc{C%87rd=~lKcE@(Xe(t0?!TGFGjIJ-HJGjwafUA}Eu>aaN`HrpQ zrgD_o$E2w3HXbw03*M9=Ue^5c$Di8tA-II-@d($zeXcH4$-}Z%X5iMjlaNRVx&nR1 z9v*guQ>TdaelPZxMp_C&cl^shewv+ZOVx>Cs;2mI;&b`I@amelSeHHgb=fzctaG3e zc`Q#sq2Mqr&dW42Vb2S%2-2*kMr8hW$G2~{Fr}WyuR)Khh3jjEq4HV>*1Ag!dRPD&RfKn~(oq?gj@C@wNqc5yH&-SD5AH`dg%j zV$m(|u9h0!oB^qW*AN{d|LsRGUIIQ5uUO&jvtB_ZIm&0z?w@Ez*F8VcQa<`=3rp#7 zV`0So_l?37n(H?Gr9>&)Bs@5j0`ec7zaqV;1Wp1ZhNt_LO&PGBVGaR8eLOtBli+HA zBw--9U2jQH?l7UQJ!@&Kt?kRK6*Pk{kn6RwWyC#dj_78^FXOyaH0z$VEvJfhi8Orp zL@nYkqVPe1=;fOis6ukWw&foQuRdw1Mj@;xf9JD#kBv*zs(-okpvChmH{AHUEgU{| z)Oh2S4A{HyG8tZHqQ@Ouk&dk+c*Eups@V8fPS|KgL>wzehK7bB9z6=E9oTGgQS_1y z9a*?xPO-ad=BW3I_mF2+)|VG<+v2;G6wIlQrsT`-*g`Iq9L?)*i@*U96h zrrCdS#?lIlkN?d)YP_cNFIe)=+lfOJyv}yHc+v1Wd~c>A^B!~c}6G z93B$n?P}(J{#(3%dC$7NgX7&h{_hurHBOt`@3tMfPCttGoGDJ9QdM|!!uZC21FxMN Aga7~l literal 25116 zcmeFZcTiJX-!_cLbL2isQBgwAQH~-40@7Pz15^Y==`BhQO-kq?BoIA{^k4<45fMQ| zFqF`eAW|YFgwR4qh;%}z2{pW1&;7jf&b-fj|2^N#_sz`=v)N_swf0)S^{dxK+)WEp z;e%2K1q1|yuU)-zOF-b)WdVWx4*xm;T=B2?@De!e3%+G~NuanxW)?X4)$^kHMF9aq z@}Zr(zX9jJ2V8Xs77!3=`0=}siYWL?K;SOn+Lepe;jZjm(EyjpNCiIjM!&6pOELLc z{EKHj*Uhzq?!zXIbFcv7UXQu5=E3vV8U|Gc1kLhi@$Pm5T=AICVciG9G~ z!o+`W|Co;d*=g~2IPpAn-?u+maJYJQD*2DeVv;6ej8s=>Cxa@=>C=aNZ?&tzb^~u7T#j|7MI_|B!j8t9jm=OM%w>N7gR*CHcO&Joj%&cXDn3*;^XSpi*gV&nch2SNkK zV^&UT6c!_5n@K_>2b-p7VH!$O;M@9%#{KtiTU8j^em52nxSw~jaX+vK_Zxwc0zZEQ zUNZb~yqbAf^vCh2!vDT~|F)VO91hR(HY`DPqt`d|le5aZSnB+HuFm(DTl_YpqPW82 zlNdBPlirYF^PAlKb2hWTMc9o?V=lEUy&HRce53kz0p)b}=o_vFTP3KeUHe;$m?bY!_g&Elw}3rsfFen_GMVI!>ALkh_a)rlfk=&7}C(?%}phlLsTwT3tm9`$ouC?f>rQ$Erom85X4 zOT7u6i~YkX22_^jYa{~Ew~nxsA-^>eam8X#2}1!01a=J)rTaKW{$3fgfP5!d6IO8| zeYA4Mw7xx?6DQ0WJW6g_XjvoJii(mB*=2<{Qa7#&2^;sQG!eEB zX!CdOtYKI02pPc=^-KNM0>W+_t0NR@)o}bF#3X9=Yyz3LW(*6<6f#nmQR*m3N*@@k zv?~kyK2By-6TF=0?4Y%GmPrTASr^oe;!X#RD6ORw! zFE^N;9IZL+gSK$SPiX6JgoG{NdvJ1bFR^h{X}b2NJ3d2HF774hV_f>eD+sq5 zb=j%h71FR>7w&rus}mmd_Sja^vC7p|hwE@j1>$FC z4TaB&ibyGFh1u zdja(r^j=OL(W!ZwJ>^Zjtt~s$?9LwgTox^l+U72NUWBuU;;4}*Y5irnG~JMF6||_% zF-J?U%bvmI0UJsx(SyCTXdI;>TE66aZ}rm$S6|J%S9ilx>uH^2F79QHNSY+e55QHJvj-<6-_ofUx&eCC#s3H9_I@?9kFreE;0~y9PBh=J#)rs ztVtVBU!8>_Ob(7n=c9ZfG>(R#z}bYuqDMnQ!eKe=zs0amjr2QRy+~)~|GF`R(09}g zh$K9}iNRdvwnY&9-+to3g*A|v$H_>c!jb_Uv|!mx-ot2aMW}-y66l{6l}0xZ9tbKI zZO`ovC|9mZ&n$F0g!Y^R@7_V=Xo=p7zwL(-j3T%Xc?EMFg)3!Sl{~^`h#0cxq2`ro zB}lO%x+9I)QD&5}dNV9I2d!I@dGY;f3e;2aN!Iv=E`~RC$!~c>(LcW)cWKwTE#1jJ zmgk#;_KA`y#-Eab5M>JwWuQHNzu_7|{KVPXf|tM?M<)Wqt>8Jk8QtI*Q}8_l@dE9U z$aFW-=A&@qhFZ&o#1aUhnE2#kk$8VRZ`mtuemj z7{u?h=J(%<#qJ~-3(@;hHZJ=Losd>?<*F0cKB3hl$a zfaHoC*T0UhOaM;hUN}^Bx|Uk2d1Y6NlDp#j4w~?v)ui@HWF3jVnV@fxQ~h_ zRvzt$YBy2sDvyh_hZUWB?3S@R&@#3!{&pmwC**@=5=;-tM-CBw3gXN;H2-QPI;YO zh|({~u5Y{=JpoUU|H`GwSj!mMd*Oxt0CjrnuFzUKEVNyogwCvNJvZl1QtqZ$Fx+$< zb(aR*JU<0?_}vcDoHrCVNZ0rco_AOiR+khUI@M;PP&8)5f+w$~A~7EGJIBpKj!F9) zJqaycu~>ZbIDB%Se_wT`JjQTTGMm+(AtVE+9>Ie#{(j2kMw(8kYC|0;ucN)dHV}ul z=nh~mli+RI>V;lg;dh_ze7l*meMmMMH)JtZUgr7kwf7O30Xr|1p=QA{?nhML6C&o3 zWmuw-mG5O85Vu47Y|(+(<)y+PUuV`K7rExy*N%%U7 zlc9Dm(`j=Km59yhFyhDouDCV~vg&z-@jG?QZVI)N6%SreDgR!gv&zSx#n!ogip0^? z1un!2#_5lPIP0|FUqNY47g9G{V4133U-cEre!DA{O|D(;_wl7*i6Sjm?{fNfs?Bci z7RlSR?*F=p?+7gqnCW9->k(C`#}<}~YTjpjv0m5YpR|9UNvt+_B^IB1qKs-Oo%k~O zMbSB54jvBkz+paya;d4Iho=pK3YNMaTpA+aTizK9Ew`F}rb*sE_A>f>8H8f^htrs~ zSGLB5Gf-Pks0D~-c$wtM8E`Ie{p#J7|E=>Qb^q&b9UIUfxD9lNA!`_QcK4Kd^gF%3 z6W*-&b$%k=<^EALe*(sGmoDXzRmq)v0>nxGvq!6Er?UCM9DE_v_1;-!Cn2uPS zE6~$XZT{v{qc=4nolsAQNrAAT=4%k5f#awSj}Nt3$IIg6O>vo1-nvBWAF%s*YXLB8 zKa%FX)q%90BFEI!QttQah28QfoObTL&E<^CiE>%H9XSpU!;8pSZSghsNR+Fb(Xoc6 zjHB&TJsM;E!ggblxCz>=?PX=-=NAS<9CT39FH-sqFG0l$Y3o{jr5Prdb_Rn&tkI`C z2-7_sOOoXq?56&Q17TIj(u=-HvQE?89X7_SDZk(-3yL^PloKzzy_%NfG||Dr&A6+^c=2Gw}McFbl1@=HQK3s*G|HWGrVw-N1&3`r@~Z z4hlRvEIRrxqPFs70q5<_J;A(eqGJhjTXQ zyrs}|R>fvP5)HPH#O;#3X`1V~TESSqF)N+6EV0p`!XOkH)D%1S*O0Y)tum0LZAf6C zZ&mOg&8}Im@YiU@egsbGb6(H4{k5h34vWuf^crHG6|q9cziacm?g&jbJ@e!?7pt24 z_s)3@^y_kJ)BAK}MD0N|7Mb8v+8~u=Re4sm^fY+Fy|B+ec``K?SnNt zc(`|I5$YQbrRB+ar|hkoxLuvp-gUwWRS^*(+I(=?^0UXAM!eaui^1i^^tNPV zba#wAyG`vhw zs6Zx+x*B$aji{IEi;PKLBfmi83v;>8NLBXDA|`uwt3e(kh3O>CnBz)iw$MsX-R+(< zr}g+!=qno45YOxBZGN0w`F)a6H91OA%x!5)! zVscV*XYMXz?}A;(=%{3tjcT=6SNKj8k#KjG?)_#c3PZz!5amR@>h{ zSr*#;AX=1y;RGd6j!=kMiT zR5mt&{cw@yE^>K!xUBSOuy#=IKoxgadTqan>qb`w)5}(L@mzs^`yfDx813J zNOd@j9Z#m3^y0l1#z(6rFRvO*N(daD*jLh&aG8Oe1}wm)76IdK?!1+jJ%f%y^+eur z(R5)RK_k+u(udv`kb~?(Tr?Q_C20D(fh=6N&cydugI@gB9NqAtM(7*P%WD}7# z|K^?l<%O+2&8+sK+c1b5VrR2KedA+2>P7=$Z7h-YzQyD`Hs-QLt=4__Z=iJ0a0kh+riXB0mP0u#Cdqc|1aIhLK`!;wzwfxHVh(WHOo7b`r()P0N-Yru=!2j3+ z4;(gtgw~jx0^9BEZ=;yjz$dy6c87cXQR82yf6l1#+&N0I&g1+|R~AL8?f@Wg4Np1@1qtVmIY? zNE^3-PMjO9 zykKpWknI#z`kUdbu(zG$=C>g_d;#??;*226~o zP9v!|-q}SB(>Wp~hy2G}FH1*!&ybCt>fNLl`JW3}lRYPcI<+^Yu0S#MsBK$y78$Wc zE}#l4!|14((Y)YtwewY#pSHJBF0lp&;cc3s`Y1_I9r+Blj*hI8wWAm&gdhi3=yksO zFL#ZyHb>QWUWq{-iU%OSH>veo72q-KL+`5L{O;OA=0Chhv9`Ccn>KQlnz*rtT{HT6Z& z;NjW>Tw|x%!B`tcc=vbjCj*Q?7oXvknI0v@m1a%X8dktaOuSC`*U^_wNa}b!nL(tS zvz4ihYdVQ8R?uIxKSGmMXl5mi4v8CB)I8+P-a6h3(ll-Jt7U!uTxO*k@h46Q93N@K zCTx;e@In*L$kCxigXNkR2_vr7cB)((h@};AFX?O>E%4zXopOQu(Ljb^R2c^ z>$*ha|%W?A)KDbvVb*Tl6wp(HhU&Y}v{IUf<6#T-%AsS3D!NFdz&QTXS~19ijV z2G{qwdPJhj5z6zARKD5^VJegkc+^ax{C5~m7FAE1;@kA!@TOWA3QlfqVBzc#&;BMI znf-YG1d77yzyOJSVlW~Cj+QHCM_|0(+`I)61}hD@S|;BJ*@2Z6R(-l#diw=x=O$@1 zvazN@#a-S~gnIQ9&mrGx{Q;X1iR^B?p&m=A#v}}h&sXcB@K@f%#3b#T7NyajXmo4i$gl~pMYFgkR zZ8Dg}^<9m&dB1nk&M*AE2$WOWRXM*2D90WaebsPhZA{}ys(gBhuVM5d|77YCpCXa% zS8maFaWrg`QnmJ=xK!G~+PGo5CKt_W9R8g2k%4kF!6%yReKVxSW%eFc!DlX(7Z-XB zm#fT&O)nZr#DL0Y90d85mBug!-O`E6=I&X2*)v$*eJw4k4U7yB$gI~`+HHuycwiK` z6c^3T7@tr?ZjTMmTC%A~IfE!u1ig<}pXV|*mo%b&>&&qdG?L>I7@&x@+;`5{E@HvEb{H* zLz*9`=v-D@dVwNmmPFQb8s}{kL zh))-@U7~qSY-CAfXiC*W7!i&*nE2R|hT!Tg4_12hUV(a)t_ZpLb%E~L>h@5fSyF@qj$-L4%TORT;XZ1FB;+ST z27=tDN&AV+g=O4KRCtt2`%}h5c*0zNr`oy{RO!v^;||07iAAg|6&@)&NH@Zt9a|$^ zDFbdv{ZQw@BuJlb%qGsfqMuJWmsqFi!{-(bTc#TEAs=zXG)8nX?|Bk*R8sof z*byK9LED<5&n@o~Z&Kc)T1>;mLnBd~9o!H>UTnyNRZ|sk-K1`Y-Spc>!h9dxZajgg z+O8ncKFBoCPnpqp3vq-(Q7Y|xueu6k6xIuc`a|CzBE!aGo+sUOBA;+PZEEI$fB(^G zr1KX6SEUnjf8INJIK9(AGx=QiUoqZYA2yi|GLh9q4ZQ38^*`4c;j7F4dhcnuyjZRw zO+h-Xxr~Bpi(8ah5u}^W&2xux!{!u`xx0aFrp%^1q{x+?zV<=X`>oy-FS0To72;Gr z+Fy77cD`!@V}_IyxtMX}udEU>@2`>iq$~pJyt9&>&>uYHjbU#dLzWS->fB}|w;x+| zYOqj0zwW?Yn@Ha}1fu*!vBC=#mU-~7+W2J7OpUi`2i>NPrB9tIUqAJ8Bac6=s)4BDVAM(OA_mwAJx-JR3&_XY{A! zR&rjko*iPTuT9DlE@5d`<~g_QKCeTX}xTB!7!3yUt=T%TX&KfU24v^XF+M2XlL z46TqJ%H4$vTe&P8>H*R#93}ph0Utvq$VDeCML$kV>dhq4CwFsyuLIJ{y~w)&zcKIz z2>*hEZ~LMHChCQCz=;&PR;C1ztS`PF=^vM(;OI7Hm!5liKH#HI|B;6dNT1b_NJcf` zw8+?iOjhp*fx}WVYg1oy#;!p-rO{xSvxg3pO>SKB48h!MUG;3ngJVZkiKau~tne+b zbz>C*n?x>FdQ_U<5^XYraLy)>GKRnnsHj>9qe792j++y7s;agj%r4N>*aji@j_?MhCE!a&+d~HKpgm{Qb-6U04Ln-IAXB~LGvx6(uU!l_ zvDhDnt3?8tnA_qry+vY31QCodJgSl$V%)ZEfk`1Cc+{M(i~7fmfHDJC&1tle_T=q~Zoi>KuvfNX2jhdo++BZ)u z*Q#JkQtFc9+xkH0y8a`~6tZOk4qR8EpbX0ca`Xo%HS%E#dUD1uDiKO_7tnehZ&H~( z^t<28^6q&jX~#$hC$5yam6Lu5Pzm_GLDW^+2~y)Tq5EM}W@P_}J98;hjempdy4P4- zP#p^Mg&Bw&W!!T!3;Z)^IWl65XsQ4N8fg>v5dtU0aY&*(`_J5iJxe=h>haB?|5907Fk~|G5g~Q}<$bw|SQB%Ad24 zqmp~m$vP&qsZrm}j9fhtlU|!h!aRJPFwwqq(fu({>b6`RwMUtTwPWA$zlXVXoh%0_ zNG0$yv>Q{P%6!866GYy+$@bQBdem+)Y^Hf*bHQNrv4|;*J38AAtu2f080_?81Z_4hV|U16dGwZv>B&oe zgXN0+^*TOtb6B~#j9xgHz)G2TgYiYkv_x6H@<;CzC=;8wc>cgj=b8DxKl&vOR$BDP zt>w{Hmg~g)!uW7w5b&MqeK#By;@q=~HMDqJI70^bGI$R6bXx~3{eHOvNzpJdi6^IR*3K1v zc=KU?uIv0_8_@+zWH*=L?B}Wa@J~d=bf~qomd!v!a&qGme3wz_sX14SNXAK9E@qk| zln>Mm${7GjX5Rr~5lz3zJhiBe*a_eks#GIjMG3?VgI40Cb1RRw;#K1jtl4MboOJK6 zPf8J{8Ei7gMr znq5U>$I26mVF2A4)+3$;^G&94KD4I7)uhh7PZGmzVA7bfPSl{_B1F_M7>)lcYLAD7 z`cpQi;?jN#pL}UI(PU(Ig%)Zf_Fz)4fq`aK5ZDjo8R46c3($5p#8Chp^Dt*-s7_VH z6z;oy$uslVSK2r0!GzgFqXFq!|B)jG9O?MUc_;X)afKLjdb>VnCgP4}x_|agkv>H@ zUvzn@GAVy=;Jm7+=~$|N9Je-MpC2KsMl-#SX-vZqnV6x*urY>fur|e|~ zVUbj{s!xkeT&P#KHH;~^XIvC68|4WNnYsCXf6ZY|IMranh=cMJf#y1q6K5oP6GR@c=GTKF@65&$RcZripS2DytUK*m-I-hW80>AL!K z*!je$8ycngINMN0o@%Fdj`YX8p+YAXJsu63*qok5|F*-5noW&TOi^)4p<6Ax$|f^B z53rV&9~$24ppPOEc>W(?uyn%mf-GesG+gXhGH)WV?q(jGzZYe20QNCqYb8_BXDmhz z8nUx1`vcN*?jCHuMkNfbKDlY(Uow4gS$rrw(Mac5Oe^HFZC`Q1mQ4xv;YKca_sju7 zP*n71n-8{41D^OSI&FM%Tq4}pa4u2zzOO?wsgXSlCY+F(2r!=L&o0H1)=FN#X^@Mi6rp7_!?wUY$zBJ_xSjxj=<=nW;u3z zF#^~0^(IVa#{)=GKF8*Ig!WXcykg3K+H@(udu>i zbLshBF&UpVP(MAN%g)Am6?Q1?8+~K~9u{|6lt~;s6ptl+RchP4b7o5pjN7sdHKn>s zT@m<0<-bs2KL7uRF6IyDP?nQO0G9JJ=J@t4d3pIwnG2IgCbCX9iDo98t0~X~(3_tE zIGAYz)?;b^!&Cc?UyH45j2|DFOoh1~h|nkd4|X>d7NxSC-HaadxmlmGdW#r{qW?;idFb?V2sM`Ky^C^y!q91VvIk8b=ry; zMNLmu2q11;BK>yeKp1Rqi)jYKV^&&7VdMOnWd0ddU;t`hI%L}yv)Kmgm=<4TMzNXK z@45zVeG`PfV*Wx3{MWLF8U!MT-Au)zzN~#Yr22gfK<|D_nv3M1VVH0G!Lb!S%#?Sq zohDd&-To)%PK+K?t=kDw{DfX(n3{Ee{(NC)Y;PO8%K1e=V98@S7|ILIJ6w{Jle4y- z3cPfcBk|7P;Pkl?2`Gh9<>;qjdK~fX#^b^T)X?cz4FsjM=xH2iwZqh=y%M z?-&ToicP3=vXY^4qL7U{7`6{A0TDUxgkpc$p^fdBy8@lvfyD1u(ngGC(p7FOPl-1z z73*dNcj%#1K1(MY<{UFz`6{2M#;Mz_+!Lo8@tj=^bTVXsR(HiCkb@I6CH_d=HZt$q z;?63a?z3;S>QQ*Yw;1To-Gza&e9iUs^~nm2f3F|pgzIw1K$5gh$~)kPB>s(dF}Cun z9D0EOe=vF@clF=PsM4r}tH}`J*nV_;2~2GDzm`+328OsIk;vgl-pH~YcHIue{`H?R zygRfUwW@>Fh?+)5DMaluu|_5P!dLa>_gYUXiLm&M3J|NB&8EeXS{4Yq36ezrk`DzoOL{ie z?WFFVFUfnhQ@6+4yAEUTwXSuUsWvHfh0Q>iR4e;hoKA2pZ% zrm*mAN!~5e*4&3>9(IqnECyW`3+!|7oUE{ZHo*~SOHoO9sbn{$W&K70k2&7~*k(Ex6y?`DcVxNa>Jn2;5|U3VT3UTFN&r$@Hy{@!Je@%>2KBb<)u zf-7UYi%57#NXWIzB=lmnLezF8)&RfXUIN%a)Jvoe->I57{j*<@W!^T824kBY;ebUw zJy&qXGaKch*tiWS1U&QRH9(kbwtk~**sbs=JTgk<?@0T{M;_~5~Vjkda-HjK!fyFcaRQ)7JY?%r0(o@1;a$>XOXv3Gkuo^@C>{NukL z;7{-E;PXnF4wr=Qy4;UG`oqY4K8l{niCMYT$fR&9NWZFK|7m$?cXxN0jR3wqA$+M` zaP{ctouA0%JHNk7{-U%srBf%F4B7I62&zH<-Qb5BTkQ(8c>kOD^`8KIv%l`#`9lu7 zvy52;kolj0d>Alr`;*;CKzrEUJk}Tng8}Vzexml3pF%gELtD%ghRQrGU4MC%@l)7S zE~vj%!c}ElTi*Vm41lK%{(1UAwaMV_QVdj6O7zvgUr^c*BT83#4b%fxPyGN>-!~51 z+vQE?4H$9z;r;>p+5xm2FaYrZU!C$-tzL$AApgutEM+NLwLRi;!h8NImQVk=<2PkK zul@TjAZV9H08N1U{Z7O`U+^b(Nf3J&3n+r;bQ(;(QCd!}V_LxC^go@hfcIF~cC$mg zO*$5$x-S#(Xni+)bFh2cfDHZBTWGU9!imxW9^o{i$F@gFDc1v=j@vcv^2Wd}E-rz+Hg>?S zi#I}Z8{e5{T;HaI3*AI@=&}<3Z%G8_=H>=+2Rx@M5_~pSCeFisQ zcPIcD9{9}v&P@X@|DW*oAE*Bh==cBJ{%=VE|CyHmZ%#|w!8Fp1sa*~hD+9skPUx&0 zJWI#^G>k`~s~XVfb~mq?n3zl`U{@62yIbqPuKk`9BNva|h{48=Z3u&Zc{xL%$|ZnY z!|VAaH=811h6fh*TGxNLZppnL>{7S(TW16DIyMnVC{e{M4UhbE9crsmRAF)~7AtS1 z@1qkd8@;Iba^I9$tWFc*hYl=|0F3M*7x_ow7g`^M-}deQbRq(cmWQO$ob|0tf0Jx^ z*H-pk_h0jSc15ozcCI%{s%9P$6qRcStZOPP3gxXi^}};7C@lP8#fB;o$@>LG8rhX& za^S7bm=(vf#(`s?Y5IyToD=F!=}k{Nz(pRCq|cX zNI*r!?*Trk&(1n*FLXGn7WT_otDj2jLvhX7xw$z&t6U?nVSs;*nGnN&^(puhaE}St zJ%0E^oWcZYwWG}k#9WzdgJUox>(QD}pW$G-QQrB_nBSM*>5Bhwvya0EO5JmO06go} zw8Mz_RXP-!4MhC0W33*xfbVwEq|_Usy^GQQsA_vNkn+mh)4kgzPwy=4;XnwCdt}L} z4q#Aj94^qrZslX8=!UB|ypMHQ?lfBd82SW7Rc=3+zZ8>NwBK?tj6HbV6CWA2Glk%L zxL!WBnNYWp0Dw%W(cu@m0CoQ}mSeg&aL5`N)&6gA2ZPiyqE2PR=XZ-Y7@no44XhD; zx1zjp*UR)W785qMn00bLau|2E=uO}M4E zlSel+V>mlY{Rwo+eBQmD2B}^046)08Jiy@$E17Oobw% zfV6xs+HJteu6ZRr0Y3Dw0cDi`PaJpy;EVb(D~-27yrv(BuxECwgGKg3w+6FiVR{mD zk+3U~u6rKZ+|Me6@eil~ok4IGekMNZ@``x$&xlRe0RZ$kKT93^d_BK>RPPm?j5}u! zGDj`bd^f5uqZD=YW#DjB+UnLfxu<}z^?}K|B38KdPdO0k)qS?;Sc|rDa7>#QDn?>Z&Vo|^uP-c z5LK@6bW-G|5NY{mw0Hk60F>KgAE^IQx|L#h%!OIeSj@mg5t63cUVU{zq&5;lWfGR{ zR%@;SP*~|88mxmkJBi^7oBSNJ@@IP?4uG!54<80$|8gnP+nZRJe_X|VS3I-LwdcQz zmPEf(4U1gqvcG*>45W>S6$VO$HG}Sih2hY^buhQ_Gp3~``* zjQ8YsyKh_`7;NvJxv;h~-b+{<_N`7SNqni4(7j?$TreUm7Wy&Pzxr4TQSpGcl~J4| z46fGg%d9;;lntUjwhS|&=jP_?C9?8%g;Es?!yNqz!&_X%JO!xKcZkQ>vVTuTXY!lhB1KV zi`w0~C9S5fpUzaK6c!f7QB1V`fMB=8eB^2l00%p7stKy8vpXY$azp00hx+AgPAz_U zkn{A*)7*guSN&TOl7n{ro6F6YLb?5e>+I-Ac`>?LrF`^$ypM8^YmaTek=m^@kdZ*s zgLJOHs-`UnR`-)Ad7A=ly0A^~Gu(Sm4p7@$PLoxXAyvFD_qQ6c?jJ>axsesf0Bdr% zvaB=6Dl})S6!#@372s^iD3{~cYQ*MYl>~C9(UR}@)n!kZo@02}9b`@@yAke*?PywE zU^*TW1|YKRmOjo)Qt9y1hUi60t?zB%=&5W)15`5`+NRA1sF&3{gb&gHMbk$qWVlbA z-IFZ~r{XJXK9R9R`IAG4N-fa6GODd7)sFY=DJcq|v@K$!>8$?ij}9Xao^Xxc6Pwsn zX#B8%y+n#X-T1-xlHsH9m*%lL&&4j2*vN8msXl+a)@wAeQsUR+;IW(i{LZ+XmKL12 z$msCRg6$s47vE$<{myKt=P@aNh%V=o``dgy7}vKoYEcXIgFf0f7dQR2Z}&4mTAh{A zv+hSZ90M@6plsR)3r(#?H1P+~F*+rC7VHupfy`0FV%FCMQ#Aid?{Dsn?i;P{Vm|#c z>T{Y%y>?>stsT(qSQ+Z%njx`R<7*#*u<>fH`3sk#=ej+Ua|tcyE1o})9PN2%tRV$T@G~m1sVNlHTp?@3D^`xVqOz%rs4lo}CFr8DWJ!DfZIn>7m(U#H)Ndf4>%`L<;H$+^lT@fY-*IHS_+ zG72|bXn0i0;B@8u)(B&_`f)+4r{xPHu4rXrN>`K*Zc&wJyz3w~rq@zPW1vifZ6%Vj zX-pI5Ge6e9_2X2})=QjMx?b(iwauIRs-s5Sh;mvA!>lD^Lf@T zVYOxy^A19E*6uDS`MrYU1fo)(+L*emFXULO&p1^scS$}hxTf`?Z!$0TX=>^j*9erd z$@S>seuEM#oR*c-L2t{EnWdcAw)xz1czj>55o0UP1r^{VQ&Fj;O_yl@$r(#F;tnQj zchrns|34XHW-nfzJMc3b`l;aFy!7hR;HEooJqH=#HmT*!6uM^0my*GCS;9Q)IaHlI zxyg*ZtZvxwzB0kx7i&3uUhh?L#O0z4`Eo}B#b^sXlsow>#f!sEmIcv{%W9c)3r_2Y zu50S$Yl|{~#i~?I@)H%Jt4n=6d(L+e*C(pdwZbe)Op~XN&uo29nQ7q}H_*ro^`QbX z6%DuZw!Uix@1dz5L|h6DBU#9yYu%QHQ0ftA#CDBM9s<9pWwcfObwQ+6W2RBF$ZoaVDSQHrYKSx&d5RH~|`z-lVwLa&YXVz0Ua z-v`-}Ux~XFggkE-!10o6OCc-txm2jsa7C$Z*?x{Z!s1Mo@0PGSAH-(pqhv`F!eW z3#yA$BDl$oE30$x5%-}tEWRv34?FWrS~ji$MQ{T!%R_a2;E~N)nu0VAD1tZlnsXFK zu2u1>NfGZ2nPMu^h!)dc8|qXx&Y}1rayWeRb$x0KYpL=8JV35{$tmLvi-_wg&+#vw zm!w4f<5qU*pKwvsk-N+y_r}x7OXCE zeS{G@Ijo?G|2}NS8k`?`d+KKa%=q`moKHSaL(J`RZS)m@1LlZVUUZ}d^27q z=vGiC@7QNcXQc-veD7VBkj^+dKqg1XF+ZbM0W2pZ9ol!K7OXEp7soX}m8Uo7Pyp&@ z>Egy2G9$Y`=rVmFR|)ofSz;$wg)lj~lLi3zIp*DCcf)q7A4p*MjgfXQ{rSW*|IEW1 zgLJIQJ;Tbd>xVT0yZ<)hWL8Z%toM-_Xdjj`Qqkq)*j(aE{EU%$VqsxYjoC^ND;2X1KMveAfUp4>+HdIx8wrpyUyWZR~|4r)( zHPwS6GWv}?4^W~e(eHi`5-bjkYRH&{Dm`KTQ{u39rVCS-og8{hy9M(T0ak-uxooZX z#oOg$Jv-?LId!g=U5HW5D|IJ7=kRM$m3?m>*#7lEx)k zMP5&OPL)*`CTKSj;ncxj`V@6AoRoedbR;nKA6+BTpuXbfn%Ql?Ho{Cx{gBJnaC+wq z_g6N_&aUEkU5jK7Ghd^p0@kLRHVgGXoK~gz?&QfI;e0Sc9sBm13W1g0GRW^=zc8@U zan~1Xx1j6e-o88MJxeMQijNKd*D|yIk(lhaZK#vK65xBwN zl}-3u?&MS-96GdlXYUvPTjwRnNFihMq}pzv5-xzSf5;dyP^-J}eYtpMW9|xJd}aqUBU6kQ!9+8`!Ljz0lBQPbXcMe27wOAtdl3Uhacq3 zTmD?#_CkKJEMa`SuV0<{{xKX_G_9fQ`K`7HHg1uwbp!3m*fJYr_1jD@vT^)@UrMJV z+=d#Y-*6H>qKr2SKD)iPmId`$h&6unPo&MSu4MCd!@I720U%t#n(5LW5`y6t!68PS za!F13a_qQ1=~}m$$vA|0Nq8BquRqiZxlrYmULj|2G;iU<_nSh{RFJ3odB-pW(IxsiH%ByKgh!gxVCr$|9?ZbC5#;p9z_(ycVxh}D?ip`fB$$Nj&wS=yp5nxJ259(oFJPFKwYOr8as_LJPpl>_xzkfMi?&yPm>Am zEERP%+H;)Ipk%6RAK1x%NM?P$@(SAfMbDiOAZ&V=3s>LzTI_`_Hqy2nRbT(Lf-pJ7 z-5CsXVJzo85KF1mT`3G+$Mpxfx~eA?Rt&+%S}=F-DSu{cOso_=Nl;@qta=KF zcXoPqcsQOOqOcBfhi~K$-zn>(nNnTcudb257cAiNYP~A*ONV@y!ZH$s9Z+=6(c;o1 z0&_@g#T=1Sevb3-bL-V0Ebx81F}mbUS?D0s2?cNsXxE)C0HZj4R^jg9lg2t02$;Cz z){=oR5s=^JXZX%R&qqfCaFtT8z)SrxCYe8AcC#v^#?o2%#nmT zQRa$L5`XKsys68M85ERozp4tVbJrUy8{aw}Gk*vkda@zvR7l=JTb*5aV1ExPvkbzG zw&aHS^Y-#rq>y+P1{N`O`Swq>%qb?-r_Y1zj;pownmo(OdS8!Ez*BOaSUM&`>#>9T zaRENQ-sJj{{Ls-t`8Q!vXA44yv`P~m$pQ6)Mx7`VY`sV*t9nUrt(H@=Rs`+8q2W}keoe7hIVa(k*O2a(mCd+F` zj960%O_DmTPN=?rz^a@?nh0v|t*m}IRbGuB^pkP4plFoDQCL+u(?{sP+SFoJcq8{@ zS$HYMqTUa>kUDOT&#kc0%0yrCcN`rc+EjyPI9~Pdj_=XQOlEot*Y@=tBbytBu}G1( z3h`dT2K;cV4?>%;d?FmS?X3Shfl|``-WwQFI)dQO3gx{8)M|+XvZbl1gpbC+6u*VI zPLmA^cyw$m2@!7rY68$86Lb(8%a+hjUug?N$d$3T6z;DwB_iam{nZg_Pq$1J3U6*a zQF0zro8Iy<bT@ z)D$#m-$oDe(ltX(#dEvAWHh$wc|O_F>vL?s=up|Zv($xwV+Q-`A3~3P*M?J$)2@Fy zVS|x>(B`(9kUBn%2!M2)0yv=Vmf7LRp0oYjTKxV7kS-%m4DIR_{e@` zLWuvy@7-|jc|7nPA{DYhqZ`^rfK5@m;`9?9s(85@$esmv->h3sKyoo;LM^DlfmPI^6^fmdypyDz?@^ny_qyFt0&XMz-yIC&)wYB^F zz6mA4at#ujL0FrdOCiYRCljr3*2bT$-jWklkHuPYgm3xNZO5}u071n8KCflYQE!|7 z*m~ds;Ir$TUivC+YxiJ0qQ)+E1y#PVsNG)lA%6oOa&b}r2aD>h#NiOgjXU+BdA0wD zE;iSrvK6M{p&g+k87t$R#A6h?#&ej8Rft^_I&|M;z|4rbMvB^9kRpxJr$HMw4`~Ig zCw*{9E^u(I1vJrgCcTOP=;JSw(4bTU466BZgN%5^w>iZouD2X0$92avGXFQ#I}`9_dJw`6UZ{G$te zX1jOwc52>homBOWhPQVDob;c7KIP-EC;LcDG8jL?FvQ1fw z+#Wa-7J)>i;tV={y7A zI1vMLF+>nr31IUaG4rl)OX3Z~=m0iTrwC*lOyYE&-&7vEKP?Zg<2|B@+44yN99)8X zwvV*L`i^q#QQrVy0>OY|ZEHeHOG~3PhtB-M8!KSI`$>yQ^TmW^!zvLc znQW~e!eyj@A5qhu4=?p*z)GCd)LBvo1&6xQ^61Mj(nbKLUs6rrcXYLo>N}SXeo)s^S}a#KPxBjK zeiZKjMplVgzaEWm`J^_Y;-@!P0Phtkc5NdDLgb8)7v6*Oy>51_K%nR_JF@7{v%6L<^fmR?sPkd1uX0Y3@qaiRiX<}L(j zv;I{lB)Ie9jSD77;Xwvy<=*<4(hzWDi~~)k+sU{PNg1NSnL*&LA547p>Xo%@yi$}& zj{uM-qsVrqSjGOk8EAp^zmin0Eb&hj+tZkaG23mWjoKq&YFfrGvVN3$M4GZ%S`Oio z$i>C2&e*%@xaeZ>Opbuy&ofIqdqJW13|zR=8M{K(Pys0p8!nl}DC8uCi1z#2DCKO` tHgzr1J7CxURl|1V>;I<%@V=M99s3>u~fyncd0wzXFJ~>P7$n diff --git a/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb index 5abf8f07c2..9600f9cf5f 100644 --- a/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb +++ b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb @@ -392,7 +392,7 @@ "from utils import *\n", "import transformer_engine.pytorch as te\n", "\n", - "hyperparams.model_name = \"../../../../gemma-weights\"\n", + "hyperparams.model_name = \"\"\n", "hyperparams.fuse_qkv_params = True\n", "hyperparams.qkv_format = \"thd\"\n", "\n", @@ -405,16 +405,9 @@ "\n", "# Compute scale_fwd with enabled fp8 autocast\n", "with te.fp8_autocast(enabled=True):\n", - " run_forward_pass(model, 10)\n", - "\n", + " run_forward_pass(model, hyperparams, 10)\n", "\n", - "model_fp8 = init_te_gemma_model(hyperparams, fp8_model_init=True).cuda()\n", - "# model_fp8 contains only fp8 copies of the weights,\n", - "# model contains bf16 copies and scaling factors. \n", - "# Both of these are copied into fp8 parameters of model_fp8.\n", - "model_fp8.load_state_dict(model.state_dict()) \n", - "# saving only fp8 weights and fp8 metadata like scaling factors\n", - "torch.save(model_fp8.state_dict(), 'model_fp8_state_dict.pth') " + "torch.save(model.state_dict(), 'model_calibrated_weights.pth') " ] }, { From d5ef40c230c7e2a943d6b9c73f3fbb7156ee2357 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 17 May 2024 16:30:05 -0700 Subject: [PATCH 129/244] te_gemma fix Signed-off-by: Pawel Gadzinski --- docs/examples/te_gemma/te_gemma.py | 85 ++++++++++++++++++------------ 1 file changed, 51 insertions(+), 34 deletions(-) diff --git a/docs/examples/te_gemma/te_gemma.py b/docs/examples/te_gemma/te_gemma.py index e1c041d585..52e85cea10 100644 --- a/docs/examples/te_gemma/te_gemma.py +++ b/docs/examples/te_gemma/te_gemma.py @@ -51,7 +51,11 @@ def __init__(self, config : GemmaConfig, layer_idx : int, *args, **kwargs): self.te_rope_emb = RotaryPositionEmbedding(256)(max_seq_len=config.max_position_embeddings).cuda() def forward(self, *args, **kwargs): # We need to pass positional encoding. - return super().forward(*args, rotary_pos_emb=self.te_rope_emb, **kwargs) + # this args cannot be passed to TransformerLayer + keys_to_remove = ["position_ids", "past_key_value", "output_attentions", "use_cache", "cache_position"] + for key in keys_to_remove: + kwargs.pop(key, None) + return (super().forward(*args, rotary_pos_emb=self.te_rope_emb, **kwargs),) # We need to return tuple to be compatible with HF. class StaticGemmaModel(torch.nn.Module): """ @@ -82,7 +86,7 @@ def forward(self, hidden_states : torch.Tensor): attention_mask=None, self_attn_mask_type=self.mask, inference_params=self.inference_params - ) # static copy - for CUDA graphs + )[0] # static copy - for CUDA graphs hidden_states.copy_(self.model.norm(hidden_states)) # static copy - for CUDA graphs logits = self.lm_head(hidden_states) @@ -148,7 +152,6 @@ class is monkey-patched with `TEGemmaDecoderLayer` class before """ def __init__(self, config: GemmaConfig): - assert config.qkv_format == "thd", "Generation using other qkv_layouts than thd is not provided in this tutorial" with replace_decoder(te_decoder_cls=TEGemmaDecoderLayer): super().__init__(config) self.hidden_size = config.hidden_size @@ -158,6 +161,10 @@ def __init__(self, config: GemmaConfig): dtype=torch.float32, ) self._model_context_phase = StaticGemmaModel(self.model, torch.float32, 'padding_causal', self.lm_head) + + if self.config.fp8: + self.fp8_recipe = DelayedScaling(fp8_format=Format.HYBRID, amax_history_len=16, amax_compute_algo="max") + @staticmethod def _padding_to_end(inputs, lengths): @@ -180,6 +187,9 @@ def _padding_to_end(inputs, lengths): new_input_ids[i,lengths[i]:] = inputs[i, 0:(max_seq_len-lengths[i])] inputs.copy_(new_input_ids) + def _next_64_multiply(self, x): + return ((x + 63) // 64) * 64 + # This function is overriden in TeGEmmaForCausalLMCudaGraphs. def _create_hidden_states_buffer(self, input_ids : torch.Tensor): return torch.empty((input_ids.shape[0], input_ids.shape[1], self.hidden_size), device="cuda", dtype=torch.float32) @@ -215,6 +225,8 @@ def _generate_context_phase( #self._model_context_phase = self.record_graph(self._model_context_phase, hidden_states) hidden_states.data[:] = self.model.embed_tokens(input_ids) logits = self._model_context_phase(hidden_states) + #import pdb + #pdb.set_trace() # We choose logits coresponding with last token in each sequence, # which have various lengths - they are stored in (inference_params.incoming_seq_len - 1) Tensor. @@ -237,37 +249,41 @@ def generate( max_new_tokens: int = 0, *args, **kwargs ): - batch_size, max_input_sequence_len = input_ids.shape[0], self._get_max_input_seq_len(input_ids) - lengths = torch.sum(input_ids.ne(pad_token_id), dim=-1).squeeze() # [s] - input_ids = F.pad(input_ids, (max_input_sequence_len - input_ids.shape[1], 0), "constant", 0) - - # InferenceParams is a cache, where keys and values of previous tokens are stored. - # Moreover it stores length of both already generated and input sequences. - inference_params = self._create_inference_params( - max_batch_size=batch_size, - max_sequence_length=max_input_sequence_len + max_new_tokens - ) + + assert self.config.qkv_format == "thd", "Generation using other qkv_layouts than thd is not provided in this tutorial" + with te.pytorch.fp8_autocast(enabled=self.config.fp8, fp8_recipe=self.fp8_recipe if self.config.fp8 else None), \ + autocast(dtype=torch.bfloat16, cache_enabled=False): + batch_size, max_input_sequence_len = input_ids.shape[0], self._get_max_input_seq_len(input_ids) + lengths = torch.sum(input_ids.ne(pad_token_id), dim=-1).squeeze() # [s] + input_ids = F.pad(input_ids, (max_input_sequence_len - input_ids.shape[1], 0), "constant", 0) + + # InferenceParams is a cache, where keys and values of previous tokens are stored. + # Moreover it stores length of both already generated and input sequences. + inference_params = self._create_inference_params( + max_batch_size=batch_size, + max_sequence_length=self._next_64_multiply(max_input_sequence_len + max_new_tokens) + ) - self._model_context_phase.set_inference_params(inference_params) - self._model_generation_phase.set_inference_params(inference_params) + self._model_context_phase.set_inference_params(inference_params) + self._model_generation_phase.set_inference_params(inference_params) - # Context phase - TEGemmaForCausalLM._padding_to_end(input_ids, lengths) - hidden_states, next_tokens = TEGemmaForCausalLM._generate_context_phase( - self, - input_ids, - inference_params - ) + # Context phase + TEGemmaForCausalLM._padding_to_end(input_ids, lengths) + hidden_states, next_tokens = TEGemmaForCausalLM._generate_context_phase( + self, + input_ids, + inference_params + ) - # Generation phase. - inference_params.thd_setup_before_new_input(next_tokens.unsqueeze(1)) - output_tokens = [next_tokens] - for _ in range(max_new_tokens): - next_tokens = self._model_generation_phase(hidden_states) - output_tokens.append(next_tokens.clone()) + # Generation phase. + inference_params.thd_setup_before_new_input(next_tokens.unsqueeze(1)) + output_tokens = [next_tokens] + for i in range(max_new_tokens): + next_tokens = self._model_generation_phase(hidden_states) + output_tokens.append(next_tokens.clone()) - result = torch.cat((input_ids, torch.stack(output_tokens).permute([1, 0])), dim=1) - return result + result = torch.cat((input_ids, torch.stack(output_tokens).permute([1, 0])), dim=1) + return result class TEGemmaForCausalLMCudaGraphs(TEGemmaForCausalLM): """ @@ -294,13 +310,14 @@ def __init__(self, config : GemmaConfig): # with their recorded version. After invocation of each of them, # captured graph will be replayed with minimal usage of CPU, # what will lead to huge speedup. + input_shape = (config.cuda_graphs_static_batch_size, config.cuda_graphs_static_max_context_len) - self.inference_params.thd_setup_before_new_input(torch.ones(input_shape), pad_token_id=0, reset=True) + self.inference_params.thd_setup_before_new_input(torch.ones(input_shape), reset=True) self._model_context_phase = self.record_graph(self._model_context_phase, self.hidden_states_buffer) # CUDA Graphs recording input_shape = torch.ones((config.cuda_graphs_static_batch_size, 1)) self.inference_params.thd_setup_before_new_input(input_shape, reset=True) - self._model_generation_phase = self.record_graph(self._model_generation_phase, self.generation_buffer) # CUDA Graphs recording + #self._model_generation_phase = self.record_graph(self._model_generation_phase, self.generation_buffer) # CUDA Graphs recording """ Functions _create_hidden_states_buffer and _create_inference_params from base class are overriden @@ -326,7 +343,7 @@ def record_graph(self, function, input_tensor): graphed_function = te.pytorch.make_graphed_callables( function, (input_tensor,), - fp8_enabled=True, + fp8_enabled=self.config.fp8, fp8_recipe=fp8_recipe, allow_unused_input=True, num_warmup_iters=3 @@ -345,4 +362,4 @@ def generate( assert self.config.cuda_graphs_static_max_context_len >= input_ids.shape[1], \ f"Input_ids shape {input_ids.shape} is greater than max_seq_len={self.max_seq_len} of recorded graphs" - return super().generate(input_ids, *args, **kwargs) \ No newline at end of file + return super().generate(input_ids, *args, **kwargs) From c94b36b2b6557138985b2e2827b0a269971c9b21 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 17 May 2024 17:01:14 -0700 Subject: [PATCH 130/244] bug fix Signed-off-by: Pawel Gadzinski --- docs/examples/te_gemma/utils.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/docs/examples/te_gemma/utils.py b/docs/examples/te_gemma/utils.py index d00f108d20..c9a9571b2d 100644 --- a/docs/examples/te_gemma/utils.py +++ b/docs/examples/te_gemma/utils.py @@ -153,7 +153,6 @@ def run_iters(num_iters): with accelerator.accumulate(model): outputs = model(**batch) loss = outputs.loss - total_loss += loss.detach().float() accelerator.backward(loss) optimizer.step() lr_scheduler.step() @@ -200,21 +199,25 @@ def restart_jupyter_notebook(): warnings.simplefilter("ignore") torch.set_warn_always(False) - +@torch.no_grad() def run_forward_pass(model, hyperparams, num_iters): """ It runs num_iters forward passes with sample data. """ + accelerator = Accelerator( + log_with="wandb", + gradient_accumulation_steps=hyperparams.gradient_accumulation_steps, + mixed_precision="no" + ) + train_dataloader = get_dataloaders(accelerator, hyperparams) + model.train() train_dataloader = enumerate(train_dataloader) for _ in range(num_iters): _, batch = next(train_dataloader) batch["input_ids"] = batch["input_ids"].cuda() - model.generate( - **batch, - max_new_tokens=10 - ) + model(batch["input_ids"]) """ Benchmarking and example generation functions. @@ -224,6 +227,13 @@ def print_sample_of_generated_texts(model): tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name) inputs = tokenizer(["Another string ... ", "I "] * 32, return_tensors="pt", padding=True) + + max_length = inputs['input_ids'].size(1) + new_length = ((max_length + 63) // 64) * 64 + inputs['input_ids'] = torch.nn.functional.pad(inputs['input_ids'], (new_length - max_length, 0), value=tokenizer.pad_token_id) + inputs['attention_mask'] = torch.nn.functional.pad(inputs['attention_mask'], (new_length - max_length, 0), value=0) + + inputs['input_ids'] = inputs['input_ids'].cuda() inputs['attention_mask'] = inputs['attention_mask'].cuda() From 1a7c0d359eada2b3324f2219bbec3722d15a2cb2 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 17 May 2024 17:09:43 -0700 Subject: [PATCH 131/244] fused=True Signed-off-by: Pawel Gadzinski --- docs/examples/te_gemma/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/examples/te_gemma/utils.py b/docs/examples/te_gemma/utils.py index c9a9571b2d..5b4ba306ec 100644 --- a/docs/examples/te_gemma/utils.py +++ b/docs/examples/te_gemma/utils.py @@ -130,7 +130,7 @@ def wrap_with_accelerator(model, hyperparams): train_dataloader = get_dataloaders(accelerator, hyperparams) # Wrap model, optimizer/scheduler, dataloaders in accelerate - optimizer = AdamW(params = model.parameters(), lr=hyperparams.learning_rate, fused=False) + optimizer = AdamW(params = model.parameters(), lr=hyperparams.learning_rate, fused=True) lr_scheduler = get_linear_schedule_with_warmup( optimizer=optimizer, num_warmup_steps=100, From afbaa3fd0cd727593a67c36097d5bb1fa66db566 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 17 May 2024 17:25:37 -0700 Subject: [PATCH 132/244] fused=True Signed-off-by: Pawel Gadzinski --- .../examples/te_gemma/te_gemma_loading_weights.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/docs/examples/te_gemma/te_gemma_loading_weights.py b/docs/examples/te_gemma/te_gemma_loading_weights.py index 772f58320d..2080cfaf7d 100644 --- a/docs/examples/te_gemma/te_gemma_loading_weights.py +++ b/docs/examples/te_gemma/te_gemma_loading_weights.py @@ -21,9 +21,14 @@ def _load_fp8_weights(vanilla_model, hyperparams): vanilla_model.load_state_dict( - torch.load(hyperparams.fp8_model_weights_filename) + torch.load(hyperparams.fp8_model_weights_filename), strict=False + # strict = false, because some parameters have + # multiple pointers to the same weight + # vanilla_model._model_context_phase.model + # and vanilla_model._model_generation_phase.model ) + def _load_standard_weights(vanilla_model, config): archive_file = os.path.join(config.model_name, "model.safetensors.index.json") resolved_archive_file, _ = get_checkpoint_shard_files(config.model_name, archive_file) @@ -31,6 +36,7 @@ def _load_standard_weights(vanilla_model, config): for shard_file in resolved_archive_file: state_dict = load_state_dict(shard_file) total_dict = total_dict | state_dict + replace_params(total_dict, vanilla_model.state_dict(), config, qkv_fused_and_interleaved=config.fuse_qkv_params) _load_state_dict_into_model(vanilla_model, total_dict, start_prefix="") # Copy parameters like embedding. @@ -45,10 +51,13 @@ def load_te_model(cls, config): Transformers repo: https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579 """ + config.use_cache = False # To make TransformerLayer compatible with GemmaModel with fp8_model_init(config.fp8_model_init): # there we need only to create model - vanilla_model = cls(config) - if config.fp8_model_init: + vanilla_model = cls(config).to(torch.bfloat16).cuda() + + # and now we copy the weights into it + if config.fp8_model_weights_filename is not None: if config.fp8_model_weights_filename is not None: _load_fp8_weights(vanilla_model, config) else: From f1537202d49f4885e5d12035dea284cc8d4bfee0 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Mon, 20 May 2024 14:18:31 -0700 Subject: [PATCH 133/244] new rope kernel (not working) Signed-off-by: Pawel Gadzinski --- .../common/fused_rope/fused_rope.cu | 38 ++++++++++++------- .../include/transformer_engine/fused_rope.h | 9 +++-- transformer_engine/pytorch/csrc/extensions.h | 2 + .../pytorch/csrc/extensions/apply_rope.cu | 15 ++++---- 4 files changed, 40 insertions(+), 24 deletions(-) diff --git a/transformer_engine/common/fused_rope/fused_rope.cu b/transformer_engine/common/fused_rope/fused_rope.cu index 14f76175dc..f7d30174ab 100644 --- a/transformer_engine/common/fused_rope/fused_rope.cu +++ b/transformer_engine/common/fused_rope/fused_rope.cu @@ -94,12 +94,14 @@ __device__ void fused_rope_block_backward( template __global__ void fused_rope_forward_kernel( - const scalar_t *src, const float *freqs, scalar_t *dst, const int h, + const scalar_t *src, const float *freqs, const int *begins, + scalar_t *dst, const int h, const int d, const int d2, const int stride_s, const int stride_b, const int stride_h, const int stride_d, const int o_stride_s, const int o_stride_b, const int o_stride_h, const int o_stride_d) { int s_id = blockIdx.x, b_id = blockIdx.y; - int offset_block = s_id * stride_s + b_id * stride_b; + int s_begin = 0; + int offset_block = (s_id + s_begin) * stride_s + b_id * stride_b; int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b; fused_rope_block_forward(src, freqs, dst, offset_block, offset_block_dst, h, d, d2, stride_h, stride_d, o_stride_h, o_stride_d); @@ -107,12 +109,14 @@ __global__ void fused_rope_forward_kernel( template __global__ void fused_rope_backward_kernel( - const scalar_t *src, const float *freqs, scalar_t *dst, const int h, + const scalar_t *src, const float *freqs, const int *begins, + scalar_t *dst, const int h, const int d, const int d2, const int stride_s, const int stride_b, const int stride_h, const int stride_d, const int o_stride_s, const int o_stride_b, const int o_stride_h, const int o_stride_d) { int s_id = blockIdx.x, b_id = blockIdx.y; - int offset_block = s_id * stride_s + b_id * stride_b; + int s_begin = begins[b_id]; + int offset_block = (s_id + s_begin) * stride_s + b_id * stride_b; int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b; fused_rope_block_backward(src, freqs, dst, offset_block, offset_block_dst, h, d, d2, stride_h, stride_d, o_stride_h, o_stride_d); @@ -150,7 +154,8 @@ __global__ void fused_rope_thd_backward_kernel( template void fused_rope_forward_launcher(const scalar_t *input, const float *freqs, - scalar_t *output, const int s, const int b, + const int *begins, scalar_t *output, + const int s, const int b, const int h, const int d, const int d2, const int stride_s, const int stride_b, const int stride_h, const int stride_d, @@ -162,14 +167,14 @@ void fused_rope_forward_launcher(const scalar_t *input, const float *freqs, dim3 threads(THREADS_PER_WARP, warps_per_block); fused_rope_forward_kernel<<>>( - input, freqs, output, h, d, d2, stride_s, stride_b, stride_h, stride_d, + input, freqs, begins, output, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b, o_stride_h, o_stride_d); NVTE_CHECK_CUDA(cudaGetLastError()); } template -void fused_rope_backward_launcher(const scalar_t *output_grads, - const float *freqs, scalar_t *input_grads, +void fused_rope_backward_launcher(const scalar_t *output_grads, const float *freqs, + const int *begins, scalar_t *input_grads, const int s, const int b, const int h, const int d, const int d2, const int stride_s, const int stride_b, const int stride_h, @@ -181,7 +186,7 @@ void fused_rope_backward_launcher(const scalar_t *output_grads, dim3 threads(THREADS_PER_WARP, warps_per_block); fused_rope_backward_kernel<<>>( - output_grads, freqs, input_grads, h, d, d2, stride_s, stride_b, stride_h, + output_grads, freqs, begins, input_grads, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b, o_stride_h, o_stride_d); NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -220,7 +225,7 @@ void fused_rope_thd_backward_launcher( NVTE_CHECK_CUDA(cudaGetLastError()); } -void fused_rope_forward(const Tensor &input, const Tensor &freqs, +void fused_rope_forward(const Tensor &input, const Tensor &freqs, const Tensor &begins, Tensor *output, const int s, const int b, const int h, const int d, const int d2, const int stride_s, const int stride_b, const int stride_h, @@ -232,12 +237,13 @@ void fused_rope_forward(const Tensor &input, const Tensor &freqs, fused_rope_forward_launcher( reinterpret_cast(input.data.dptr), reinterpret_cast(freqs.data.dptr), + reinterpret_cast(begins.data.dptr), reinterpret_cast(output->data.dptr), s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b, o_stride_h, o_stride_d, stream);); } -void fused_rope_backward(const Tensor &output_grads, const Tensor &freqs, +void fused_rope_backward(const Tensor &output_grads, const Tensor &freqs, const Tensor &begins, Tensor *input_grads, const int s, const int b, const int h, const int d, const int d2, const int stride_s, const int stride_b, @@ -250,6 +256,7 @@ void fused_rope_backward(const Tensor &output_grads, const Tensor &freqs, fused_rope_backward_launcher( reinterpret_cast(output_grads.data.dptr), reinterpret_cast(freqs.data.dptr), + reinterpret_cast(begins.data.dptr), reinterpret_cast(input_grads->data.dptr), s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b, o_stride_h, o_stride_d, stream);); @@ -295,7 +302,8 @@ void fused_rope_thd_backward(const Tensor &output_grads, } // end namespace transformer_engine void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs, - NVTETensor output, const int s, const int b, + const NVTETensor begins, NVTETensor output, + const int s, const int b, const int h, const int d, const int d2, const int stride_s, const int stride_b, const int stride_h, const int stride_d, @@ -306,13 +314,14 @@ void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs, using namespace transformer_engine; fused_rope_forward(*reinterpret_cast(input), *reinterpret_cast(freqs), + *reinterpret_cast(begins), reinterpret_cast(output), s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b, o_stride_h, o_stride_d, stream); } -void nvte_fused_rope_backward(const NVTETensor output_grads, - const NVTETensor freqs, NVTETensor input_grads, +void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor freqs, + const NVTETensor begins, NVTETensor input_grads, const int s, const int b, const int h, const int d, const int d2, const int stride_s, const int stride_b, const int stride_h, @@ -323,6 +332,7 @@ void nvte_fused_rope_backward(const NVTETensor output_grads, using namespace transformer_engine; fused_rope_backward(*reinterpret_cast(output_grads), *reinterpret_cast(freqs), + *reinterpret_cast(begins), reinterpret_cast(input_grads), s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b, o_stride_h, o_stride_d, stream); diff --git a/transformer_engine/common/include/transformer_engine/fused_rope.h b/transformer_engine/common/include/transformer_engine/fused_rope.h index cb712aecff..ed7474f881 100644 --- a/transformer_engine/common/include/transformer_engine/fused_rope.h +++ b/transformer_engine/common/include/transformer_engine/fused_rope.h @@ -17,6 +17,7 @@ extern "C" { * * \param[in] input Input tensor for fused rope. * \param[in] freqs The freqs tensor. + * \param[in] begins The beginning offsets. * \param[out] output Output tensor. * \param[in] s Length of the s dimension of input. * \param[in] b Length of the b dimension of input. @@ -34,7 +35,8 @@ extern "C" { * \param[in] stream CUDA stream used for the operation. */ void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs, - NVTETensor output, const int s, const int b, + const NVTETensor begins, NVTETensor output, + const int s, const int b, const int h, const int d, const int d2, const int stride_s, const int stride_b, const int stride_h, const int stride_d, @@ -46,6 +48,7 @@ void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs, * * \param[in] output_grads Incoming gradient tensor for backward. * \param[in] freqs The freqs tensor. + * \param[in] begins The beginning offsets. * \param[out] input_grads Input gradient tensor to calculate. * \param[in] s Length of the s dimension of output_grads. * \param[in] b Length of the b dimension of output_grads. @@ -62,8 +65,8 @@ void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs, * \param[in] o_stride_d Stride of the d dimension of input_grads. * \param[in] stream CUDA stream used for the operation. */ -void nvte_fused_rope_backward(const NVTETensor output_grads, - const NVTETensor freqs, NVTETensor input_grads, +void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor freqs, + const NVTETensor begins, NVTETensor input_grads, const int s, const int b, const int h, const int d, const int d2, const int stride_s, const int stride_b, const int stride_h, diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 916908d3ef..31ef53106c 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -629,11 +629,13 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reductio at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, + const at::Tensor &begins, const bool transpose_output_memory ); at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs, + const at::Tensor &begins, const bool transpose_output_memory ); diff --git a/transformer_engine/pytorch/csrc/extensions/apply_rope.cu b/transformer_engine/pytorch/csrc/extensions/apply_rope.cu index 455d152fe8..f54597ff1d 100644 --- a/transformer_engine/pytorch/csrc/extensions/apply_rope.cu +++ b/transformer_engine/pytorch/csrc/extensions/apply_rope.cu @@ -6,8 +6,8 @@ #include "extensions.h" -at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, - const bool transpose_output_memory) { +at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, + const at::Tensor &begins, const bool transpose_output_memory) { using namespace transformer_engine; TORCH_CHECK(input.dim() == 4, "expected 4D tensor"); TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); @@ -55,9 +55,10 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, auto input_cu = makeTransformerEngineTensor(input); auto freqs_cu = makeTransformerEngineTensor(freqs); + auto begins_cu = makeTransformerEngineTensor(begins); auto output_cu = makeTransformerEngineTensor(output); - nvte_fused_rope_forward(input_cu.data(), freqs_cu.data(), output_cu.data(), s, + nvte_fused_rope_forward(input_cu.data(), freqs_cu.data(), begins_cu.data(), output_cu.data(), s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b, o_stride_h, o_stride_d, at::cuda::getCurrentCUDAStream()); @@ -65,9 +66,8 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, return output; } -at::Tensor fused_rope_backward(const at::Tensor &output_grads, - const at::Tensor &freqs, - const bool transpose_output_memory) { +at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs, + const at::Tensor &begins, const bool transpose_output_memory) { using namespace transformer_engine; TORCH_CHECK(output_grads.dim() == 4, "expected 4D tensor"); TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); @@ -114,10 +114,11 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, auto output_grads_cu = makeTransformerEngineTensor(output_grads); auto freqs_cu = makeTransformerEngineTensor(freqs); + auto begins_cu = makeTransformerEngineTensor(begins); auto input_grads_cu = makeTransformerEngineTensor(input_grads); nvte_fused_rope_backward( - output_grads_cu.data(), freqs_cu.data(), input_grads_cu.data(), s, b, h, + output_grads_cu.data(), freqs_cu.data(), begins_cu.data(), input_grads_cu.data(), s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b, o_stride_h, o_stride_d, at::cuda::getCurrentCUDAStream()); From 8f572e3bdec87d88a0dec69c844ce61cd4f24bf0 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 21 May 2024 11:21:01 -0700 Subject: [PATCH 134/244] merge with THD branch Signed-off-by: Pawel Gadzinski --- docs/examples/te_gemma/utils.py | 11 +- .../common/fused_attn/fused_attn.cpp | 42 +- .../fused_attn_f16_arbitrary_seqlen.cu | 90 ++- .../fused_attn_f16_arbitrary_seqlen.h | 14 +- .../common/fused_rope/fused_rope.cu | 54 +- .../include/transformer_engine/fused_attn.h | 26 +- .../include/transformer_engine/fused_rope.h | 8 +- transformer_engine/pytorch/attention.py | 600 +++++++++++--- .../pytorch/cpp_extensions/fused_attn.py | 79 +- transformer_engine/pytorch/csrc/extensions.h | 13 +- .../pytorch/csrc/extensions/apply_rope.cu | 13 +- .../pytorch/csrc/extensions/attention.cu | 743 ++++++++++++++++-- .../pytorch/csrc/extensions/pybind.cpp | 3 - 13 files changed, 1402 insertions(+), 294 deletions(-) diff --git a/docs/examples/te_gemma/utils.py b/docs/examples/te_gemma/utils.py index 5b4ba306ec..7fe4ba3b5a 100644 --- a/docs/examples/te_gemma/utils.py +++ b/docs/examples/te_gemma/utils.py @@ -42,7 +42,7 @@ def __init__(self): self.dataset_name = "timdettmers/openassistant-guanaco" self.dataset_text_field = "text" self.learning_rate = 1.41e-5 - self.batch_size = 16 + self.batch_size = 8 self.max_seq_length = 256 self.gradient_accumulation_steps = 1 self.num_warmup_steps=5 @@ -229,7 +229,7 @@ def print_sample_of_generated_texts(model): max_length = inputs['input_ids'].size(1) - new_length = ((max_length + 63) // 64) * 64 + new_length = ((max_length + 63) // 64) * 128 inputs['input_ids'] = torch.nn.functional.pad(inputs['input_ids'], (new_length - max_length, 0), value=tokenizer.pad_token_id) inputs['attention_mask'] = torch.nn.functional.pad(inputs['attention_mask'], (new_length - max_length, 0), value=0) @@ -243,7 +243,10 @@ def print_sample_of_generated_texts(model): print(text) print("=" * 100) -def benchmark_generation(model, tokenizer, batch_size, context_length, max_new_tokens): + + +def benchmark_generation(model, batch_size, context_length, max_new_tokens): + tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name) inputs = tokenizer(["a" * context_length] * batch_size, return_tensors="pt", padding=True) start = torch.cuda.Event(enable_timing=True) @@ -253,7 +256,7 @@ def benchmark_generation(model, tokenizer, batch_size, context_length, max_new_t model.generate( inputs['input_ids'].cuda(), - max_new_tokens = 256 + max_new_tokens=max_new_tokens ) torch.cuda.synchronize() end.record() diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 64b8b865d1..c56e385f97 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -135,17 +135,24 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( } if (((cudnn_runtime_version >= 8903 && sm_arch_ >= 80) || (cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90))) + && ((cudnn_runtime_version < 90000 && max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0) + || (cudnn_runtime_version >= 90000)) && ((cudnn_runtime_version < 8907 && num_attn_heads == num_gqa_groups) || (cudnn_runtime_version >= 8907)) + && ((head_dim <= 128 && head_dim % 8 == 0) + // TODO (cyang): add is_training to nvte_get_fused_attn_backend + // d=256 only supported for forward + || (sm_arch_ >= 90 && cudnn_runtime_version >= 90000 + && head_dim <= 256 && head_dim % 8 == 0)) && ((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) || ((cudnn_runtime_version >= 8906) && (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS || (bias_type == NVTE_Bias_Type::NVTE_ALIBI && attn_mask_type != NVTE_Mask_Type::NVTE_NO_MASK && attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK - && sm_arch_ == 90) + && sm_arch_ >= 90) || (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS - && sm_arch_ == 90)))) + && sm_arch_ >= 90)))) && ((cudnn_runtime_version < 8906 && attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || ((cudnn_runtime_version >= 8906) && (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK @@ -157,9 +164,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) && bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) && ((qkv_format == NVTE_QKV_Format::NVTE_SBHD) - || (qkv_format == NVTE_QKV_Format::NVTE_THD) + || (sm_arch_ >= 90 && cudnn_runtime_version >= 90000 + && qkv_format == NVTE_QKV_Format::NVTE_THD) || (qkv_format == NVTE_QKV_Format::NVTE_BSHD))) { - flag_arb = true; + flag_arb = true; } if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) && (flag_arb == true)) { @@ -209,6 +217,7 @@ void nvte_fused_attn_fwd_qkvpacked( const NVTETensor seq_offsets_q, const NVTETensor seq_offsets_k, const NVTETensor seq_offsets_v, + const NVTETensor seq_offsets_o, const NVTETensor rng_state, size_t max_seqlen, bool is_training, float attn_scale, float dropout, @@ -223,6 +232,7 @@ void nvte_fused_attn_fwd_qkvpacked( const Tensor *input_seq_offsets_q = reinterpret_cast(seq_offsets_q); const Tensor *input_seq_offsets_k = reinterpret_cast(seq_offsets_k); const Tensor *input_seq_offsets_v = reinterpret_cast(seq_offsets_v); + const Tensor *input_seq_offsets_o = reinterpret_cast(seq_offsets_o); const Tensor *input_rng_state = reinterpret_cast(rng_state); const Tensor *input_QKV = reinterpret_cast(QKV); const Tensor *input_Bias = reinterpret_cast(Bias); @@ -273,7 +283,7 @@ void nvte_fused_attn_fwd_qkvpacked( input_QKV, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens, - input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, + input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, input_seq_offsets_o, input_rng_state, wkspace, stream, handle); #else @@ -311,6 +321,7 @@ void nvte_fused_attn_bwd_qkvpacked( const NVTETensor seq_offsets_q, const NVTETensor seq_offsets_k, const NVTETensor seq_offsets_v, + const NVTETensor seq_offsets_o, size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, @@ -324,6 +335,7 @@ void nvte_fused_attn_bwd_qkvpacked( const Tensor *input_seq_offsets_q = reinterpret_cast(seq_offsets_q); const Tensor *input_seq_offsets_k = reinterpret_cast(seq_offsets_k); const Tensor *input_seq_offsets_v = reinterpret_cast(seq_offsets_v); + const Tensor *input_seq_offsets_o = reinterpret_cast(seq_offsets_o); const Tensor *input_QKV = reinterpret_cast(QKV); const Tensor *input_O = reinterpret_cast(O); const Tensor *input_dO = reinterpret_cast(dO); @@ -386,7 +398,7 @@ void nvte_fused_attn_bwd_qkvpacked( output_S, output_dQKV, output_dBias, input_cu_seqlens, - input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, + input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, input_seq_offsets_o, input_rng_state, wkspace, stream, handle); #else @@ -430,6 +442,7 @@ void nvte_fused_attn_fwd_kvpacked( const NVTETensor seq_offsets_q, const NVTETensor seq_offsets_k, const NVTETensor seq_offsets_v, + const NVTETensor seq_offsets_o, const NVTETensor rng_state, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, @@ -444,6 +457,7 @@ void nvte_fused_attn_fwd_kvpacked( const Tensor *input_seq_offsets_q = reinterpret_cast(seq_offsets_q); const Tensor *input_seq_offsets_k = reinterpret_cast(seq_offsets_k); const Tensor *input_seq_offsets_v = reinterpret_cast(seq_offsets_v); + const Tensor *input_seq_offsets_o = reinterpret_cast(seq_offsets_o); const Tensor *input_rng_state = reinterpret_cast(rng_state); const Tensor *input_Q = reinterpret_cast(Q); const Tensor *input_KV = reinterpret_cast(KV); @@ -498,7 +512,7 @@ void nvte_fused_attn_fwd_kvpacked( input_Q, input_KV, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, - input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, + input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, input_seq_offsets_o, input_rng_state, wkspace, stream, handle); #else @@ -539,6 +553,7 @@ void nvte_fused_attn_bwd_kvpacked( const NVTETensor seq_offsets_q, const NVTETensor seq_offsets_k, const NVTETensor seq_offsets_v, + const NVTETensor seq_offsets_o, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, @@ -552,6 +567,7 @@ void nvte_fused_attn_bwd_kvpacked( const Tensor *input_seq_offsets_q = reinterpret_cast(seq_offsets_q); const Tensor *input_seq_offsets_k = reinterpret_cast(seq_offsets_k); const Tensor *input_seq_offsets_v = reinterpret_cast(seq_offsets_v); + const Tensor *input_seq_offsets_o = reinterpret_cast(seq_offsets_o); const Tensor *input_Q = reinterpret_cast(Q); const Tensor *input_KV = reinterpret_cast(KV); const Tensor *input_O = reinterpret_cast(O); @@ -619,7 +635,7 @@ void nvte_fused_attn_bwd_kvpacked( output_S, output_dQ, output_dKV, output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv, - input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, + input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, input_seq_offsets_o, input_rng_state, wkspace, stream, handle); #else const char *err_msg = @@ -663,6 +679,7 @@ void nvte_fused_attn_fwd( const NVTETensor seq_offsets_q, const NVTETensor seq_offsets_k, const NVTETensor seq_offsets_v, + const NVTETensor seq_offsets_o, const NVTETensor rng_state, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, @@ -677,6 +694,7 @@ void nvte_fused_attn_fwd( const Tensor *input_seq_offsets_q = reinterpret_cast(seq_offsets_q); const Tensor *input_seq_offsets_k = reinterpret_cast(seq_offsets_k); const Tensor *input_seq_offsets_v = reinterpret_cast(seq_offsets_v); + const Tensor *input_seq_offsets_o = reinterpret_cast(seq_offsets_o); const Tensor *input_rng_state = reinterpret_cast(rng_state); const Tensor *input_Q = reinterpret_cast(Q); const Tensor *input_K = reinterpret_cast(K); @@ -723,7 +741,7 @@ void nvte_fused_attn_fwd( input_Q, input_K, input_V, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, - input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, + input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, input_seq_offsets_o, input_rng_state, wkspace, stream, handle); #else @@ -766,6 +784,7 @@ void nvte_fused_attn_bwd( const NVTETensor seq_offsets_q, const NVTETensor seq_offsets_k, const NVTETensor seq_offsets_v, + const NVTETensor seq_offsets_o, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, @@ -779,6 +798,7 @@ void nvte_fused_attn_bwd( const Tensor *input_seq_offsets_q = reinterpret_cast(seq_offsets_q); const Tensor *input_seq_offsets_k = reinterpret_cast(seq_offsets_k); const Tensor *input_seq_offsets_v = reinterpret_cast(seq_offsets_v); + const Tensor *input_seq_offsets_o = reinterpret_cast(seq_offsets_o); const Tensor *input_Q = reinterpret_cast(Q); const Tensor *input_K = reinterpret_cast(K); const Tensor *input_V = reinterpret_cast(V); @@ -839,7 +859,7 @@ void nvte_fused_attn_bwd( output_S, output_dQ, output_dK, output_dV, output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv, - input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, + input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, input_seq_offsets_o, input_rng_state, wkspace, stream, handle); #else const char *err_msg = @@ -868,4 +888,4 @@ void nvte_fused_attn_bwd( } else { NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); } -} +} \ No newline at end of file diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index c40dd327ad..7a41f3cd14 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -57,7 +57,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( void *devPtrSoftmaxStats, void *devPtrO, void* devPtrDropoutSeed, void* devPtrDropoutOffset, void* devPtrCuSeqlensQ, void* devPtrCuSeqlensKV, - void* devPtrSeqOffsetsQ, void* devPtrSeqOffsetsK, void* devPtrSeqOffsetsV, + void* devPtrSeqOffsetsQ, void* devPtrSeqOffsetsK, + void* devPtrSeqOffsetsV, void* devPtrSeqOffsetsO, cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { @@ -98,6 +99,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( std::shared_ptr, // offset_q std::shared_ptr, // offset_k std::shared_ptr, // offset_v + std::shared_ptr, // offset_o std::shared_ptr, // dropout_seed std::shared_ptr >; // dropout_offset @@ -122,7 +124,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( std::shared_ptr Q, K, V, attn_scale; std::shared_ptr bias, seq_q, seq_kv; - std::shared_ptr offset_q, offset_k, offset_v; + std::shared_ptr offset_q, offset_k, offset_v, offset_o; std::shared_ptr dropout_seed, dropout_offset; offset_q = mha_graph->tensor(fe::graph::Tensor_attributes() @@ -140,6 +142,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl( .set_dim({b+1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(fe::DataType_t::INT32)); + offset_o = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_o") + .set_dim({b+1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); std::vector q_stride(4); std::vector k_stride(4); @@ -246,7 +253,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( O->set_output(true) .set_dim({b, h, s_q, d}) .set_stride(o_stride) - .set_ragged_offset(offset_q); + .set_ragged_offset(offset_o); } else { O->set_output(true) .set_dim({b, h, s_q, d}) @@ -268,8 +275,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( auto padding_tuple = is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); auto offset_tuple = is_ragged ? - std::make_tuple(offset_q, offset_k, offset_v) : - std::make_tuple(nullptr, nullptr, nullptr); + std::make_tuple(offset_q, offset_k, offset_v, offset_o) : + std::make_tuple(nullptr, nullptr, nullptr, nullptr); auto dropout_tuple = is_dropout ? std::make_tuple(dropout_seed, dropout_offset) : std::make_tuple(nullptr, nullptr); @@ -288,7 +295,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( }; auto [mha_graph, Q, K, V, attn_scale, O, Stats, - bias, seq_q, seq_kv, offset_q, offset_k, offset_v, + bias, seq_q, seq_kv, offset_q, offset_k, offset_v, offset_o, dropout_seed, dropout_offset] = get_graph( sdpa_f16_fprop_cache, descriptor); @@ -335,6 +342,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( variant_pack[offset_q] = devPtrSeqOffsetsQ; variant_pack[offset_k] = devPtrSeqOffsetsK; variant_pack[offset_v] = devPtrSeqOffsetsV; + variant_pack[offset_o] = devPtrSeqOffsetsO; } if (is_dropout) { @@ -358,7 +366,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( void* devPtrdQ, void* devPtrdK, void* devPtrdV, void* devPtrdO, void* devPtrdBias, void* devPtrDropoutSeed, void* devPtrDropoutOffset, void* devPtrCuSeqlensQ, void* devPtrCuSeqlensKV, - void* devPtrSeqOffsetsQ, void* devPtrSeqOffsetsK, void* devPtrSeqOffsetsV, + void* devPtrSeqOffsetsQ, void* devPtrSeqOffsetsK, + void* devPtrSeqOffsetsV, void* devPtrSeqOffsetsO, cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; @@ -403,6 +412,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( std::shared_ptr, // offset_q std::shared_ptr, // offset_k std::shared_ptr, // offset_v + std::shared_ptr, // offset_o std::shared_ptr, // dropout_seed std::shared_ptr >; // dropout_offset @@ -427,7 +437,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( std::shared_ptr q, k, v, o, dO, stats, attn_scale; std::shared_ptr bias, dBias, seq_q, seq_kv; - std::shared_ptr offset_q, offset_k, offset_v; + std::shared_ptr offset_q, offset_k, offset_v, offset_o; std::shared_ptr dropout_seed, dropout_offset; offset_q = mha_graph->tensor(fe::graph::Tensor_attributes() @@ -445,6 +455,11 @@ void fused_attn_arbitrary_seqlen_bwd_impl( .set_dim({b+1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(fe::DataType_t::INT32)); + offset_o = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_o") + .set_dim({b+1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); std::vector q_stride(4); std::vector k_stride(4); std::vector v_stride(4); @@ -478,12 +493,12 @@ void fused_attn_arbitrary_seqlen_bwd_impl( .set_name("O") .set_dim({b, h, s_q, d}) .set_stride(o_stride) - .set_ragged_offset(offset_q)); + .set_ragged_offset(offset_o)); dO = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("dO") .set_dim({b, h, s_q, d}) .set_stride(o_stride) - .set_ragged_offset(offset_q)); + .set_ragged_offset(offset_o)); } else { q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q") @@ -620,8 +635,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( auto padding_tuple = is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); auto offset_tuple = is_ragged ? - std::make_tuple(offset_q, offset_k, offset_v) : - std::make_tuple(nullptr, nullptr, nullptr); + std::make_tuple(offset_q, offset_k, offset_v, offset_o) : + std::make_tuple(nullptr, nullptr, nullptr, nullptr); auto dropout_tuple = is_dropout ? std::make_tuple(dropout_seed, dropout_offset) : std::make_tuple(nullptr, nullptr); @@ -640,7 +655,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( }; auto [mha_graph, q, k, v, o, dO, stats, attn_scale, dQ, dK, dV, - bias, dBias, seq_q, seq_kv, offset_q, offset_k, offset_v, + bias, dBias, seq_q, seq_kv, offset_q, offset_k, offset_v, offset_o, dropout_seed, dropout_offset] = get_graph( sdpa_f16_bprop_cache, descriptor); @@ -698,6 +713,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( variant_pack[offset_q] = devPtrSeqOffsetsQ; variant_pack[offset_k] = devPtrSeqOffsetsK; variant_pack[offset_v] = devPtrSeqOffsetsV; + variant_pack[offset_o] = devPtrSeqOffsetsO; } if (is_dropout) { @@ -717,8 +733,8 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, - const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, const Tensor *seq_offsets_v, + NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *seq_offsets_q, + const Tensor *seq_offsets_k, const Tensor *seq_offsets_v, const Tensor *seq_offsets_o, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; @@ -749,6 +765,7 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( void *devPtrSeqOffsetsQ = seq_offsets_q->data.dptr; void *devPtrSeqOffsetsK = seq_offsets_k->data.dptr; void *devPtrSeqOffsetsV = seq_offsets_v->data.dptr; + void *devPtrSeqOffsetsO = seq_offsets_o->data.dptr; if (Aux_CTX_Tensors->size == 0) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { @@ -805,7 +822,8 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, - devPtrSeqOffsetsQ, devPtrSeqOffsetsK, devPtrSeqOffsetsV, + devPtrSeqOffsetsQ, devPtrSeqOffsetsK, + devPtrSeqOffsetsV, devPtrSeqOffsetsO, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); @@ -834,7 +852,7 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t num_attn_hea Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, const Tensor *cu_seqlens, const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, const Tensor *seq_offsets_v, - const Tensor *rng_state, + const Tensor *seq_offsets_o, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; @@ -877,6 +895,7 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t num_attn_hea void *devPtrSeqOffsetsQ = seq_offsets_q->data.dptr; void *devPtrSeqOffsetsK = seq_offsets_k->data.dptr; void *devPtrSeqOffsetsV = seq_offsets_v->data.dptr; + void *devPtrSeqOffsetsO = seq_offsets_o->data.dptr; void* devPtrDropoutSeed = rng_state->data.dptr; void* devPtrDropoutOffset = reinterpret_cast( @@ -892,7 +911,8 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t num_attn_hea devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, - devPtrSeqOffsetsQ, devPtrSeqOffsetsK, devPtrSeqOffsetsV, + devPtrSeqOffsetsQ, devPtrSeqOffsetsK, + devPtrSeqOffsetsV, devPtrSeqOffsetsO, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); @@ -915,9 +935,9 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, const Tensor *input_Q, const Tensor *input_KV, - const Tensor *input_Bias, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, const Tensor *seq_offsets_v, + const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, + const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *seq_offsets_q, + const Tensor *seq_offsets_k, const Tensor *seq_offsets_v, const Tensor *seq_offsets_o, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; @@ -950,6 +970,7 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( void *devPtrSeqOffsetsQ = seq_offsets_q->data.dptr; void *devPtrSeqOffsetsK = seq_offsets_k->data.dptr; void *devPtrSeqOffsetsV = seq_offsets_v->data.dptr; + void *devPtrSeqOffsetsO = seq_offsets_o->data.dptr; if (Aux_CTX_Tensors->size == 0) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { @@ -1006,7 +1027,8 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, - devPtrSeqOffsetsQ, devPtrSeqOffsetsK, devPtrSeqOffsetsV, + devPtrSeqOffsetsQ, devPtrSeqOffsetsK, + devPtrSeqOffsetsV, devPtrSeqOffsetsO, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); @@ -1038,8 +1060,8 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( Tensor *output_dBias, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, const Tensor *seq_offsets_v, - const Tensor *rng_state, Tensor *workspace, - cudaStream_t stream, cudnnHandle_t handle) { + const Tensor *seq_offsets_o, const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const auto QKV_type = input_Q->data.dtype; @@ -1081,6 +1103,7 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( void *devPtrSeqOffsetsQ = seq_offsets_q->data.dptr; void *devPtrSeqOffsetsK = seq_offsets_k->data.dptr; void *devPtrSeqOffsetsV = seq_offsets_v->data.dptr; + void *devPtrSeqOffsetsO = seq_offsets_o->data.dptr; void* devPtrDropoutSeed = rng_state->data.dptr; void* devPtrDropoutOffset = reinterpret_cast( @@ -1096,7 +1119,8 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, - devPtrSeqOffsetsQ, devPtrSeqOffsetsK, devPtrSeqOffsetsV, + devPtrSeqOffsetsQ, devPtrSeqOffsetsK, + devPtrSeqOffsetsV, devPtrSeqOffsetsO, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); @@ -1123,7 +1147,7 @@ void fused_attn_arbitrary_seqlen_fwd( const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, const Tensor *seq_offsets_v, - const Tensor *rng_state, + const Tensor *seq_offsets_o, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; @@ -1147,6 +1171,7 @@ void fused_attn_arbitrary_seqlen_fwd( void *devPtrSeqOffsetsQ = seq_offsets_q->data.dptr; void *devPtrSeqOffsetsK = seq_offsets_k->data.dptr; void *devPtrSeqOffsetsV = seq_offsets_v->data.dptr; + void *devPtrSeqOffsetsO = seq_offsets_o->data.dptr; if (Aux_CTX_Tensors->size == 0) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { @@ -1203,7 +1228,8 @@ void fused_attn_arbitrary_seqlen_fwd( devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, - devPtrSeqOffsetsQ, devPtrSeqOffsetsK, devPtrSeqOffsetsV, + devPtrSeqOffsetsQ, devPtrSeqOffsetsK, + devPtrSeqOffsetsV, devPtrSeqOffsetsO, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); @@ -1235,8 +1261,8 @@ void fused_attn_arbitrary_seqlen_bwd(size_t batch, size_t num_attn_heads, size_t Tensor *output_dBias, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, const Tensor *seq_offsets_v, - const Tensor *rng_state, Tensor *workspace, - cudaStream_t stream, cudnnHandle_t handle) { + const Tensor *seq_offsets_o, const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const auto QKV_type = input_Q->data.dtype; void *devPtrQ = input_Q->data.dptr; @@ -1266,6 +1292,7 @@ void fused_attn_arbitrary_seqlen_bwd(size_t batch, size_t num_attn_heads, size_t void *devPtrSeqOffsetsQ = seq_offsets_q->data.dptr; void *devPtrSeqOffsetsK = seq_offsets_k->data.dptr; void *devPtrSeqOffsetsV = seq_offsets_v->data.dptr; + void *devPtrSeqOffsetsO = seq_offsets_o->data.dptr; void* devPtrDropoutSeed = rng_state->data.dptr; void* devPtrDropoutOffset = reinterpret_cast( @@ -1280,7 +1307,8 @@ void fused_attn_arbitrary_seqlen_bwd(size_t batch, size_t num_attn_heads, size_t devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, - devPtrSeqOffsetsQ, devPtrSeqOffsetsK, devPtrSeqOffsetsV, + devPtrSeqOffsetsQ, devPtrSeqOffsetsK, + devPtrSeqOffsetsV, devPtrSeqOffsetsO, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); @@ -1299,4 +1327,4 @@ void fused_attn_arbitrary_seqlen_bwd(size_t batch, size_t num_attn_heads, size_t } } } // namespace transformer_engine -#endif // CUDNN_VERSION >= 8900 +#endif // CUDNN_VERSION >= 8900 \ No newline at end of file diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index baedf8ca74..90e06e1cdc 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -26,7 +26,7 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, const Tensor *seq_offsets_v, - const Tensor *rng_state, Tensor *workspace, + const Tensor *seq_offsets_o, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_bwd_qkvpacked( @@ -39,7 +39,7 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked( Tensor *output_dQKV, Tensor *output_dBias, const Tensor *cu_seqlens, const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, const Tensor *seq_offsets_v, - const Tensor *rng_state, Tensor *workspace, + const Tensor *seq_offsets_o, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_fwd_kvpacked( @@ -52,7 +52,7 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, - const Tensor *seq_offsets_v, const Tensor *rng_state, + const Tensor *seq_offsets_v, const Tensor *seq_offsets_o, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_bwd_kvpacked( @@ -65,7 +65,7 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, - const Tensor *seq_offsets_v, const Tensor *rng_state, + const Tensor *seq_offsets_v, const Tensor *seq_offsets_o, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_fwd( @@ -79,7 +79,7 @@ void fused_attn_arbitrary_seqlen_fwd( Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, - const Tensor *seq_offsets_v, const Tensor *rng_state, + const Tensor *seq_offsets_v, const Tensor *seq_offsets_o, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_bwd( @@ -94,10 +94,10 @@ void fused_attn_arbitrary_seqlen_bwd( Tensor *output_dV, Tensor *output_dBias, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *seq_offsets_q, const Tensor *seq_offsets_k, - const Tensor *seq_offsets_v, const Tensor *rng_state, + const Tensor *seq_offsets_v, const Tensor *seq_offsets_o, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); #endif // CUDNN_VERSION >= 8900 } // namespace transformer_engine -#endif // TRANSFORMER_ENGINE_COMMON_FUSED_ATTN_FUSED_ATTN_ARBITRARY_SEQLEN_H_ +#endif // TRANSFORMER_ENGINE_COMMON_FUSED_ATTN_FUSED_ATTN_ARBITRARY_SEQLEN_H_ \ No newline at end of file diff --git a/transformer_engine/common/fused_rope/fused_rope.cu b/transformer_engine/common/fused_rope/fused_rope.cu index f7d30174ab..c78aa6851e 100644 --- a/transformer_engine/common/fused_rope/fused_rope.cu +++ b/transformer_engine/common/fused_rope/fused_rope.cu @@ -15,11 +15,11 @@ namespace transformer_engine { template __device__ void fused_rope_block_forward( - const scalar_t *src, const float *freqs, scalar_t *dst, + const scalar_t *src, const float *freqs, scalar_t *dst, const int begin_offset, const int offset_block, const int offset_block_dst, const int h, const int d, const int d2, const int stride_h, const int stride_d, const int o_stride_h, const int o_stride_d) { - int s_id = blockIdx.x; + int s_id = blockIdx.x + begin_offset; #pragma unroll for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { float v_cos, v_sin; @@ -54,11 +54,11 @@ __device__ void fused_rope_block_forward( template __device__ void fused_rope_block_backward( - const scalar_t *src, const float *freqs, scalar_t *dst, + const scalar_t *src, const float *freqs, scalar_t *dst, const int begin_offset, const int offset_block, const int offset_block_dst, const int h, const int d, const int d2, const int stride_h, const int stride_d, const int o_stride_h, const int o_stride_d) { - int s_id = blockIdx.x; + int s_id = blockIdx.x + begin_offset; #pragma unroll for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { float v_cos = cosf(freqs[s_id * d2 + d_id]); @@ -100,10 +100,10 @@ __global__ void fused_rope_forward_kernel( const int stride_h, const int stride_d, const int o_stride_s, const int o_stride_b, const int o_stride_h, const int o_stride_d) { int s_id = blockIdx.x, b_id = blockIdx.y; - int s_begin = 0; - int offset_block = (s_id + s_begin) * stride_s + b_id * stride_b; + int begin_offset = (begins == 0) ? 0 : begins[b_id]; + int offset_block = s_id * stride_s + b_id * stride_b; int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b; - fused_rope_block_forward(src, freqs, dst, offset_block, offset_block_dst, h, + fused_rope_block_forward(src, freqs, dst, begin_offset, offset_block, offset_block_dst, h, d, d2, stride_h, stride_d, o_stride_h, o_stride_d); } @@ -115,16 +115,16 @@ __global__ void fused_rope_backward_kernel( const int stride_h, const int stride_d, const int o_stride_s, const int o_stride_b, const int o_stride_h, const int o_stride_d) { int s_id = blockIdx.x, b_id = blockIdx.y; - int s_begin = begins[b_id]; - int offset_block = (s_id + s_begin) * stride_s + b_id * stride_b; + int begin_offset = (begins == 0) ? 0 : begins[b_id]; + int offset_block = s_id * stride_s + b_id * stride_b; int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b; - fused_rope_block_backward(src, freqs, dst, offset_block, offset_block_dst, h, + fused_rope_block_backward(src, freqs, dst, begin_offset, offset_block, offset_block_dst, h, d, d2, stride_h, stride_d, o_stride_h, o_stride_d); } template __global__ void fused_rope_thd_forward_kernel( - const scalar_t *src, const int *cu_seqlens, const float *freqs, + const scalar_t *src, const int *cu_seqlens, const float *freqs, const int *begins, scalar_t *dst, const int h, const int d, const int d2, const int stride_t, const int stride_h, const int stride_d, const int o_stride_t, const int o_stride_h, const int o_stride_d) { @@ -133,13 +133,14 @@ __global__ void fused_rope_thd_forward_kernel( if (t_id >= cu_seqlens[b_id + 1]) return; int offset_block = t_id * stride_t; int offset_block_dst = t_id * o_stride_t; - fused_rope_block_forward(src, freqs, dst, offset_block, offset_block_dst, h, + int begin_offset = (begins == 0) ? 0 : begins[b_id]; + fused_rope_block_forward(src, freqs, dst, begin_offset, offset_block, offset_block_dst, h, d, d2, stride_h, stride_d, o_stride_h, o_stride_d); } template __global__ void fused_rope_thd_backward_kernel( - const scalar_t *src, const int *cu_seqlens, const float *freqs, + const scalar_t *src, const int *cu_seqlens, const float *freqs, const int *begins, scalar_t *dst, const int h, const int d, const int d2, const int stride_t, const int stride_h, const int stride_d, const int o_stride_t, const int o_stride_h, const int o_stride_d) { @@ -148,7 +149,8 @@ __global__ void fused_rope_thd_backward_kernel( if (t_id >= cu_seqlens[b_id + 1]) return; int offset_block = t_id * stride_t; int offset_block_dst = t_id * o_stride_t; - fused_rope_block_backward(src, freqs, dst, offset_block, offset_block_dst, h, + int begin_offset = (begins == 0) ? 0 : begins[b_id]; + fused_rope_block_backward(src, freqs, dst, begin_offset, offset_block, offset_block_dst, h, d, d2, stride_h, stride_d, o_stride_h, o_stride_d); } @@ -193,7 +195,7 @@ void fused_rope_backward_launcher(const scalar_t *output_grads, const float *fre template void fused_rope_thd_forward_launcher( - const scalar_t *input, const int *cu_seqlens, const float *freqs, + const scalar_t *input, const int *cu_seqlens, const float *freqs, const int *begins, scalar_t *output, const int max_s, const int b, const int h, const int d, const int d2, const int stride_t, const int stride_h, const int stride_d, const int o_stride_t, const int o_stride_h, const int o_stride_d, @@ -203,14 +205,15 @@ void fused_rope_thd_forward_launcher( dim3 threads(THREADS_PER_WARP, warps_per_block); fused_rope_thd_forward_kernel<<>>( - input, cu_seqlens, freqs, output, h, d, d2, stride_t, stride_h, stride_d, + input, cu_seqlens, freqs, begins, output, h, d, d2, stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d); NVTE_CHECK_CUDA(cudaGetLastError()); } template void fused_rope_thd_backward_launcher( - const scalar_t *output_grads, const int *cu_seqlens, const float *freqs, + const scalar_t *output_grads, const int *cu_seqlens, + const float *freqs, const int *begins, scalar_t *input_grads, const int max_s, const int b, const int h, const int d, const int d2, const int stride_t, const int stride_h, const int stride_d, const int o_stride_t, const int o_stride_h, @@ -220,7 +223,7 @@ void fused_rope_thd_backward_launcher( dim3 threads(THREADS_PER_WARP, warps_per_block); fused_rope_thd_backward_kernel<<>>( - output_grads, cu_seqlens, freqs, input_grads, h, d, d2, stride_t, + output_grads, cu_seqlens, freqs, begins, input_grads, h, d, d2, stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d); NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -263,7 +266,7 @@ void fused_rope_backward(const Tensor &output_grads, const Tensor &freqs, const } void fused_rope_thd_forward(const Tensor &input, const Tensor &cu_seqlens, - const Tensor &freqs, Tensor *output, + const Tensor &freqs, const Tensor &begins, Tensor *output, const int max_s, const int b, const int h, const int d, const int d2, const int stride_t, const int stride_h, const int stride_d, @@ -275,13 +278,14 @@ void fused_rope_thd_forward(const Tensor &input, const Tensor &cu_seqlens, reinterpret_cast(input.data.dptr), reinterpret_cast(cu_seqlens.data.dptr), reinterpret_cast(freqs.data.dptr), + reinterpret_cast(begins.data.dptr), reinterpret_cast(output->data.dptr), max_s, b, h, d, d2, stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, stream);); } -void fused_rope_thd_backward(const Tensor &output_grads, - const Tensor &cu_seqlens, const Tensor &freqs, +void fused_rope_thd_backward(const Tensor &output_grads, const Tensor &cu_seqlens, + const Tensor &freqs, const Tensor &begins, Tensor *input_grads, const int max_s, const int b, const int h, const int d, const int d2, const int stride_t, const int stride_h, @@ -294,6 +298,7 @@ void fused_rope_thd_backward(const Tensor &output_grads, reinterpret_cast(output_grads.data.dptr), reinterpret_cast(cu_seqlens.data.dptr), reinterpret_cast(freqs.data.dptr), + reinterpret_cast(begins.data.dptr), reinterpret_cast(input_grads->data.dptr), max_s, b, h, d, d2, stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, stream);); @@ -340,7 +345,8 @@ void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor fr void nvte_fused_rope_thd_forward(const NVTETensor input, const NVTETensor cu_seqlens, - const NVTETensor freqs, NVTETensor output, + const NVTETensor freqs, + const NVTETensor begins, NVTETensor output, const int max_s, const int b, const int h, const int d, const int d2, const int stride_t, const int stride_h, const int stride_d, @@ -351,6 +357,7 @@ void nvte_fused_rope_thd_forward(const NVTETensor input, fused_rope_thd_forward(*reinterpret_cast(input), *reinterpret_cast(cu_seqlens), *reinterpret_cast(freqs), + *reinterpret_cast(begins), reinterpret_cast(output), max_s, b, h, d, d2, stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, stream); @@ -358,7 +365,7 @@ void nvte_fused_rope_thd_forward(const NVTETensor input, void nvte_fused_rope_thd_backward( const NVTETensor output_grads, const NVTETensor cu_seqlens, - const NVTETensor freqs, NVTETensor input_grads, const int max_s, + const NVTETensor freqs, const NVTETensor begins, NVTETensor input_grads, const int max_s, const int b, const int h, const int d, const int d2, const int stride_t, const int stride_h, const int stride_d, const int o_stride_t, const int o_stride_h, const int o_stride_d, cudaStream_t stream) { @@ -367,6 +374,7 @@ void nvte_fused_rope_thd_backward( fused_rope_thd_backward(*reinterpret_cast(output_grads), *reinterpret_cast(cu_seqlens), *reinterpret_cast(freqs), + *reinterpret_cast(begins), reinterpret_cast(input_grads), max_s, b, h, d, d2, stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, stream); diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 48cebed28a..ac5f8fbc78 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -176,10 +176,11 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( * \param[out] O The output O tensor. * \param[out] Aux_CTX_Tensors Auxiliary output tensors when training, * e.g. M, ZInv, rng_state. - * \param[in] cu_seqlens Accumulative sequence lengths, [batch_size + 1]. + * \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1]. * \param[in] seq_offsets_q Cumulative sequence offsets for Q, [batch_size + 1]. * \param[in] seq_offsets_k Cumulative sequence offsets for K, [batch_size + 1]. * \param[in] seq_offsets_v Cumulative sequence offsets for V, [batch_size + 1]. + * \param[in] seq_offsets_o Cumulative sequence offsets for O, [batch_size + 1]. * \param[in] rng_state Seed and offset of CUDA random number generator. * \param[in] max_seqlen Max sequence length used for computing, * it may be >= max(seqlen_i) for i=0,...batch_size-1. @@ -202,6 +203,7 @@ void nvte_fused_attn_fwd_qkvpacked( const NVTETensor seq_offsets_q, const NVTETensor seq_offsets_k, const NVTETensor seq_offsets_v, + const NVTETensor seq_offsets_o, const NVTETensor rng_state, size_t max_seqlen, bool is_training, float attn_scale, float dropout, @@ -229,10 +231,11 @@ void nvte_fused_attn_fwd_qkvpacked( * e.g. M, ZInv, rng_state. * \param[out] dQKV The gradient of the QKV tensor. * \param[out] dBias The gradient of the Bias tensor. - * \param[in] cu_seqlens Accumulative sequence lengths, [batch_size + 1]. + * \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1]. * \param[in] seq_offsets_q Cumulative sequence offsets for Q, [batch_size + 1]. * \param[in] seq_offsets_k Cumulative sequence offsets for K, [batch_size + 1]. * \param[in] seq_offsets_v Cumulative sequence offsets for V, [batch_size + 1]. + * \param[in] seq_offsets_o Cumulative sequence offsets for O, [batch_size + 1]. * \param[in] max_seqlen Max sequence length used for computing, * it may be >= max(seqlen_i) for i=0,...batch_size-1. * \param[in] attn_scale Scaling factor for Q * K.T. @@ -256,6 +259,7 @@ void nvte_fused_attn_bwd_qkvpacked( const NVTETensor seq_offsets_q, const NVTETensor seq_offsets_k, const NVTETensor seq_offsets_v, + const NVTETensor seq_offsets_o, size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, @@ -285,11 +289,12 @@ void nvte_fused_attn_bwd_qkvpacked( * \param[out] O The output O tensor. * \param[out] Aux_CTX_Tensors Auxiliary output tensors when training, * e.g. M, ZInv, rng_state. - * \param[in] cu_seqlens_q Accumulative sequence lengths for Q, [batch_size + 1]. - * \param[in] cu_seqlens_kv Accumulative sequence lengths for KV, [batch_size + 1]. + * \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1]. + * \param[in] cu_seqlens_kv Cumulative sequence lengths for KV, [batch_size + 1]. * \param[in] seq_offsets_q Cumulative sequence offsets for Q, [batch_size + 1]. * \param[in] seq_offsets_k Cumulative sequence offsets for K, [batch_size + 1]. * \param[in] seq_offsets_v Cumulative sequence offsets for V, [batch_size + 1]. + * \param[in] seq_offsets_o Cumulative sequence offsets for O, [batch_size + 1]. * \param[in] rng_state Seed and offset of CUDA random number generator. * \param[in] max_seqlen_q Max sequence length used for computing for Q. * it may be >= max(seqlen_q_i) for i=0,...batch_size-1. @@ -316,6 +321,7 @@ void nvte_fused_attn_fwd_kvpacked( const NVTETensor seq_offsets_q, const NVTETensor seq_offsets_k, const NVTETensor seq_offsets_v, + const NVTETensor seq_offsets_o, const NVTETensor rng_state, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, @@ -344,11 +350,12 @@ void nvte_fused_attn_fwd_kvpacked( * \param[out] dQ The gradient of the Q tensor. * \param[out] dKV The gradient of the KV tensor. * \param[out] dBias The gradient of the Bias tensor. - * \param[in] cu_seqlens_q Accumulative sequence lengths for Q, [batch_size + 1]. - * \param[in] cu_seqlens_kv Accumulative sequence lengths for KV, [batch_size + 1]. + * \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1]. + * \param[in] cu_seqlens_kv Cumulative sequence lengths for KV, [batch_size + 1]. * \param[in] seq_offsets_q Cumulative sequence offsets for Q, [batch_size + 1]. * \param[in] seq_offsets_k Cumulative sequence offsets for K, [batch_size + 1]. * \param[in] seq_offsets_v Cumulative sequence offsets for V, [batch_size + 1]. + * \param[in] seq_offsets_o Cumulative sequence offsets for O, [batch_size + 1]. * \param[in] max_seqlen_q Max sequence length used for computing for Q. * it may be >= max(seqlen_q_i) for i=0,...batch_size-1. * \param[in] max_seqlen_kv Max sequence length used for computing for KV. @@ -377,6 +384,7 @@ void nvte_fused_attn_bwd_kvpacked( const NVTETensor seq_offsets_q, const NVTETensor seq_offsets_k, const NVTETensor seq_offsets_v, + const NVTETensor seq_offsets_o, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, @@ -415,6 +423,7 @@ void nvte_fused_attn_bwd_kvpacked( * \param[in] seq_offsets_q Cumulative sequence offsets for Q, [batch_size + 1]. * \param[in] seq_offsets_k Cumulative sequence offsets for K, [batch_size + 1]. * \param[in] seq_offsets_v Cumulative sequence offsets for V, [batch_size + 1]. + * \param[in] seq_offsets_o Cumulative sequence offsets for O, [batch_size + 1]. * \param[in] rng_state Seed and offset of CUDA random number generator. * \param[in] max_seqlen_q Max sequence length used for computing for Q. * it may be >= max(seqlen_q_i) for i=0,...batch_size-1. @@ -442,6 +451,7 @@ void nvte_fused_attn_fwd( const NVTETensor seq_offsets_q, const NVTETensor seq_offsets_k, const NVTETensor seq_offsets_v, + const NVTETensor seq_offsets_o, const NVTETensor rng_state, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, @@ -480,6 +490,7 @@ void nvte_fused_attn_fwd( * \param[in] seq_offsets_q Cumulative sequence offsets for Q, [batch_size + 1]. * \param[in] seq_offsets_k Cumulative sequence offsets for K, [batch_size + 1]. * \param[in] seq_offsets_v Cumulative sequence offsets for V, [batch_size + 1]. + * \param[in] seq_offsets_o Cumulative sequence offsets for O, [batch_size + 1]. * \param[in] max_seqlen_q Max sequence length used for computing for Q. * it may be >= max(seqlen_q_i) for i=0,...batch_size-1. * \param[in] max_seqlen_kv Max sequence length used for computing for K and V. @@ -510,6 +521,7 @@ void nvte_fused_attn_bwd( const NVTETensor seq_offsets_q, const NVTETensor seq_offsets_k, const NVTETensor seq_offsets_v, + const NVTETensor seq_offsets_o, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, @@ -520,4 +532,4 @@ void nvte_fused_attn_bwd( } // extern "C" #endif -#endif +#endif \ No newline at end of file diff --git a/transformer_engine/common/include/transformer_engine/fused_rope.h b/transformer_engine/common/include/transformer_engine/fused_rope.h index ed7474f881..d1f9f1a5bc 100644 --- a/transformer_engine/common/include/transformer_engine/fused_rope.h +++ b/transformer_engine/common/include/transformer_engine/fused_rope.h @@ -79,6 +79,7 @@ void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor fr * \param[in] input Input tensor for fused rope. * \param[in] cu_seqlens The cumulative sum of sequence lengths tensor. * \param[in] freqs The freqs tensor. + * \param[in] begins The beginning offsets. * \param[out] output Output tensor. * \param[in] max_s Max sequence length. * \param[in] b Batch size. @@ -95,7 +96,9 @@ void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor fr */ void nvte_fused_rope_thd_forward(const NVTETensor input, const NVTETensor cu_seqlens, - const NVTETensor freqs, NVTETensor output, + const NVTETensor freqs, + NVTETensor begins, + NVTETensor output, const int max_s, const int b, const int h, const int d, const int d2, const int stride_t, const int stride_h, const int stride_d, @@ -107,6 +110,7 @@ void nvte_fused_rope_thd_forward(const NVTETensor input, * \param[in] output_grads Incoming gradient tensor for backward. * \param[in] cu_seqlens The cumulative sum of sequence lengths tensor. * \param[in] freqs The freqs tensor. + * \param[in] begins The beginning offsets. * \param[out] input_grads Input gradient to calculate. * \param[in] max_s Max sequence length. * \param[in] b Batch size. @@ -123,7 +127,7 @@ void nvte_fused_rope_thd_forward(const NVTETensor input, */ void nvte_fused_rope_thd_backward( const NVTETensor output_grads, const NVTETensor cu_seqlens, - const NVTETensor freqs, NVTETensor input_grads, const int max_s, + const NVTETensor freqs, NVTETensor begins, NVTETensor input_grads, const int max_s, const int b, const int h, const int d, const int d2, const int stride_t, const int stride_h, const int stride_d, const int o_stride_t, const int o_stride_h, const int o_stride_d, cudaStream_t stream); diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 661970c893..0d3a468d7a 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -234,29 +234,6 @@ def save_new_key_and_value_layer(self, layer_number, key_layer, value_layer): value_layer = inference_value_memory[:sequence_end, batch_start:batch_end, ...] return key_layer, value_layer - def pick_freqs(self, freq, pos_emb_buffer): - """ - Parameters - ---------- - freq: torch.Tensor [max_pos_emb, 1, 1, d] - Tensor with frequencies used in rotarty positional encoding application. - pos_emb_buffer: torch.Tensor [b, max_incoming_seq_len, 1, d] - Buffer for positional embedding frequencies for each sequence in batch. - - If self.incoming_seq_len contains numbers [s1, s2, ...], then - pos_emb_buffer[0, :] = freq[s1:(s1 + max_incoming_seq_len), 1, 1, d]. - """ - batch_size, _, _ , hidden_dim = pos_emb_buffer.shape - tex.get_values( - freq, - self.seq_len, - self.incoming_seq_len, - pos_emb_buffer, - self.max_incoming_seq_len, - batch_size, - hidden_dim - ) - @torch.no_grad() @@ -1470,18 +1447,21 @@ def forward( freqs: torch.Tensor, tensor_format: str = "sbhd", cu_seqlens: Union[torch.Tensor, None] = None, + begins: Union[torch.Tensor, None] = None, ) -> torch.Tensor: + if begins is None: + begins = torch.Tensor() if tensor_format == "sbhd": - output = tex.fused_rope_forward(t, freqs, False) + output = tex.fused_rope_forward(t, freqs, begins, False) elif tensor_format == "bshd": output = tex.fused_rope_forward( - t.transpose(0, 1), freqs, True + t.transpose(0, 1), freqs, begins, True ).transpose(0, 1) elif tensor_format == "thd": - output = tex.fused_rope_thd_forward(t, cu_seqlens, freqs) + output = tex.fused_rope_thd_forward(t, cu_seqlens, freqs, begins) else: raise ValueError(f"Unsupported tensor_format: {tensor_format}.") - ctx.save_for_backward(freqs, cu_seqlens) + ctx.save_for_backward(freqs, cu_seqlens, begins) ctx.tensor_format = tensor_format return output @@ -1490,15 +1470,15 @@ def forward( def backward( ctx, grad_output: torch.Tensor ) -> Tuple[Union[torch.Tensor, None], ...]: - freqs, cu_seqlens = ctx.saved_tensors + freqs, cu_seqlens, begins = ctx.saved_tensors if ctx.tensor_format == "sbhd": - grad_input = tex.fused_rope_backward(grad_output, freqs, False) + grad_input = tex.fused_rope_backward(grad_output, freqs, begins, False) elif ctx.tensor_format == "bshd": grad_input = tex.fused_rope_backward( grad_output.transpose(0, 1), freqs, True ).transpose(0, 1) elif ctx.tensor_format == "thd": - grad_input = tex.fused_rope_thd_backward(grad_output, cu_seqlens, freqs) + grad_input = tex.fused_rope_thd_backward(grad_output, cu_seqlens, begins, freqs) else: raise ValueError(f"Unsupported tensor_format: {ctx.tensor_format}.") @@ -1520,6 +1500,7 @@ def apply_rotary_pos_emb( tensor_format: str = "sbhd", fused: bool = False, cu_seqlens: Union[torch.Tensor, None] = None, + begins: Union[torch.Tensor, None] = None, ) -> torch.Tensor: """ Apply rotary positional embedding tensor to the input tensor. @@ -1540,12 +1521,17 @@ def apply_rotary_pos_emb( cu_seqlens: torch.Tensor, default = None. Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and dtype torch.int32. Only valid when `tensor_format` is 'thd'. + begins: torch.Tensor, default = None. + We may not want begin all the sequences from the 0 embedding. This tensor argument allows that. """ + assert not (begins is not None and not fused), \ + """begins != None and fused=False is not supported""" + if fused: assert ( tensor_format != "thd" or cu_seqlens is not None ), "cu_seqlens must not be None when tensor_format is 'thd'." - return FusedRoPEFunc.apply(t, freqs, tensor_format, cu_seqlens) + return FusedRoPEFunc.apply(t, freqs, tensor_format, cu_seqlens, begins) assert tensor_format in ("sbhd", "bshd"), ( "Only formats `sbhd` or `bshd` are supported for input tensor `t` " @@ -2265,19 +2251,88 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): @staticmethod def forward(ctx, is_training, max_seqlen, cu_seqlens, - seq_offsets_q, seq_offsets_k, seq_offsets_v, + seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, qkv, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, - rng_gen, fused_attention_backend, use_FAv2_bwd): - out, aux_ctx_tensors = fused_attn_fwd_qkvpacked( - is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype, - fused_attention_backend, - seq_offsets_q, seq_offsets_k, seq_offsets_v, - attn_bias, None, None, None, None, None, - attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, - rng_gen) - - ctx.save_for_backward(qkv, out, cu_seqlens, seq_offsets_q, seq_offsets_k, seq_offsets_v) + rng_gen, fused_attention_backend, use_FAv2_bwd, + fp8, fp8_meta): + if fp8: + if _NVTE_DEBUG: + print('[DotProductAttention]: using FP8 forward') + if fp8_meta["recipe"].fp8_mha: + assert (isinstance(qkv, Float8Tensor)), "qkv must be Float8Tensors for FP8 MHA." + fp8_meta["scaling_fwd"].scale_inv[META_QKV] = qkv._scale_inv + fused_attention_backend = FusedAttnBackend["FP8"] + fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + # 1: qkv packed, 2: kv packed, 3: qkv separate + qkv_group = len(qkv_layout.split('_')) + assert (qkv_group == 1 + ), f"qkv layout should conform to 3hd or h3d, e.g. sb3hd, \ + but found {qkv_layout}." + if fp8_meta["recipe"].fp8_mha: + qkv_fp8 = qkv._data + else: + qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) + qkv_fp8 = cast_to_fp8(qkv_c, + fp8_meta["scaling_fwd"], + META_QKV, fp8_dtype_forward).view(qkv.shape) + out_fp8, aux_ctx_tensors = fused_attn_fwd_qkvpacked( + is_training, max_seqlen, cu_seqlens, + qkv_fp8, fp8_dtype_forward, fused_attention_backend, attn_bias, + seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, + fp8_meta["scaling_fwd"].scale_inv[META_QKV], + fp8_meta["scaling_fwd"].scale_inv[META_S], + fp8_meta["scaling_fwd"].scale[META_S], + fp8_meta["scaling_fwd"].scale[META_O], + fp8_meta["scaling_fwd"].amax_history[0][META_S], + fp8_meta["scaling_fwd"].amax_history[0][META_O], + attn_scale, dropout_p, fast_zero_fill, qkv_layout, + attn_bias_type, attn_mask_type, rng_gen) + if fp8_meta["recipe"].fp8_mha: + out_ret = Float8Tensor(data=out_fp8, + fp8_meta=fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=META_O, + fp8_dtype=fp8_dtype_forward, + dtype=qkv.dtype, + ) + else: + out_ret = cast_from_fp8( + out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), + fp8_meta["scaling_fwd"], META_O, + fp8_dtype_forward, qkv_dtype).view(out_fp8.shape) + out_save = out_ret + if fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) + qkv = cast_from_fp8(qkv_c._data, + fp8_meta["scaling_fwd"], + META_QKV, fp8_dtype_forward, TE_DType[qkv.dtype]).view(qkv.shape) + out_save = cast_from_fp8( + out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), + fp8_meta["scaling_fwd"], META_O, + fp8_dtype_forward, qkv_dtype).view(out_fp8.shape) + fp8_tensors = (qkv_fp8, out_fp8, + fp8_meta["scaling_fwd"].scale.clone(), + fp8_meta["scaling_fwd"].scale_inv.clone()) + else: + if _NVTE_DEBUG: + print('[DotProductAttention]: using non-FP8 forward') + out_ret, aux_ctx_tensors = fused_attn_fwd_qkvpacked( + is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype, + fused_attention_backend, attn_bias, + seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, + None, None, None, None, None, None, + attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, + rng_gen) + fp8_tensors = (None, None, None, None) + out_save = out_ret + + ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + qkvo_tensors = (qkv, out_save) if not ctx.fp8 else (None, None) + ctx.save_for_backward(*qkvo_tensors, cu_seqlens, + seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, + *fp8_tensors, *aux_ctx_tensors) + ctx.fp8_meta = fp8_meta ctx.aux_ctx_tensors = aux_ctx_tensors ctx.max_seqlen = max_seqlen ctx.qkv_dtype = qkv_dtype @@ -2302,7 +2357,10 @@ def backward(ctx, d_out): d_out = d_out._data d_out = d_out.contiguous() - qkv, out, cu_seqlens, seq_offsets_q, seq_offsets_k, seq_offsets_v = ctx.saved_tensors + (qkv, out, cu_seqlens, + seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, + qkv_fp8, out_fp8, + fwd_scales, fwd_scale_invs) = ctx.saved_tensors if not ctx.aux_ctx_tensors[0].is_contiguous(): ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous() if ctx.use_FAv2_bwd: @@ -2319,22 +2377,75 @@ def backward(ctx, d_out): ) dqkv = dqkv[..., :d_out.shape[-1]] else: - dqkv, *rest = fused_attn_bwd_qkvpacked( - ctx.max_seqlen, cu_seqlens, qkv, out, d_out, - ctx.qkv_dtype, ctx.aux_ctx_tensors, - ctx.fused_attention_backend, - seq_offsets_q, seq_offsets_k, seq_offsets_v, - None, None, None, None, None, None, None, None, None, - ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, - ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type) + with torch.cuda.nvtx.range("_FusedAttn_qkvpacked"): + if ctx.fp8: + if _NVTE_DEBUG: + print('[DotProductAttention]: using FP8 backward') + fp8_dtype_forward = get_fp8_te_dtype( + ctx.fp8_meta["recipe"], fprop_tensor=True) + fp8_dtype_backward = get_fp8_te_dtype( + ctx.fp8_meta["recipe"], fprop_tensor=False) + if ctx.fp8_meta["recipe"].fp8_mha: + d_out_fp8 = d_out + ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO] = d_out_f8tensor._scale_inv + else: + d_out_fp8 = cast_to_fp8( + d_out.view(-1, d_out.shape[-2] * d_out.shape[-1]), + ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward + ).view(d_out.shape) + dqkv_fp8, *rest = fused_attn_bwd_qkvpacked( + ctx.max_seqlen, cu_seqlens, + qkv_fp8, out_fp8, d_out_fp8, + fp8_dtype_forward, fp8_dtype_backward, ctx.aux_ctx_tensors, + ctx.fused_attention_backend, + seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, + fwd_scale_invs[META_QKV], # d_scale_qkv, + fwd_scale_invs[META_S], # d_scale_s, + fwd_scale_invs[META_O], # d_scale_o, + ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO], # d_scale_do + ctx.fp8_meta['scaling_bwd'].scale_inv[META_DP], # d_scale_dp + fwd_scales[META_S], # q_scale_s + ctx.fp8_meta['scaling_bwd'].scale[META_DP], # q_scale_dp + ctx.fp8_meta['scaling_bwd'].scale[META_DQKV], # q_scale_dqkv + ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DP], # amax_dp + ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DQKV], # amax_dqkv + ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, + ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type) + if ctx.fp8_meta["recipe"].fp8_mha: + dqkv = Float8Tensor(data=dqkv_fp8, + fp8_meta=ctx.fp8_meta, + fp8_meta_forward=False, + fp8_meta_index=META_DQKV, + fp8_dtype=fp8_dtype_backward, + dtype=d_out_f8tensor.dtype, + ) + else: + dqkv_c_fp8 = dqkv_fp8.view(-1, + dqkv_fp8.shape[-3] * dqkv_fp8.shape[-2] * dqkv_fp8.shape[-1]) + dqkv = cast_from_fp8(dqkv_c_fp8, + ctx.fp8_meta["scaling_bwd"], META_DQKV, + fp8_dtype_backward, ctx.qkv_dtype).view(dqkv_fp8.shape) + else: + if _NVTE_DEBUG: + print('[DotProductAttention]: using non-FP8 backward') + if d_out.dtype == torch.uint8: + d_out = d_out_f8tensor.from_float8(qkv.dtype) + dqkv, *rest = fused_attn_bwd_qkvpacked( + ctx.max_seqlen, cu_seqlens, qkv, out, d_out, + ctx.qkv_dtype, ctx.qkv_dtype, ctx.aux_ctx_tensors, + ctx.fused_attention_backend, + seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, + None, None, None, None, None, None, None, None, None, None, + ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, + ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type) # if no_bias or alibi, return dqkv if ctx.attn_bias_type in ["no_bias", "alibi"]: - return (None, None, None, None, None, None, dqkv, None, None, None, + return (None, None, None, None, None, None,None, None, None, None, dqkv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None) # else, return (dqkv, dbias) - return (None, None, None, None, None, None, dqkv, None, rest[0], None, + return (None, None, None, None, None, None, None, None, None, None, dqkv, None, rest[0], None, None, None, None, None, None, None, None, None, None, None, None, None) @@ -2344,20 +2455,94 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): @staticmethod def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, - seq_offsets_q, seq_offsets_k, seq_offsets_v, + seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, q, kv, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill, - qkv_layout, attn_bias_type, attn_mask_type, - rng_gen, fused_attention_backend, use_FAv2_bwd): - out, aux_ctx_tensors = fused_attn_fwd_kvpacked( - is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, - q, kv, qkv_dtype, fused_attention_backend, - seq_offsets_q, seq_offsets_k, seq_offsets_v, - attn_bias, None, None, None, None, None, - attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, - rng_gen) - - ctx.save_for_backward(q, kv, out, cu_seqlens_q, cu_seqlens_kv, - seq_offsets_q, seq_offsets_k, seq_offsets_v) + qkv_layout, attn_bias_type, attn_mask_type, rng_gen, fused_attention_backend, + use_FAv2_bwd, fp8, fp8_meta): + if fp8: + if _NVTE_DEBUG: + print('[DotProductAttention]: using FP8 forward') + if fp8_meta["recipe"].fp8_mha: + assert (isinstance(q, Float8Tensor) + and isinstance(kv, Float8Tensor)), "q/kv must be Float8Tensors for FP8 MHA." + fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv + fused_attention_backend = FusedAttnBackend["FP8"] + fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + if fp8_meta["recipe"].fp8_mha: + q_fp8, kv_fp8 = q._data, kv._data + else: + # 1: qkv packed, 2: kv packed, 3: qkv separate + qkv_group = len(qkv_layout.split('_')) + assert (qkv_group == 2 + ), f"qkv layout should conform to hd_2hd or hd_h2d, e.g. sbhd_sb2hd, \ + but found {qkv_layout}." + q_fp8 = cast_to_fp8(q, + fp8_meta["scaling_fwd"], + META_QKV, fp8_dtype_forward).view(q.shape) + kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) + kv_fp8 = cast_to_fp8(kv_c, + fp8_meta["scaling_fwd"], + META_QKV, fp8_dtype_forward).view(kv.shape) + out_fp8, aux_ctx_tensors = fused_attn_fwd_kvpacked( + is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, + q_fp8, kv_fp8, fp8_dtype_forward, fused_attention_backend, attn_bias, + seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, + fp8_meta["scaling_fwd"].scale_inv[META_QKV], + fp8_meta["scaling_fwd"].scale_inv[META_S], + fp8_meta["scaling_fwd"].scale[META_S], + fp8_meta["scaling_fwd"].scale[META_O], + fp8_meta["scaling_fwd"].amax_history[0][META_S], + fp8_meta["scaling_fwd"].amax_history[0][META_O], + attn_scale, dropout_p, fast_zero_fill, qkv_layout, + attn_bias_type, attn_mask_type, rng_gen) + if fp8_meta["recipe"].fp8_mha: + out_ret = Float8Tensor(data=out_fp8, + fp8_meta=fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=META_O, + fp8_dtype=fp8_dtype_forward, + dtype=q.dtype, + ) + else: + out_ret = cast_from_fp8( + out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), + fp8_meta["scaling_fwd"], META_O, + fp8_dtype_forward, qkv_dtype).view(out_fp8.shape) + out_save = out_ret + if fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + q = cast_from_fp8(q._data, + fp8_meta["scaling_fwd"], + META_QKV, fp8_dtype_forward, TE_DType[q.dtype]).view(q.shape) + kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) + kv = cast_from_fp8(kv_c._data, + fp8_meta["scaling_fwd"], + META_QKV, fp8_dtype_forward, TE_DType[kv.dtype]).view(kv.shape) + out_save = cast_from_fp8( + out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), + fp8_meta["scaling_fwd"], META_O, + fp8_dtype_forward, qkv_dtype).view(out_fp8.shape) + fp8_tensors = (q_fp8, kv_fp8, out_fp8, + fp8_meta["scaling_fwd"].scale.clone(), + fp8_meta["scaling_fwd"].scale_inv.clone()) + else: + if _NVTE_DEBUG: + print('[DotProductAttention]: using non-FP8 forward') + out_ret, aux_ctx_tensors = fused_attn_fwd_kvpacked( + is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, + q, kv, qkv_dtype, fused_attention_backend, attn_bias, + seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, + None, None, None, None, None, None, + attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, + rng_gen) + out_save = out_ret + fp8_tensors = (None, None, None, None, None) + + ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + qkvo_tensors = (q, kv, out_save) if not ctx.fp8 else (None, None, None) + ctx.save_for_backward(*qkvo_tensors, cu_seqlens_q, cu_seqlens_kv, + seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, + *fp8_tensors, *aux_ctx_tensors) + ctx.fp8_meta = fp8_meta ctx.aux_ctx_tensors = aux_ctx_tensors ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_kv = max_seqlen_kv @@ -2384,7 +2569,9 @@ def backward(ctx, d_out): d_out = d_out.contiguous() (q, kv, out, cu_seqlens_q, cu_seqlens_kv, - seq_offsets_q, seq_offsets_k, seq_offsets_v) = ctx.saved_tensors + seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, + q_fp8, kv_fp8, out_fp8, + fwd_scales, fwd_scale_invs, *aux_ctx_tensors) = ctx.saved_tensors if not ctx.aux_ctx_tensors[0].is_contiguous(): ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous() if ctx.use_FAv2_bwd: @@ -2403,23 +2590,87 @@ def backward(ctx, d_out): dq = dq[..., :d_out.shape[-1]] dkv = dkv[..., :d_out.shape[-1]] else: - dq, dkv, *rest = fused_attn_bwd_kvpacked( - ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, - q, kv, out, d_out, - ctx.qkv_dtype, ctx.aux_ctx_tensors, - ctx.fused_attention_backend, - seq_offsets_q, seq_offsets_k, seq_offsets_v, - None, None, None, None, None, None, None, None, None, - ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, - ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type) + with torch.cuda.nvtx.range("_FusedAttn_kvpacked"): + if ctx.fp8: + if _NVTE_DEBUG: + print('[DotProductAttention]: using FP8 backward') + fp8_dtype_forward = get_fp8_te_dtype( + ctx.fp8_meta["recipe"], fprop_tensor=True) + fp8_dtype_backward = get_fp8_te_dtype( + ctx.fp8_meta["recipe"], fprop_tensor=False) + if ctx.fp8_meta["recipe"].fp8_mha: + d_out_fp8 = d_out + ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO] = d_out_f8tensor._scale_inv + else: + d_out_fp8 = cast_to_fp8( + d_out.view(-1, d_out.shape[-2] * d_out.shape[-1]), + ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward + ).view(d_out.shape) + dq_fp8, dkv_fp8, *rest = fused_attn_bwd_kvpacked( + ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, + q_fp8, kv_fp8, out_fp8, d_out_fp8, + fp8_dtype_forward, fp8_dtype_backward, ctx.aux_ctx_tensors, + ctx.fused_attention_backend, + seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, + fwd_scale_invs[META_QKV], # d_scale_qkv, + fwd_scale_invs[META_S], # d_scale_s, + fwd_scale_invs[META_O], # d_scale_o, + ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO], # d_scale_do + ctx.fp8_meta['scaling_bwd'].scale_inv[META_DP], # d_scale_dp + fwd_scales[META_S], # q_scale_s + ctx.fp8_meta['scaling_bwd'].scale[META_DP], # q_scale_dp + ctx.fp8_meta['scaling_bwd'].scale[META_DQKV], # q_scale_dqkv + ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DP], # amax_dp + ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DQKV], # amax_dqkv + ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, + ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type) + if ctx.fp8_meta["recipe"].fp8_mha: + dq = Float8Tensor(data=dq_fp8, + fp8_meta=ctx.fp8_meta, + fp8_meta_forward=False, + fp8_meta_index=META_DQKV, + fp8_dtype=fp8_dtype_backward, + dtype=d_out_f8tensor.dtype, + ) + dkv = Float8Tensor(data=dkv_fp8, + fp8_meta=ctx.fp8_meta, + fp8_meta_forward=False, + fp8_meta_index=META_DQKV, + fp8_dtype=fp8_dtype_backward, + dtype=d_out_f8tensor.dtype, + ) + else: + dq = cast_from_fp8( + dq_fp8.view(-1, dq_fp8.shape[-2] * dq_fp8.shape[-1]), + ctx.fp8_meta["scaling_bwd"], META_DQKV, + fp8_dtype_backward, ctx.qkv_dtype).view(dq_fp8.shape) + dkv_c_fp8 = dkv_fp8.view(-1, + dkv_fp8.shape[-3] * dkv_fp8.shape[-2] * dkv_fp8.shape[-1]) + dkv = cast_from_fp8(dkv_c_fp8, + ctx.fp8_meta["scaling_bwd"], META_DQKV, + fp8_dtype_backward, ctx.qkv_dtype).view(dkv_fp8.shape) + else: + if _NVTE_DEBUG: + print('[DotProductAttention]: using non-FP8 backward') + if d_out.dtype == torch.uint8: + d_out = d_out_f8tensor.from_float8(q.dtype) + dq, dkv, *rest = fused_attn_bwd_kvpacked( + ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, + q, kv, out, d_out, + ctx.qkv_dtype, ctx.qkv_dtype, ctx.aux_ctx_tensors, + ctx.fused_attention_backend, + seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, + None, None, None, None, None, None, None, None, None, None, + ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, + ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type) # if no_bias or alibi, return dqkv if ctx.attn_bias_type in ["no_bias", "alibi"]: - return (None, None, None, None, None, None, None, None, dq, dkv, None, None, None, + return (None, None, None, None, None, None, None, None, None, None, None, None, dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None) # else, return (dqkv, dbias) - return (None, None, None, None, None, None, None, None, dq, dkv, None, rest[0], None, + return (None, None, None, None, None, None, None, None, None, None, None, None, dq, dkv, None, rest[0], None, None, None, None, None, None, None, None, None, None, None, None, None) @@ -2428,7 +2679,7 @@ class FusedAttnFunc(torch.autograd.Function): @staticmethod def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, - seq_offsets_q, seq_offsets_k, seq_offsets_v, + seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, q, k, v, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, rng_gen, fused_attention_backend, use_FAv2_bwd, fp8, fp8_meta): @@ -2481,6 +2732,7 @@ def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seql out_fp8, aux_ctx_tensors = fused_attn_fwd( is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, q_fp8, k_fp8, v_fp8, fp8_dtype_forward, fused_attention_backend, attn_bias, + seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, fp8_meta["scaling_fwd"].scale_inv[META_QKV], fp8_meta["scaling_fwd"].scale_inv[META_S], fp8_meta["scaling_fwd"].scale[META_S], @@ -2551,8 +2803,9 @@ def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seql print('[DotProductAttention]: using non-FP8 forward') out_ret, aux_ctx_tensors = fused_attn_fwd( is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, - q, k, v, qkv_dtype, fused_attention_backend, - seq_offsets_q, seq_offsets_k, seq_offsets_v, attn_bias, None, None, None, None, None, None, + q, k, v, qkv_dtype, fused_attention_backend, attn_bias, + seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, + None, None, None, None, None, None, attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, rng_gen) out_save = out_ret @@ -2566,9 +2819,12 @@ def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seql if tensor is not None: tensor.activation_offloading = True - - ctx.save_for_backward(q, k, v, out, - cu_seqlens_q, cu_seqlens_kv, seq_offsets_q, seq_offsets_k, seq_offsets_v) + ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None) + ctx.save_for_backward(*qkvo_tensors, cu_seqlens_q, cu_seqlens_kv, + seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, + *fp8_tensors, *aux_ctx_tensors) + ctx.fp8_meta = fp8_meta ctx.aux_ctx_tensors = aux_ctx_tensors ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_kv = max_seqlen_kv @@ -2595,7 +2851,9 @@ def backward(ctx, d_out): d_out = d_out.contiguous() (q, k, v, out, cu_seqlens_q, cu_seqlens_kv, - seq_offsets_q, seq_offsets_k, seq_offsets_v) = ctx.saved_tensors + seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, + q_fp8, k_fp8, v_fp8, out_fp8, + fwd_scales, fwd_scale_invs, *aux_ctx_tensors) = ctx.saved_tensors if not ctx.aux_ctx_tensors[0].is_contiguous(): ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous() if ctx.use_FAv2_bwd: @@ -2616,23 +2874,124 @@ def backward(ctx, d_out): dk = dk[..., :d_out.shape[-1]] dv = dv[..., :d_out.shape[-1]] else: - dq, dk, dv, *rest = fused_attn_bwd( - ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, - q, k, v, out, d_out, - ctx.qkv_dtype, ctx.aux_ctx_tensors, - ctx.fused_attention_backend, - seq_offsets_q, seq_offsets_k, seq_offsets_v, - None, None, None, None, None, None, None, None, None, - ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, - ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type) + with torch.cuda.nvtx.range("_FusedAttn"): + if ctx.fp8: + if _NVTE_DEBUG: + print('[DotProductAttention]: using FP8 backward') + fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) + fp8_dtype_backward = get_fp8_te_dtype( + ctx.fp8_meta["recipe"], fprop_tensor=False) + if ctx.fp8_meta["recipe"].fp8_mha: + d_out_fp8 = d_out + ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO] = d_out_f8tensor._scale_inv + else: + d_out_fp8 = cast_to_fp8( + d_out.view(-1, d_out.shape[-2] * d_out.shape[-1]), + ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward + ).view(d_out.shape) + dq_fp8, dk_fp8, dv_fp8, *rest = fused_attn_bwd( + ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, + q_fp8, k_fp8, v_fp8, out_fp8, d_out_fp8, + fp8_dtype_forward, fp8_dtype_backward, ctx.aux_ctx_tensors, + ctx.fused_attention_backend, + seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, + fwd_scale_invs[META_QKV], # d_scale_qkv, + fwd_scale_invs[META_S], # d_scale_s, + fwd_scale_invs[META_O], # d_scale_o, + ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO], # d_scale_do + ctx.fp8_meta['scaling_bwd'].scale_inv[META_DP], # d_scale_dp + fwd_scales[META_S], # q_scale_s + ctx.fp8_meta['scaling_bwd'].scale[META_DP], # q_scale_dp + ctx.fp8_meta['scaling_bwd'].scale[META_DQKV], # q_scale_dqkv + ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DP], # amax_dp + ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DQKV], # amax_dqkv + ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, + ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type) + if ctx.fp8_meta["recipe"].fp8_mha: + dq = Float8Tensor(data=dq_fp8, + fp8_meta=ctx.fp8_meta, + fp8_meta_forward=False, + fp8_meta_index=META_DQKV, + fp8_dtype=fp8_dtype_backward, + dtype=d_out_f8tensor.dtype, + ) + dk = Float8Tensor(data=dk_fp8, + fp8_meta=ctx.fp8_meta, + fp8_meta_forward=False, + fp8_meta_index=META_DQKV, + fp8_dtype=fp8_dtype_backward, + dtype=d_out_f8tensor.dtype, + ) + dv = Float8Tensor(data=dv_fp8, + fp8_meta=ctx.fp8_meta, + fp8_meta_forward=False, + fp8_meta_index=META_DQKV, + fp8_dtype=fp8_dtype_backward, + dtype=d_out_f8tensor.dtype, + ) + else: + qkv_group = len(ctx.qkv_layout.split('_')) + if qkv_group == 1: + dim = ctx.qkv_layout.find('3') + dqkv_fp8 = _combine_tensors([dq_fp8,dk_fp8,dv_fp8], dim) + dqkv_c_fp8 = dqkv_fp8.view(-1, + dqkv_fp8.shape[-3] * dqkv_fp8.shape[-2] * dqkv_fp8.shape[-1]) + dqkv = cast_from_fp8(dqkv_c_fp8, + ctx.fp8_meta["scaling_bwd"], META_DQKV, + fp8_dtype_backward, ctx.qkv_dtype).view(dqkv_fp8.shape) + dq, dk, dv = _SplitAlongDim.apply(dqkv, dim, [1,1,1]) + dq, dk, dv = [x.squeeze(dim) for x in [dq, dk, dv]] + if qkv_group == 2: + dq = cast_from_fp8( + dq_fp8.view(-1, dq_fp8.shape[-2] * dq_fp8.shape[-1]), + ctx.fp8_meta["scaling_bwd"], META_DQKV, + fp8_dtype_backward, ctx.qkv_dtype).view(dq_fp8.shape) + dim = ctx.qkv_layout.split('_')[1].find('2') + dkv_fp8 = _combine_tensors([dk_fp8,dv_fp8], dim) + dkv_c_fp8 = dkv_fp8.view(-1, + dkv_fp8.shape[-3] * dkv_fp8.shape[-2] * dkv_fp8.shape[-1]) + dkv = cast_from_fp8(dkv_c_fp8, + ctx.fp8_meta["scaling_bwd"], META_DQKV, + fp8_dtype_backward, ctx.qkv_dtype).view(dkv_fp8.shape) + dk, dv = _SplitAlongDim.apply(dkv, dim, [1,1]) + dk, dv = [x.squeeze(dim) for x in [dk, dv]] + if qkv_group == 3: + dq = cast_from_fp8( + dq_fp8.view(-1, dq_fp8.shape[-2] * dq_fp8.shape[-1]), + ctx.fp8_meta["scaling_bwd"], META_DQKV, + fp8_dtype_backward, ctx.qkv_dtype).view(dq_fp8.shape) + dk = cast_from_fp8( + dk_fp8.view(-1, dk_fp8.shape[-2] * dk_fp8.shape[-1]), + ctx.fp8_meta["scaling_bwd"], META_DQKV, + fp8_dtype_backward, ctx.qkv_dtype).view(dk_fp8.shape) + dv = cast_from_fp8( + dv_fp8.view(-1, dv_fp8.shape[-2] * dv_fp8.shape[-1]), + ctx.fp8_meta["scaling_bwd"], META_DQKV, + fp8_dtype_backward, ctx.qkv_dtype).view(dv_fp8.shape) + else: + if _NVTE_DEBUG: + print('[DotProductAttention]: using non-FP8 backward') + if d_out.dtype == torch.uint8: + d_out = d_out_f8tensor.from_float8(q.dtype) + dq, dk, dv, *rest = fused_attn_bwd( + ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, + q, k, v, out, d_out, + ctx.qkv_dtype, ctx.qkv_dtype, ctx.aux_ctx_tensors, + ctx.fused_attention_backend, + seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, + None, None, None, None, None, None, None, None, None, None, + ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, + ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type) # if no_bias or alibi, return dqkv if ctx.attn_bias_type in ["no_bias", "alibi"]: - return (None, None, None, None, None, None, None, None, dq, dk, dv, None, None, None, + return (None, None, None, None, None, None, + None, None, None, dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None) # else, return (dqkv, dbias) - return (None, None, None, None, None, None, None, None, dq, dk, dv, None, rest[0], None, + return (None, None, None, None, None, None, + None, None, None, dq, dk, dv, None, rest[0], None, None, None, None, None, None, None, None, None, None, None, None, None) @@ -2728,6 +3087,7 @@ def forward( seq_offsets_q: Optional[torch.Tensor] = None, seq_offsets_k: Optional[torch.Tensor] = None, seq_offsets_v: Optional[torch.Tensor] = None, + seq_offsets_o: Optional[torch.Tensor] = None, max_seqlen_q: Optional[int] = None, max_seqlen_kv: Optional[int] = None, attn_mask_type: str = "causal", @@ -2803,22 +3163,26 @@ def forward( and cu_seqlens_q is not None and cu_seqlens_kv is not None ), "max_seqlen_q/kv and cu_seqlens_q/kv can not be None when qkv_format is thd!" - if (seq_offsets_q is None or seq_offsets_k is None or seq_offsets_v is None): + if (seq_offsets_q is None + or seq_offsets_k is None + or seq_offsets_v is None + or seq_offsets_o is None): qkv_group = ''.join([x for x in qkv_layout if x not in 'bst']) num_heads = query_layer.shape[-2] num_gqa_groups = key_layer.shape[-2] head_dim = query_layer.shape[-1] + seq_offsets_o = num_heads * head_dim * cu_seqlens_q if qkv_group == 'hd_hd_hd': seq_offsets_q = num_heads * head_dim * cu_seqlens_q seq_offsets_k = num_gqa_groups * head_dim * cu_seqlens_kv seq_offsets_v = num_gqa_groups * head_dim * cu_seqlens_kv if qkv_group in ['3hd', 'h3d']: - seq_offsets_q = num_heads * head_dim * cu_seqlens_q - seq_offsets_k = num_heads * head_dim * 2 * cu_seqlens_q + seq_offsets_q = num_heads * head_dim * 3 * cu_seqlens_q + seq_offsets_k = num_heads * head_dim * 3 * cu_seqlens_q seq_offsets_v = num_heads * head_dim * 3 * cu_seqlens_q if qkv_group in ['hd_2hd', 'hd_h2d']: seq_offsets_q = num_heads * head_dim * cu_seqlens_q - seq_offsets_k = num_gqa_groups * head_dim * cu_seqlens_kv + seq_offsets_k = num_gqa_groups * head_dim * 2 * cu_seqlens_kv seq_offsets_v = num_gqa_groups * head_dim * 2 * cu_seqlens_kv qkv_dtype = TE_DType[query_layer.dtype] @@ -2874,7 +3238,7 @@ def forward( self.training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, - seq_offsets_q, seq_offsets_k, seq_offsets_v, + seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, query_layer, key_layer, value_layer, qkv_dtype, core_attention_bias, @@ -3165,6 +3529,7 @@ def forward( seq_offsets_q: Optional[torch.Tensor] = None, seq_offsets_k: Optional[torch.Tensor] = None, seq_offsets_v: Optional[torch.Tensor] = None, + seq_offsets_o: Optional[torch.Tensor] = None, max_seqlen_q: Optional[int] = None, max_seqlen_kv: Optional[int] = None, attn_mask_type: Optional[str] = None, @@ -3242,15 +3607,18 @@ def forward( cu_seqlens_kv: Optional[torch.Tensor], default = `None` Cumulative sum of sequence lengths in a batch for `key_layer` and `value_layer`, with shape [batch_size + 1] and dtype torch.int32. - seqlen_offsets_q: Optional[torch.Tensor], default = `None` + seq_offsets_q: Optional[torch.Tensor], default = `None` Cumulative offset of different sequences in a batch for `query_layer`, with shape [batch_size + 1] and dtype torch.int32. Required for `thd` layouts. - seqlen_offsets_k: Optional[torch.Tensor], default = `None` + seq_offsets_k: Optional[torch.Tensor], default = `None` Cumulative offset of different sequences in a batch for `key_layer`, with shape [batch_size + 1] and dtype torch.int32. Required for `thd` layouts. - seqlen_offsets_v: Optional[torch.Tensor], default = `None` + seq_offsets_v: Optional[torch.Tensor], default = `None` Cumulative offset of different sequences in a batch for `value_layer`, with shape [batch_size + 1] and dtype torch.int32. Required for `thd` layouts. + seq_offsets_o: Optional[torch.Tensor], default = `None` + Cumulative offset of different sequences in a batch for forward output, + with shape [batch_size + 1] and dtype torch.int32. Required for `thd` layouts. max_seqlen_q: Optional[int], default = `None` Maximum sequence length in `query_layer`. Calculated from `cu_seqlens_q` if not provided. @@ -3371,11 +3739,13 @@ def forward( seq_offsets_q = self.alloc(batch_size + 1, dtype=torch.int32, device="cuda") seq_offsets_k = self.alloc(batch_size + 1, dtype=torch.int32, device="cuda") seq_offsets_v = self.alloc(batch_size + 1, dtype=torch.int32, device="cuda") + seq_offsets_o = self.alloc(batch_size + 1, dtype=torch.int32, device="cuda") cu_seqlens_q[1:].copy_(torch.cumsum(inference_params.incoming_seq_len, dim=0)) cu_seqlens_kv[1:].copy_(torch.cumsum(inference_params.seq_len + inference_params.incoming_seq_len, dim=0)) seq_offsets_q.copy_(torch.arange(0, batch_size + 1, dtype=torch.int32, device="cuda") * self.channels * max_seqlen_q) + seq_offsets_o.copy_(seq_offsets_q) seq_offsets_k.copy_(torch.arange(0, batch_size + 1, dtype=torch.int32, device="cuda") * self.channels * max_seqlen_kv) seq_offsets_v.copy_(seq_offsets_k) @@ -3669,6 +4039,7 @@ def forward( seq_offsets_q=seq_offsets_q, seq_offsets_k=seq_offsets_k, seq_offsets_v=seq_offsets_v, + seq_offsets_o=seq_offsets_o, attn_mask_type=attn_mask_type, attention_mask=attention_mask, fused_attention_backend=fused_attention_backend, @@ -3689,6 +4060,7 @@ def forward( seq_offsets_q=seq_offsets_q, seq_offsets_k=seq_offsets_k, seq_offsets_v=seq_offsets_v, + seq_offsets_o=seq_offsets_o, attn_mask_type=attn_mask_type, attention_mask=attention_mask, fused_attention_backend=fused_attention_backend, @@ -4416,6 +4788,8 @@ def forward( # duplicate the pos_emb for self attention if not isinstance(rotary_pos_emb, tuple): rotary_pos_emb = ((rotary_pos_emb,) * 2) + + q_pos_emb, k_pos_emb = rotary_pos_emb if self.qkv_format == "thd" and inference_params is not None: # For thd attention incoming tokens can be on different positions, @@ -4428,24 +4802,10 @@ def forward( # and key_layer [1, :] corresponds to the token with position 6 = 5 + 1. key_layer = key_layer.contiguous() query_layer = query_layer.contiguous() - batch_size, hidden_dim = query_layer.shape[0], query_layer.shape[-1] - - q_pos_emb = self.alloc((batch_size, inference_params.max_incoming_seq_len, 1, hidden_dim), torch.float32, "cuda") - k_pos_emb = self.alloc((batch_size, inference_params.max_incoming_seq_len, 1, hidden_dim), torch.float32, "cuda") - q_freq, k_freq = rotary_pos_emb - - # inference_params object is aware of the positions of incoming tokens. - inference_params.pick_freqs(q_freq, q_pos_emb) - inference_params.pick_freqs(k_freq, k_pos_emb) - - # We need to apply different positional encoding for each element of the batch. - for i in range(batch_size): - key_layer[i,].copy_(apply_rotary_pos_emb(key_layer[i,:].unsqueeze(0), k_pos_emb[i,:].unsqueeze(1), "bshd", fused=True)[0,:]) - query_layer[i,:].copy_(apply_rotary_pos_emb(query_layer[i,:].unsqueeze(0), q_pos_emb[i,:].unsqueeze(1), "bshd", fused=True)[0,:]) + key_layer.copy_(apply_rotary_pos_emb(key_layer, k_pos_emb, "bshd", fused=True, begins=inference_params.seq_len)) + query_layer.copy_(apply_rotary_pos_emb(query_layer, q_pos_emb, "bshd", fused=True, begins=inference_params.seq_len)) else: - q_pos_emb, k_pos_emb = rotary_pos_emb - # adjust key and value for inference if inference_params is not None: if self.qkv_format == "sbhd": diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 12ef702d9a..4a8aea13da 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -82,10 +82,11 @@ def fused_attn_fwd_qkvpacked( qkv: torch.Tensor, qkv_dtype: tex.DType, fused_attention_backend: tex.NVTE_Fused_Attn_Backend, + attn_bias: torch.Tensor = None, seq_offsets_q: torch.Tensor = None, seq_offsets_k: torch.Tensor = None, seq_offsets_v: torch.Tensor = None, - attn_bias: torch.Tensor = None, + seq_offsets_o: torch.Tensor = None, d_scale_qkv: torch.Tensor = None, d_scale_s: torch.Tensor = None, q_scale_s: torch.Tensor = None, @@ -118,15 +119,17 @@ def fused_attn_fwd_qkvpacked( data type of QKV; in tex.DType, not torch.dtype fused_attention_backend: tex.NVTE_Fused_Attn_Backend please see FusedAttention module for details on supported backends. + attn_bias: torch.Tensor, default = None + input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias"; + shape [1, num_heads, max_seqlen, max_seqlen], same data type as qkv seq_offsets_q: torch.Tensor, default = None cumulative sequence offsets for Q; shape [batch_size + 1] seq_offsets_k: torch.Tensor, default = None cumulative sequence offsets for K; shape [batch_size + 1] seq_offsets_v: torch.Tensor, default = None cumulative sequence offsets for V; shape [batch_size + 1] - attn_bias: torch.Tensor, default = None - input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias"; - shape [1, num_heads, max_seqlen, max_seqlen], same data type as qkv + seq_offsets_o: torch.Tensor, default = None + cumulative sequence offsets for O; shape [batch_size + 1] d_scale_qkv: torch.Tensor, default = None input tensor for the dequantization of QKV in FP8 computations d_scale_s: torch.Tensor, default = None @@ -234,8 +237,8 @@ def fused_attn_fwd_qkvpacked( max_seqlen, is_training, attn_scale, dropout, fast_zero_fill, QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], cu_seqlens, qkv, qkv_dtype, - seq_offsets_q, seq_offsets_k, seq_offsets_v, - d_scale_qkv, q_scale_s, q_scale_o, amax_s, amax_o, attn_bias, + seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, + d_scale_qkv, d_scale_s, q_scale_s, q_scale_o, amax_s, amax_o, attn_bias, rng_gen, rng_elts_per_thread, ) @@ -256,6 +259,7 @@ def fused_attn_bwd_qkvpacked( seq_offsets_q: torch.Tensor = None, seq_offsets_k: torch.Tensor = None, seq_offsets_v: torch.Tensor = None, + seq_offsets_o: torch.Tensor = None, d_scale_qkv: torch.Tensor = None, d_scale_s: torch.Tensor = None, d_scale_o: torch.Tensor = None, @@ -305,6 +309,8 @@ def fused_attn_bwd_qkvpacked( cumulative sequence offsets for K; shape [batch_size + 1] seq_offsets_v: torch.Tensor, default = None cumulative sequence offsets for V; shape [batch_size + 1] + seq_offsets_o: torch.Tensor, default = None + cumulative sequence offsets for O; shape [batch_size + 1] d_scale_qkv: torch.Tensor, default = None input tensor for the dequantization of QKV in FP8 computations d_scale_s: torch.Tensor, default = None @@ -379,9 +385,9 @@ def fused_attn_bwd_qkvpacked( output_tensors = tex.fused_attn_bwd_qkvpacked( max_seqlen, attn_scale, dropout, fast_zero_fill, QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], - cu_seqlens, qkv, o, d_o, qkv_dtype, aux_ctx_tensors, - seq_offsets_q, seq_offsets_k, seq_offsets_v, - d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, + cu_seqlens, qkv, o, d_o, qkv_dtype, dqkv_dtype, aux_ctx_tensors, + seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, + d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, d_scale_dp, q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv, ) @@ -398,10 +404,11 @@ def fused_attn_fwd_kvpacked( kv: torch.Tensor, qkv_dtype: tex.DType, fused_attention_backend: tex.NVTE_Fused_Attn_Backend, + attn_bias: torch.Tensor = None, seq_offsets_q: torch.Tensor = None, seq_offsets_k: torch.Tensor = None, seq_offsets_v: torch.Tensor = None, - attn_bias: torch.Tensor = None, + seq_offsets_o: torch.Tensor = None, d_scale_qkv: torch.Tensor = None, d_scale_s: torch.Tensor = None, q_scale_s: torch.Tensor = None, @@ -441,15 +448,17 @@ def fused_attn_fwd_kvpacked( data type of Q and KV; in tex.DType, not torch.dtype fused_attention_backend: tex.NVTE_Fused_Attn_Backend please see FusedAttention module for details on supported backends. + attn_bias: torch.Tensor, default = None + input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias"; + shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q and kv seq_offsets_q: torch.Tensor, default = None cumulative sequence offsets for Q; shape [batch_size + 1] seq_offsets_k: torch.Tensor, default = None cumulative sequence offsets for K; shape [batch_size + 1] seq_offsets_v: torch.Tensor, default = None cumulative sequence offsets for V; shape [batch_size + 1] - attn_bias: torch.Tensor, default = None - input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias"; - shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q and kv + seq_offsets_o: torch.Tensor, default = None + cumulative sequence offsets for O; shape [batch_size + 1] d_scale_qkv: torch.Tensor, default = None input tensor for the dequantization of QKV in FP8 computations d_scale_s: torch.Tensor, default = None @@ -558,8 +567,8 @@ def fused_attn_fwd_kvpacked( max_seqlen_q, max_seqlen_kv, is_training, attn_scale, dropout, fast_zero_fill, QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], cu_seqlens_q, cu_seqlens_kv, q, kv, qkv_dtype, - seq_offsets_q, seq_offsets_k, seq_offsets_v, - d_scale_qkv, q_scale_s, q_scale_o, amax_s, amax_o, + seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, + d_scale_qkv, d_scale_s, q_scale_s, q_scale_o, amax_s, amax_o, attn_bias, rng_gen, rng_elts_per_thread, ) @@ -583,6 +592,7 @@ def fused_attn_bwd_kvpacked( seq_offsets_q: torch.Tensor = None, seq_offsets_k: torch.Tensor = None, seq_offsets_v: torch.Tensor = None, + seq_offsets_o: torch.Tensor = None, d_scale_qkv: torch.Tensor = None, d_scale_s: torch.Tensor = None, d_scale_o: torch.Tensor = None, @@ -639,6 +649,8 @@ def fused_attn_bwd_kvpacked( cumulative sequence offsets for K; shape [batch_size + 1] seq_offsets_v: torch.Tensor, default = None cumulative sequence offsets for V; shape [batch_size + 1] + seq_offsets_o: torch.Tensor, default = None + cumulative sequence offsets for O; shape [batch_size + 1] d_scale_qkv: torch.Tensor, default = None input tensor for the dequantization of QKV in FP8 computations d_scale_s: torch.Tensor, default = None @@ -717,9 +729,9 @@ def fused_attn_bwd_kvpacked( output_tensors = tex.fused_attn_bwd_kvpacked( max_seqlen_q, max_seqlen_kv, attn_scale, dropout, fast_zero_fill, QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], - cu_seqlens_q, cu_seqlens_kv, q, kv, o, d_o, qkv_dtype, aux_ctx_tensors, - seq_offsets_q, seq_offsets_k, seq_offsets_v, - d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, + cu_seqlens_q, cu_seqlens_kv, q, kv, o, d_o, qkv_dtype, dqkv_dtype, aux_ctx_tensors, + seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, + d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, d_scale_dp, q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv, ) @@ -737,10 +749,11 @@ def fused_attn_fwd( v: torch.Tensor, qkv_dtype: tex.DType, fused_attention_backend: tex.NVTE_Fused_Attn_Backend, + attn_bias: torch.Tensor = None, seq_offsets_q: torch.Tensor = None, seq_offsets_k: torch.Tensor = None, seq_offsets_v: torch.Tensor = None, - attn_bias: torch.Tensor = None, + seq_offsets_o: torch.Tensor = None, d_scale_qkv: torch.Tensor = None, d_scale_s: torch.Tensor = None, q_scale_s: torch.Tensor = None, @@ -784,15 +797,17 @@ def fused_attn_fwd( data type of Q, K and V; in tex.DType, not torch.dtype fused_attention_backend: tex.NVTE_Fused_Attn_Backend please see FusedAttention module for details on supported backends. + attn_bias: torch.Tensor, default = None + input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias"; + shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q, k and v seq_offsets_q: torch.Tensor, default = None cumulative sequence offsets for Q; shape [batch_size + 1] seq_offsets_k: torch.Tensor, default = None cumulative sequence offsets for K; shape [batch_size + 1] seq_offsets_v: torch.Tensor, default = None cumulative sequence offsets for V; shape [batch_size + 1] - attn_bias: torch.Tensor, default = None - input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias"; - shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q, k and v + seq_offsets_o: torch.Tensor, default = None + cumulative sequence offsets for O; shape [batch_size + 1] d_scale_qkv: torch.Tensor, default = None input tensor for the dequantization of Q, K and V in FP8 computations d_scale_s: torch.Tensor, default = None @@ -889,9 +904,8 @@ def fused_attn_fwd( output_tensors = tex.fused_attn_fwd( max_seqlen_q, max_seqlen_kv, is_training, attn_scale, dropout, fast_zero_fill, QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], - cu_seqlens_q, cu_seqlens_kv, - q, k, v, qkv_dtype, - seq_offsets_q, seq_offsets_k, seq_offsets_v, + cu_seqlens_q, cu_seqlens_kv, q, k, v, qkv_dtype, + seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, d_scale_qkv, d_scale_s, q_scale_s, q_scale_o, amax_s, amax_o, attn_bias, rng_gen, rng_elts_per_thread, ) @@ -917,6 +931,7 @@ def fused_attn_bwd( seq_offsets_q: torch.Tensor = None, seq_offsets_k: torch.Tensor = None, seq_offsets_v: torch.Tensor = None, + seq_offsets_o: torch.Tensor = None, d_scale_qkv: torch.Tensor = None, d_scale_s: torch.Tensor = None, d_scale_o: torch.Tensor = None, @@ -976,6 +991,8 @@ def fused_attn_bwd( cumulative sequence offsets for K; shape [batch_size + 1] seq_offsets_v: torch.Tensor, default = None cumulative sequence offsets for V; shape [batch_size + 1] + seq_offsets_o: torch.Tensor, default = None + cumulative sequence offsets for O; shape [batch_size + 1] d_scale_qkv: torch.Tensor, default = None input tensor for the dequantization of Q, K and V in FP8 computations d_scale_s: torch.Tensor, default = None @@ -1041,9 +1058,6 @@ def fused_attn_bwd( ), "aux_ctx_tensors must contain rng_state as its last element." if fused_attention_backend == FusedAttnBackend["FP8"]: - print("rr") - print(d_scale_qkv) - exit() assert (d_scale_qkv is not None), "d_scale_qkv is required for FP8 fused attention." assert (d_scale_s is not None), "d_scale_s is required for FP8 fused attention." assert (d_scale_o is not None), "d_scale_o is required for FP8 fused attention." @@ -1061,11 +1075,10 @@ def fused_attn_bwd( output_tensors = tex.fused_attn_bwd( max_seqlen_q, max_seqlen_kv, attn_scale, dropout, fast_zero_fill, QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], - cu_seqlens_q, cu_seqlens_kv, - q, k, v, o, d_o, qkv_dtype, aux_ctx_tensors, - seq_offsets_q, seq_offsets_k, seq_offsets_v, - d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, + cu_seqlens_q, cu_seqlens_kv, q, k, v, o, d_o, qkv_dtype, dqkv_dtype, aux_ctx_tensors, + seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, + d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, d_scale_dp, q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv, ) - return output_tensors + return output_tensors \ No newline at end of file diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 31ef53106c..66826ace4b 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -34,6 +34,7 @@ std::vector fused_attn_fwd_qkvpacked( const c10::optional seq_offsets_q, const c10::optional seq_offsets_k, const c10::optional seq_offsets_v, + const c10::optional seq_offsets_o, const c10::optional descale_QKV, const c10::optional descale_S, const c10::optional scale_S, @@ -60,6 +61,7 @@ std::vector fused_attn_bwd_qkvpacked( const c10::optional seq_offsets_q, const c10::optional seq_offsets_k, const c10::optional seq_offsets_v, + const c10::optional seq_offsets_o, const c10::optional descale_QKV, const c10::optional descale_S, const c10::optional descale_O, @@ -85,6 +87,7 @@ std::vector fused_attn_fwd_kvpacked( const c10::optional seq_offsets_q, const c10::optional seq_offsets_k, const c10::optional seq_offsets_v, + const c10::optional seq_offsets_o, const c10::optional descale_QKV, const c10::optional descale_S, const c10::optional scale_S, @@ -113,6 +116,7 @@ std::vector fused_attn_bwd_kvpacked( const c10::optional seq_offsets_q, const c10::optional seq_offsets_k, const c10::optional seq_offsets_v, + const c10::optional seq_offsets_o, const c10::optional descale_QKV, const c10::optional descale_S, const c10::optional descale_O, @@ -139,6 +143,7 @@ std::vector fused_attn_fwd( const c10::optional seq_offsets_q, const c10::optional seq_offsets_k, const c10::optional seq_offsets_v, + const c10::optional seq_offsets_o, const c10::optional descale_QKV, const c10::optional descale_S, const c10::optional scale_S, @@ -168,6 +173,7 @@ std::vector fused_attn_bwd( const c10::optional seq_offsets_q, const c10::optional seq_offsets_k, const c10::optional seq_offsets_v, + const c10::optional seq_offsets_o, const c10::optional descale_QKV, const c10::optional descale_S, const c10::optional descale_O, @@ -183,7 +189,6 @@ at::Tensor fa_prepare_fwd(at::Tensor qkvi); at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v); void attention_copy(torch::Tensor A, torch::Tensor seq_len, torch::Tensor incoming_seq_len, torch::Tensor B, int max_incoming_seq_len, int max_seq_len, int b, int s); -void get_values(torch::Tensor A, torch::Tensor seq_len, torch::Tensor incoming_seq_len, torch::Tensor B, int max_incoming_seq_len, int b, int d); /*************************************************************************************************** * GEMM @@ -641,12 +646,14 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_seqlens, - const at::Tensor &freqs + const at::Tensor &freqs, + const at::Tensor &begins ); at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Tensor &cu_seqlens, - const at::Tensor &freqs + const at::Tensor &freqs, + const at::Tensor &begins ); /*************************************************************************************************** diff --git a/transformer_engine/pytorch/csrc/extensions/apply_rope.cu b/transformer_engine/pytorch/csrc/extensions/apply_rope.cu index f54597ff1d..3f2791a0d8 100644 --- a/transformer_engine/pytorch/csrc/extensions/apply_rope.cu +++ b/transformer_engine/pytorch/csrc/extensions/apply_rope.cu @@ -127,7 +127,8 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_seqlens, - const at::Tensor &freqs) { + const at::Tensor &freqs, + const at::Tensor &begins) { using namespace transformer_engine; TORCH_CHECK(input.dim() == 3, "expected 3D tensor"); TORCH_CHECK(cu_seqlens.dim() == 1, "expected 1D tensor"); @@ -169,10 +170,12 @@ at::Tensor fused_rope_thd_forward(const at::Tensor &input, auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens); auto freqs_cu = makeTransformerEngineTensor(freqs); auto output_cu = makeTransformerEngineTensor(output); + auto begins_cu = makeTransformerEngineTensor(begins); nvte_fused_rope_thd_forward( input_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), output_cu.data(), - max_s, b, h, d, d2, stride_t, stride_h, stride_d, o_stride_t, o_stride_h, + begins_cu.data(), max_s, b, h, d, d2, + stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, at::cuda::getCurrentCUDAStream()); return output; @@ -180,7 +183,8 @@ at::Tensor fused_rope_thd_forward(const at::Tensor &input, at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Tensor &cu_seqlens, - const at::Tensor &freqs) { + const at::Tensor &freqs, + const at::Tensor &begins) { using namespace transformer_engine; TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(cu_seqlens.dim() == 1, "expected 1D tensor"); @@ -220,9 +224,10 @@ at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens); auto freqs_cu = makeTransformerEngineTensor(freqs); auto input_grads_cu = makeTransformerEngineTensor(input_grads); + auto begins_cu = makeTransformerEngineTensor(begins); nvte_fused_rope_thd_backward( - output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), + output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), begins_cu.data(), input_grads_cu.data(), max_s, b, h, d, d2, stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, at::cuda::getCurrentCUDAStream()); diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index 9be4fd3d35..0e39070475 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -99,6 +99,7 @@ std::vector fused_attn_fwd_qkvpacked( const c10::optional seq_offsets_q, const c10::optional seq_offsets_k, const c10::optional seq_offsets_v, + const c10::optional seq_offsets_o, const c10::optional descale_QKV, const c10::optional descale_S, const c10::optional scale_S, @@ -126,7 +127,7 @@ std::vector fused_attn_fwd_qkvpacked( // construct NVTE tensors TensorWrapper te_QKV, te_S, te_O, te_Bias, te_cu_seqlens; - TensorWrapper te_seq_offsets_q, te_seq_offsets_k, te_seq_offsets_v; + TensorWrapper te_seq_offsets_q, te_seq_offsets_k, te_seq_offsets_v, te_seq_offsets_o; if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { // FP8 auto h = q_shape[q_shape.size() - 2]; @@ -173,7 +174,10 @@ std::vector fused_attn_fwd_qkvpacked( te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens.data_ptr(), cu_seqlens_shape, DType::kInt32, nullptr, nullptr, nullptr); - if ((seq_offsets_q.has_value()) && (seq_offsets_k.has_value()) && (seq_offsets_v.has_value())) { + if ((seq_offsets_q.has_value()) + && (seq_offsets_k.has_value()) + && (seq_offsets_v.has_value()) + && (seq_offsets_o.has_value())) { auto seq_offsets_q_sizes = seq_offsets_q.value().sizes().vec(); std::vector seq_offsets_q_shape{ seq_offsets_q_sizes.begin(), seq_offsets_q_sizes.end()}; @@ -183,12 +187,17 @@ std::vector fused_attn_fwd_qkvpacked( auto seq_offsets_v_sizes = seq_offsets_v.value().sizes().vec(); std::vector seq_offsets_v_shape{ seq_offsets_v_sizes.begin(), seq_offsets_v_sizes.end()}; + auto seq_offsets_o_sizes = seq_offsets_o.value().sizes().vec(); + std::vector seq_offsets_o_shape{ + seq_offsets_o_sizes.begin(), seq_offsets_o_sizes.end()}; te_seq_offsets_q = makeTransformerEngineTensor(seq_offsets_q.value().data_ptr(), seq_offsets_q_shape, DType::kInt32, nullptr, nullptr, nullptr); te_seq_offsets_k = makeTransformerEngineTensor(seq_offsets_k.value().data_ptr(), seq_offsets_k_shape, DType::kInt32, nullptr, nullptr, nullptr); te_seq_offsets_v = makeTransformerEngineTensor(seq_offsets_v.value().data_ptr(), seq_offsets_v_shape, DType::kInt32, nullptr, nullptr, nullptr); + te_seq_offsets_o = makeTransformerEngineTensor(seq_offsets_o.value().data_ptr(), + seq_offsets_o_shape, DType::kInt32, nullptr, nullptr, nullptr); } // extract random number generator seed and offset @@ -218,6 +227,7 @@ std::vector fused_attn_fwd_qkvpacked( te_seq_offsets_q.data(), te_seq_offsets_k.data(), te_seq_offsets_v.data(), + te_seq_offsets_o.data(), te_rng_state.data(), max_seqlen, is_training, attn_scale, p_dropout, @@ -269,6 +279,7 @@ std::vector fused_attn_fwd_qkvpacked( te_seq_offsets_q.data(), te_seq_offsets_k.data(), te_seq_offsets_v.data(), + te_seq_offsets_o.data(), te_rng_state.data(), max_seqlen, is_training, attn_scale, p_dropout, @@ -297,6 +308,7 @@ std::vector fused_attn_bwd_qkvpacked( const c10::optional seq_offsets_q, const c10::optional seq_offsets_k, const c10::optional seq_offsets_v, + const c10::optional seq_offsets_o, const c10::optional descale_QKV, const c10::optional descale_S, const c10::optional descale_O, @@ -411,8 +423,11 @@ std::vector fused_attn_bwd_qkvpacked( TensorWrapper te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens.data_ptr(), cu_seqlens_shape, DType::kInt32, nullptr, nullptr, nullptr); - TensorWrapper te_seq_offsets_q, te_seq_offsets_k, te_seq_offsets_v; - if ((seq_offsets_q.has_value()) && (seq_offsets_k.has_value()) && (seq_offsets_v.has_value())) { + TensorWrapper te_seq_offsets_q, te_seq_offsets_k, te_seq_offsets_v, te_seq_offsets_o; + if ((seq_offsets_q.has_value()) + && (seq_offsets_k.has_value()) + && (seq_offsets_v.has_value()) + && (seq_offsets_o.has_value())) { auto seq_offsets_q_sizes = seq_offsets_q.value().sizes().vec(); std::vector seq_offsets_q_shape{ seq_offsets_q_sizes.begin(), seq_offsets_q_sizes.end()}; @@ -422,12 +437,17 @@ std::vector fused_attn_bwd_qkvpacked( auto seq_offsets_v_sizes = seq_offsets_v.value().sizes().vec(); std::vector seq_offsets_v_shape{ seq_offsets_v_sizes.begin(), seq_offsets_v_sizes.end()}; + auto seq_offsets_o_sizes = seq_offsets_o.value().sizes().vec(); + std::vector seq_offsets_o_shape{ + seq_offsets_o_sizes.begin(), seq_offsets_o_sizes.end()}; te_seq_offsets_q = makeTransformerEngineTensor(seq_offsets_q.value().data_ptr(), seq_offsets_q_shape, DType::kInt32, nullptr, nullptr, nullptr); te_seq_offsets_k = makeTransformerEngineTensor(seq_offsets_k.value().data_ptr(), seq_offsets_k_shape, DType::kInt32, nullptr, nullptr, nullptr); te_seq_offsets_v = makeTransformerEngineTensor(seq_offsets_v.value().data_ptr(), seq_offsets_v_shape, DType::kInt32, nullptr, nullptr, nullptr); + te_seq_offsets_o = makeTransformerEngineTensor(seq_offsets_o.value().data_ptr(), + seq_offsets_o_shape, DType::kInt32, nullptr, nullptr, nullptr); } // create workspace @@ -447,6 +467,7 @@ std::vector fused_attn_bwd_qkvpacked( te_seq_offsets_q.data(), te_seq_offsets_k.data(), te_seq_offsets_v.data(), + te_seq_offsets_o.data(), max_seqlen, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, @@ -473,6 +494,7 @@ std::vector fused_attn_bwd_qkvpacked( te_seq_offsets_q.data(), te_seq_offsets_k.data(), te_seq_offsets_v.data(), + te_seq_offsets_o.data(), max_seqlen, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, @@ -498,6 +520,7 @@ std::vector fused_attn_fwd_kvpacked( const c10::optional seq_offsets_q, const c10::optional seq_offsets_k, const c10::optional seq_offsets_v, + const c10::optional seq_offsets_o, const c10::optional descale_QKV, const c10::optional descale_S, const c10::optional scale_S, @@ -521,7 +544,7 @@ std::vector fused_attn_fwd_kvpacked( // construct NVTE tensors TensorWrapper te_Q, te_KV, te_S, te_O, te_Bias, te_cu_seqlens_q, te_cu_seqlens_kv; - TensorWrapper te_seq_offsets_q, te_seq_offsets_k, te_seq_offsets_v; + TensorWrapper te_seq_offsets_q, te_seq_offsets_k, te_seq_offsets_v, te_seq_offsets_o; if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { // FP8 auto h = q_shape[q_shape.size() - 2]; @@ -576,7 +599,10 @@ std::vector fused_attn_fwd_kvpacked( te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape, DType::kInt32, nullptr, nullptr, nullptr); - if ((seq_offsets_q.has_value()) && (seq_offsets_k.has_value()) && (seq_offsets_v.has_value())) { + if ((seq_offsets_q.has_value()) + && (seq_offsets_k.has_value()) + && (seq_offsets_v.has_value()) + && (seq_offsets_o.has_value())) { auto seq_offsets_q_sizes = seq_offsets_q.value().sizes().vec(); std::vector seq_offsets_q_shape{ seq_offsets_q_sizes.begin(), seq_offsets_q_sizes.end()}; @@ -586,12 +612,17 @@ std::vector fused_attn_fwd_kvpacked( auto seq_offsets_v_sizes = seq_offsets_v.value().sizes().vec(); std::vector seq_offsets_v_shape{ seq_offsets_v_sizes.begin(), seq_offsets_v_sizes.end()}; + auto seq_offsets_o_sizes = seq_offsets_o.value().sizes().vec(); + std::vector seq_offsets_o_shape{ + seq_offsets_o_sizes.begin(), seq_offsets_o_sizes.end()}; te_seq_offsets_q = makeTransformerEngineTensor(seq_offsets_q.value().data_ptr(), seq_offsets_q_shape, DType::kInt32, nullptr, nullptr, nullptr); te_seq_offsets_k = makeTransformerEngineTensor(seq_offsets_k.value().data_ptr(), seq_offsets_k_shape, DType::kInt32, nullptr, nullptr, nullptr); te_seq_offsets_v = makeTransformerEngineTensor(seq_offsets_v.value().data_ptr(), seq_offsets_v_shape, DType::kInt32, nullptr, nullptr, nullptr); + te_seq_offsets_o = makeTransformerEngineTensor(seq_offsets_o.value().data_ptr(), + seq_offsets_o_shape, DType::kInt32, nullptr, nullptr, nullptr); } // extract rng seed and offset @@ -623,6 +654,7 @@ std::vector fused_attn_fwd_kvpacked( te_seq_offsets_q.data(), te_seq_offsets_k.data(), te_seq_offsets_v.data(), + te_seq_offsets_o.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout, @@ -676,6 +708,7 @@ std::vector fused_attn_fwd_kvpacked( te_seq_offsets_q.data(), te_seq_offsets_k.data(), te_seq_offsets_v.data(), + te_seq_offsets_o.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout, @@ -707,6 +740,7 @@ std::vector fused_attn_bwd_kvpacked( const c10::optional seq_offsets_q, const c10::optional seq_offsets_k, const c10::optional seq_offsets_v, + const c10::optional seq_offsets_o, const c10::optional descale_QKV, const c10::optional descale_S, const c10::optional descale_O, @@ -812,8 +846,11 @@ std::vector fused_attn_bwd_kvpacked( te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape, DType::kInt32, nullptr, nullptr, nullptr); - TensorWrapper te_seq_offsets_q, te_seq_offsets_k, te_seq_offsets_v; - if ((seq_offsets_q.has_value()) && (seq_offsets_k.has_value()) && (seq_offsets_v.has_value())) { + TensorWrapper te_seq_offsets_q, te_seq_offsets_k, te_seq_offsets_v, te_seq_offsets_o; + if ((seq_offsets_q.has_value()) + && (seq_offsets_k.has_value()) + && (seq_offsets_v.has_value()) + && (seq_offsets_o.has_value())) { auto seq_offsets_q_sizes = seq_offsets_q.value().sizes().vec(); std::vector seq_offsets_q_shape{ seq_offsets_q_sizes.begin(), seq_offsets_q_sizes.end()}; @@ -823,12 +860,17 @@ std::vector fused_attn_bwd_kvpacked( auto seq_offsets_v_sizes = seq_offsets_v.value().sizes().vec(); std::vector seq_offsets_v_shape{ seq_offsets_v_sizes.begin(), seq_offsets_v_sizes.end()}; + auto seq_offsets_o_sizes = seq_offsets_o.value().sizes().vec(); + std::vector seq_offsets_o_shape{ + seq_offsets_o_sizes.begin(), seq_offsets_o_sizes.end()}; te_seq_offsets_q = makeTransformerEngineTensor(seq_offsets_q.value().data_ptr(), seq_offsets_q_shape, DType::kInt32, nullptr, nullptr, nullptr); te_seq_offsets_k = makeTransformerEngineTensor(seq_offsets_k.value().data_ptr(), seq_offsets_k_shape, DType::kInt32, nullptr, nullptr, nullptr); te_seq_offsets_v = makeTransformerEngineTensor(seq_offsets_v.value().data_ptr(), seq_offsets_v_shape, DType::kInt32, nullptr, nullptr, nullptr); + te_seq_offsets_o = makeTransformerEngineTensor(seq_offsets_o.value().data_ptr(), + seq_offsets_o_shape, DType::kInt32, nullptr, nullptr, nullptr); } // convert auxiliary tensors from forward to NVTETensors @@ -880,6 +922,7 @@ std::vector fused_attn_bwd_kvpacked( te_seq_offsets_q.data(), te_seq_offsets_k.data(), te_seq_offsets_v.data(), + te_seq_offsets_o.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, @@ -909,6 +952,7 @@ std::vector fused_attn_bwd_kvpacked( te_seq_offsets_q.data(), te_seq_offsets_k.data(), te_seq_offsets_v.data(), + te_seq_offsets_o.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, @@ -935,6 +979,7 @@ std::vector fused_attn_fwd( const c10::optional seq_offsets_q, const c10::optional seq_offsets_k, const c10::optional seq_offsets_v, + const c10::optional seq_offsets_o, const c10::optional descale_QKV, const c10::optional descale_S, const c10::optional scale_S, @@ -959,7 +1004,7 @@ std::vector fused_attn_fwd( // construct NVTE tensors TensorWrapper te_Q, te_K, te_V, te_S, te_O, te_Bias; TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv; - TensorWrapper te_seq_offsets_q, te_seq_offsets_k, te_seq_offsets_v; + TensorWrapper te_seq_offsets_q, te_seq_offsets_k, te_seq_offsets_v, te_seq_offsets_o; if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { // FP8 auto h = q_shape[q_shape.size() - 2]; @@ -1018,7 +1063,10 @@ std::vector fused_attn_fwd( te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape, DType::kInt32, nullptr, nullptr, nullptr); - if ((seq_offsets_q.has_value()) && (seq_offsets_k.has_value()) && (seq_offsets_v.has_value())) { + if ((seq_offsets_q.has_value()) + && (seq_offsets_k.has_value()) + && (seq_offsets_v.has_value()) + && (seq_offsets_o.has_value())) { auto seq_offsets_q_sizes = seq_offsets_q.value().sizes().vec(); std::vector seq_offsets_q_shape{ seq_offsets_q_sizes.begin(), seq_offsets_q_sizes.end()}; @@ -1028,12 +1076,17 @@ std::vector fused_attn_fwd( auto seq_offsets_v_sizes = seq_offsets_v.value().sizes().vec(); std::vector seq_offsets_v_shape{ seq_offsets_v_sizes.begin(), seq_offsets_v_sizes.end()}; + auto seq_offsets_o_sizes = seq_offsets_o.value().sizes().vec(); + std::vector seq_offsets_o_shape{ + seq_offsets_o_sizes.begin(), seq_offsets_o_sizes.end()}; te_seq_offsets_q = makeTransformerEngineTensor(seq_offsets_q.value().data_ptr(), seq_offsets_q_shape, DType::kInt32, nullptr, nullptr, nullptr); te_seq_offsets_k = makeTransformerEngineTensor(seq_offsets_k.value().data_ptr(), seq_offsets_k_shape, DType::kInt32, nullptr, nullptr, nullptr); te_seq_offsets_v = makeTransformerEngineTensor(seq_offsets_v.value().data_ptr(), seq_offsets_v_shape, DType::kInt32, nullptr, nullptr, nullptr); + te_seq_offsets_o = makeTransformerEngineTensor(seq_offsets_o.value().data_ptr(), + seq_offsets_o_shape, DType::kInt32, nullptr, nullptr, nullptr); } // extract rng seed and offset @@ -1067,6 +1120,7 @@ std::vector fused_attn_fwd( te_seq_offsets_q.data(), te_seq_offsets_k.data(), te_seq_offsets_v.data(), + te_seq_offsets_o.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout, @@ -1121,6 +1175,7 @@ std::vector fused_attn_fwd( te_seq_offsets_q.data(), te_seq_offsets_k.data(), te_seq_offsets_v.data(), + te_seq_offsets_o.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout, @@ -1153,6 +1208,7 @@ std::vector fused_attn_bwd( const c10::optional seq_offsets_q, const c10::optional seq_offsets_k, const c10::optional seq_offsets_v, + const c10::optional seq_offsets_o, const c10::optional descale_QKV, const c10::optional descale_S, const c10::optional descale_O, @@ -1326,8 +1382,11 @@ std::vector fused_attn_bwd( te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape, DType::kInt32, nullptr, nullptr, nullptr); - TensorWrapper te_seq_offsets_q, te_seq_offsets_k, te_seq_offsets_v; - if ((seq_offsets_q.has_value()) && (seq_offsets_k.has_value()) && (seq_offsets_v.has_value())) { + TensorWrapper te_seq_offsets_q, te_seq_offsets_k, te_seq_offsets_v, te_seq_offsets_o; + if ((seq_offsets_q.has_value()) + && (seq_offsets_k.has_value()) + && (seq_offsets_v.has_value()) + && (seq_offsets_o.has_value())) { auto seq_offsets_q_sizes = seq_offsets_q.value().sizes().vec(); std::vector seq_offsets_q_shape{ seq_offsets_q_sizes.begin(), seq_offsets_q_sizes.end()}; @@ -1337,12 +1396,17 @@ std::vector fused_attn_bwd( auto seq_offsets_v_sizes = seq_offsets_v.value().sizes().vec(); std::vector seq_offsets_v_shape{ seq_offsets_v_sizes.begin(), seq_offsets_v_sizes.end()}; + auto seq_offsets_o_sizes = seq_offsets_o.value().sizes().vec(); + std::vector seq_offsets_o_shape{ + seq_offsets_o_sizes.begin(), seq_offsets_o_sizes.end()}; te_seq_offsets_q = makeTransformerEngineTensor(seq_offsets_q.value().data_ptr(), seq_offsets_q_shape, DType::kInt32, nullptr, nullptr, nullptr); te_seq_offsets_k = makeTransformerEngineTensor(seq_offsets_k.value().data_ptr(), seq_offsets_k_shape, DType::kInt32, nullptr, nullptr, nullptr); te_seq_offsets_v = makeTransformerEngineTensor(seq_offsets_v.value().data_ptr(), seq_offsets_v_shape, DType::kInt32, nullptr, nullptr, nullptr); + te_seq_offsets_o = makeTransformerEngineTensor(seq_offsets_o.value().data_ptr(), + seq_offsets_o_shape, DType::kInt32, nullptr, nullptr, nullptr); } // convert auxiliary tensors from forward to NVTETensors @@ -1396,6 +1460,7 @@ std::vector fused_attn_bwd( te_seq_offsets_q.data(), te_seq_offsets_k.data(), te_seq_offsets_v.data(), + te_seq_offsets_o.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, @@ -1427,6 +1492,7 @@ std::vector fused_attn_bwd( te_seq_offsets_q.data(), te_seq_offsets_k.data(), te_seq_offsets_v.data(), + te_seq_offsets_o.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, @@ -1609,6 +1675,625 @@ at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v) { return qkv; } +/*************************************************************************************************** + * Support THD format for Context Parallel: Binary search + **************************************************************************************************/ + +__forceinline__ +__device__ int binary_search(int target, int *array, int len) { + int left = 1, right = len - 1; + while (left < right) { + int mid = (left + right) / 2; + if (array[mid] <= target) { + left = mid + 1; + } else { + right = mid; + } + } + return left - 1; +} + +/*************************************************************************************************** + * Support THD format for Context Parallel: Read the half of a THD tensor + **************************************************************************************************/ + +__global__ void thd_read_half_tensor_kernel(void *half, + void *tensor, + int *cu_seqlens, + int batch, + int hidden_size_in_bytes, + int half_idx, + int dim_size_of_token) { + extern __shared__ int cu_seqlens_s[]; + for (int i = threadIdx.x; i <= batch; i += blockDim.x) { + cu_seqlens_s[i] = cu_seqlens[i] / 2; + } + __syncthreads(); + + int warpid = (blockIdx.x * blockDim.x + threadIdx.x) / 32; + int laneid = threadIdx.x % 32; + int num_warps = (blockDim.x * gridDim.x) / 32; + int num_total_tokens = cu_seqlens_s[batch]; + int num_float4s_per_token = hidden_size_in_bytes / sizeof(float4); + + size_t offset = static_cast(dim_size_of_token) * hidden_size_in_bytes; + half = reinterpret_cast(reinterpret_cast(half) + offset/2 * blockIdx.y); + tensor = reinterpret_cast(reinterpret_cast(tensor) + offset * blockIdx.y); + + for (int token_id = warpid; token_id < num_total_tokens; token_id += num_warps) { + int seqid = binary_search(token_id, cu_seqlens_s, batch + 1); + + size_t offset_in_bytes = static_cast(token_id) * hidden_size_in_bytes; + float4* cur_half_token = reinterpret_cast(reinterpret_cast(half) + \ + offset_in_bytes); + + offset_in_bytes = (static_cast(token_id) + cu_seqlens_s[seqid + half_idx]) * \ + hidden_size_in_bytes; + float4* cur_token = reinterpret_cast(reinterpret_cast(tensor) + \ + offset_in_bytes); + + for (int idx = laneid; idx < num_float4s_per_token; idx += 32) { + cur_half_token[idx] = cur_token[idx]; + } + } +} + +at::Tensor thd_read_half_tensor(const at::Tensor &tensor, + const at::Tensor &cu_seqlens, + int half_idx) { + NVTE_CHECK(tensor.dim() == 3 || tensor.dim() == 4); + NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); + NVTE_CHECK(cu_seqlens.dim() == 1); + NVTE_CHECK(cu_seqlens.size(0) >= 2); + + // Shapes of q and dq are [t, h, d], so the dimension of "t" is 0 + // Shapes of kv and dkv are [2, t, h, d], so the dimension of "t" is 1 + int seq_dim = tensor.dim() == 3 ? 0 : 1; + + int batch = cu_seqlens.size(0) - 1; + int num_heads = tensor.size(seq_dim + 1); + int dim_per_head = tensor.size(seq_dim + 2); + int hidden_size_in_bytes = num_heads * dim_per_head * c10::elementSize(tensor.scalar_type()); + + // For 128-bits load/store + NVTE_CHECK(hidden_size_in_bytes % 16 == 0); + + // Generate output + std::vector shape(tensor.dim()); + for (size_t i = 0; i < shape.size(); i++) { + shape[i] = tensor.size(i); + } + shape[seq_dim] /= 2; + at::Tensor half = at::empty(shape, at::CUDA(tensor.scalar_type())); + + // Launch Kernel + constexpr unsigned int block = 256; + unsigned int grid_x = (tensor.size(seq_dim) / 2 * 32 + block - 1) / block; + unsigned int grid_y = 1; + for (int i = 0; i < seq_dim; i++) { + grid_y *= tensor.size(i); + } + dim3 grid = {grid_x, grid_y}; + thd_read_half_tensor_kernel<<>>( + half.data_ptr(), + tensor.data_ptr(), + cu_seqlens.data_ptr(), + batch, + hidden_size_in_bytes, + half_idx, + tensor.size(seq_dim)); + + return half; +} + +/*************************************************************************************************** + * Support THD format for Context Parallel: softmax_lse related operations + **************************************************************************************************/ + +template +__global__ void thd_lse_kernel(lse_dtype *lse, float *half_lse, int *cu_seqlens, + int batch, int num_heads, int max_seqlen) { + extern __shared__ int cu_seqlens_s[]; + for (int i = threadIdx.x; i <= batch; i += blockDim.x) { + cu_seqlens_s[i] = cu_seqlens[i] / 2; + } + __syncthreads(); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int num_threads = blockDim.x * gridDim.x; + int num_total_tokens = cu_seqlens_s[batch]; + + for (int token_id = tid; token_id < num_total_tokens; token_id += num_threads) { + int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); + for (int head_id = blockIdx.y; head_id < num_heads; head_id += gridDim.y) { + size_t row = static_cast(seq_id) * num_heads + head_id; + int col = token_id - cu_seqlens_s[seq_id]; + int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; + + size_t idx = row * max_seqlen + col + seq_len; + size_t half_idx = row * max_seqlen / 2 + col; + + Functor::run(lse, half_lse, idx, half_idx); + } + } +} + +struct LseCorrectionFunctor { + __forceinline__ + __device__ static void run(double *lse, float *half_lse, size_t idx, size_t half_idx) { + double val = lse[idx]; + float val_per_step = half_lse[half_idx]; + double max_scale = max(val, val_per_step); + double min_scale = min(val, val_per_step); + lse[idx] = max_scale + log(1.0 + exp(min_scale - max_scale)); + } +}; + +void thd_second_half_lse_correction(at::Tensor lse, + const at::Tensor &lse_per_step, + const at::Tensor &cu_seqlens, + int total_tokens) { + NVTE_CHECK(lse.scalar_type() == at::ScalarType::Double); + NVTE_CHECK(lse_per_step.scalar_type() == at::ScalarType::Float); + NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); + + NVTE_CHECK(lse.dim() == 3); + NVTE_CHECK(lse_per_step.dim() == 3); + NVTE_CHECK(cu_seqlens.dim() == 1); + + int batch = lse.size(0); + int num_heads = lse.size(1); + int max_seqlen = lse.size(2); + + NVTE_CHECK(lse_per_step.size(0) == batch); + NVTE_CHECK(lse_per_step.size(1) == num_heads); + NVTE_CHECK(lse_per_step.size(2) == max_seqlen / 2); + NVTE_CHECK(cu_seqlens.size(0) == batch + 1); + + constexpr unsigned int block = 256; + unsigned int grid_x = (total_tokens / 2 + block - 1) / block; + unsigned int grid_y = num_heads; + dim3 grid = {grid_x, grid_y}; + thd_lse_kernel<<>>( + lse.data_ptr(), + lse_per_step.data_ptr(), + cu_seqlens.data_ptr(), + batch, + num_heads, + max_seqlen); +} + +struct ReadLseFunctor { + __forceinline__ + __device__ static void run(float *lse, float *half_lse, size_t idx, size_t half_idx) { + half_lse[half_idx] = lse[idx]; + } +}; + +at::Tensor thd_read_second_half_lse(const at::Tensor &lse, + const at::Tensor &cu_seqlens, + int total_tokens) { + NVTE_CHECK(lse.scalar_type() == at::ScalarType::Float); + NVTE_CHECK(lse.dim() == 3); + NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); + NVTE_CHECK(cu_seqlens.dim() == 1); + + int batch = lse.size(0); + int num_heads = lse.size(1); + int max_seqlen = lse.size(2); + + NVTE_CHECK(cu_seqlens.size(0) == batch + 1); + + std::vector shape = {batch, num_heads, max_seqlen / 2}; + at::Tensor half_lse = at::zeros(shape, at::CUDA(lse.scalar_type())); + + constexpr unsigned int block = 256; + unsigned int grid_x = (total_tokens / 2 + block - 1) / block; + unsigned int grid_y = num_heads; + dim3 grid = {grid_x, grid_y}; + thd_lse_kernel<<>>( + lse.data_ptr(), + half_lse.data_ptr(), + cu_seqlens.data_ptr(), + batch, + num_heads, + max_seqlen); + + return half_lse; +} + +/*************************************************************************************************** + * Support THD format for Context Parallel: Out correction in forward + **************************************************************************************************/ + +template +__global__ void thd_out_correction_kernel(dtype *out, + dtype *out_per_step, + float *lse, + float *lse_per_step, + int *cu_seqlens, + int batch, + int num_heads, + int dim_per_head, + int max_seqlen) { + extern __shared__ int cu_seqlens_s[]; + for (int i = threadIdx.x; i <= batch; i += blockDim.x) { + cu_seqlens_s[i] = cu_seqlens[i] / (only_second_half + 1); + } + __syncthreads(); + + int tile_id = (blockIdx.x * blockDim.x + threadIdx.x) / tile_size; + int lane_id = threadIdx.x % tile_size; + int num_tiles = (blockDim.x * gridDim.x) / tile_size; + int num_total_tokens = cu_seqlens_s[batch]; + int num_loops_per_head = dim_per_head * sizeof(dtype) / sizeof(float4); + + for (int token_id = tile_id; token_id < num_total_tokens; token_id += num_tiles) { + int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); + for (int head_id = blockIdx.y; head_id < num_heads; head_id += gridDim.y) { + size_t idx, idx_per_step; + + size_t row = static_cast(seq_id) * num_heads + head_id; + int col = token_id - cu_seqlens_s[seq_id]; + int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; + idx = row * max_seqlen + col + seq_len * only_second_half; + idx_per_step = row * max_seqlen / (only_second_half + 1) + col; + float lse_corrected_exp = exp(lse_per_step[idx_per_step] - lse[idx]); + + idx = token_id + cu_seqlens_s[seq_id + 1] * only_second_half; + idx = (idx * num_heads + head_id) * dim_per_head; + idx_per_step = (static_cast(token_id) * num_heads + head_id) * dim_per_head; + dtype *cur_out = out + idx; + dtype *cur_out_per_step = out_per_step + idx_per_step; + + for (int j = lane_id; j < num_loops_per_head; j += tile_size) { + float4 data_per_step = reinterpret_cast(cur_out_per_step)[j]; + float4 data = reinterpret_cast(cur_out)[j]; + dtype *p_per_step = reinterpret_cast(&data_per_step); + dtype *p = reinterpret_cast(&data); + for (int k = 0; k < sizeof(float4) / sizeof(dtype); k++) { + p[k] += p_per_step[k] * lse_corrected_exp; + } + reinterpret_cast(cur_out)[j] = data; + } + } + } +} + +template +static void thd_out_correction_helper(at::Tensor out, + const at::Tensor &out_per_step, + const at::Tensor &lse, + const at::Tensor &lse_per_step, + const at::Tensor &cu_seqlens) { + NVTE_CHECK(out.scalar_type() == out_per_step.scalar_type()); + NVTE_CHECK(lse.scalar_type() == at::ScalarType::Float); + NVTE_CHECK(lse_per_step.scalar_type() == at::ScalarType::Float); + NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); + + int total_tokens = out.size(0); + int num_heads = out.size(1); + int dim_per_head = out.size(2); + int batch = lse.size(0); + int max_seqlen = lse.size(2); + + NVTE_CHECK(out_per_step.size(0) == total_tokens / (only_second_half + 1)); + NVTE_CHECK(out_per_step.size(1) == num_heads); + NVTE_CHECK(out_per_step.size(2) == dim_per_head); + NVTE_CHECK(lse.size(1) == num_heads); + NVTE_CHECK(lse_per_step.size(0) == batch); + NVTE_CHECK(lse_per_step.size(1) == num_heads); + NVTE_CHECK(lse_per_step.size(2) == max_seqlen / (only_second_half + 1)); + NVTE_CHECK(cu_seqlens.size(0) == batch + 1); + + constexpr int tile = 16; + constexpr int block = 512; + unsigned int grid_x = (static_cast(total_tokens) / (only_second_half + 1) * \ + tile + block - 1) / block; + dim3 grid = {grid_x, (unsigned int)num_heads}; + + thd_out_correction_kernel<<>>( + out.data_ptr(), + out_per_step.data_ptr(), + lse.data_ptr(), + lse_per_step.data_ptr(), + cu_seqlens.data_ptr(), + batch, + num_heads, + dim_per_head, + max_seqlen); +} + +void thd_out_correction(at::Tensor out, + const at::Tensor &out_per_step, + const at::Tensor &lse, + const at::Tensor &lse_per_step, + const at::Tensor &cu_seqlens, + bool only_second_half) { + if (only_second_half) { + if (out.scalar_type() == at::ScalarType::Half) { + using dtype = at::Half; + thd_out_correction_helper(out, out_per_step, lse, lse_per_step, cu_seqlens); + } else if (out.scalar_type() == at::ScalarType::BFloat16) { + using dtype = at::BFloat16; + thd_out_correction_helper(out, out_per_step, lse, lse_per_step, cu_seqlens); + } else if (out.scalar_type() == at::ScalarType::Float) { + using dtype = float; + thd_out_correction_helper(out, out_per_step, lse, lse_per_step, cu_seqlens); + } else { + NVTE_ERROR("Unsupported dtype of out\n"); + } + } else { + if (out.scalar_type() == at::ScalarType::Half) { + using dtype = at::Half; + thd_out_correction_helper(out, out_per_step, lse, lse_per_step, cu_seqlens); + } else if (out.scalar_type() == at::ScalarType::BFloat16) { + using dtype = at::BFloat16; + thd_out_correction_helper(out, out_per_step, lse, lse_per_step, cu_seqlens); + } else if (out.scalar_type() == at::ScalarType::Float) { + using dtype = float; + thd_out_correction_helper(out, out_per_step, lse, lse_per_step, cu_seqlens); + } else { + NVTE_ERROR("Unsupported dtype of out\n"); + } + } +} + +/*************************************************************************************************** + * Support THD format for Context Parallel: Gradients correction in backward + **************************************************************************************************/ + +template +__global__ void thd_grad_correction_kernel(dtype *grad, + dtype *grad_per_step, + int *cu_seqlens, + int batch, + int hidden_size, + int dim_size_of_token) { + extern __shared__ int cu_seqlens_s[]; + for (int i = threadIdx.x; i <= batch; i += blockDim.x) { + if constexpr (functor_idx < 2) { + cu_seqlens_s[i] = cu_seqlens[i] / 2; + } else { + cu_seqlens_s[i] = cu_seqlens[i]; + } + } + __syncthreads(); + + int group_id = (blockIdx.x * blockDim.x + threadIdx.x) / group_size; + int lane_id = threadIdx.x % group_size; + int num_groups = (blockDim.x * gridDim.x) / group_size; + int num_total_tokens = cu_seqlens_s[batch]; + int num_inner_loops = hidden_size * sizeof(dtype) / sizeof(float4); + + size_t offset = static_cast(dim_size_of_token) * hidden_size; + if constexpr (functor_idx < 2) { + grad_per_step = grad_per_step + offset / 2 * blockIdx.y; + } else { + grad_per_step = grad_per_step + offset * blockIdx.y; + } + grad = grad + offset * blockIdx.y; + + for (int token_id = group_id; token_id < num_total_tokens; token_id += num_groups) { + int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); + + int token_offset; + bool is_first_half; + if constexpr (functor_idx < 2) { + token_offset = cu_seqlens_s[seq_id + functor_idx]; + is_first_half = (functor_idx == 0); + } else { + token_offset = 0; + int len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; + is_first_half = (token_id - cu_seqlens_s[seq_id]) < (len / 2); + } + + dtype *token = &grad[(token_id + token_offset) * static_cast(hidden_size)]; + dtype *token_per_step = &grad_per_step[token_id * static_cast(hidden_size)]; + for (int idx = lane_id; idx < num_inner_loops; idx += group_size) { + if (is_first_half) { + Functor_0::run(token, token_per_step, idx); + } else { + Functor_1::run(token, token_per_step, idx); + } + } + } +} + +struct EmptyFunctor { + __forceinline__ + __device__ static void run(void *token, void *token_per_step, int idx) {} +}; + +struct CopyFunctor { + __forceinline__ + __device__ static void run(void *token, void *token_per_step, int idx) { + reinterpret_cast(token)[idx] = reinterpret_cast(token_per_step)[idx]; + } +}; + +template +struct AddFunctor { + __forceinline__ + __device__ static void run(dtype *token, dtype *token_per_step, int idx) { + float4 d_ = reinterpret_cast(token)[idx]; + dtype *p_ = reinterpret_cast(&d_); + + float4 d = reinterpret_cast(token_per_step)[idx]; + dtype *p = reinterpret_cast(&d); + + #pragma unroll + for (int i = 0; i < sizeof(float4) / sizeof(dtype); i++) { + p_[i] += p[i]; + } + + reinterpret_cast(token)[idx] = d_; + } +}; + +template +static void thd_grad_correction_helper(at::Tensor grad, + const at::Tensor &grad_per_step, + const at::Tensor &cu_seqlens) { + NVTE_CHECK(grad.dim() == 3 || grad.dim() == 4); + NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); + NVTE_CHECK(cu_seqlens.dim() == 1); + + // Shape of dq is [t, h, d], so the dimension of "t" is 0 + // Shape of dkv is [2, t, h, d], so the dimension of "t" is 1 + int seq_dim = grad.dim() == 3 ? 0 : 1; + + int total_tokens = grad.size(seq_dim); + int num_heads = grad.size(seq_dim + 1); + int dim_per_head = grad.size(seq_dim + 2); + int batch = cu_seqlens.size(0) - 1; + + if constexpr (functor_idx < 2) { + NVTE_CHECK(grad_per_step.size(seq_dim) == total_tokens / 2); + } else { + NVTE_CHECK(grad_per_step.size(seq_dim) == total_tokens); + } + NVTE_CHECK(grad_per_step.size(seq_dim + 1) == num_heads); + NVTE_CHECK(grad_per_step.size(seq_dim + 2) == dim_per_head); + + size_t hidden_size = num_heads * dim_per_head; + NVTE_CHECK((hidden_size * c10::elementSize(grad.scalar_type())) % 16 == 0); + + constexpr unsigned int block = 256; + unsigned int grid_x; + if constexpr (functor_idx < 2) { + grid_x = (total_tokens / 2 * 32 + block - 1) / block; + } else { + grid_x = (total_tokens * 32 + block - 1) / block; + } + unsigned int grid_y = 1; + for (int i = 0; i < seq_dim; i++) { + grid_y *= grad.size(i); + } + dim3 grid = {grid_x, grid_y}; + + thd_grad_correction_kernel + <<>>( + grad.data_ptr(), + grad_per_step.data_ptr(), + cu_seqlens.data_ptr(), + batch, + hidden_size, + total_tokens); +} + +template +static void thd_grad_dispatcher(at::Tensor grad, + const at::Tensor &grad_per_step, + const at::Tensor &cu_seqlens, + const std::string &first_half, + const std::string &second_half) { + if (first_half == "add" && second_half == "none") { + thd_grad_correction_helper, EmptyFunctor, 0>( + grad, grad_per_step, cu_seqlens); + } else if (first_half == "copy" && second_half == "none") { + thd_grad_correction_helper( + grad, grad_per_step, cu_seqlens); + } else if (first_half == "none" && second_half == "add") { + thd_grad_correction_helper, 1>( + grad, grad_per_step, cu_seqlens); + } else if (first_half == "none" && second_half == "copy") { + thd_grad_correction_helper( + grad, grad_per_step, cu_seqlens); + } else if (first_half == "add" && second_half == "copy") { + thd_grad_correction_helper, CopyFunctor, 2>( + grad, grad_per_step, cu_seqlens); + } else if (first_half == "copy" && second_half == "add") { + thd_grad_correction_helper, 2>( + grad, grad_per_step, cu_seqlens); + } else { + NVTE_ERROR("Unsupported Functor of first half and second_half\n"); + } +} + +void thd_grad_correction(at::Tensor grad, + const at::Tensor &grad_per_step, + const at::Tensor &cu_seqlens, + const std::string &first_half, + const std::string &second_half) { + if (grad.scalar_type() == at::ScalarType::Half) { + thd_grad_dispatcher(grad, grad_per_step, cu_seqlens, first_half, second_half); + } else if (grad.scalar_type() == at::ScalarType::BFloat16) { + thd_grad_dispatcher(grad, grad_per_step, cu_seqlens, first_half, second_half); + } else if (grad.scalar_type() == at::ScalarType::Float) { + thd_grad_dispatcher(grad, grad_per_step, cu_seqlens, first_half, second_half); + } else { + NVTE_ERROR("Unsupported dtype of grad\n"); + } +} + +/*************************************************************************************************** + * Support THD format for Context Parallel: Generate partitioned indices for input tokens + **************************************************************************************************/ + +__global__ void thd_partition_indices_kernel(int *output, + int *cu_seqlens, + int batch, + int total_tokens, + int world_size, + int rank) { + extern __shared__ int cu_seqlens_s[]; + for (int i = threadIdx.x; i <= batch; i += blockDim.x) { + int seqlen = cu_seqlens[i]; + // Currently we assume that each sequence length is divisible by (world_size*2) since we have + // to distribute each sequence evenly to different GPUs. + assert(seqlen % (world_size*2) == 0); + cu_seqlens_s[i] = seqlen / world_size; + } + __syncthreads(); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int num_threads = blockDim.x * gridDim.x; + + for (int token_id = tid; token_id < total_tokens / world_size; token_id += num_threads) { + int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); + int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; + int index = token_id - cu_seqlens_s[seq_id]; + int offset = index < seq_len/2 ? rank : (world_size-1) * 2 - rank; + index += cu_seqlens_s[seq_id] * world_size + seq_len / 2 * offset; + output[token_id] = index; + } +} + +at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, + int total_tokens, + int world_size, + int rank) { + NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); + NVTE_CHECK(cu_seqlens.dim() == 1); + NVTE_CHECK(cu_seqlens.size(0) >= 2); + NVTE_CHECK(rank >= 0 && rank < world_size); + NVTE_CHECK(world_size > 0); + NVTE_CHECK(total_tokens > 0 && total_tokens % (world_size * 2) == 0); + + int batch = cu_seqlens.size(0) - 1; + + std::vector shape = {total_tokens / world_size}; + at::Tensor output = at::empty(shape, at::CUDA(at::ScalarType::Int)); + + constexpr unsigned int block = 256; + unsigned int grid = (output.size(0) + block - 1) / block; + thd_partition_indices_kernel<<>>( + output.data_ptr(), + cu_seqlens.data_ptr(), + batch, + total_tokens, + world_size, + rank); + + return output; +} + // Kernel used to update KV chache when attention layout is "thd". extern "C" @@ -1635,43 +2320,9 @@ __global__ void attention_copy_kernel( } } -// Kernel used in positional encoding application. -extern "C" -__global__ void get_values_kernel( - float* src, - int* seq_len, - int* incoming_seq_len, - float* dst, - int max_incoming_seq_len, - int b, - int d - ) - { - // src [s, 1, 1, d] - // dst [b] - for(int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { - int to_copy = d * incoming_seq_len[batch_idx]; - int offset = seq_len[batch_idx]; - - float* begin_src_copy = src + d * offset; - float* begin_dst_copy = dst + d * max_incoming_seq_len * batch_idx; - - for(int i = threadIdx.x; i < to_copy; i += blockDim.x) { - *(begin_dst_copy + i) = *(begin_src_copy + i); - } - } -} - void attention_copy(torch::Tensor A, torch::Tensor seq_len, torch::Tensor incoming_seq_len, torch::Tensor B, int max_incoming_seq_len, int max_seq_len, int b, int s) { attention_copy_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>(reinterpret_cast<__nv_bfloat16*>(A.data_ptr()), seq_len.data_ptr(), incoming_seq_len.data_ptr(), reinterpret_cast<__nv_bfloat16*>(B.data_ptr()), max_incoming_seq_len, max_seq_len, b, s); } - -void get_values(torch::Tensor A, torch::Tensor seq_len, torch::Tensor incoming_seq_len, torch::Tensor B, int max_incoming_seq_len, int b, int d) { - get_values_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>(A.data_ptr(), - seq_len.data_ptr(), - incoming_seq_len.data_ptr(), - B.data_ptr(), max_incoming_seq_len, b, d); -} diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 246724130f..3171c3b0f6 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -101,10 +101,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version"); m.def("get_cudnn_version", &get_cudnn_version, "Get cuDNN version"); m.def("userbuf_comm_available", &userbuf_comm_available, "If userbuf backend is available"); - - m.def("attention_copy", &attention_copy, "attention_copy"); - m.def("get_values", &get_values, "get_values"); // Data structures py::class_(m, "FP8TensorMeta") From 65e6b576d5f11b92d293dacc1b29eb293c391835 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 21 May 2024 13:43:23 -0700 Subject: [PATCH 135/244] Times for finetuning Signed-off-by: Pawel Gadzinski --- ...utorial_accelerate_hf_gemma_with_te.ip1ynb | 299 ++++++++++++++++++ 1 file changed, 299 insertions(+) create mode 100644 docs/examples/te_gemma/tutorial_accelerate_hf_gemma_with_te.ip1ynb diff --git a/docs/examples/te_gemma/tutorial_accelerate_hf_gemma_with_te.ip1ynb b/docs/examples/te_gemma/tutorial_accelerate_hf_gemma_with_te.ip1ynb new file mode 100644 index 0000000000..dcdd28c30a --- /dev/null +++ b/docs/examples/te_gemma/tutorial_accelerate_hf_gemma_with_te.ip1ynb @@ -0,0 +1,299 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Accelerating a Hugging Face Gemma model with Transformer Engine" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In the previous [tutorial](../te_llama/tutorial_accelerate_hf_llama_with_te.ipynb), we demonstrated how to accelerate HF Llama models using the Transformer Engine library. We replaced `GemmaDecoderLayer` with `TransformerLayer` from the Transformer Engine, achieving a **25%** speedup. Furthermore, we conducted the finetuning in FP8 precision, which yielded an additional **39%** speedup from the baseline.\n", + "\n", + "Now, we will undertake a similar enhancement for the Google's [Gemma](https://blog.google/technology/developers/gemma-open-models/) model." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Dependencies for this tutorial\n", + "\n", + "Following files and media are necessary to effectively run this tutorial:\n", + "\n", + "1. `te_gemma.py`\n", + " - This file contains the code to load a Hugging Face Gemma checkpoint in Transformer Engine's `TransformerLayer` instead of Hugging Face's `GemmaDecoderLayer`. This is used in the following two sections of the tutorial - \"Improvement 1\" and \"Improvement 2\".\n", + "2. `utils.py`\n", + " - This file contains the code related to dataloading, hyperparameters, setting up model/optimizers/accelerator, model training and other miscellaneous tasks like restarting the jupyter notebook from within the cell. \n", + "3. `media/`\n", + " - This directory contains the images used in the following tutorial." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Differences between Llama and Gemma" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Thr Llama and the Gemma are very similar models - both are based on Transformer Decoder architecture. The most important architectural differences between them are the following:\n", + "\n", + "\n", + "| Feature | Llama | Gemma |\n", + "|----------------------------------------------|------------------------------------|--------------------------------------------|\n", + "| **Norm Layer** | Standard RMSNorm
$ y = \\frac{x - \\mathrm{E}[x]}{ \\sqrt{\\mathrm{Var}[x] + \\varepsilon}} * \\gamma + \\beta $ | RMSNorm with zero centered gamma parameter
$ y = \\frac{x - \\mathrm{E}[x]}{ \\sqrt{\\mathrm{Var}[x] + \\varepsilon}} * (\\textcolor{red}{1 +} \\gamma) + \\beta $ |\n", + "| **Embedding Dimension/Head Dimension** | 4096/4096 | 3072/4096 |\n", + "| **Activation Function** | SwiGlu | GeGlu |\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# [Baseline] Running HF `GemmaModel` (Precision: `BF16`)\n", + "\n", + "Similarly to the Llama tutorial, we begin the experiments by running baseline Hugging Face Gemma model finetuning in BF16 precision.\n", + "\n", + "
\n", + "\n", + "Note\n", + " \n", + "This tutorial loads and trains a Gemma 7B model which takes up most of the GPU memory and therefore, we need to restart the jupyter notebook each time before running the following sections. A small utility method `restart_jupyter_notebook` is defined in the accompanying `utils.py` file. This function restarts the jupyter notebook so that the GPU memory is flushed before the model is loaded again from the checkpoint in order to avoid running into OOM (Out Of Memory) errors.\n", + "\n", + "If the utility doesn't work, comment this line `restart_jupyter_notebook()` in the following cell and manually restart the jupyter notebook before running the cell. Repeat the same for other sections in this tutorial.\n", + "\n", + "
\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "10 finetuning steps complete!\n", + "\n", + "Average time taken per step: \n", + "298 \n", + "milliseconds\n" + ] + } + ], + "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "restart_jupyter_notebook()\n", + "\n", + "\n", + "# Import necessary packages and methods\n", + "from utils import *\n", + "\n", + "\n", + "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", + "## !!! `model_name` attr must point to the location of the model weights !!!\n", + "## Weights can be downloaded from: https://huggingface.co/google/gemma-7b\n", + "hyperparams.model_name = \"../../../../gemma-7b\" # <== Add model weight location here e.g. \"/path/to/downloaded/gemma/weights\"\n", + "hyperparams.mixed_precision = \"bf16\"\n", + "\n", + "\n", + "# Init the model and accelerator wrapper\n", + "model = init_baseline_model(hyperparams).cuda()\n", + "accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)\n", + "\n", + "\n", + "# Finetune the model\n", + "finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's add this information in a table and keep comparing it with a few possible improvements in future sections:\n", + "\n", + "| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n", + "|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n", + "| HF (baseline) | BF16 | 298 | 1 |" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# [Improvement 1] Replace HF's `GemmaDecoderLayer` with TE's `TransformerLayer` (Precision: `BF16`)\n", + "\n", + "We replace *GemmaDecoderLayer* with the highly tuned *TransformerLayer*, similarly to our approach in the [Llama tutorial](../te_llama/tutorial_accelerate_hf_llama_with_te.ipynb). Let's observe the impact this change has on the model's speed." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "10 finetuning steps complete!\n", + "\n", + "Average time taken per step: \n", + "257 \n", + "milliseconds\n" + ] + } + ], + "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "restart_jupyter_notebook()\n", + "\n", + "\n", + "# Import necessary packages and methods\n", + "from utils import *\n", + "\n", + "\n", + "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", + "## !!! `model_name` attr must point to the location of the model weights !!!\n", + "## Weights can be downloaded from: https://huggingface.co/google/gemma-7b\n", + "hyperparams.model_name = \"../../../../gemma-7b\" # <== Add model weight location here e.g. \"/path/to/downloaded/gemma/weights\"\n", + "hyperparams.mixed_precision = \"bf16\"\n", + "\n", + "\n", + "# Init the model and accelerator wrapper\n", + "model = init_te_gemma_model(hyperparams).cuda()\n", + "accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)\n", + "\n", + "\n", + "# Finetune the model\n", + "finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Compared to the \"baseline\" implementation, we see that using Transformer Engine's `TransformerLayer` in place of Huggging Face's `GemmaDecoderLayer` gives a speedup of **16%** even when using only BF16 precision!\n", + "\n", + "| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n", + "|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n", + "| HF (baseline) | BF16 | 298 | 1 |\n", + "| TE (replace `GemmaDecoderLayer` with `TE.TransformerLayer`) | BF16 | 257 | 1.16 |" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# [Improvement 2] Replace HF's `GemmaDecoderLayer` with TE's `TransformerLayer` (Precision: `FP8`)\n", + "\n", + "The last improvement is about enabling FP8 precision. Let's see how it works." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "10 finetuning steps complete!\n", + "\n", + "Average time taken per step: \n", + "214 \n", + "milliseconds\n" + ] + } + ], + "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "#restart_jupyter_notebook()\n", + "\n", + "\n", + "# Import necessary packages and methods\n", + "from utils import *\n", + "\n", + "\n", + "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", + "## !!! `model_name` attr must point to the location of the model weights !!!\n", + "## Weights can be downloaded from: https://huggingface.co/google/gemma-7b\n", + "hyperparams.model_name = \"../../../../gemma-7b\" # <== Add model weight location here e.g. \"/path/to/downloaded/gemma/weights\"\n", + "hyperparams.mixed_precision = \"fp8\"\n", + "\n", + "\n", + "# Init the model and accelerator wrapper\n", + "model = init_te_gemma_model(hyperparams).cuda()\n", + "accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)\n", + "\n", + "\n", + "# Finetune the model\n", + "finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n", + "|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n", + "| HF (baseline) | BF16 | 298 | 1 |\n", + "| TE (replace `GemmaDecoderLayer` with `TE.TransformerLayer`) | BF16 | 257 | 1.16 |\n", + "| TE (replace `GemmaDecoderLayer` with `TE.TransformerLayer`) | FP8 | 214 | 1.39 |\n", + "\n", + "\n", + "After turning on FP8 precision, we get even more speedup of almost **39%**!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Conclusion\n", + "\n", + "As shown in the [Llama tutorial](../te_llama/tutorial_accelerate_hf_llama_with_te.ipynb), using the `TransformerLayer` module from Transformer Engine to replace Hugging Face's `GemmaDecoderLayer` results in a speedup compared to Hugging Face's native Gemma implementation." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## See more\n", + "\n", + "We also prepared [tutorial](./tutorial_generation_gemma_with_te.ipynb) in which we will show how to speedup the Gemma model generation using Transformer Engine." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From d82cb9f1814418ecc8f00b43a437d70fea147e7b Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 21 May 2024 13:45:14 -0700 Subject: [PATCH 136/244] Times for finetuning Signed-off-by: Pawel Gadzinski --- ...utorial_accelerate_hf_gemma_with_te.ip1ynb | 299 ------------------ ...tutorial_accelerate_hf_gemma_with_te.ipynb | 249 --------------- 2 files changed, 548 deletions(-) delete mode 100644 docs/examples/te_gemma/tutorial_accelerate_hf_gemma_with_te.ip1ynb delete mode 100644 docs/examples/te_gemma/tutorial_accelerate_hf_gemma_with_te.ipynb diff --git a/docs/examples/te_gemma/tutorial_accelerate_hf_gemma_with_te.ip1ynb b/docs/examples/te_gemma/tutorial_accelerate_hf_gemma_with_te.ip1ynb deleted file mode 100644 index dcdd28c30a..0000000000 --- a/docs/examples/te_gemma/tutorial_accelerate_hf_gemma_with_te.ip1ynb +++ /dev/null @@ -1,299 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Accelerating a Hugging Face Gemma model with Transformer Engine" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "In the previous [tutorial](../te_llama/tutorial_accelerate_hf_llama_with_te.ipynb), we demonstrated how to accelerate HF Llama models using the Transformer Engine library. We replaced `GemmaDecoderLayer` with `TransformerLayer` from the Transformer Engine, achieving a **25%** speedup. Furthermore, we conducted the finetuning in FP8 precision, which yielded an additional **39%** speedup from the baseline.\n", - "\n", - "Now, we will undertake a similar enhancement for the Google's [Gemma](https://blog.google/technology/developers/gemma-open-models/) model." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Dependencies for this tutorial\n", - "\n", - "Following files and media are necessary to effectively run this tutorial:\n", - "\n", - "1. `te_gemma.py`\n", - " - This file contains the code to load a Hugging Face Gemma checkpoint in Transformer Engine's `TransformerLayer` instead of Hugging Face's `GemmaDecoderLayer`. This is used in the following two sections of the tutorial - \"Improvement 1\" and \"Improvement 2\".\n", - "2. `utils.py`\n", - " - This file contains the code related to dataloading, hyperparameters, setting up model/optimizers/accelerator, model training and other miscellaneous tasks like restarting the jupyter notebook from within the cell. \n", - "3. `media/`\n", - " - This directory contains the images used in the following tutorial." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Differences between Llama and Gemma" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Thr Llama and the Gemma are very similar models - both are based on Transformer Decoder architecture. The most important architectural differences between them are the following:\n", - "\n", - "\n", - "| Feature | Llama | Gemma |\n", - "|----------------------------------------------|------------------------------------|--------------------------------------------|\n", - "| **Norm Layer** | Standard RMSNorm
$ y = \\frac{x - \\mathrm{E}[x]}{ \\sqrt{\\mathrm{Var}[x] + \\varepsilon}} * \\gamma + \\beta $ | RMSNorm with zero centered gamma parameter
$ y = \\frac{x - \\mathrm{E}[x]}{ \\sqrt{\\mathrm{Var}[x] + \\varepsilon}} * (\\textcolor{red}{1 +} \\gamma) + \\beta $ |\n", - "| **Embedding Dimension/Head Dimension** | 4096/4096 | 3072/4096 |\n", - "| **Activation Function** | SwiGlu | GeGlu |\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# [Baseline] Running HF `GemmaModel` (Precision: `BF16`)\n", - "\n", - "Similarly to the Llama tutorial, we begin the experiments by running baseline Hugging Face Gemma model finetuning in BF16 precision.\n", - "\n", - "
\n", - "\n", - "Note\n", - " \n", - "This tutorial loads and trains a Gemma 7B model which takes up most of the GPU memory and therefore, we need to restart the jupyter notebook each time before running the following sections. A small utility method `restart_jupyter_notebook` is defined in the accompanying `utils.py` file. This function restarts the jupyter notebook so that the GPU memory is flushed before the model is loaded again from the checkpoint in order to avoid running into OOM (Out Of Memory) errors.\n", - "\n", - "If the utility doesn't work, comment this line `restart_jupyter_notebook()` in the following cell and manually restart the jupyter notebook before running the cell. Repeat the same for other sections in this tutorial.\n", - "\n", - "
\n" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "10 finetuning steps complete!\n", - "\n", - "Average time taken per step: \n", - "298 \n", - "milliseconds\n" - ] - } - ], - "source": [ - "# Restart the notebook (to flush the GPU memory)\n", - "from utils import restart_jupyter_notebook\n", - "restart_jupyter_notebook()\n", - "\n", - "\n", - "# Import necessary packages and methods\n", - "from utils import *\n", - "\n", - "\n", - "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", - "## !!! `model_name` attr must point to the location of the model weights !!!\n", - "## Weights can be downloaded from: https://huggingface.co/google/gemma-7b\n", - "hyperparams.model_name = \"../../../../gemma-7b\" # <== Add model weight location here e.g. \"/path/to/downloaded/gemma/weights\"\n", - "hyperparams.mixed_precision = \"bf16\"\n", - "\n", - "\n", - "# Init the model and accelerator wrapper\n", - "model = init_baseline_model(hyperparams).cuda()\n", - "accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)\n", - "\n", - "\n", - "# Finetune the model\n", - "finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's add this information in a table and keep comparing it with a few possible improvements in future sections:\n", - "\n", - "| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n", - "|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n", - "| HF (baseline) | BF16 | 298 | 1 |" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# [Improvement 1] Replace HF's `GemmaDecoderLayer` with TE's `TransformerLayer` (Precision: `BF16`)\n", - "\n", - "We replace *GemmaDecoderLayer* with the highly tuned *TransformerLayer*, similarly to our approach in the [Llama tutorial](../te_llama/tutorial_accelerate_hf_llama_with_te.ipynb). Let's observe the impact this change has on the model's speed." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "10 finetuning steps complete!\n", - "\n", - "Average time taken per step: \n", - "257 \n", - "milliseconds\n" - ] - } - ], - "source": [ - "# Restart the notebook (to flush the GPU memory)\n", - "from utils import restart_jupyter_notebook\n", - "restart_jupyter_notebook()\n", - "\n", - "\n", - "# Import necessary packages and methods\n", - "from utils import *\n", - "\n", - "\n", - "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", - "## !!! `model_name` attr must point to the location of the model weights !!!\n", - "## Weights can be downloaded from: https://huggingface.co/google/gemma-7b\n", - "hyperparams.model_name = \"../../../../gemma-7b\" # <== Add model weight location here e.g. \"/path/to/downloaded/gemma/weights\"\n", - "hyperparams.mixed_precision = \"bf16\"\n", - "\n", - "\n", - "# Init the model and accelerator wrapper\n", - "model = init_te_gemma_model(hyperparams).cuda()\n", - "accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)\n", - "\n", - "\n", - "# Finetune the model\n", - "finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Compared to the \"baseline\" implementation, we see that using Transformer Engine's `TransformerLayer` in place of Huggging Face's `GemmaDecoderLayer` gives a speedup of **16%** even when using only BF16 precision!\n", - "\n", - "| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n", - "|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n", - "| HF (baseline) | BF16 | 298 | 1 |\n", - "| TE (replace `GemmaDecoderLayer` with `TE.TransformerLayer`) | BF16 | 257 | 1.16 |" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# [Improvement 2] Replace HF's `GemmaDecoderLayer` with TE's `TransformerLayer` (Precision: `FP8`)\n", - "\n", - "The last improvement is about enabling FP8 precision. Let's see how it works." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "10 finetuning steps complete!\n", - "\n", - "Average time taken per step: \n", - "214 \n", - "milliseconds\n" - ] - } - ], - "source": [ - "# Restart the notebook (to flush the GPU memory)\n", - "from utils import restart_jupyter_notebook\n", - "#restart_jupyter_notebook()\n", - "\n", - "\n", - "# Import necessary packages and methods\n", - "from utils import *\n", - "\n", - "\n", - "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", - "## !!! `model_name` attr must point to the location of the model weights !!!\n", - "## Weights can be downloaded from: https://huggingface.co/google/gemma-7b\n", - "hyperparams.model_name = \"../../../../gemma-7b\" # <== Add model weight location here e.g. \"/path/to/downloaded/gemma/weights\"\n", - "hyperparams.mixed_precision = \"fp8\"\n", - "\n", - "\n", - "# Init the model and accelerator wrapper\n", - "model = init_te_gemma_model(hyperparams).cuda()\n", - "accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)\n", - "\n", - "\n", - "# Finetune the model\n", - "finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n", - "|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n", - "| HF (baseline) | BF16 | 298 | 1 |\n", - "| TE (replace `GemmaDecoderLayer` with `TE.TransformerLayer`) | BF16 | 257 | 1.16 |\n", - "| TE (replace `GemmaDecoderLayer` with `TE.TransformerLayer`) | FP8 | 214 | 1.39 |\n", - "\n", - "\n", - "After turning on FP8 precision, we get even more speedup of almost **39%**!" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Conclusion\n", - "\n", - "As shown in the [Llama tutorial](../te_llama/tutorial_accelerate_hf_llama_with_te.ipynb), using the `TransformerLayer` module from Transformer Engine to replace Hugging Face's `GemmaDecoderLayer` results in a speedup compared to Hugging Face's native Gemma implementation." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## See more\n", - "\n", - "We also prepared [tutorial](./tutorial_generation_gemma_with_te.ipynb) in which we will show how to speedup the Gemma model generation using Transformer Engine." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.12" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/docs/examples/te_gemma/tutorial_accelerate_hf_gemma_with_te.ipynb b/docs/examples/te_gemma/tutorial_accelerate_hf_gemma_with_te.ipynb deleted file mode 100644 index 3dca60e093..0000000000 --- a/docs/examples/te_gemma/tutorial_accelerate_hf_gemma_with_te.ipynb +++ /dev/null @@ -1,249 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Accelerating a Hugging Face Gemma model with Transformer Engine" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "In the previous [tutorial](../te_llama/tutorial_accelerate_hf_llama_with_te.ipynb), we demonstrated how to accelerate HF Llama models using the Transformer Engine library. We replaced `GemmaDecoderLayer` with `TransformerLayer` from the Transformer Engine, achieving a **25%** speedup. Furthermore, we conducted the finetuning in FP8 precision, which yielded an additional **39%** speedup from the baseline.\n", - "\n", - "Now, we will undertake a similar enhancement for the Google's [Gemma](https://blog.google/technology/developers/gemma-open-models/) model." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Dependencies for this tutorial\n", - "\n", - "Following files and media are necessary to effectively run this tutorial:\n", - "\n", - "1. `te_gemma.py`\n", - " - This file contains the code to load a Hugging Face Gemma checkpoint in Transformer Engine's `TransformerLayer` instead of Hugging Face's `GemmaDecoderLayer`. This is used in the following two sections of the tutorial - \"Improvement 1\" and \"Improvement 2\".\n", - "2. `utils.py`\n", - " - This file contains the code related to dataloading, hyperparameters, setting up model/optimizers/accelerator, model training and other miscellaneous tasks like restarting the jupyter notebook from within the cell. \n", - "3. `media/`\n", - " - This directory contains the images used in the following tutorial." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Differences between Llama and Gemma" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Thr Llama and the Gemma are very similar models - both are based on Transformer Decoder architecture. The most important architectural differences between them are the following:\n", - "\n", - "\n", - "| Feature | Llama | Gemma |\n", - "|----------------------------------------------|------------------------------------|--------------------------------------------|\n", - "| **Norm Layer** | Standard RMSNorm
$ y = \\frac{x - \\mathrm{E}[x]}{ \\sqrt{\\mathrm{Var}[x] + \\varepsilon}} * \\gamma + \\beta $ | RMSNorm with zero centered gamma parameter
$ y = \\frac{x - \\mathrm{E}[x]}{ \\sqrt{\\mathrm{Var}[x] + \\varepsilon}} * (\\textcolor{red}{1 +} \\gamma) + \\beta $ |\n", - "| **Embedding Dimension/Head Dimension** | 4096/4096 | 3072/4096 |\n", - "| **Activation Function** | SwiGlu | GeGlu |\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# [Baseline] Running HF `GemmaModel` (Precision: `BF16`)\n", - "\n", - "Similarly to the Llama tutorial, we begin the experiments by running baseline Hugging Face Gemma model finetuning in BF16 precision.\n", - "\n", - "
\n", - "\n", - "Note\n", - " \n", - "This tutorial loads and trains a Gemma 7B model which takes up most of the GPU memory and therefore, we need to restart the jupyter notebook each time before running the following sections. A small utility method `restart_jupyter_notebook` is defined in the accompanying `utils.py` file. This function restarts the jupyter notebook so that the GPU memory is flushed before the model is loaded again from the checkpoint in order to avoid running into OOM (Out Of Memory) errors.\n", - "\n", - "If the utility doesn't work, comment this line `restart_jupyter_notebook()` in the following cell and manually restart the jupyter notebook before running the cell. Repeat the same for other sections in this tutorial.\n", - "\n", - "
\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Restart the notebook (to flush the GPU memory)\n", - "from utils import restart_jupyter_notebook\n", - "restart_jupyter_notebook()\n", - "\n", - "\n", - "# Import necessary packages and methods\n", - "from utils import *\n", - "\n", - "\n", - "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", - "## !!! `model_name` attr must point to the location of the model weights !!!\n", - "## Weights can be downloaded from: https://huggingface.co/google/gemma-7b\n", - "hyperparams.model_name = \"../../../../gemma-weights\" # <== Add model weight location here e.g. \"/path/to/downloaded/gemma/weights\"\n", - "hyperparams.mixed_precision = \"bf16\"\n", - "\n", - "\n", - "# Init the model and accelerator wrapper\n", - "model = init_baseline_model(hyperparams).cuda()\n", - "accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)\n", - "\n", - "\n", - "# Finetune the model\n", - "finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's add this information in a table and keep comparing it with a few possible improvements in future sections:\n", - "\n", - "| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n", - "|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n", - "| HF (baseline) | BF16 | - | 1 |" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# [Improvement 1] Replace HF's `GemmaDecoderLayer` with TE's `TransformerLayer` (Precision: `BF16`)\n", - "\n", - "We replace *GemmaDecoderLayer* with the highly tuned *TransformerLayer*, similarly to our approach in the [Llama tutorial](../te_llama/tutorial_accelerate_hf_llama_with_te.ipynb). Let's observe the impact this change has on the model's speed." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Restart the notebook (to flush the GPU memory)\n", - "from utils import restart_jupyter_notebook\n", - "#restart_jupyter_notebook()\n", - "\n", - "\n", - "# Import necessary packages and methods\n", - "from utils import *\n", - "\n", - "\n", - "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", - "## !!! `model_name` attr must point to the location of the model weights !!!\n", - "## Weights can be downloaded from: https://huggingface.co/google/gemma-7b\n", - "hyperparams.model_name = \"../../../../gemma-weights\" # <== Add model weight location here e.g. \"/path/to/downloaded/gemma/weights\"\n", - "hyperparams.mixed_precision = \"bf16\"\n", - "\n", - "\n", - "# Init the model and accelerator wrapper\n", - "model = init_te_gemma_model(hyperparams).cuda()\n", - "accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)\n", - "\n", - "\n", - "# Finetune the model\n", - "finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Compared to the \"baseline\" implementation, we see that using Transformer Engine's `TransformerLayer` in place of Huggging Face's `GemmaDecoderLayer` gives a speedup of **??%** even when using only BF16 precision!\n", - "\n", - "| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n", - "|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n", - "| HF (baseline) | BF16 | 315 | 1 |\n", - "| TE (replace `GemmaDecoderLayer` with `TE.TransformerLayer`) | BF16 | 252 | 1.25 |" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# [Improvement 2] Replace HF's `GemmaDecoderLayer` with TE's `TransformerLayer` (Precision: `FP8`)\n", - "\n", - "The last improvement is about enabling FP8 precision. Let's see how it works." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Restart the notebook (to flush the GPU memory)\n", - "from utils import restart_jupyter_notebook\n", - "#restart_jupyter_notebook()\n", - "\n", - "\n", - "# Import necessary packages and methods\n", - "from utils import *\n", - "\n", - "\n", - "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", - "## !!! `model_name` attr must point to the location of the model weights !!!\n", - "## Weights can be downloaded from: https://huggingface.co/google/gemma-7b\n", - "hyperparams.model_name = \"../../../../gemma-weights\" # <== Add model weight location here e.g. \"/path/to/downloaded/gemma/weights\"\n", - "hyperparams.mixed_precision = \"fp8\"\n", - "\n", - "\n", - "# Init the model and accelerator wrapper\n", - "model = init_te_gemma_model(hyperparams).cuda()\n", - "accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)\n", - "\n", - "\n", - "# Finetune the model\n", - "finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n", - "|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n", - "| HF (baseline) | BF16 | - | 1 |\n", - "| TE (replace `GemmaDecoderLayer` with `TE.TransformerLayer`) | BF16 | - | - |\n", - "| TE (replace `GemmaDecoderLayer` with `TE.TransformerLayer`) | FP8 | - | - |\n", - "\n", - "\n", - "After turning on FP8 precision, we get even more speedup of almost **??%**!" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Conclusion\n", - "\n", - "As shown in the [Llama tutorial](../te_llama/tutorial_accelerate_hf_llama_with_te.ipynb), using the `TransformerLayer` module from Transformer Engine to replace Hugging Face's `GemmaDecoderLayer` results in a speedup compared to Hugging Face's native Gemma implementation." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## See more\n", - "\n", - "We also prepared [tutorial](./tutorial_generation_gemma_with_te.ipynb) in which we will show how to speedup the Gemma model generation using Transformer Engine." - ] - } - ], - "metadata": { - "language_info": { - "name": "python" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} From bc26c4d18bf15613b474a6cc36808d52d397cc98 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 21 May 2024 13:45:45 -0700 Subject: [PATCH 137/244] Times for finetuning Signed-off-by: Pawel Gadzinski --- ...tutorial_accelerate_hf_gemma_with_te.ipynb | 299 ++++++++++++++++++ 1 file changed, 299 insertions(+) create mode 100644 docs/examples/te_gemma/tutorial_accelerate_hf_gemma_with_te.ipynb diff --git a/docs/examples/te_gemma/tutorial_accelerate_hf_gemma_with_te.ipynb b/docs/examples/te_gemma/tutorial_accelerate_hf_gemma_with_te.ipynb new file mode 100644 index 0000000000..dcdd28c30a --- /dev/null +++ b/docs/examples/te_gemma/tutorial_accelerate_hf_gemma_with_te.ipynb @@ -0,0 +1,299 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Accelerating a Hugging Face Gemma model with Transformer Engine" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In the previous [tutorial](../te_llama/tutorial_accelerate_hf_llama_with_te.ipynb), we demonstrated how to accelerate HF Llama models using the Transformer Engine library. We replaced `GemmaDecoderLayer` with `TransformerLayer` from the Transformer Engine, achieving a **25%** speedup. Furthermore, we conducted the finetuning in FP8 precision, which yielded an additional **39%** speedup from the baseline.\n", + "\n", + "Now, we will undertake a similar enhancement for the Google's [Gemma](https://blog.google/technology/developers/gemma-open-models/) model." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Dependencies for this tutorial\n", + "\n", + "Following files and media are necessary to effectively run this tutorial:\n", + "\n", + "1. `te_gemma.py`\n", + " - This file contains the code to load a Hugging Face Gemma checkpoint in Transformer Engine's `TransformerLayer` instead of Hugging Face's `GemmaDecoderLayer`. This is used in the following two sections of the tutorial - \"Improvement 1\" and \"Improvement 2\".\n", + "2. `utils.py`\n", + " - This file contains the code related to dataloading, hyperparameters, setting up model/optimizers/accelerator, model training and other miscellaneous tasks like restarting the jupyter notebook from within the cell. \n", + "3. `media/`\n", + " - This directory contains the images used in the following tutorial." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Differences between Llama and Gemma" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Thr Llama and the Gemma are very similar models - both are based on Transformer Decoder architecture. The most important architectural differences between them are the following:\n", + "\n", + "\n", + "| Feature | Llama | Gemma |\n", + "|----------------------------------------------|------------------------------------|--------------------------------------------|\n", + "| **Norm Layer** | Standard RMSNorm
$ y = \\frac{x - \\mathrm{E}[x]}{ \\sqrt{\\mathrm{Var}[x] + \\varepsilon}} * \\gamma + \\beta $ | RMSNorm with zero centered gamma parameter
$ y = \\frac{x - \\mathrm{E}[x]}{ \\sqrt{\\mathrm{Var}[x] + \\varepsilon}} * (\\textcolor{red}{1 +} \\gamma) + \\beta $ |\n", + "| **Embedding Dimension/Head Dimension** | 4096/4096 | 3072/4096 |\n", + "| **Activation Function** | SwiGlu | GeGlu |\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# [Baseline] Running HF `GemmaModel` (Precision: `BF16`)\n", + "\n", + "Similarly to the Llama tutorial, we begin the experiments by running baseline Hugging Face Gemma model finetuning in BF16 precision.\n", + "\n", + "
\n", + "\n", + "Note\n", + " \n", + "This tutorial loads and trains a Gemma 7B model which takes up most of the GPU memory and therefore, we need to restart the jupyter notebook each time before running the following sections. A small utility method `restart_jupyter_notebook` is defined in the accompanying `utils.py` file. This function restarts the jupyter notebook so that the GPU memory is flushed before the model is loaded again from the checkpoint in order to avoid running into OOM (Out Of Memory) errors.\n", + "\n", + "If the utility doesn't work, comment this line `restart_jupyter_notebook()` in the following cell and manually restart the jupyter notebook before running the cell. Repeat the same for other sections in this tutorial.\n", + "\n", + "
\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "10 finetuning steps complete!\n", + "\n", + "Average time taken per step: \n", + "298 \n", + "milliseconds\n" + ] + } + ], + "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "restart_jupyter_notebook()\n", + "\n", + "\n", + "# Import necessary packages and methods\n", + "from utils import *\n", + "\n", + "\n", + "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", + "## !!! `model_name` attr must point to the location of the model weights !!!\n", + "## Weights can be downloaded from: https://huggingface.co/google/gemma-7b\n", + "hyperparams.model_name = \"../../../../gemma-7b\" # <== Add model weight location here e.g. \"/path/to/downloaded/gemma/weights\"\n", + "hyperparams.mixed_precision = \"bf16\"\n", + "\n", + "\n", + "# Init the model and accelerator wrapper\n", + "model = init_baseline_model(hyperparams).cuda()\n", + "accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)\n", + "\n", + "\n", + "# Finetune the model\n", + "finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's add this information in a table and keep comparing it with a few possible improvements in future sections:\n", + "\n", + "| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n", + "|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n", + "| HF (baseline) | BF16 | 298 | 1 |" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# [Improvement 1] Replace HF's `GemmaDecoderLayer` with TE's `TransformerLayer` (Precision: `BF16`)\n", + "\n", + "We replace *GemmaDecoderLayer* with the highly tuned *TransformerLayer*, similarly to our approach in the [Llama tutorial](../te_llama/tutorial_accelerate_hf_llama_with_te.ipynb). Let's observe the impact this change has on the model's speed." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "10 finetuning steps complete!\n", + "\n", + "Average time taken per step: \n", + "257 \n", + "milliseconds\n" + ] + } + ], + "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "restart_jupyter_notebook()\n", + "\n", + "\n", + "# Import necessary packages and methods\n", + "from utils import *\n", + "\n", + "\n", + "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", + "## !!! `model_name` attr must point to the location of the model weights !!!\n", + "## Weights can be downloaded from: https://huggingface.co/google/gemma-7b\n", + "hyperparams.model_name = \"../../../../gemma-7b\" # <== Add model weight location here e.g. \"/path/to/downloaded/gemma/weights\"\n", + "hyperparams.mixed_precision = \"bf16\"\n", + "\n", + "\n", + "# Init the model and accelerator wrapper\n", + "model = init_te_gemma_model(hyperparams).cuda()\n", + "accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)\n", + "\n", + "\n", + "# Finetune the model\n", + "finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Compared to the \"baseline\" implementation, we see that using Transformer Engine's `TransformerLayer` in place of Huggging Face's `GemmaDecoderLayer` gives a speedup of **16%** even when using only BF16 precision!\n", + "\n", + "| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n", + "|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n", + "| HF (baseline) | BF16 | 298 | 1 |\n", + "| TE (replace `GemmaDecoderLayer` with `TE.TransformerLayer`) | BF16 | 257 | 1.16 |" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# [Improvement 2] Replace HF's `GemmaDecoderLayer` with TE's `TransformerLayer` (Precision: `FP8`)\n", + "\n", + "The last improvement is about enabling FP8 precision. Let's see how it works." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "10 finetuning steps complete!\n", + "\n", + "Average time taken per step: \n", + "214 \n", + "milliseconds\n" + ] + } + ], + "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "#restart_jupyter_notebook()\n", + "\n", + "\n", + "# Import necessary packages and methods\n", + "from utils import *\n", + "\n", + "\n", + "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", + "## !!! `model_name` attr must point to the location of the model weights !!!\n", + "## Weights can be downloaded from: https://huggingface.co/google/gemma-7b\n", + "hyperparams.model_name = \"../../../../gemma-7b\" # <== Add model weight location here e.g. \"/path/to/downloaded/gemma/weights\"\n", + "hyperparams.mixed_precision = \"fp8\"\n", + "\n", + "\n", + "# Init the model and accelerator wrapper\n", + "model = init_te_gemma_model(hyperparams).cuda()\n", + "accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)\n", + "\n", + "\n", + "# Finetune the model\n", + "finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n", + "|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n", + "| HF (baseline) | BF16 | 298 | 1 |\n", + "| TE (replace `GemmaDecoderLayer` with `TE.TransformerLayer`) | BF16 | 257 | 1.16 |\n", + "| TE (replace `GemmaDecoderLayer` with `TE.TransformerLayer`) | FP8 | 214 | 1.39 |\n", + "\n", + "\n", + "After turning on FP8 precision, we get even more speedup of almost **39%**!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Conclusion\n", + "\n", + "As shown in the [Llama tutorial](../te_llama/tutorial_accelerate_hf_llama_with_te.ipynb), using the `TransformerLayer` module from Transformer Engine to replace Hugging Face's `GemmaDecoderLayer` results in a speedup compared to Hugging Face's native Gemma implementation." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## See more\n", + "\n", + "We also prepared [tutorial](./tutorial_generation_gemma_with_te.ipynb) in which we will show how to speedup the Gemma model generation using Transformer Engine." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 183f1f1de3a6b09b075968a2ad281f23999651ae Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 22 May 2024 09:44:50 -0700 Subject: [PATCH 138/244] fixes Signed-off-by: Pawel Gadzinski --- docs/examples/te_gemma/te_gemma.py | 56 ++++++++++++++----------- docs/examples/te_gemma/utils.py | 12 ++++-- transformer_engine/pytorch/attention.py | 3 ++ 3 files changed, 43 insertions(+), 28 deletions(-) diff --git a/docs/examples/te_gemma/te_gemma.py b/docs/examples/te_gemma/te_gemma.py index 52e85cea10..54de549ce2 100644 --- a/docs/examples/te_gemma/te_gemma.py +++ b/docs/examples/te_gemma/te_gemma.py @@ -20,6 +20,7 @@ import torch.nn.functional as F + class TEGemmaDecoderLayer(te.pytorch.TransformerLayer): """ Wrapper class over TE's `TransformerLayer`. This makes the wrapper very @@ -79,14 +80,15 @@ def set_inference_params(self, inference_params): self.inference_params = inference_params def forward(self, hidden_states : torch.Tensor): - hidden_states.data[:] = hidden_states.data[:] * self.normalizer # static operation - for CUDA graphs - for decoder_layer in self.model.layers: - hidden_states.data[:] = decoder_layer( - hidden_states, - attention_mask=None, - self_attn_mask_type=self.mask, - inference_params=self.inference_params - )[0] # static copy - for CUDA graphs + with torch.no_grad(): + hidden_states.data[:] = hidden_states.data[:] * self.normalizer # static operation - for CUDA graphs + for decoder_layer in self.model.layers: + hidden_states.data[:] = decoder_layer( + hidden_states, + attention_mask=None, + self_attn_mask_type=self.mask, + inference_params=self.inference_params + )[0] # static copy - for CUDA graphs hidden_states.copy_(self.model.norm(hidden_states)) # static copy - for CUDA graphs logits = self.lm_head(hidden_states) @@ -154,6 +156,7 @@ class is monkey-patched with `TEGemmaDecoderLayer` class before def __init__(self, config: GemmaConfig): with replace_decoder(te_decoder_cls=TEGemmaDecoderLayer): super().__init__(config) + self.to(torch.bfloat16).cuda() self.hidden_size = config.hidden_size self._model_generation_phase = GemmaGenerator( lm_head=self.lm_head, @@ -222,11 +225,9 @@ def _generate_context_phase( # We need to update offsets before every forward pass to make cache work properly. inference_params.thd_setup_before_new_input(input_ids, pad_token_id=0, reset=True) - #self._model_context_phase = self.record_graph(self._model_context_phase, hidden_states) + hidden_states.data[:] = self.model.embed_tokens(input_ids) logits = self._model_context_phase(hidden_states) - #import pdb - #pdb.set_trace() # We choose logits coresponding with last token in each sequence, # which have various lengths - they are stored in (inference_params.incoming_seq_len - 1) Tensor. @@ -249,10 +250,10 @@ def generate( max_new_tokens: int = 0, *args, **kwargs ): - + self.eval() assert self.config.qkv_format == "thd", "Generation using other qkv_layouts than thd is not provided in this tutorial" - with te.pytorch.fp8_autocast(enabled=self.config.fp8, fp8_recipe=self.fp8_recipe if self.config.fp8 else None), \ - autocast(dtype=torch.bfloat16, cache_enabled=False): + print(f"self.config.fp8 = {self.config.fp8}") + with autocast(dtype=torch.bfloat16, cache_enabled=False): batch_size, max_input_sequence_len = input_ids.shape[0], self._get_max_input_seq_len(input_ids) lengths = torch.sum(input_ids.ne(pad_token_id), dim=-1).squeeze() # [s] input_ids = F.pad(input_ids, (max_input_sequence_len - input_ids.shape[1], 0), "constant", 0) @@ -269,8 +270,8 @@ def generate( # Context phase TEGemmaForCausalLM._padding_to_end(input_ids, lengths) - hidden_states, next_tokens = TEGemmaForCausalLM._generate_context_phase( - self, + + hidden_states, next_tokens = self._generate_context_phase( input_ids, inference_params ) @@ -278,9 +279,11 @@ def generate( # Generation phase. inference_params.thd_setup_before_new_input(next_tokens.unsqueeze(1)) output_tokens = [next_tokens] - for i in range(max_new_tokens): - next_tokens = self._model_generation_phase(hidden_states) - output_tokens.append(next_tokens.clone()) + + with te.pytorch.fp8_autocast(enabled=False, fp8_recipe=self.fp8_recipe if self.config.fp8 else None): + for _ in range(max_new_tokens): + next_tokens = self._model_generation_phase(hidden_states) + output_tokens.append(next_tokens.clone()) result = torch.cat((input_ids, torch.stack(output_tokens).permute([1, 0])), dim=1) return result @@ -293,7 +296,7 @@ class TEGemmaForCausalLMCudaGraphs(TEGemmaForCausalLM): """ def __init__(self, config : GemmaConfig): super().__init__(config) - self.to("cuda") + # Przekonwertuj siebie na bf16 chatgpt... # Preparation of the static buffers. self.config = config self.hidden_states_buffer = torch.empty( @@ -306,18 +309,21 @@ def __init__(self, config : GemmaConfig): self._model_generation_phase.set_inference_params(self.inference_params) self._model_context_phase.set_inference_params(self.inference_params) + def record(self): + self.eval() # Here "the trick" happens. We override methods from TEGemmaForCausalLM # with their recorded version. After invocation of each of them, # captured graph will be replayed with minimal usage of CPU, # what will lead to huge speedup. + - input_shape = (config.cuda_graphs_static_batch_size, config.cuda_graphs_static_max_context_len) - self.inference_params.thd_setup_before_new_input(torch.ones(input_shape), reset=True) + input_shape = (self.config.cuda_graphs_static_batch_size, self.config.cuda_graphs_static_max_context_len) + self.inference_params.thd_setup_before_new_input(torch.randn(input_shape), reset=True) self._model_context_phase = self.record_graph(self._model_context_phase, self.hidden_states_buffer) # CUDA Graphs recording - input_shape = torch.ones((config.cuda_graphs_static_batch_size, 1)) + input_shape = torch.randn((self.config.cuda_graphs_static_batch_size, 1)) self.inference_params.thd_setup_before_new_input(input_shape, reset=True) - #self._model_generation_phase = self.record_graph(self._model_generation_phase, self.generation_buffer) # CUDA Graphs recording + self._model_generation_phase = self.record_graph(self._model_generation_phase, self.generation_buffer) # CUDA Graphs recording """ Functions _create_hidden_states_buffer and _create_inference_params from base class are overriden @@ -338,7 +344,7 @@ def record_graph(self, function, input_tensor): # function is invoked on argument (self.hidden_states,) and all kernels are recorded. # record_graph() returns captured function, which can be run later with minimal use of th CPU. fp8_format = Format.HYBRID - fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") + fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=1024, amax_compute_algo="max") with autocast(dtype=torch.bfloat16, cache_enabled=False): graphed_function = te.pytorch.make_graphed_callables( function, diff --git a/docs/examples/te_gemma/utils.py b/docs/examples/te_gemma/utils.py index 7fe4ba3b5a..b9ce4b78b3 100644 --- a/docs/examples/te_gemma/utils.py +++ b/docs/examples/te_gemma/utils.py @@ -112,6 +112,8 @@ def init_te_gemma_model(hyperparams): for key, value in hyperparams.__dict__.items(): setattr(config, key, value) model = load_te_model(cls, config) + if hyperparams.generation_cuda_graphs: + model.record() return model @@ -245,7 +247,11 @@ def print_sample_of_generated_texts(model): -def benchmark_generation(model, batch_size, context_length, max_new_tokens): +def benchmark_generation(model): + batch_size = 64 + context_length = 128 + max_new_tokens = 1024 - 128 + print(f"Benchmarking for batch_size={batch_size} and total tokens = {context_length + max_new_tokens}") tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name) inputs = tokenizer(["a" * context_length] * batch_size, return_tensors="pt", padding=True) @@ -253,7 +259,7 @@ def benchmark_generation(model, batch_size, context_length, max_new_tokens): end = torch.cuda.Event(enable_timing=True) torch.cuda.synchronize() start.record() - + model.generate( inputs['input_ids'].cuda(), max_new_tokens=max_new_tokens @@ -262,4 +268,4 @@ def benchmark_generation(model, batch_size, context_length, max_new_tokens): end.record() print(f"Benchmark with context_length={context_length} and max_new_tokens={max_new_tokens} took {start.elapsed_time(end)} ms.") - print(f"Peak GPU memoty usage: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB") + print(f"Peak GPU memory usage: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB") diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 0d3a468d7a..1a380b88b2 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -3940,6 +3940,9 @@ def forward( # max512 backend will only support [1, h, s, s] os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1" + if self.qkv_format != "thd": # added by me #TODO - i need that in case d=256 fused attention is not run + use_fused_attention = False + if use_fused_attention: fused_attention_backend = tex.get_fused_attn_backend( TE_DType[query_layer.dtype] From d23e2b36afeca91f8ce7c0b2796058112620a524 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 22 May 2024 15:32:01 -0700 Subject: [PATCH 139/244] Minor fixes Signed-off-by: Pawel Gadzinski --- .../tutorial_generation_gemma_with_te.ipynb | 314 ++++++++++++------ docs/examples/te_gemma/utils.py | 38 ++- 2 files changed, 239 insertions(+), 113 deletions(-) diff --git a/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb index 9600f9cf5f..5595d86a22 100644 --- a/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb +++ b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb @@ -71,11 +71,7 @@ "\n", "#### Benchmarking\n", "\n", - "We'll evaluate the generation time across three benchmarks:\n", - "- Long input sequences (up to 256 tokens) with short generation (up to 128 tokens),\n", - "- Short input sequences (up to 64 tokens) with long generation (up to 1000 tokens).\n", - "\n", - "All benchmarks are conducted with a batch size of 64 using the dataset \"timdettmers/openassistant-guanaco\".\n", + "We'll evaluate the generation time across one benchmark: generation with context phase max sequence length = 128, batch size = 64 and number of generated tokens = 1024 - 128.\n", "\n", "
\n", "Note\n", @@ -127,30 +123,48 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "7477e469", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Another string ... \n", + "\n", + "I have a new 2019 15\" MBP with 16GB RAM and 1TB SSD. I have a 2015 15\" MBP with 16GB RAM and 1TB SSD. I have a 2011 15\" MBP with 16GB RAM and 1TB SSD. I have a 2011 13\" MBP with 1\n", + "====================================================================================================\n", + "I love a good DIY project. I love the challenge of creating something from scratch, and I love the sense of accomplishment that comes with finishing a project.\n", + "\n", + "I also love the fact that I can make something that is unique and special to me.\n", + "\n", + "There is something so satisfying about taking a blank canvas and turning it into something beautiful and functional.\n", + "\n", + "I also love the fact that I can save money by doing things myself.\n", + "\n", + "When I make something myself, I know exactly\n", + "====================================================================================================\n", + "Benchmarking for batch_size=64 and total tokens = 1024\n", + "Benchmark with context_length=128 and max_new_tokens=896 took 42079.8125 ms.\n", + "Peak GPU memory usage: 65.96 GB\n" + ] + } + ], "source": [ - "# Restart the notebook (to flush the GPU memory)\n", - "from utils import restart_jupyter_notebook\n", - "restart_jupyter_notebook()\n", - "\n", "# Import necessary packages and methods\n", "from utils import *\n", "\n", "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", "## !!! `model_name` attr must point to the location of the model weights !!!\n", "## Weights can be downloaded from: https://huggingface.co/google/gemma-7b\n", - "hyperparams.model_name = \"../../../../gemma-weights\" # <== Add model weight location here e.g. \"/path/to/downloaded/gemma/weights\"\n", + "hyperparams.model_name = \"../gemma-weights\" # <== Add model weight location here e.g. \"/path/to/downloaded/gemma/weights\"\n", "hyperparams.mixed_precision = \"bf16\"\n", "\n", "model = init_baseline_model(hyperparams).cuda()\n", "\n", "print_sample_of_generated_texts(model)\n", - "\n", - "benchmark_generation(model, 64, 128, 1024)\n", - "benchmark_generation(model, 64, 256, 128)" + "benchmark_generation(model)" ] }, { @@ -160,9 +174,9 @@ "source": [ "We put these times into the table for later comparison.\n", "\n", - "| Models | max_input_len=64, max_new_tokens=1000 | max_input_len=128, max_new_tokens=128 | \n", + "| Models | Time | Memory | \n", "|-------------------------------------------------------------|---------------------------------------|--------------------------------------|\n", - "| HF (baseline) | - | - | " + "| HF (baseline) | 42,0 sec | - | " ] }, { @@ -211,10 +225,45 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "id": "4fc5e1cd", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The device memory hasn't been flushed, try manually restarting the Jupyter kernel!\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "self.config.fp8 = False\n", + "Another string ... \n", + "\n", + "I have a 2007 1.9 TDI 105bhp and the engine management light came on.\n", + "\n", + "I have a code reader and it came up with the following:\n", + "\n", + "16885 - P0341 - Camshaft Position Sensor (G40) - No Signal\n", + "\n", + "I have replaced the camshaft sensor and the light is still on.\n", + "\n", + "I have checked the wiring to the sensor and it is fine.\n", + "\n", + "I have checked the\n", + "====================================================================================================\n", + "I love the new Star Wars series The Mandalorian. I’ve been a fan of the franchise since I was a kid, and I’ve been a fan of The Mandalorian since it was first announced. I’ve been a fan of The Mandalorian since the first trailer was released. I’ve been a fan of The Mandalorian since the first episode of the first season was\n", + "====================================================================================================\n", + "Benchmarking for batch_size=64 and total tokens = 1024\n", + "self.config.fp8 = False\n", + "Benchmark with context_length=128 and max_new_tokens=896 took 27791.4375 ms.\n", + "Peak GPU memory usage: 65.96 GB\n" + ] + } + ], "source": [ "# Restart the notebook (to flush the GPU memory)\n", "from utils import restart_jupyter_notebook\n", @@ -226,17 +275,14 @@ "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", "## !!! `model_name` attr must point to the location of the model weights !!!\n", "## Weights can be downloaded from: https://huggingface.co/google/gemma-7b\n", - "hyperparams.model_name = \"../../../../gemma-weights\" # <== Add model weight location here e.g. \"/path/to/downloaded/gemma/weights\"\n", - "hyperparams.mixed_precision = \"bf16\"\n", - "hyperparams.fuse_qkv_params = False\n", + "hyperparams.model_name = \"../gemma-weights\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", "hyperparams.qkv_format = \"thd\"\n", "\n", "# Init the model and accelerator wrapper\n", - "model = init_te_gemma_model(hyperparams).to(torch.bfloat16).cuda()\n", + "model = init_te_gemma_model(hyperparams).cuda()\n", "\n", "print_sample_of_generated_texts(model)\n", - "benchmark_generation(model, 64, 128, 1024)\n", - "benchmark_generation(model, 64, 256, 128)" + "benchmark_generation(model)" ] }, { @@ -244,12 +290,12 @@ "id": "8e397a65", "metadata": {}, "source": [ - "By using THD attention we obtained following speedups:\n", + "By using THD attention we obtained following speedup:\n", "\n", - "| Models | max_input_len=64, max_new_tokens=1000 | max_input_len=128, max_new_tokens=128 | \n", + "| Models | Time | Speedup | \n", "|-------------------------------------------------------------|---------------------------------------|--------------------------------------|\n", - "| HF (baseline) | - | - |\n", - "| THD attention with TE | - | - | " + "| HF (baseline) | 42,0 sec | 1 |\n", + "| THD attention with TE | 27,8 sec | 1.51 | " ] }, { @@ -303,10 +349,38 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "31a3a8a3", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "self.config.fp8 = False\n", + "Another string ... \n", + "\n", + "I have a 2007 1.9 TDI 105bhp and the engine management light came on.\n", + "\n", + "I have a code reader and it came up with the following:\n", + "\n", + "16885 - P0341 - Camshaft Position Sensor (G40) - No Signal\n", + "\n", + "I have replaced the camshaft sensor and the light is still on.\n", + "\n", + "I have checked the wiring to the sensor and it is fine.\n", + "\n", + "I have checked the\n", + "====================================================================================================\n", + "I love the new Star Wars series The Mandalorian. I’ve been a fan of the franchise since I was a kid, and I’ve been a fan of The Mandalorian since it was first announced. I’ve been a fan of The Mandalorian since the first trailer was released. I’ve been a fan of The Mandalorian since the first episode of the first season was\n", + "====================================================================================================\n", + "Benchmarking for batch_size=64 and total tokens = 1024\n", + "self.config.fp8 = False\n", + "Benchmark with context_length=128 and max_new_tokens=896 took 16560.943359375 ms.\n", + "Peak GPU memory usage: 63.81 GB\n" + ] + } + ], "source": [ "#Restart the notebook (to flush the GPU memory)\n", "from utils import restart_jupyter_notebook\n", @@ -314,29 +388,18 @@ "\n", "from utils import *\n", "\n", - "hyperparams.model_name = \"../../../../gemma-weights\"\n", - "hyperparams.fuse_qkv_params = True\n", + "hyperparams.model_name = \"../gemma-weights\"\n", "hyperparams.qkv_format = \"thd\"\n", "\n", "hyperparams.generation_cuda_graphs = True\n", "\n", - "# CUDA Graphs needs all kernels argument to be static - not to change between\n", - "# the time of recording and the time of generation.\n", - "# We need to allocate buffer large enough to fit all sequences.\n", "hyperparams.cuda_graphs_static_batch_size = 64\n", "hyperparams.cuda_graphs_static_max_seq_len = 1024\n", "hyperparams.cuda_graphs_static_max_context_len = 128\n", "model = init_te_gemma_model(hyperparams).cuda()\n", "\n", "print_sample_of_generated_texts(model)\n", - "benchmark_generation(model, batch_size=64, context_len=128, max_new_tokens=1024)\n", - "\n", - "hyperparams.cuda_graphs_static_batch_size = 64\n", - "hyperparams.cuda_graphs_static_max_seq_len = 128\n", - "hyperparams.cuda_graphs_static_max_context_len = 256\n", - "model = init_te_gemma_model(hyperparams).cuda()\n", - "\n", - "benchmark_generation(model, batch_size=64, context_len=256, max_new_tokens=128)" + "benchmark_generation(model)" ] }, { @@ -344,14 +407,13 @@ "id": "53bb430f", "metadata": {}, "source": [ - "We finally obtained the **??%** speedup.\n", + "We obtained the **2.51x** speedup!\n", "\n", - "| Models | max_input_len=64, max_new_tokens=1000 | max_input_len=128, max_new_tokens=128 | \n", + "| Models | Time | Speedup | \n", "|-------------------------------------------------------------|---------------------------------------|--------------------------------------|\n", - "| HF (baseline) | - | - |\n", - "| THD attention with TE | - | - | \n", - "| THD attention + FP8 with TE | - | - | \n", - "| THD attention + FP8 + Cuda Graphs with TE | - | - | " + "| HF (baseline) | 42,0 sec | 1 |\n", + "| THD attention with TE | 27,8 sec | 1.51 | \n", + "| THD attention + Cuda Graphs with TE | 16,7 sec | 2.51 | " ] }, { @@ -380,34 +442,34 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "id": "aecee0e1", "metadata": {}, "outputs": [], "source": [ - "#Restart the notebook (to flush the GPU memory)\n", - "from utils import restart_jupyter_notebook\n", - "restart_jupyter_notebook()\n", - "\n", "from utils import *\n", "import transformer_engine.pytorch as te\n", "\n", - "hyperparams.model_name = \"\"\n", + "hyperparams.model_name = \"../gemma-weights\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", "hyperparams.fuse_qkv_params = True\n", - "hyperparams.qkv_format = \"thd\"\n", "\n", "model = init_te_gemma_model(hyperparams).cuda()\n", "\n", "# Calibration\n", - "with te.fp8_autocast(enabled=False, calibrating=True):\n", + "with te.fp8_autocast(enabled=False, calibrating=True), \\\n", + " torch.autocast(device_type='cuda', dtype=torch.bfloat16):\n", " model.train()\n", - " run_forward_pass(model, num_iters=100)\n", + " run_forward_pass(model, hyperparams, num_iters=512)\n", "\n", "# Compute scale_fwd with enabled fp8 autocast\n", - "with te.fp8_autocast(enabled=True):\n", + "with te.fp8_autocast(enabled=True), \\\n", + " torch.autocast(device_type='cuda', dtype=torch.bfloat16):\n", " run_forward_pass(model, hyperparams, 10)\n", "\n", - "torch.save(model.state_dict(), 'model_calibrated_weights.pth') " + "# Some parameters are in pointing to the same tensors, we do not want to double save them.\n", + "dict_to_save = {k: v for k, v in model.state_dict().items() \\\n", + " if (\"_context_phase\" not in k and \"_generation_phase\" not in k)}\n", + "torch.save(dict_to_save, '/root/model_calibrated_weights.pth') " ] }, { @@ -422,37 +484,56 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "a913f54d", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "self.config.fp8 = True\n", + "Another string ... \n", + "====================================================================================================\n", + "I love a good list.\n", + "\n", + "I love a good list of things to do, a good list of things to buy, a good list of things to read, a good list of things to watch.\n", + "\n", + "I love a good list of things to do in a city.\n", + "\n", + "I love a good list of things to do in a city that I’ve never been to before.\n", + "\n", + "I love a good list of things to do in a city that I’ve never been to before that I\n", + "====================================================================================================\n", + "Benchmarking for batch_size=64 and total tokens = 1024\n", + "self.config.fp8 = True\n", + "Benchmark with context_length=128 and max_new_tokens=896 took 19161.548828125 ms.\n", + "Peak GPU memory usage: 63.82 GB\n" + ] + } + ], "source": [ "#Restart the notebook (to flush the GPU memory)\n", "from utils import restart_jupyter_notebook\n", - "#restart_jupyter_notebook()\n", + "restart_jupyter_notebook()\n", "\n", "from utils import *\n", - "hyperparams.model_name = \"../../../../gemma-weights\"\n", - "hyperparams.fuse_qkv_params = True\n", + "\n", + "hyperparams.model_name = \"../gemma-weights\"\n", "hyperparams.qkv_format = \"thd\"\n", "\n", "hyperparams.fp8 = True\n", "# We load calibrated fp8 weights directly from the file.\n", - "hyperparams.fp8_model_weights_filename = \"model_fp8_state_dict.pth\"\n", + "hyperparams.fp8_model_weights_filename = \"/root/model_calibrated_weights.pth\"\n", "\n", + "hyperparams.generation_cuda_graphs = True\n", "hyperparams.cuda_graphs_static_batch_size = 64\n", "hyperparams.cuda_graphs_static_max_seq_len = 1024\n", - "hyperparams.cuda_graphs_static_max_context_len=128\n", + "hyperparams.cuda_graphs_static_max_context_len = 128\n", "model = init_te_gemma_model(hyperparams).cuda()\n", "\n", "print_sample_of_generated_texts(model)\n", - "benchmark_generation(model, batch_size=64, context_len=128, max_new_tokens=1024)\n", - "\n", - "hyperparams.cuda_graphs_static_max_seq_len = 128\n", - "hyperparams.cuda_graphs_static_max_context_len=256\n", - "model = init_te_gemma_model(hyperparams).cuda()\n", - "\n", - "benchmark_generation(model, batch_size=64, context_len=256, max_new_tokens=128)" + "benchmark_generation(model, measure_memory=True)" ] }, { @@ -460,13 +541,7 @@ "id": "8cdbb56c", "metadata": {}, "source": [ - "We add the speedups to the table:\n", - "\n", - "| Models | max_input_len=64, max_new_tokens=1000 | max_input_len=128, max_new_tokens=128 | \n", - "|-------------------------------------------------------------|---------------------------------------|--------------------------------------|\n", - "| HF (baseline) | - | - |\n", - "| THD attention with TE | - | - | \n", - "| THD attention + FP8 with TE | - | - | " + "We see that speedup is smaller than without fp8. It is because ... " ] }, { @@ -496,41 +571,58 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "96264b9c", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "self.config.fp8 = True\n", + "Another string ... \n", + "====================================================================================================\n", + "I love a good list.\n", + "\n", + "I love a good list of things to do, a good list of things to buy, a good list of things to read, a good list of things to watch.\n", + "\n", + "I love a good list of things to do in a city.\n", + "\n", + "I love a good list of things to do in a city that I’ve never been to before.\n", + "\n", + "I love a good list of things to do in a city that I’ve never been to before that I\n", + "====================================================================================================\n", + "Benchmarking for batch_size=64 and total tokens = 1024\n", + "self.config.fp8 = True\n", + "Benchmark with context_length=128 and max_new_tokens=896 took 11993.3818359375 ms.\n", + "Peak GPU memory usage: 56.60 GB\n" + ] + } + ], "source": [ "#Restart the notebook (to flush the GPU memory)\n", "from utils import restart_jupyter_notebook\n", "restart_jupyter_notebook()\n", "\n", + "# Import necessary packages and methods\n", "from utils import *\n", "\n", - "hyperparams.model_name = \"../../../../gemma-weights\"\n", - "hyperparams.fuse_qkv_params = True\n", + "hyperparams.model_name = \"../gemma-weights\"\n", + "hyperparams.fuse_qkv_params = True # Needed for fp8_model_init().\n", "hyperparams.qkv_format = \"thd\"\n", "\n", - "hyperparams.generation_cuda_graphs = True\n", - "hyperparams.cuda_graphs_static_batch_size = 64\n", - "hyperparams.cuda_graphs_static_max_context_len=128\n", - "hyperparams.cuda_graphs_static_max_context_len=1024\n", - "\n", "hyperparams.fp8 = True\n", - "hyperparams.fp8_model_weights_filename = \"model_fp8_state_dict.pth\"\n", - "# It impacts the behaviour of the load_te_model() function in te_gemma_loading_weights.py file.\n", - "hyperparams.fp8_model_init = True \n", + "hyperparams.fp8_model_init = True # This will result in storing only fp8 weights.\n", + "hyperparams.fp8_model_weights_filename = \"/root/model_calibrated_weights.pth\"\n", "\n", + "hyperparams.generation_cuda_graphs = True\n", + "hyperparams.cuda_graphs_static_batch_size = 64\n", + "hyperparams.cuda_graphs_static_max_seq_len = 1024\n", + "hyperparams.cuda_graphs_static_max_context_len = 128\n", "model = init_te_gemma_model(hyperparams).cuda()\n", "\n", "print_sample_of_generated_texts(model)\n", - "benchmark_generation(model, batch_size=64, context_len=128, max_new_tokens=1024)\n", - "\n", - "hyperparams.cuda_graphs_static_max_seq_len = 128\n", - "hyperparams.cuda_graphs_static_max_context_len=256\n", - "model = init_te_gemma_model(hyperparams).cuda()\n", - "\n", - "benchmark_generation(model, batch_size=64, context_len=256, max_new_tokens=128)" + "benchmark_generation(model)" ] }, { @@ -538,6 +630,15 @@ "id": "3e30ca5a", "metadata": {}, "source": [ + "We finally obtained the **??%** speedup.\n", + "\n", + "| Models | Time | Speedup | \n", + "|-------------------------------------------------------------|---------------------------------------|--------------------------------------|\n", + "| HF (baseline) | 42,0 sec | 1 |\n", + "| THD attention with TE | 27,8 sec | 1.51 | \n", + "| THD attention + Cuda Graphs with TE | 16,7 sec | 2.51 |\n", + "| THD attention + FP8 with TE + fp8_model_init() | 12,0 sec | 3.50 | \n", + "\n", "Total memory usage dropped by the **a%**! We can use it to increase batch size to obtain even larger speedup." ] }, @@ -549,6 +650,17 @@ "## Conclusions" ] }, + { + "cell_type": "markdown", + "id": "824129be", + "metadata": {}, + "source": [ + "
\n", + "\n", + "\"\"\n", + "
" + ] + }, { "cell_type": "markdown", "id": "7bb2452d", diff --git a/docs/examples/te_gemma/utils.py b/docs/examples/te_gemma/utils.py index b9ce4b78b3..810f4c6484 100644 --- a/docs/examples/te_gemma/utils.py +++ b/docs/examples/te_gemma/utils.py @@ -5,6 +5,8 @@ import time import sys import IPython +import random +import string from te_gemma_loading_weights import load_te_model @@ -227,33 +229,44 @@ def run_forward_pass(model, hyperparams, num_iters): def print_sample_of_generated_texts(model): tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name) - inputs = tokenizer(["Another string ... ", "I "] * 32, return_tensors="pt", padding=True) - + inputs = tokenizer(["Tell me something about GPUs:", "Tell me something about NVIDIA:"] * 32, return_tensors="pt", padding=True) max_length = inputs['input_ids'].size(1) new_length = ((max_length + 63) // 64) * 128 inputs['input_ids'] = torch.nn.functional.pad(inputs['input_ids'], (new_length - max_length, 0), value=tokenizer.pad_token_id) inputs['attention_mask'] = torch.nn.functional.pad(inputs['attention_mask'], (new_length - max_length, 0), value=0) - inputs['input_ids'] = inputs['input_ids'].cuda() inputs['attention_mask'] = inputs['attention_mask'].cuda() - outputs = model.generate(**inputs, max_new_tokens=100) + outputs = model.generate(**inputs, max_new_tokens=50) generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True) - for text in generated_texts[:2]: - print(text) - print("=" * 100) + + print("=" * 30 + " Generation example 1 " + "=" * 30) + print(generated_texts[0]) + print("=" * 30 + " Generation example 2 " + "=" * 30) + print(generated_texts[1]) +def _generate_random_words(num_words, max_word_length): + words = [] + for _ in range(num_words): + word_length = random.randint(1, max_word_length) + word = ''.join(random.choices(string.ascii_lowercase, k=word_length)) + words.append(word) + return words -def benchmark_generation(model): +def benchmark_generation(model, measure_memory=False): batch_size = 64 context_length = 128 max_new_tokens = 1024 - 128 - print(f"Benchmarking for batch_size={batch_size} and total tokens = {context_length + max_new_tokens}") + print("=" * 30 + " Benchmarking " + "=" * 30) + print(f"Benchmarking for batch_size = {batch_size} and max total tokens = {context_length + max_new_tokens}") + + input_str = _generate_random_words(batch_size, context_length) + tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name) - inputs = tokenizer(["a" * context_length] * batch_size, return_tensors="pt", padding=True) + inputs = tokenizer(input_str, return_tensors="pt", padding=True) start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) @@ -267,5 +280,6 @@ def benchmark_generation(model): torch.cuda.synchronize() end.record() - print(f"Benchmark with context_length={context_length} and max_new_tokens={max_new_tokens} took {start.elapsed_time(end)} ms.") - print(f"Peak GPU memory usage: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB") + print(f"Time: {start.elapsed_time(end)/1000:.2f} s.") + if measure_memory: + print(f"Peak GPU memory usage: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB") From 967be16e83aaa8d889fa0f375dcaf360955d1fad Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 22 May 2024 15:46:41 -0700 Subject: [PATCH 140/244] Minor fixes Signed-off-by: Pawel Gadzinski --- docs/examples/te_gemma/te_gemma.py | 26 ++++++-------------------- 1 file changed, 6 insertions(+), 20 deletions(-) diff --git a/docs/examples/te_gemma/te_gemma.py b/docs/examples/te_gemma/te_gemma.py index 54de549ce2..937cd98780 100644 --- a/docs/examples/te_gemma/te_gemma.py +++ b/docs/examples/te_gemma/te_gemma.py @@ -252,8 +252,8 @@ def generate( ): self.eval() assert self.config.qkv_format == "thd", "Generation using other qkv_layouts than thd is not provided in this tutorial" - print(f"self.config.fp8 = {self.config.fp8}") - with autocast(dtype=torch.bfloat16, cache_enabled=False): + with autocast(dtype=torch.bfloat16, cache_enabled=False), \ + te.pytorch.fp8_autocast(enabled=self.config.fp8, fp8_recipe=self.fp8_recipe if self.config.fp8 else None): batch_size, max_input_sequence_len = input_ids.shape[0], self._get_max_input_seq_len(input_ids) lengths = torch.sum(input_ids.ne(pad_token_id), dim=-1).squeeze() # [s] input_ids = F.pad(input_ids, (max_input_sequence_len - input_ids.shape[1], 0), "constant", 0) @@ -280,10 +280,10 @@ def generate( inference_params.thd_setup_before_new_input(next_tokens.unsqueeze(1)) output_tokens = [next_tokens] - with te.pytorch.fp8_autocast(enabled=False, fp8_recipe=self.fp8_recipe if self.config.fp8 else None): - for _ in range(max_new_tokens): - next_tokens = self._model_generation_phase(hidden_states) - output_tokens.append(next_tokens.clone()) + + for _ in range(max_new_tokens): + next_tokens = self._model_generation_phase(hidden_states) + output_tokens.append(next_tokens.clone()) result = torch.cat((input_ids, torch.stack(output_tokens).permute([1, 0])), dim=1) return result @@ -355,17 +355,3 @@ def record_graph(self, function, input_tensor): num_warmup_iters=3 ) return graphed_function - - @torch.no_grad() - def generate( - self, - input_ids: Optional[torch.Tensor] = None, - *args, - **kwargs, - ): - assert self.config.cuda_graphs_static_batch_size == input_ids.shape[0], \ - f"Input_ids shape {input_ids.shape} does not match batch_size={self.batch_size} of recorded graphs" - assert self.config.cuda_graphs_static_max_context_len >= input_ids.shape[1], \ - f"Input_ids shape {input_ids.shape} is greater than max_seq_len={self.max_seq_len} of recorded graphs" - - return super().generate(input_ids, *args, **kwargs) From 4bf081b9d03c7acb3eb0c2a7cc98156f587457ed Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 22 May 2024 16:15:56 -0700 Subject: [PATCH 141/244] Minor fixes Signed-off-by: Pawel Gadzinski --- docs/examples/te_gemma/media/graphs-1.png | Bin 0 -> 16100 bytes docs/examples/te_gemma/media/graphs_2.png | Bin 0 -> 15177 bytes 2 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 docs/examples/te_gemma/media/graphs-1.png create mode 100644 docs/examples/te_gemma/media/graphs_2.png diff --git a/docs/examples/te_gemma/media/graphs-1.png b/docs/examples/te_gemma/media/graphs-1.png new file mode 100644 index 0000000000000000000000000000000000000000..f42b50fe0d7804e638f5e719f90cd381cc565fcb GIT binary patch literal 16100 zcmeHubyQUE*Y6Mtf)XMjEsZotmnaRAO2d$XG)UJVARQ7Sjg)}I(A^yhQbP|V-NMi_ za}VF|@4d0!d)IyMT6evF-1|L?#mqT#&hzZOpU<=7v-XM5P*Zq-M~MdlfgUI+%4&f? zn1;ajIo!Lz|M*i_ZY6Y@(baJrbdT-%s zW##zZ#>wpvt6d5Nq6aC-%Dnc<*kAN^d-Mh_eB`9-X}$Ea>){J}`>Ln7jPz^_ytcnI zU3|fEZVT>23jJ1D{rivm*P`Qs#hyOpjr=OGR#LQ#nJ2;>$k5Hgpn_|tIP>rY?7*$5 zWCxj7tX7x+|K8kGl76_Ek&#$_@cnABsr~@z{>J_DwyAdFeFoqd|5$82nx$G74jgXg zKFy7(gJi_S#PMpJ2n3=s_5*O4%}n3b)zzE^3%IV?Rb`XulCW^Rw?>)BzGvEGfKs7a zJ>lHU0K2tHo^r=Q5qAxxK51HqS|+Rjrfr zhLXmuSXoQ>$YWX}c@%CvKjrX|=YRk%1o%f@DkN%IxtuQckjzWAE4B?C}(R zv;tGfo#Y_aY{4&(ep~g~)I9TdZ@R3?MtJyuSC?9yh#sI7Gi~f)d^e+^=ED+3wFP}P zYI|u}S=1eYHi21&@#mQ#6q4-6q&)G&Ov=8hNv;~d?^O~6-Wn$Ucz}!&TJ(t%?q_cp25&pb!vsmybm4Ex5Cdm#b)eanEu>a z^PSVKwJa7NhKTq4KEU$Azu2KMBhD{k`gFx}F7iZYSHG%05|^i&%$<_PHi#RmBe{NM zG2Wa>QAVS6nhJMx+<>dg%F5=GbZ5%nZ0nJLOl8l1P$)dqk1*hljHW^*5`esX?W4Y3 zKFS$Uk1r@E{0Q@TvoSFz6&;IQ-DV%#kxcnR!9Cm5%!>u8VCi#_HONhlGT&WJc?)_z z*cVNq@B*|O%tjF*M0;18wwu{dR`!^CAzvjG^5Ee$WwtDpTZErkm7FDLg`m7}c$wOX z;SEnIIQODYe1HV7>UQ|wP>M~l2hRSZrx*0Lx45Ly-`U-@_&3b`{{6d0>@ILQH^i#J zno%ub(OPgiMb*Z^z8lU%(4{ktcdg_b&3M8iC`|h9a*%?zj3+@Lv0|g~Rn|E_PL25p zlm+w)YP5sknAfcAdvqf(W!|{x|Cur6;$?(kYnDdw8lI4vd$Y*ypZ71*z#ND_t!qcI zg~bd7-FAl_yH4JneRMm1ZBkHNci-YoU7mhji4Us;9#Y_A?5ElAeMT(?O}NdO^vW5Q zt?|At-XT)bWNYf+p;ZQ90g{q?R;xy+)@*m2#(7pQMHTbMTf5Sie@e>6KEkeq>_z_w z$p54zsBYYSRifKCxvS7r_dN#TF>Mg5v8DrM6=f7#Kb_*v{i#bW9}%afK6U_MSR`~t z-+NWspRRF-GW+Ipe+8GtA^B$_n!wpK>RpW>77!5)_a1T;giRMCKkDiYh3u?jJ>%-e zoi|nyt?Ais&C1ceKS^dE?pGUSb3z)Xxtcej&3Torc!sqAd0`Q9%<>x5#|l~Jv#+v( z+I>cVw&SsFb{(X8{uRc7bPnH;7s=Ct)|-lL3OrZYq_e|29T4yP>Grcb z_KKr>WV^i2-WA(Dm)DX1awwN)#O3}VrQI1P^g+d1mG<-k_ORl7Ye6~2A&+!m4$@)Z zfpghQ5tsT&L7Oyn{5-cl)w+Rj^I5cPF5NPUgpx5vILEfwu0wZ;w)fXKnR?YaR07z9 zpyw5`_2gp!iF!T=+`5c{T7 z=U8GSUA?!hWGX_{Pyn>F&gV89K3R`#olXNh7ha`5`b4sXs_p9N@$19I@Py9I)y)TV zxxR;H6rw0wYECpQRl7;__Vmq@J@>R$p$1&jFWT}h4w=`jY<|}+tn8GVQxxFY(#=mv z7)!3Y6{7gvahClcKMc|cdP7P@pTY%*LV2Ge;E4_jNo;B9 zsjxpLNo1@Ro#`jw;c$8XfqU2|!(Zg&CYKZjly>p)#({t_Kph?N6B?~fB+6lG@;_G= zQ}jIOe!S(l%H$*7)jN@pB>z{l|PK|p>Ds6fWdz0cbMZ$(Y248H4S@EV~Kl+>%12#IcOTiY9 z@>vsZScU1D)qXSIx8_;an;IS1hT)V8iDOsGSyYkwiLO4G)%x2kW1><@Nl#K@{F^b9 zIpNpGV?)_T-8*-P|E)PC7p4p4X?3MF{3$?P130 z(lWL_+<;E5&7Wm!!AG7S{3KWm$yE4jvYy=2k%&7Py0EORDb#99w%f91#0#1|(e-sB zwEgb6n$`@LoqaHB257>YstWuK9Kz`YoP25WSWsZU>um|uG#BaeR_lARx>N*oGm(2O zCN}R%<+PE+DIjWi`(pxMM(ob}3U(w~TKtkrf!H#EOif~?;*-)KBGsTx!Y#eWZ5#6~ z8x9%%D0?IKeuh&v+7Cm24h5m$%g|bj{(IEr-ye84H&d;97cH#C=D3dfoYD#1VYMSL zoxeJFp4}toJvo~m8`le7G}M(L`oY_YEyJ64)BgroA0yvU!xXP$qlOFuT$ZwRT&Y>S z!nRzwMQ-z6ae_myNhtbeRi{)BWz6}qU?RE-Jd_aPB!vR2qNchVc49vc3_R_LjTtuM zPc;i#zRd8YCUUOY01efbcahe&MLw?Hz;ITNdn^5GD!?iPY0RPWD1ez_bghM7N4CrHu%TK2ACpp-XZEPV;W!^#^NN0on&5JfNiYcH{So=;^F01z{X3w*y&Llx2iEwc zjM%|1k|7Yt#r*{B@%Z>fEHdf#K;CWVwdaSH!(S3)NGh!M=WV`A2aV?>#($pIKDkPg zKRO1s)Hw8UWAB|OJbCDf!HbpQj})$F1tNZ?9DT2v@x_Jxqk%_1Wu)B$QSk3PylHU} z4UKL5A+R*|a{d+~Ks?~D!jDg0L4OJ9ObRydHh=X2zGNTKq@`O;x1GB9$xs6mL?Lkb zw!x^ZNK$KTeKU?`E>;R_RK_*3*=gu-Oxm`Re{iw4ANDe>z39~5$Yq$L;rEp_pl@hPP#Y*PX_B)>)Bpsk;*?3YGQR7j zLiCx>((Um#qf0jIYSNM=$iuLwLN+GZu51(LQ>f}WM@gKVDOeyBki;1uu9F z#h+>EC(}o8Pi3&+zOek`^*$t-YU(}f*=&i0v4ock z4OC(KMqOjTxncxm%BIZ;{8D^+R~wycYt09?`0rEjk%kf4yF>E^{V9LJwFvc-m)&NT z?ePz6f(6-II;)mGom8NrrX}V}ik8@@7mzg6XDkC%A&b%hv2bL+TS}qAF)uZzXSR>) zOuA{kd%#TNmTT_R2CCTI_fqXeazVu^;;o$&#sGhYgNF!>&_KdgalP{!oo44fz3VEE zq-@)}5y{*Aob;xW$cu?g#L&me$)Cas=;_Vc58zeSGC72+4 z{cF#gi$&@9w5+JUE5=Vu}*aOvw@t#?Jq-5KkU*Xju3Lq3;O3#hlXhbKP8p6d1do)mWUnzlJiGfK8({ffI zn^WbJdaJdQXauR#h@8Bdr;*osp5o7eV5SHPK9p_gG@G-Cb}fTIYj_WcSIr}VY} z{Kk*|@tNRgx@jbdcKMAHC*;J$$bULLGvFh`r4pqd7i8~F{qnqe*30zn(8(J64A1Rg z8(Wdk&xaDxB?mWDZI>7N(G;G7v9t?!I}<_#`oSd5R0hweQ;%YyZUdJyZ7VNNcf|1$ zm>2+)FtQp&sV7T;G~IHF4)sMW+#WM2h1@@{qjTGk)4QI<*Q@m620}C7-k&G*5)4af zv=Mw>-{m__xURVQ>FY8@gMWM&m(^JmF`cJyeDujer~d&6?Q`w|2!&O|2?nsBq4x}X zUYPU5NRKKuKmQqsRT5e@_zr~W2}jMz@mebQirZD+fq?1V{n}PF)ThS-!e6XAW*?lD z?)9_CbmUiEA5STacC5S6w_k%Ye2-fz^b*JNyhKn-H|fCoVcBWnS?a%VlmRCZNYZhT zjmo(ikmYM*ePfDfI=ch!T5OMDp_Gn^*mRN7@NRjhRwu-5Fa0rLu0PbQoqBC^VlF(l z8u{^_i1l1%COkV}Upou^zKPVLZ_hjHawOF!Y_l~#>AsQs^pY6Yd|FZ@yROunS);fx zBq%dodYjjwM^u3CtjkamT4;jA1}*P2FYQrtfRqq(){QN0yMnl17LM98zkyxm?SLr9 zopC7w@~UzL4Z2C$u4)e>7Mc5sf?4=4-?n$)| z?+Z9nU)`t10X1Be3xhgNw2cl&-H|iP9WGcM(Nt)_uLX>J&U&{P6?Ug~5z~c`U_aqC z)c<&zI|%Sj2Z*N^bW%lYK!WPy(_%8Ca@?CX372X;9}UVC&=A_?p^@TNm&6d@FS>DGAc2K^K=hD4ikt;;zXA|0zZ__-{Lx!o# zpjWiTUJ)|Wjuj}kwPX>{#<(_i=_!=f`RdB*8XjfGm@n)w!5;oz@o;a}uy-QXV4MG- zidINyN&L24uey2R*SliZG{zZFNR9NVf05$%R&V&7t?m1z8osHS(YM!4GS9)q_XZ|B zI4JVGti}lL?^q~A?hH4c*M5-SfxIB)xmK_E{i*H1mSkzIrJ0=It>1KpeL?!!FFuK*TsSUTi$c7MuSYlHAbrFGl^8*w5*<2OC?+{&TH=@nIBJ3D%Xw#{KidKP zZnFT%$4dq&lf|G;f3;A*i9X}zollv}SaK*jLDzG*rO~QH;NFwg7X|^r_=?GL^n;6X zqqEtehV~!k!^Rr(@6YTbj?O*t~-u|mx`ai)5H(K zAHf7si;hRvyaqL<8V8^(-#@KVoAgTfzId%_V<#G#bZQ3rjw~C4h(CQ8H>ksA zO;*D@4%9w}E->$>vww7He>OW;1gjK%pO?)S+f)3hB9Hu#+~5mKViJp=7srqJyuXB7 z3d&IPwAwP}wnF}VQ%EWEsomf;g;jIOFmd>8g%$Gs=h(4#jDtg`WcygNz**ZQzvw_# zLpRTR{tvx4^XibyZLttdBz|+2Iw@E;IH)%2!pOmx)i66?HauXlL_#5C66*iXef?|B zN&~C6kEa|LV2v};Y}ZkykNkQZd?{%GlowesTp*h>fNZ=QF)Bwb#E9Fe-};&lDR==ZvVx*XK1gptCos$m{kUGV8;b zP;uk!NC0{h2)_i5lSZcUtPYgx`k* z{UCEw0pRA!vgQ?D&zp$Lv-j9KimfmK31X-^c8v9HbXdn4`+`x#-Dt5a^ZEp`*26cJOB73$B#u~o>1zy zlO;x(Feo+LEwt`N&&d|OJFa=ml%=@0uZ7tLrL4R#?$Y={A+R*Ob-2^p7kvCCRd zQe|vzGPF#^v%`wqjQ<|Z!LF68LkxEftln%mKAX?>%sIvXCg}~-wRk$t&EXBz?tXbK zbxt23maDwm&~}NjY&ujV`XOCrR_-ENw&|8PNz;oH&cS;ob;%KDX4nVkdHO^QR6R8l zH+J-Iyi(w3_NQHa?n8unt6$(ZD+;q_JNB^ThSq6;)5xub)Op?J@H$f`m*M`glP4+DXCX5~nzGn|D#HxBk1E#SlcV$1v^8 zF+6RtJLd@k$szkqPra*6`R_T;q*pkugMqRbOGOKh<>n8qsH+4Ve)F`doq&tP<(?m!{a_`)%JrpMWx*f#Sd{yS6nQgkz}HaFWO&LH&b^m@;;-S z7c`Rh=KT2gdbKgs(m7qY5?7r>@_J4766>4rK6F68vvKaQz4PhdaBO7-NE_jPPtT0Y z)GX)HFfO>`at!73+n7b>_hydPUnKqXA=uVmJemp{{pxclpLAb}T-xb^XwjFjd*E1i z&Ghno<^o}^+P*<(2A$Ze?UVyGXYfbb?JgY3t<1;G<|4lc8mw`oBd8O!#OTHZ0A?b+ zEG^lNz)T@J%<1!vhgpcoBZg=ay3AWa@q<7zyb1vAF%9lrLp?(~iSw4WOD?B=bnZcj zIC;cXFU`rG7MDs1a={b1^!u)O$iIF!$HK0@!`KfF#^~nh8ETgfirzaYl5E1s9ge6> zbm}YT%Ff9-AacdOTJ7SXd#Pf<1ys?WGIg;*6e6ya-_IRhZ2uZ85^bR(89oS}wU=ub zU@-^CJ^B1UtvpGOF!W0Iq5fR$XNsb*GQWGdGojIYFESvr-HHbc-uFV8e^{y{Gq;@Y z{7m7qqaEizaVc8G*@HLXX+wQ5*QHVT(#VSYEQx`+WtiUF?pO_ntb+pXxjAw^0$HvuP@3r?m*@hLe1#?bUMI- zh~8?#l6VjJyIj1s*l>VaO!lCooTM|ulmrg?YtK19Y}t0NU>U||PUCaM)B=R+LW zyFz4x7a~cA+oP;uY zWrLrmxHY(Z^{iS~8T@*0rs~Y$*%X}&08!x4VsC$C#*2#++xy$o(F?0Vp!aSt$c*I; zuuJcYUTaZHRgkWITl|=rE+r2=mMGk_?^DuTAO>ODEB2>+QUZt>b?Zd zAp1cZd%FU+H}dILO4hBP|K9(|-S#qLqHoi1iHC;L}F7T@gc$2c+2;szWv|7oZHLmybF6kG4V^??O(4p$Nc-L~Or9;f{$ z3-vGU4x|rF^ad8BZ>AwOYy&&FuNwg zflIPt=2ddHKQXcZZFm3vbhV5QP(9qOhtW`WGNNV&;(WREs*)f$e3j@rZsb3NngO@V zCS~~5ZHuQOYjQEOM8jVblB;NHYL;5i1FNj8mt3YU^>~+JRas{=OYs2q;NwJLcUtP> zUFV=xdwbcRJ5x4g#WHL0jFkkePouTeFYRD%X=l}Jx&6NfHB+joF?RyO#-Q1Q3p$IZ`Q{DjcRM>xzqOc9^@fa?7?Bz`cAxdlyoLr5*oy3D;+^8_ z^Ci!;v%GF%x2i`6H?ldrt`Oo86m^uVTs4}{wY9fdaVoFZwpO0EbmC=u0FoXT(^eNM zKgv-a?k*BA6>!NMiI1qB8KS%V%R1z@KD>!i7jebnB%WKX8Lp(Y_^&%5<2e*(PVMtM(^++AG#i>t} zXMB+Yi=kN&v^>@w7h_6Sw_7a0ue3Y7SPw>$$Q8+Y{0nR|n{Aw`p0YZ@EJpf9u09|nI1T%F%so{ZHyfsI@TnTd_3=W6ZvoZu2IM_u?^6NPcsKE}p0 zj_pW~1-FS;DSQ@pO(013N2H`?ebJ9;HnIh*3G9WaN_99?Sx7<&Xd>qwn;)>VvlGV% z=p>0{0lyx!qXSNNr`4wsWyZ~9_~hI@Dm<+rrKOyp4&f!g;0IJf%bhrMlkSW17L|JD zV$zyLE!X&A-D5cqMzW>BU)GpU54P4ysg4e|F zSnlR>6asTV`kc@X`+wFZq}~nAQ>kJe=Es)l7|9erMofb(u1>c>pz7KCQN0f*>)rK} zB}vJ{XYJ*cmGMx>L&#d)yyJ-~dL!NGQ%1&p;I@(}JQ*Shv@gIcdi@z>)if}`Q%c~4 zp!`|@Q5r5c(qFRGy0EVcrmD&~LCP%kLuc(nf4=Qr?TtF_z#;@Wvpxu-$8J5)*jXCz zvW!}d@q`@DSj6Sj)Vw}-f9fBcwQsE7HfXt1la-M0@MeV&626Q4DeeJYkA4=5=&aT& z(|IN+xU&8QEN^AS3|dY$bnMi(&iqK_*nq*)dioF#5ASH!I@|2}Z2w~MFd$4Wip=a} z<0sIal_F?I#-f@M=n416nXhwr`Rdh!gT_tTl%H>n+{Wa*HZwhW&APCEZ)|wXfW$oy zdTy5<1`9Q5^)WkGXdD|~6S2>Siol3K!x=wx%WAFy5AU?QspFx9}wNZ*fCf#W#H z4VJsR`^7pBBtpOXZKrvq>5mjXz4ib}$L4(@K=y&EXlUCN-dca$dFh0`kf%OI5WSlX zMIy_qE*v^PqG|Cg;R5O%*QagBGsBkagXT|6_x!I;JA(mL8O5zM}LNS7}?qH0rMzmY7)!G(Z;m;o|^+|i(Kzj zi`!o(vZ}9MZf5t73f zZ{BRp*YnW&?(ucgT>_28$9-VPlSqcQBN-8KQdhX3qcH`#m6=M5hP7C64HJ{tb`$~% z48X?8+0#uZ9UUEgztp6aayXeQm~%LV*9_CK7K|GsnNgZ~@-hke-2GB{N8E`}SJ%?G zMLiy1EnmNWZQ2!bPp`rh19jbwrivotdg1EI&!V0c?Y;}^v=_4I4tZ}ckkY8b*0T1Lo@{+=92$8&G+x$D{QBpw)`35 z@Y=7PuJFJ1*qf`x=g=*e1xADW>~j6c6ekBHmDwL&*&F?EC1cezPS^_E-he!tKSo_2 zcH0YGQ(s4qqc5nYlcZx2bA4BelX*bvg2ag`r+yVSU}6zVeR+R>Ns#HzWby4R8ag_m zSU3cxr(0uQJH>_0%N-cE2G9C9IXMXe-F-kg$zHTu-b?^^RpGYtCa{K_+#T2V@A@ctL^JtSRr_z02zEY;0__vA4$t-7dM$cERno{HB^3 z^rl2-Wi^VYZ{ufr@Ep<56H6H)UcU?H1?TPuGLBsrZ~0SsT~=j(qze-!Fs*caz@p$Y z$L4OkFu#RU#r*ySpr*^H7DtxZL6{9Hy z@tj%_WOoUuw^G~QO|VFu{(1=t>lI)J2J@xzv2Muk|@KfedJW~*6~xeP;ZPXVlhKf&$Suo6^Ry{>P_}$YK$dRJ&ZDqfrQo z1W=L?a)FhWo-Vn`gF{Fgot#YOhq~OP5O&P4UPE0?xA%E(W|A-9agr?nYv8q!YAr#f z!(Q|v*$Oc1fa}(|?wbu_IpUXL1adE5VtMS(easa1T+Ioi?dqe4lwV$~#pOZJs3%Zq zIJ-jrZSZs3g*||_o>784X=12^39=z5Q6M4%I0mw)4+o3Qo=u7#fKQbA0@*FAUmX!Y z3}I5rdN+i?#?wz;acyKU;%C77(gZ$*$4w0l4e`GHL1DkpK&&otjP2Np&;cDV*^gw1 z(nvK2Abq;Qn7=-2^hq(U>d`ePYoWG16N|uw*wG zdMhd2OG`_Wi17j9AV&KuOH0e;Sn-37emV%y4d!!d<9D@P2wGV3i=~r>uMFrGsbz>@ z=qH=i&zn`x(nm7H`4DeEC&JGE_HA`*EDs1sKyW;5Da;_?nfECxivWNf5Okw~^#j~P&sUjXKql_S-w$8w z>aLy5JB2YRHSARl(+|6bm6AAH=Oi$Bn~Y{jQteJdfu8_>7bo!t*CgQ53Yhj08QE%m zJCaLQbMIL;=WrI9S!w5wq*8(eT0NEjj#Jwu%Qy9G_&e_^&b31bS{%^hU={>~zZz>3 zt47cf%+M((KOg%xAWNf;2_>$!@=|~_j*Q2I+LN5i*9pp+<<)f(!!IsQ1(>?K;ETy$ zIiI(;Hyby52eNCJ3(9yxO3kSOjZ+)-(?#6j!1iKVqQBQx_I z;Bg>)du`=}9raN;Ed!ZX@02b$#$mp$)3$aMh?)<@Y(!!hf&NX95UQNZL`uI{U{NYOLkW2TkE+&opoSx%J~4 z8MtA#-$P-zSKc>o&qH_&vr6*vc42m`^XoCvVIQ%R$yysY%Kgan*Qn5aPoRWpT1;qp(gj>R0g@KT3gA%8 z7sH+`IP^QHk0lp5aBB`1;1>E(?;P_>+tUkssL@&T zPNc^2o~)*}$%bexrzg@G-@8vrUwe`fgaV_5C3DL`Z~Pt+(h=OJgKrA}dFrkHkxngq%Xe`F zvrJ)7W`aNJcFhrPw}UDrq0Zd4!3g)~IJc6!2UNaF}TbPj3hn&WAg zbur6X0>1((Gr}PHuO0~7)0Wl9@II0sSG7GdsCN#+@s2#Qk0c|;m$JB#8eifHBb*^+ zza&P&h`z+{aJ5*fVEwB?J$ipzwYn3j9}@tQOw-v?|IUVakZhm5kR3~8L+)>n6Uy27 z>23**gi)$L+u_<7)bV#0EM4ozLCz}nlMG$Gj0nD&m~&)UGzn+(d05acX?&Gq6MG0h z!A`(kQ}CL$fA?sR^_^V=Ynr+_YrXT_gWV{$0Olc7I@|hDzYSnH;=Z^dpvioaze_(^ zQwX)ql;JRP#SjUOs6w+Si{9|fUC>5YQ5L`ffW2nw8hIg`qmqOa)*M=i8|yn)2P&bhsy`&GYR{9cX7)V*cO0G@es$$z472)8!6s75 zk$aJhd+p4SB%}HMh@N17zHt6eOih_NjY@|lhood{w;kZ9AqEk-xSEe{D@HRQEdqQM zVTJuHzx`qp>8R9o=_bfd8YR|P-TSCZ#fb2dSJd#t5@7|9pS9j@hf zJvG{dMXyCNM=Er`b6UV#yg`AtIT4&n$X-uz_t(R?bBXxy519DYzVcPAMDYQC>gU%{e`FTmG#yx~mv_`22U>8-{2)j6^O(e3_!IJZCa#oVZ*|R1aX=zT{o^ z6D8f3Sb;r%*$o~o)uv#cDqmdx$qBhEhHzk}j^}Ji_KL1LCea-FfR%-|!>~Nv$ER=c z=mr*zP+wnvpGPeN^hxB}DL1J*{c2=So@$9Sf;RwLwa#*ihq(KJL}3xzmeub^q<3yq zi;hF8P#UZB^{dHbqi7skTZzHsf1&S6bLlz1=GFW}0bS(A#et+c{ojKawA@vvzZp{e zum0yI;9$8kSdpC#x1mL2?VblRr%%Qa$aeR6fZ@=#e{Y3+I>g65U^lSXN?TUXiUBWo zT$KN^%ODiYvY9mO05f{Lc|4R&c4y+-%CF9;3~3;-Gb&y&qpe{WX?t^wi8J$-Oi3_pGsMZDz?97s|-?_6T;41{UGmJ0fwh|v;P(cFMl4G&7_5V6Jn^H0Z zU<29xuw2PDj#Jvr;2734#M>(kf6d#stcm5K3!0k5V0=WtcV6 l`)04~-2Z)^bqtNu*;cmmgMFF|cugIoB&Q}@{=ziqzX0Ag*t-A# literal 0 HcmV?d00001 diff --git a/docs/examples/te_gemma/media/graphs_2.png b/docs/examples/te_gemma/media/graphs_2.png new file mode 100644 index 0000000000000000000000000000000000000000..35c34ede5559bd0c26ce807789ee6d3fdb2bb062 GIT binary patch literal 15177 zcmdVB2UL?=*YE4rt?aEJd#gwjuz*MtDG})^A|kz)P(%nV6d?qpTTw(nT7VEhDWQY} zArLy*%BDjQ0s#^NBApl_B}78I5AXYpao_j5_uO&6Gsbt$IE;~zXC+zBTx+hi=9=sG z{E0U=HQ+lga{SPtLwtt!?pPi=bQpT*&@cMG9RohWA1>Dg-hPGMHnjc?_>20@<26tg z4$-#{u?qAI33CgE9P$bX^oJ)Sl{Mul?bbIgK>0f0s`S0BNv)xGlW5I7{Z}uFyQ*ij1Mfn+} zUwF@7`2BaOUx)C9q4@D!CM0wowV#OBq{cg_dZ%D`=POaz^}|n}ME}720`%@r zUTlrO;qh}8Ig3>fgw`gzGwz7dpZ%wwYY@HqO{S!ueS*z!KJO6U{`=SSs$b$s+K+gE4hs$oIQ)tmonrC;3u31YR{%}wp1Jh-RS#d%X>Ik4j^k)XBM zNBRfq!!^XaSdb3;Pk%*!8ru zqORtlGp4m#QhQL(P+so)PL2mVP3Hs1_*=xiA*2|S>KeE-+#nM#-t$E=e9sVXCh0>m z)Fh^!;U5$xgZ=DLY?+$ zj0&3n*x3|nV9seBD&4-?&04R+v$Wn*Lcuy_qt$bEXRc@}kaT(7LMB`*yFrei<}jN= z$)$3s!F4RODx_Txq>W(S_+!Z3Rc3hhbHD)qH)}mZWGw6@Wzw4JwN*NN?XEWUHX&x+ z^0~{@&}AB8-t-4Uv&2L0Qv2)b)KBrX(YePOU2Q0@;GCYb^xgKTZq=FqFY~bGvG?4w zMrAsBt`5jDDHW?%+LQL%Ht@pYfyq`^cUaSL<*F~3DVKL@KQcPIT))**%7`=6C+wc= zEjq+{UTvNG{;s=ASZ3G31jy_9Y8yRPK*T~vb~7`sD7Fpvm?&^N+~M49P_c3&GmL;0 z&*bSyHxBH>!3RU~&zi7gsi&bLHsXts6xLQ^KT_$ke%c$d2W?1A0yC&?09jUBR~bH6 z!ycKl5y_dtSgP#7?CDqP#b5lQrU#!U8=N8 zj(3{09l=TLX}|Ao_tC}))`@65PjcmRpM@f`OB+2SLrFKuY&D0OBEUz!ceQr(st!oX~d;IhrA-x80PfcNwT9}6Du@u9*% z5QCIpgyLJ8Z%D*8h6E*fa5sYV%@c1eeJ>1_GF&PfN-Pm9Lw9GZ+UeMT0PJA5Y3#)j zclH-3Ji;G_|BMKs{+WGHk+R}FqL`1hFibKSNpv$&SYnRQ1BNTnp6*HC?Jjl~EdI zMqb63>Z?ULw&%=k^z7A_&;c44^V27I)CUW+P795QAjP9=6iemA%PF`cOpX09zv!i( zW5gG$3u=szy@Cmy(@HL&jR@ZEYa!bSd=a}dXSZ77R_fo|OQ}!Z8u?K)F06-pWtDbT zJCto7-b`|8jq-VF>+g#&&R=nMz|_^%l>^2_%@ob`FvUEsWoZx^(rI=yIa8o#L208* zf;pV(utc>JP4yAsc2{2%@QUP)x3)fH+PN7;Kwvqo(T*t|A`z>*XBG0YKV;9Kjx4^o zogedK>*7%b>B_>`dc6d6v5cRgu1?84-RMR~8s=GQQDy99*G>~7a*G_aHpni=Z@j#m zpYi>u1JJqeS=$-LRy>gp#3S~qIPr5B#;a_iW-|17|J1p@=f;(&?5RR6U^e}pJsQy6 zV27U%N#9=njI>V|X^=PclE|e6wY)<kEzMl)E(^y6t3+{oN2Iu;ktPv~x@^ZA3f!duJ2lYK zyVNro^X%xrPur2tkUG~*luejeylT~#Dfeb_M<~s_?Mrj_5|=5}Kt1hk_hz8bpi;(a ztQtPgYzpakcCP!{h)6X%VZF191qz|(%We#WfmN>$V2&iZ?aqHwlI>etESO>qcC3Yp zDJj_~^E&T)XGDp6T}B1pL0@%EE!o$^*b^V)9KHH_Ee&yI(qag#^kjLal%82si!mtC z&3R)xSOJ+=mal#Ki`U74p#?z?Muv$=BfH_r%HB}cWDbqkHM>mR8^nGnf49x*4xd#kGakAhC@|k<&8wEHz5o%&&Ecar*o@h zpgHKAskR7M;b@zsVO;|%cRHwp9+aa>EF-$wT_#@(p}}}U_626rGklk6$+o_8udo_h zAB}5kXpL*+qq1^8mZtZImvNa;__}&~Rh{A+`|Xj<^7n75Th_L7+dfZf7(rgjXcK1Y zg`Ky3$5O9qUT}r=TC%yF{f@!j=m2dkE7~Rx8*3n&%Md{at2A@F8`^!nYd2m_F)6Ef z)ZE}E=J%sIARhnRU0O@ZcEfbdZ7~%P_sHKj*E3HN%y?B|1tmt0rS1QDnrHRYgF0X; z1!Yb3G3p1PYIEnEls7CSk{#AA82F?1m*S z+c}g6@M3Cc_}uElB`UevI{FBy5*VV!8)=o*t*>LF7WdF|vt$I`I6vE3sSPF!Qv+n= zM&3FRG>_5DNFkV+%)2sW&2l;cp8ov(;*Aak3hA`Kn*CKd91jllwkX;bL zPDW?p);1ZhvJnuD>6nN3-mlVKlMlGvYBrKM-DMxp^!|=Y#|8>zU}8w#7&^1N=lY9w zX#Zowq2**`Ej-qo5t?p*vvZIbLi#H-vYS)yAz+tz_jgYXEsj3=l&-p^6?a@&`~Ir1 zulMhA(k%(pE2KE$Of`2;d4%8Y`=}vI8Q!U1-1z;)M3$&X zbz|t;r;hv!utkm37P+=!(&&9`jU-_1`i=Pd86jlR?r8j1)}gKVA)iCf2~6bSnQ0kEUH4 zo_U~rZWZTN(+*h6bD4S%OInGvm34sLupak8TSlXMpeUYf9-I(nTTuejIN!>Xig;hl z0-}`~%C?lIC_nN|cu-ec0E@_d=v zzED=vnr^eBTrVPKz1?zmb|N)v=J?CPfsDNlEz4t#W^eE8RY~sDU%$VlLe%U)J=yAS zPBL2o{O5}&d|&k2oz8f&`726H@yi)os~mwPktR$zKQOD)Ax4<%Aog$H3y*uXQT$+& zUe~hj&YEGM@J|I}tmh(WjKq*WE2ACR?#<;-c860wr;pb3zP&yf_tA0}9VM0o;!xIEcN=DTd$9$Gj zR35iTf%%F*j1-Wo3-NFoe#Dpg_}Z8={i;fjriCb2p|+ORb5CjHi?2^Vfg;X>A|bV! zS}&~ehP>y}Dm&&PZAQMZ->%HIDYNGhp1Wbt1j!N~Y$>09bVF0QPElWy7im9+Sl0Qt z&iWXHX6hN}52c`1^|pInlyl$t(R>PWs3?(9-fs@;H#zZl<`ff2I!a?ZuQWQdr{c4L zKs1;;ntLi2R+rptpTZJ!xG`QSmpQwLe6=YfM400UHF{eh&}YMILJgY8<)Ka->cH&M zcCwu;kS$IAZ9Vc*gW^t$h=k(v9Kn@@P_SCA&H1)aguzNLA)*3frSE)CG?cJt;3c@R z1ul7MWyz)>F6YJ+tv9atF9*88gDq+s_ z`4)epSO;f}Yo06WhsfO(q5&WF*65?@3uQc&`fq{5&;Hf^LAoMO$flSk1R1wBWU~9% z!ELnlOHfXMv7W1aRxyM)UR)qwiAFpg89JL(!7{5ny76i9I;13e= zkYW*E{}YCU;zsnls|_ViAQLy|h^VBF1!Q?e?ZXwjet!h1Ioiv^Now4M9_CQi+8UJ~ z6JnYIBx$J-qgR>5%fPm&9_YPt_3qv-^RyzrT5O&8=&|fr$b0_>Fs=KfK9)ruFnwP1 zLF;Dyem(pPIMq_O*{m#wl_0WV?zHtv*11*{ye^$c666mk^BeUa7+uaYdUMw!8X8u! zG?~F}WDTV_xN6ms>Mc|6*2J6wohu(Ke}EM{yo;|!t8zT*}k2m#BjRQ_&mI1dUk@@=lJ+Vtixn##XAqc?%$>g zsV7Aj2H$|-8j*XyMO~^ENL86n={s-wo#gfOl6I)?4aq`Ve#2@5_%a zt1F5#QR@`@EB?6 zWNy}I8yN&PG5a=fk2m##&)%<6^m+rQ1x~A^+sXVck;_5ak}_@5{qnX$NxI|fX2|G^ zhA>m=R&G&v*h3e7vm%ERbA73wS47SoyPU4lyT#9j9bHbbpOui9J?ADT z+kg7Kq>H%TeJ6?M`5Dw>^lkUEn)@X&zZ?7~JTQr$|2Ri)vYbRF;xv!Gb}`Psy#*5x zX|g8$?itUO`r}_nQ?g_ikOd#m7`0VfTieSKyujt_rz-#8JM-+v5qkmdH?)+A=E*;o zD=x8NJpb0LzVMJ%q&QIH(zAcXRB!(~3VmUd;RYC6?EQ@!2YBUWc3N86q7ra<{usWx z+A=kSkItA;S&ECCDVa3fLQ7MeJ z(xuOTW$C_N}9^dXhzjeY?%c{SVE@e18 zaP;4-s4v{Z&(k^OJ5Bj%#;zATpWj+z#h8AIuDrWvwtFh!U!htKY{i$deBDr-!^?YQ z0ZW_RdBtrw{4ayUUpGI&RxT9?wGmYRBZG~wt$>c$A&R6UuEaxL=EkN~#0*K-#2GsA z8#>|(owpeZxf*IxgL=Y9$pKxT&3MUvb)SW>)OmV*xV9FMjk#?KP0T=%>@jfvaRs}?$4i1=5j{sYL-S7L3gg!z48YUs1ZxtZ{nkPzon+w z%0Z6V0K&LZcKI`R9M7f(otJETXR4`XlBq~KxA}N<<+tTYI}!Ml)K>If=w>btkiGUU zEIAT(-W-OMv~Rt-@HT0i0cyC5iR0jW(=Qxr-p52)D~r*WWu0d}`8 zpN(MCPOKekT@%_eIyb7p4s#Oy=%TJdYfqSM z{pwS5kSxcBGX1H?6{fL7wDsM(T@qnrX6>)6)TPAn+N;PFa>0_be&j=w!k#1116*xl zm4ix0k+)?K(htm?cL!-Klp61P#>4@~&DO@|`|j&Bc+UjxYRtVE5H?B4K)BLvG3mq! zP@(E4vHHQPnK03rQp&ZSU3g~Ng?S=N+3&(9R1byM#Mkd@nVQv0lJz9qlZPe}k_DFQ)_2q(@;B>^Y9o*f1DyG1c z*i?E?rSoS^)8(RGiHe~bB_EN@QTU?J;BF7EoOXMmbfuVPmq6(0OR-5%`YYN%kpKIL zxxSqhPt}8T`B0nKRybV-AWqu&%?-_8oOjBc<{;`N@tLGenJ*7ZrM?dWgkyK3%(wU_ zuR?Y>zM4~_X<^ag3>?l%A#zxs#n-hSE zVc%|9Mk?%ntvIZpC8@A{NC6zNf%T+^2gTzr(6#ESY%;HsQjGI;aKDu+UCE8M#!1R( z+cd6jTpUqVPn+-IkxO4zse#8fiMK>lh(};eEBym($5Kr-FIoW=2dQFM?i?N@+r3)a zs8mrV6%);Rd$^ky(`v(|xE`U_sFh~B?{M40+m#cOo+Re|Quop?dYYu&IET8}FPvYX zIElAAxJ~>k&&8w#=tc+PYu!>?5+nLw)c06-QuIV(diyfFamt+8e2E`GE~+h(-+j%H7=}E71TMt6jV~MtF!&-1j|dAQUGv+Y2pMeVP7am?2@bk8V!=W!5cq5N(&c89kD1<^geLu?>xP z_9H3{`FBfPPL$~D?rtrkFN;~G77R7FvJ@W2nwgV3(T3=<7C7SaS#IRpFzQ%Ew4)v&CTLFx$SiT&RxkIq0unmr!{P6h zDyE#Y?!I&o2l$FM)qWCp=`#9$+y{mPHg#a8?}b%T*YKRt+D5FG07G!VhVeGsLAIN{ zRMb|@{u|7$_E?Os%q&AIRV2i?;`GTXY`oWDZr`*6cjuX3OLduHosQrax~m8|61FXv z?c_|;Z3L;>O!+caj5^C(+lPdS)Zzre1q1p8Uy(LPe2a-%o<+r^-D^8Ov;0gcLyb_! zxYnTiDFM;kCwY@!Dta?#`O4$k$_(5w#SiRdm#$~xRkTR;ngMC^YHSXp$3rA8Aado& z3h^PT+R`L?d+wNM{FQ19VDmY;NlR2_;fp1gtFx zb}>Yl{TSG*tm0a6XXrV#LZ~YrcnBgr3k`56ciFFZFgn>kjD6+hOWS+j%Dz_&b8fQ^ zh_lmGue0aFB3gem(KTWacr9uS=)+$n^Y5|v*e11L-Gei%JNjaV-HqZKY zaDdl+`0c#Q;qM5QbPY`HKXGO;?~+d^n?;KRDnVP&x8(A!VV z8kR^;1PA)QwCv+7c4dtKIGYg~NwBj08iK}mmULCmhLBRvn!t{$G<20LcbNC}?bS4z zj-nQ4o0YJa@|g&epkIZP*UKRjB{=^ zb4mWtVQW+?HX}>glx<(o@{WxNjhj>m92sjiHj&vh$ZI$0;*>QiofGc$7fko=$jVDw z2e1Xw1z9k)FZaF^C#-1bxNCtYtnmBNlh(S}vJ5>(JXXIh&+b%rFzLlK(e<{-uE^D_ z)XVvu=C{xWGCF#E3T|^n9MnYW##jzq1LlH&g<^6ZQ$AuGV0pL9WY>Hap5heM&w2B` zY8!9GNu1bP+l>_PJ}&1m9d1X0v-!mabVAkZ!0L<}Tl0<6}uYrUu8M5OswtB5xr1V&94)doJ?7}HUj*r7)_ zi5l~|6+bLHxclyn)hL0(F zrM}cUGxwb+Oi#~tG)d`?11nfMWf~HlBS?O9T>QAEoV?r*AFz!%f_Wwk;2j}s!!2K5 zJa;0Lxp`K*1$V2?*}VLKmelVomK1J_vThOz)xEn_QEE=10Hcf24rV|8$BToF2N-4d z0_b98plDpgIperXlc2YRdcYc)AsRv&j2^fz1?77stJ!Z|K42Z5xPE^9F9CP>{+zqb z=mh8H(5adYIuNbm$j9N^FN+K%(r*(IGj0HXlD zdZq)q1zQ0UhUQf{O3TO1$Q(c)xzx68Zbe{h3mG!18V4D-w8|--e=rqK&lUCkxxqaqrm3JE ze5|t01M@+(#xcY(G=vZkNpQr9@gHl#&>(lirF{YW0MihbL5nylV80CLvZiRx`+3T2S#tos`m@;=yt0 zV>$cg3pds^m4Kk6-MHRpRRCMrOc9h@+r4-~dK*5Mdd_+Mi_XzD!j8C}phC~k@Wo`{ ziBXv2ND{+kJ6kJyfAv0)IrZmcl{cDqM)-jP1N|vhP{s%)fPuyz_ZTCsbw!WcCd1+7 zr6+vaNN0B!zdzH7Wz;;nUp`44Dk3L`xQwx4KIpxP=FTsyU2L**tW)qYDqs;z-7RAYg!pWfa@Bu0Om`zMq>zh22(!{?V|(4D=d z-2+UY+;`zj1S;|7`3FXp%da>x@bzem=i+bV_RWhi2v6u<(csp%PYM~#M&_0y9cDQA zx+0L7be-N*?ZR=^P-Y~QP=xL7551K4merNgQN*`;JLf}3p zza0_Stq+5=NG6VPk6p}OmspB^sJbqZ2t4+vQmSl{q&2UoqEJn_4kwa-DE5L=1yU;p zx1Noa=?}ThXgOA)O}z#ro(A3_+z#&0N)`S)p1vjjsHoGF@0G{!@&&`q(ts)Tb;}Uu zrC7oZ6K7!6!$45?vvN14&uoS1WOuDTc=zn|fVFqu z!-f4}#p(Zq3c4aG=LpS_c0FMHAiU(zV8W8sl~~+u_skykPx2PTUkk+J_4ih9kyiC5 zS5tmH$cO=U@%8d#PfyQ(2vZ*Y6Sdr3pqTf%fo5v(zpo%n@TdI; z6H{rW?&$wxfHJ>9%z*|8{*T9L1_Ow>06|(;)2`%yvC7WZt(;<_P|DF3#ZE^Z%R0@m z;JO{U!pyzlDY^!3e{`SJEKVn;1wPz2-Mh5TdG4$@I0M_BdexMdA|#>s{P6#`TU_aY z-@ZPa?9Wk?{x_J}5%~`Z-T$f){y)G@bB4p$6j(ynMn*qKlD&t z2}o7@I(kHK-$6L8hUCspk_1B0xY2=w1U`Ay-6iwXBESrgygmP>@KZhfoJV zZ`RkR$dMny(rpkyUlq)Xe)~0MNgNRd8GiwQot)=PdXL28X{#_fvn!xLkld@GW0tbUC*g#TgZ>d$3#jC&1 za}I=vfE-Z>i`rI&inG2TYd2t5hdQkG4PZpBfMM#uf>mT>HJeMGiW$6X61}PABkDDL*KqG z2lH#ZoPXVIjH4tpAirH_8u{&SH>|tl@*U?0k^9bf3r>ED3&nD&IT~C1ot(%PB)>2F zlSExi+!%%W>5pzxo%ck$XU=R% zOx&9@q3m7)Iu$^=6z18_NTFke*SB}Ik$`6^A6ye^*y)k0yhCbD95$+?3%lsWjqf+g z?SYG(_o;<5F&buj=ZLbC6+=|lR?KOZj_q=c3Rqi_e|FDA4_(LF6-BEM2FK=OI^B77 z7gB^&;7g#q3&Rk}4(|xQrSlxRt>|zgZHoLkC&`?YE)QZt2bWnmo0Vbt;>;qrbI=v| zkC3;8VPeOw>7If}cHY_JQL}VE&)7?WMe%a>_^B0GSE5S{(G~ACuoB+VFM&~4G$?Ct zCp~jfjO0eOif2ZuK?(kt+FH~O%icZP z?)0-!jBCYbi@L7yF~hNztbMnHXM9YJHO8y*#n(~7;+K?hD-RP_YXh-^C`B{7$-Z+V z{?IR6&B@W!DY3#4HJr_6k?}~BIflT+u?>3VVNQ)6Lfo&(;F?(fTI2k(m%2Bce8gzl z=DF_ZtE-rVeZesA(Px)zBEF^Sz(3wbT;GeqpY*Fp=?wIxv|HIUe!Mb2hLm%x+00~> zVRcHIm(-Pcy&|9UF{zJ6FN1)kEGZVTq?#drWw=#s6SI+Zajw5BMG}$$S?q%HXZRa| zBq)5PCozFem~$)SZSjRpy#e?$H)x?Ig<~A+!OD?H~WmY{QdDp9-Y>EOu92!q3#Q z`ej=p2Rs$HgIXn#tolSB5dZ;K7qP>l-n+*>E($JzFj}_NFyXNnwUfgIQtITL; z`CD{ki^v)*@;X$^ePVMqz~PpWwCK&zCWMDdfkHw<4jSWuornd!-WzBbFoRidAjKHO~+0v)}(}diu-Ar)x#u z!jB?1##fI=S(kuxr%KZ#Zn&V`N?T-#%y#{rQblh^tX7Yz?(XNv7t$U#B>a{=Uw2Bt0jcS!si%C(BlbRwI%KDR z0zs8VX(l6aB4TWMo0OsZ*|RdlLeXS&kQv@Gc&8^9pRx8q*FC^FIEt<&gJmtT;3$fLj%D&-nT6x%B=OgP=LR{lU;2p+Lq%H%!oa<|% z-eRm=RP)(7_wi_zvQ|}2$`RI4RhB3z4Y%Xq%o>-|&~nW|1P+Q~WF={~;9M0?S~Vo> z%)zmr7m5M8MqNQ03iP_!>3zdY2 zH;Br2ozD;(Hzo%nD#+i1&{w2JSxze@XYZ!-{}H`07CMdy`QC5f`-u;BsaK@~jrmY` za^`I2a?y6ORky{<>SxnB4552slHDgAsuo;yA#Xq(zeHusJfTb(=r)7}TkTa*2En)2 zC7xuuL>S81;lATh6$Y1~`?(ySeJI9#FBfItZQ8k`@C}k%OF;yz3kWg}-Ja>budwPwBqK7fx>WE%#3GA1@!bmNx7;mxCkRG&d1YNGEChVgVKKdzN&jA^{G%zv22*{r57?(@#$1ggdW;pA7HQltGir_2~++~WFR;J4 zqHC?=2g;+qDghFS4A?SC1r}_mzq<)|c2_9Uwesz`uN{h4NA}nw2isigx=sS3+*YQ? zDE>&994-oAA7T{;9$`+0gT8mERt{C#ew(WlQ&q|{wIFT z#W&FTW`EA2cib!V?i=qNAvI*$C+m4=D?3E)ug5K4k3F;A+8zF4KE!dqp9)N%W?t%e zI=gj-l&NS{xre^4sBL;WGnXFp{L?G4Ge;%R_6FB3qOZZXk+RP{620+v=$3Lbl9Q-j z{Tbou9Q@j4n*v@#N2I$TFUR7{l#AOcvy0UNT@+&xW|OMP#`!XGem{Rs_wdOErmjDSe99N=@5UbTk- z<#hCJl`^AMfdJ_xe9vkCz*AV@;QW9*Gf1RvojWX*1YhN79~s@^#OI5ZBy#*HM4XdC zTe@d#c%yOtpD;f>QOkr85bH|Ir!%A0%4b^n)iK6_Q(RnVYmI@v!P&0NX`4t`Qor1lcIV%X|^t42*4vhl&_JzRmk|{RJA^N63pzRW{v5%Bj zwv-(#(!kxQ=bkRt^pC>U8gHA%=N`n0smPVT9SsWBD(#&?A5oEpW@dKI&d&MG-4QRU zO<6SFwDl{q=ZYOf14t?PIn-)MLMUqzOYD!jXAr1)VGU-L+iMp6<4^w^)Ar144In94 z#si{FK0I1qypK%Wa`w*xnY@Z+`Wgn-K)=DdYvNMyYKSmy}@2%F0+%ZQ8Cei+dok)QG4B$Eode+{ibl z`Uz=~J4z4@wo6l}|LpFe?wZ{EvBoP6gkFN5!-rWBn0$9VUpFF_!?`Zqo24nGPZ3jO z1&)2Z48VYiy>3;{MNiJv?_Vo2xpa0{%>F?KqCM12p;C#riTS!V| zl*u_YLm7Upis*dD(l0@scMvI@F651X)S7|tqwogSmxkYr)}K zd4G1R-^Wwdt(|9i?J?_tv&P`gwYtd&qv=QX`QqQQYYktfu_qDi9%Pm}iuu=XWmDs& zZ3**63W?eXtCF8@hClenCeQuTy#CtNN8)yljwrJyX`HS%vlHK=x}z6(cDlkn;My@7 zQ$^-BHpNsb6)*{Obwx6lp5FkgbHuAOkt#qM{F8iqBm+;)fr`0>1t@2Q9_?~WY%#l* zvA>|nU2798Bl((BeopV3tHt5m!qS?5A>xR$nnU@OX+>Sq_)@E#jnXy=6#w^T;9uu_ zeSAKgEWF~lVq;^YY?3x%a4Wevzr`?h7-*7=YIrhATwQS=iLaOm0z~Jg$*BFGHinkA zwg_J+)IidUWhQab`~D_Rk6B>Gz(U%{YF;e>gbMprO3JZ*sCD!}62HLk=xDEMe^)l4 zyj<4H#6NnVXlr+`tK`n}(;b`j6O^EaD;RrzryZ}|<@S6Z+BU5QIz_tR|65#%U*8=u z3oX}C*WkkA$|o+daTrXRo6xX~DR=k-Em9K<2A8#6iO-1i_s911_QIFTN>aCfTTbJ| zEEwN0GD>K6QCyFDvC-Wyt}a8~Qc0cFQs&N1d_(n>mp*DW@|O(0{(U+sLmMC|cQeA) z=llPIe}3)z2fbtO9<>F&sz3JM4)gwBIN{hnKUs{Ia{7`{4bX-|hWe&=aJL^k`@aD7 CQjcT+ literal 0 HcmV?d00001 From a541e6304485e82a237fa229b9acc6d8f18b61b6 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 22 May 2024 16:44:22 -0700 Subject: [PATCH 142/244] Minor fixes Signed-off-by: Pawel Gadzinski --- docs/examples/te_gemma/media/graphs_2.png | Bin 15177 -> 40595 bytes .../tutorial_generation_gemma_with_te.ipynb | 251 ++++++++++-------- 2 files changed, 133 insertions(+), 118 deletions(-) diff --git a/docs/examples/te_gemma/media/graphs_2.png b/docs/examples/te_gemma/media/graphs_2.png index 35c34ede5559bd0c26ce807789ee6d3fdb2bb062..6f6e0b16732a4e07afbe8eb553f0a24faeb07f14 100644 GIT binary patch literal 40595 zcmeFZbySyY)GdmI1u7;8f(RChibzO`AczWxbfZXjON%HNgi_L=pmZqRC<;>2AT81$ zUD9X1xcB#+Z;U(cIRD)H-E+tBkIfc;@xITqp0(zhbFTNeD{Mc-MRaAN#{@8E3{}g|xvbZL1A!V#(VSU$3lSJFd*ie(% zT*FLL)5u)M*kXKBAs-3JQ4&!h0qF-JzdLM9_78-vO;%S97LbT;Gkq#})Ra#6;Rg?9 z`!6RxkWaRX|N1fc%ADhwlqh>}YtQSd?DSl}`y8isQUKbJ zuH^6E#cg=PYQ+ziP=;4)tFEagTl-46d@z!J#F5#S7{2mbdw6j0C1YQ2uMWO6^3{H_ zt72jv-$O1({{C$s7Rr@gx|G@Fn9HbJ?t7XcH0`dGf5MEyKB7p?X?#c)x7iHv{EM<44>Y(aU;!Ux?f;? zt!UHkL04*{Z5q#AjXpo#-apH4Gu108PF|8DxxLeqjz8)B`xDoMh0FWQVk081%(t6+ z-Me>BnBg*I#BToO+e>q!QU(SFH9to+=#~N)%d}EGEb2awwP%~Re$a6H9?pM0y!s^8 zC?n#Acm0sGxp`__e7wx<+YysJr6Zko=eG-S2P%#kkByJk(;hl`!>S4WL8@E zv6OM?yJQWnkUZ7o1!?sZzULB~&#UCz?>xgd_rdElGqc#<1;?0;;djVOdl=&5;%?A-oxFd1$o4eD=fupOx@MLm;;dSg+b(-1Wn{=27c9JI)b0NK zMDfGDYc-=KUUd8hL)7ug=Uy*Zo0(k{XuBZ)@qXFE@V8s^EocqdgFVbBR@i@Ce)sX? zokQo8UotYQ6(}Plxa|L=>^i~~6>OCA?6aZ!v~7O>LT=cBLx=Q7ni9>jl|QzP<}8in zTGoqkoMvK*6NwL~9`Gscsa;)JO4qKY%~n3QGMnrqV{4mpk(E_s=nzR{bWVvjn$(i2JwIa3#Ka_xEgl}$!@^Ycf5xHl z)v@t3i=jKBb@PR54x(~OUuf+o1g2T%M%$u^W822QWy=;W$K|Zyk&!;LY~PnJBLZ2q zhntj)M=gcFr)pOZe|yD{W77R7Cnv`j3yVb=c^_i3T~d3iiIPs3{~g=-uS zQ%z>)^H&!~`p516bi1Wn4As3%ytm6u@YSnVIV*E*e)w#T&If|{ji#n%KB+NQiuh(6 zEaU#-$B#Ex1+WNWx5~eL+m_SSUllB)s2Dd?A3d^D?a!m2prADCaaBKm|G4~o6;o5w zAU=m2GjsDQBQE60)?!fsneF8Cn4vHYhiam`nCbLsNn~JzGxEYE(_cbFw)*<~z;jQ} zu@7Acs@Zn)YEkd72fwy!YH7tqMG4v1*u=%gzCt>%ut+1~s4-qnAIXv4sZ_mUo9JTq zZl_b>)zNLng(}(E*{zwz;hwa-asB;Tef|A+ZkxQbfaVXbHebUZ1LG!~PgD=;fcWrr^ z>HPWT%0SlH1}TXl{A^9t_CnCH8Ma>H#qT;Y=x!`b^w6!zIYwfB1b9>G&iq~#Vr#g&RtNDOiF+ECr zMh1IhbMpx5Cz0OxmRhQr&Yd&5Qu-cSc)sAzTYSHP)r)8D?$UO4xj1PSW9^rZa9La* zX-*!my;jt165vy4VQD$kD5oZ4YWe|1Y}wH%bIfCObhI^1*XNva*3snVwoyu)qIi5H za-3qC&MUszSc$+VZzZpgI{J4QadhoI%j1!@*@`NAE+2`|u;t@5?e8Hh^|}^<3Gwmu zu~N}9U(Y+T3>I~8blFQS&5v``8>hDPT5wq|OldkGq1g7m5@~2^dhhA!DX^t(VPSy) z1B@Le#pH?^1`36WVhqQQrRg;?i3PKZVTb8mcwbk%<6jw)KnlO;BzXURRJ&28qM{=5 z((wcP_vhv;+PUxzjSLRnR#FPa3ePNct*KU3RrQ-*n$i#zeO55@?!2fsyHR^g zNr|vIb#T>tAP6L(Ve{~%kdGff3NLd<#rt(1;kH6(SyD^rxF6@S9t)Y8o}T^2;ANrz zkn*WX?`yyzT!s_!_NC+vZpClr=B(H*U5ZIdlg0sRVDRFpi~OC^Hrg?@vNT?pOr&Kc zs{6_V$d9clcLi&bv;|KyM8#`(A_POkqt{b+r|nIg6UV|rrQ`ME^CPGwL6BMja};(NWmaH!7m{LP zt`nD{;l=wtDk{ooX={H6Zq%?nOY-%3HTKek8*&v(FAklL!|T)*%H#BNr zH`LX=1#$sO^lI&;cOF3fRwP0ZTZDp46>->AkuR&m2=I`|_ox@PL+@TUATy3@>>XCT zQ2r8^CnGC6T=PcIcBWROij<=#WJ_51?G3`XZB(YaR@JYgVq#8R4i#r(XOBY$*}99S z221@G<;|)s6Ng^EA_Dt12z1pyJ3hbEh$|v;FamhK%c)Soh}5Va(bn3kh>IMV4qvmc zS?2Jyt-CTM%=sj;(j;5CHP0rsyrQD2$2haC5ii984Fl!c{(5#YN9@7Dv*g~kmBNRU zE*hSiXK`m{IGlInvYC)krV{F*R^MkpiIXQq2tXws&X?HV-><6FGXJ|p;~*E2x~n+4 z5`2C4;7_qPZ{J?U8Lf4vV!dQNrYtJA_4DBjqmH<^I3@s{MEMU_c@ zaLax3OESo9)}1yYNNYqnQ!hQla?KuPc3y^Bpu6r;1BaP|bIM$yh+ zTV0AvPL>lAdQ5Bgn+4sE%(>S$=^X#OA$Ef3WfC(xhbiPezE~?~nFu&7RHzYaYuZzy z+mT}##G>H}TtVcpw{MT-Pk-muGcsyAlbrX~*lDS4lMw}(h~s5s4~3zd5)#e4^X-v{ z)gX?3sc7*b?3Rp*N+=Nl$xbUHH?Cj5Fu!}b5}$(@zD1y>va(w2L}P93oBC)8{pKVU zq_ev3q1*%xGm=PmRTSR*J~A>=5P0t}ZqT1uNsDimm9jYkr7A%o)uQx>^(izW2u?i% zgN8SPZk5pi#(9&jv`)*W=3S+9PRe}Y)%L$$g@z_dI4#zDFYX~YEI+Cp{3z_e~#h9T1HYet8qwl4PS*OT{qaj%O= zA!kvM8XBZK%&U2WcC` zy~D9xWYN`@KqdD+1vig0b{D%HS2@x^^W&Gdk56Pru64zz)fs+%H5`2Hnra;*BX%VV z0QYDSS{oq*(#IBq8{*>ZN*B5w^pGPB4+Det|K=5J$xKg|$+a5wZX0bPh>=(cha0F6 z{a+7JQ^%tOssWhSBYV8V3L~G)+LZ|OJTm2^SzI3APX91tHYRiczth~PgNI()2h8#7>`0V-%;(_Y$ z*Y!ict;Em@(OfqBVj$))@oB4{pPy!iY-wrf88$WjrUXf0VOK}Qb=-#!lH-MI1!MX1oe|*#3ql-!gZ&pg zwQt;be)#ZVi$4>0v937SaZ2-LZ~Q6hy2_sssR`{bU`h1;ef#Po+g{ZcVTa$9(%Rjc zoSf9k%>MJ|f#Hu70W@_)w*ulQhb>D*$;OJ4W)l-bl5+N2ho!l)o=Kg>nW4q@0lw(L zK^;~H;A_Q9qfp%X+;~?qO4)&f2la3q`n5Ort9X@0CfoO&A(CETCer^*V70w4KR|Hr ztgR(J#&N4L+8YH01tOJU!z!aMkYaD9rp#Il*gGEdc(i7lMy*JCnu7q z=|_~NAtT{ki@&w`^*w9Za%AxG{N<^NDcgKpdS%RJ|cP1bNkqIn^o_292n!m%a(p*_N2r5H#PpW#_?fLl$}!{>)~%9R#!0lHZ!ef1HtV&?kI3D$HR{0eQjC*Vm#Vw>#$K zEQj&qw%ZkC8dST{$0B(VfexB10ikwJLn9q+(z~cAo%;7;^e1hpzt4zv=X>3q@sOXe zTAa~MJ=2=*P~qe4eFf)F{=A z+)_T4rKa<>?hq74N)6We#< z)?jT;ct}VB2)ChMUtgeSpL$pq5E^Po+G1&G*&iS=q<6cir6u;+zLR?3Ue^y_MTMQc z>!ii^EYHtn1JgtM!v6?mnx4n01(!5EQ2n|}%cpCcZeo1bp&n%m3kw|}*`SM>H)`~c zhxEDN8cIP4ye$1M#4#l=OVaB5H6J5JNW5rD4*$DMDqZ_tx_ti~|2;8!Y6};U5#vPDQ&KjuZEM+J0lHG z&t5j~pW)JVp@=!GA>!h90OxXh%h#+-d%EuzkAWX<7L`095oTaTz3ccy9+~Az@|BWV zY>Z!M=n=Ers3=A-TcWt|xTK`9$<381>YpA(BvU6gI`4CyX%4w;-fO~%QXnHKSMPAL4vvkj_Ti5t* z8^@>#|J`L}Wj7=wb}yL4F%Ht%88sCdgKpDw9VN*Ep2MQqc1 zxfFNb*3x=^Spa#)u;yC39l{ITFqtjSrix*AclQV%RZ8-_Po6TIBccv{9#@Pq?XKu= zn@b;6FLy7Z$`Y(<`)R_NX3+AHev9#c#LV72#B5pKb-4t2inOU$etz$^?>iBo!Sra0 zvGC1mGzoS`-1aBbP?Z-5pAo)!^M&I~B;AvCxkaUG?~%XMBoc89DmU7LBYx`no_kIJCLi!Xl>R}FQ z6|G0lyu9Q923q1S?{V&GmfMn%mX=n1(ms5#A-XcJy}jMwU>qX30g6_!>O5Q6emAb> zU2Y8b!)J=@tw$g4axJ={5i`vb(skG3$z@btEIs@GusSlmz+bq^q(*kZq86& z9hNH`o)t}yJUy|rqNk&?&9;6gjgxI+Q4iGvUe{drLBoqjhK6k>7wnmsnPrG`J~GnS z($Zq+?KU?z7twWD+PTldxJL&Ojk+r>B^BxExy#bRVn)Z{LyJLv60;V~xbu5{c6M2m zo4AAoX&aktkcYz4ypbFArC%bllUTLC$Gm;}nXD58#)8lf%HaLt&Jsl~&z?RVl~d=9|#t;c}Kw*{bMP&w#kHTHL@PH+&5->$JmB}V0is}eCm@8-*t{}abad%v`eNM zmYVzY{(L()hXRW2oxAAx48BjXYm~UtdbN=QuYj?T!u5mmYmdCbsu;eN)nn-P9Gl{d zyF&M3G15N%RR3h25S5UVlPgpb6im08XDjFhdDxm^7~I;{mRmLTz_8V=O|34x!cG($w=l%gfN0yDuZ5+Z+SuxFtNPbL6uf;w;zCsb!ha-?u(HG%VGt zF7dP)cb$~K@6?33Y?5-U&A2EvIZqj!3t^4-M~;vqi_f4$u0 z@^0I|UI2o|z-Db)+faO;@o>mSLM{RG@Q<)~M@kBu1v6A+z{b>L#+yxrIm(I(>+9-l z9roJgkBkISk4?Sd2bwq`78E5y=OFFj;Q;`C2jK~}kf1V^O%-H-$bkO>?qK(d2o=5_ZDN!sdX^+4j;47T6OoycREJmZ-}P?cYBb zW;Zq@yW!chXOi050}+awTt6BEZlS}e{OVAAdg2$-#@)5$E68}P)nBnc36Oe%Sw$7@ z-quqeQBIkCvvEttCGy88@iRm9pdH_X!Nm_(Wwo=%jvdq4sTs5Je9E=&BeT0kvS~-4 zALhzb5@bMdssd^MS7&$4ELqN}n?JKHpn98q7yjtZY4tlGTW6d(v;8f0E0DEhR=XCu zGc)T;?m`PSYJf!Og;0g$AkY~VENFopBm{z!;vqEvVvVh>am{K@Dn2JA#;oPaiMt%66)l8|cQ?eY7Q21UZ%L=JARL zpNGt{7efYnZRJ!{46RP*PnJ*%^1Y)iw*g=VnU>2Z!X2 zbDK)v%(K7OyF{+dtDCShulW)i+Y_3TT9tn|>qqtb9NjtNaX6qYr(^ZF_tB zXq1qNIw{>e>8l-PBMh2CVrE2bjqB{vVF^_{=UA0Xx*u6R1zMh>(yzx#mLy$GUef3MQjMj ze#N{8A8Z_4nU-9(pH+gu+cuH`(CnSWCorUM=2Lr-2-$NcR?VGbYz`fQt`+iGUi5jc zZ2#)|Xz`E_Taq1!wyT(&V3b|N|J=TF2jZF@2sMk~SGuQ` zh{{N2HTSz^pVl_)7lqv11^ ztkpAiA75N4NpL2XbbcIkI{m(9PGNv7LRP zw){CEN>yT`b_&_-{*$2Mh^FFeUUHIGGkBaph~zdJ2$)MJqp}DhBEz8NE%Y&6H2<4! zXsZcFy1#tH-#D>gzX*=DZd7rbm8!ehy$H=OQ+*XxW{U`roeItmK%-U7C#jIwMNxLp zLQs{Fmp=uD1RdI3Ky3B4eu6YvQb1FXoT^L8`a?tS8iOR?%sYrKg@3-DCm5}ml#~m- zRscNx*7NK)d(t;*$Mcx=K7|_h$jBN5{WN7ebiBK2ipP{)Q^PJbZ4nQ+mr!Pt2+s2q zKfe>BkLEwQXp1q{mSBdW;;C{oW4n_z&;A;nUdT~tM}yRFD<>l(^N%b;C^ArUiS{YE zU@;D=5S4;WL?YR6PQ6*QluF-##f4O@%I8y(%QGc{SXsWQWoAM_6gNgVhLCgf&P5NCD|hL2 z6*%+5q;`1zF@C>QYPe7I_qnEaW#o4}Pm5ID7E&LBV0 ztxjx(i+41I)Q{iC2lem&fhfj;5yCOy$l6N#pr}15D=SNc;j!s=!cc{d zsi~>S(K7uOKpQILv7zAw$38zfLbpx5*Z`Rkl1IRcDJ@M+!iltLHMIMsuWuAub<_>k zh+Q<}S^?-jfu^dCN*q>xUG|Aq>dfJ1Q*X|r|CPIU?=5Jg*#Z7Fh4F%=r0k|3ECWgz z`r=uFdeyIeqR7Y-@bu}n4!iMuFxYy=#?8~!{K@F0QzbtSS*%NrkjVO`7ku%x++$e` z28HL_U30_s-;k2>!7^-}QTr+AY_ct6@*|FbQl=4G?lARid{r=qKC(*jMg#-oyogD+K6tRyd`6joUU;6_dObPX$3wVJrxPaQkisL?!c;JHi zT3R~tN`!Xk;Due5T`&&CAqokdlW?k_NCnns96m4(%Lv4j;aU;8Be$K#p+yZ%meMu9 zIA0Kh_I9)*R}Q6p_!T?YJ>p(^A5}S&RaG@2M>Vv!C&1uygx|3}2OF1zrUKaGGWq!sgtNXUZvgZy@MFY2@h<&N#oFZ%MheH3MS*wvM~$-k4ifeq z#3`i7#ZhCY_?Va&qn*#4`dPzk5w=K#$oel1GT$*X`)D=R9t$J@(EvoZ1T?Gu`VH_a z!hn|y+O!#OJ=I}W!A!W|{6a$7BklAE{RJxfh?moHB-Jub#5XA0;hJZ;xQsh_c#*A} z`wV;l0SPNoy4}1H2mxK_48UA0)sYY)o*y`4F%I82N4ZU%ry(26voVCkGH>&(8HsTCg z6`Dx!^D8%iBDBJ07VH%Z9rNLKj6wnhCqa0^7RT3CjiGWAZ}(;-^jzrfvT|~-A^4p| z%MH*!NZs{EIJpa2vnA?QJa!tc50>QJk)L}eA5<>05H8Z?iQ>pP#?M-*r_Y?ZNl-O= z4j4Fhfn&+9HbJ4x+sCF(?#yLAc)2@SG_FSxTCw%$Kimv24ZNvm1N(gI2_n^+ExX>1F>n>~UY zGCgU6=DRRDgZ;;-6`?Idq#z_rz*fKD;4?#d;p4~^S5Ofe;B+7?DiFuPhaEq8QXd_Y zV!ECm-fz3vU9g1%feIsG{pTk;36>L52;u!1m_AKUFW_@RtPZpz;mzDo-K&w>;`{2= zn`DPkeJas_@_C1$o3<#1P+H!?c>|iyK7Cq~rdvPrr^jmqw?cGAaCi|^^uf-oEY5$7 zhW)=15!D_3s#*S$6#Rs?6vm_5FNmP|hxtOSmy(rjvg96YLtucI?J4z=Q&kNk9EmW@ z5qc7M2gL&WY+&01R8;F?Bj_mf94{;S4%f^c@Lc}2&N5v@7!9H!C0G;`z=g+_|5(Wg z9uSEC?)D6kIQI*WT6XjOxk7N1z)9dk_U_xqhBHX)DP#gE98^MG2h0&ekdK+F!p7HU z!6gHsAJy+`Ny#pqg9U^#Mx+7=d8i4Eot@`Ud~wr6hgtjf8qu!DROh`0w?&u(Mgmm1 zJF8v>2F8Le#I`_@Ak+K&@C}VrJXI_bKl+^}{^er+K^SMb?B)`PMYrHW2H$_=NMdjA zeQ+;KFk_%m-&8PnWS1Lk3h%J}r;vJ3stBJE41Tw6xf8C2Z1Vv0;1Zv@1C&DhbYtkE40L79)xbbY=~PN$o$Su3y2$QYPvwq=#QCxkSqao=ysnRJz(J|v>N-@DdUy4DC? zo;Wi6t3CX#JhUsCJQTTBALM)fP?ee-eW7EE@8#IsFWB?IHRltU_isIuGq#3X`ZFIt zejK4YJj&?>2WDx}H*M$UJJ7DP4b>fek~$wR=8B7{bGZhoxy!B<@$4C1+}_YNl3LhC z{@(rNs_Y37A)y(YW%p=m=Y{3g>iGqGtwR@whUIIEE?HYz28Q4LZS1@CP5IL+&V4iK znubnEyG8S5_blbA80V=PXMXT~ljKsQ|J5kdErj<&a!*|w%joz5TZTC&&0SrZ9XzB! z+YaZFcwYAKHB8Yfm-e>mi!Um*jGJ&?ZK4J=s49DDyOdb6f`S)d(;;7&kZF4@tHOU( z`8Db=yzKe?;h)x(bROF3m$~0a9UZgE$$Dt-hKPvv8+$%~?hQXh;p#}?4R52Iq4!1| z)tizd0|Ek!N>7DowyJ~FS8`>3u*vmB+o&wMq;Fe{cxk-|29`jOi1@l5FMhtyr@zQw zb=*KgmI(qRTRG`vhY~2BXtkALIq#%7ql+c;8#l4BM>B7{`U6TW;T6Q`AVf}qXt8sr zPF+ySR=7D7>9M?#gi;*6do|CrJJJRLSCSoP;|0RcjkRS;)OD3qY;?I6Eh}FR5}Z&C zk<$q9$pzB^xRG9XqJRK!zVy&V!hsI20aeX!wXD^bY9u7Zu*d5{J=HTbl)*fJs{Q0O zVDCVc+4E`z3Cv0v0s0Bd3Lo!Rp3B&!AJN8m7RxypYTa22KyxOzO3U*rt5zyPy+do@ z@UZuW6!xi|Bv@@Wr77hbeR`begn&%(9evhx*Rz<+I6R{q-R#Z=;!y@E`uALo=_p>?71b!Yw4_9@%=gHxS(ts^Hd8 zRg#C6KMdGaxL;EydyMIO=j79ljb7cJu{K)}+A#lN`@GdAu^tgkZqKd>;{g-S#bx*P z&s0gZrRZUD2Xsjp%atJ6nO*S5(d^a{cf@AzYA-+ z{BtY+pHDg`@11Rd2ZCtV&_a#CfEJISAZ#&&4uoS5I_M6V5roDJ4Dv+HfIv?u_676# z(v%wK9dWp2W)JnY<)rxFRMGx+Lt6}n3QfOs3pPW8`wiJvJX|DNmr^Yi~#>+$eutvJb^@@#*eBe{Rx z4u^i7S`y)A8$2H@cEck$p87pcoO_f^WGD?nXvDQ@JX+EHLv;FxH#=>v#O(JwmbO76 z+?cKkcxsB5N1inPk@RJ7rdAv=RAlLFu~KPKNqh~ux}$qZjdHx<_QZbGJC+9 zBf4dv=yaow9NM;$FJH)%4imXWS#4V1 zkcc)3n3d)0x37P^z&{^<*<35L83uRa8wlP5p3c&$=HTF9zp$`Q4LiJ+gmT-X%nLRk zgQVrytWo;-ekl#-+M5{E`(X0!i6g1b=eI~OGhU)EIp1Z zBWlDD8QkK8bs9Q!ESQ^NV#vu6e&JzgZsXJ3z+ci*Qmo4ScW=RtOEb!NlnjgnGz-4g zGJHOzISda8a9eB`6`WufiWU#OgI*HYPR-J8_=G4LLfhrZQb7Ec(R{d(Mh+GJdD3Y= zUG+-*9ErMnYKsBjTpqits!zk{tcmfV8!0Qv^P^D(USe69)H;Xxnc7U3g^hBZk3OJ& zi*W=cWOEElN_fqqr$@N6-=71#FU zV><|Go%p0Atgs$bVf2sn(43)8y*GMcLnQwz#@jkwi!e3;`W4-oX3e~IRnWWewdDa~ zA_vEy3ofIw@%c005!Pna)*202FhEU0GDwbVCWazV1$Qro6?fM-%y&c+BXMZYF(rE& zH4vT;V(Q0y(rEo1JMj(_e1e&Q|L1QJ(}gpU!Iz2n_hyEJL&Q`z$%jdkG0id{MpH{+p@C$3*+k((evYSm=K29&lk3TDAgdb?nZgz zZlMe}Gdi>F&jVo}}%R*H_x+2W)Abo~lq+0Ju zp!#{<|45V!`G%M+#*h#(!j4HxOnfxx9M`CUE=&x!I;<^qjZdmVdSxWMf7<*@foOQ& ztSu(5)q$Y|HIyW$YIALEsc zB^kl*>rp=l11=^KZi3>Hg>M*%7`@(c_(Fq%4uevI@OB%cp^acsiNQ-xDZ=hNE>Wh zf+_{}3gOGWC)a%|=(6Ax-<|D%0=r}9&UE80J|e+`;%)+W?8f@cr4Bv2D8a5{n)n}j zfv_WB*spoF~$=lU_~<;0%VN`B!3LOJPI}Ak39)8o`~It_F?3lE;6~eP_sreN8|d zbFV6Lu^20ubYRJNlsOLKyCa7TCbw_kERG+)R@XS@2IazxE0mbCg8wGz1k(dAGJ{jWO_+?eWK#u3z0w|0eU{(KpEy2ku@wwqb+c7Rf`m_bBW!G2EGObYuCk zr`N&Z#qD=~TDMsE?w>C1y9gB?>}eXnJl;ZExV7$y1%m0VT^a4 z#zWY=i*HR%hGqxNSl-@_c^VoT+MJA=U>~Jv4~VrtWD@T*5^IN_{sO$+Lf5XnQ7d$; z8JP{>cjB`e@ASZ#d+g?>G1XhPjhs9^XJzF}Q`14X?T+{Bm|I#t2@j_O0NX}VK5dH; zh4ZhThH+c(!g$5m=1`7A*_}>KYvtQjk1JayU9lAt5tuJ;Y@B_gpNlUSl6luETrOW= zf0+;$c&4qE|Cm+=vF&trb_Th-724E?dtaXq?-&SM2HOz8=YVI9aNd;+d-cj0CX~~n zO2O=gKgY&Ak%<&6LNTmKv${Mj0atmXV$TVfd5#^Mjmxm7Aj;pz?EL-G*#>~l*C z6{@~SzU{2~#lRHEB6-k}8V}T|)Q`*(R|A{1mt>^yr|Rn6;G%wHo1esp%p`KhMv_O5 z9$io`A~9&sqQJLayMLb%VfO(|Y*tRruk3-a_-+RJ^>0p($><%$MDtTmPuIvu26tfw zVkLpEPj@iF?(c<%T*%0A&PPEBrsQ|zfj0aUZ1m@^UlCkIz_1DC_!AIHWQs+u5=$whCuvv~?BnBe(P9;0Z?d{ z3SmAzQopD^F8Sr>^Qt;7XbK4n|3cROR#o-H#bpz`B{hjk91Te-N(vkoFYd*%lzwvA zYC77=#>vTfl!u1~i^Rkd%K?~1K_RId$80n5;{yX$bPL5NC}uwg2BIzAizF|U6Q`VJ+# zfxL|C_4U<@+9A_wXa@ID0IwZ8{O_bLpKcG^W}Yx2b=d$ttQ3y6#~8Vg)qJr2a-lfM zch;T1e*H>;|C5H_kp>1AozA?=SWhg;RyeVAE*+*MBkdg+=)uz^vg{WwIxdZJ(C|Fi z1EVni`DvWs$%=%gzmO)&8AJ3QUk% zRFmEE^BpYR{)Twe;YR0JdYb4s|6d(acVeg95M62fbD5XOOXyJ$FX|7X0VA2$Yk za{HuS1a9&I>dXsG@aC?_atF`K|JT>S^ZqI2ZgC9)>ylo>w|fc?9&nP9AAU?sk>-5< zn+Wds>D}x^lAxfZlpi)bbO;{C{nB!r|Af&U!!jj#c}hs+&(KTntw76>0?WtaCr^+a z{^K=*+_3)A&^gRdboSu+ z9p`^T@eOPqg+<_@Z2S{sB3Cl1M`$+x%ZsNcCqF@*OhJ_oLGI=-YJUii9S;vrO_r&K z#YhwR*w|R(fnDoM`YKL(#Qg^?E$st58R<>)2>r3+{uovG^y$+ZF2TPWeCG*n8BO>k zCSB<-U#5EUe?o@XN<{IFaL2{dmkyJMgofVL*LR4xdBXVb3y#Yu zO|k#2M)W66kmCE^{rgYScd$huJ0D;oxHd^8cR0@bZ}d7#8fF#(;7%YX_$?+-uU{{P zUsCH>p2rdHgMbb~s;aabH*Wk1*kGBD8S3F3aCPGk4iB@9xs_E3Wc(D&i9+%!Lh2v3 zg##ld>E-%sa8z5pGhrL}jOr)BJG)<5S(#IP`}&tiZjn}+vcjY`QNA{jl0D4L<^Ct! z-|j#qsA_AYPE^WtjgHo~lo4IO!0;Ch`S67aNgVL*|Ch^dD{$b(R+&uxx~2@axM4OR@1Gl&HY{sHd|bZ% z_abXm;_7;>?RV6g~cUcJh>#p63-(_e)NAo zEcTAy31i65E}zt|-~~jgPcr<^Z+chO79rq*A=a%VMUZOZKQgU{!rQHicpD@qLHKbu zkl=8}|4(1ZwcLXg?IG^NV;;gbXF|=m83f zS6nhSQ&vBVHlPst*{whykh^>7Z_o@&_m#C_u=UJLJN3WI?;Xcu^Wxbv0WUnOCFk&e zT-8Zq3>FpRn*Z~ched)J*Z=*FGm{v6CcNE0f8O?!`CEs0_anU$a!L)%TmN~vx8GIg zt>8fuN#z~?l%Gp~D*p!Fs~s0?@E8mHxwGNl&}Ay>{hMj8Z~1raVoCn*^5Vm?|Hcn{ z`2PK-qsMIb9H5>yiM&8XDkUlAcJ)BS2g`?U8b{NtlGu*y04Uf_BylDMamkC~Dk^Th z?ctIJHH{XW(i;>#c5duYjLhyL$>QGJWI?2}gT!6Sd<>V8Qy<6{-E$qEXWe}$;i16u z8n*lQ@HZi$`_9$tuZa|W5ttPlaK+@3uBn)t0HEqdEY#Z@{t6}^-U~+9eVNCx925D? zZnK-nb3`tOWLBedX<}|p7*BQT9vX^GNQF7*{{8#%9R-k-P`HwlOrFIAJ)KAjZk^<1 zJoa{p!Y4D=0@ICWEkT z21drm|Llt2PP|?H0_6l|hr2NlP04L_77~n0d%N;kwE`Nc_qPP#ldFAq^AJf9>?JiH zHE-Su zLb)D;ag-;Ftv?kPZvrwqsCn!Wp1E)}vNl4a)RW{P@IB7bo2)EK6f^;0VNx>EP0+;< z>8zVKZ*CC!vwpgwuHNcoa(?>r#|I6#!$@1&vfX7@%%US3qN5qFM=6wD+-}Krk4TbNK9B`N9v<@=jG?u zW*VzKc6B{==FCp)*H=i{x8r4*UugH=(6|@~5G38`W~RnMNyc_$qs&H^PY;2{gLb(L zCZ&w-rI3^{Gck#o=#ggsTxKw)Xllx0HQM?F?iCUguRnNVmIr)hn%T8FLRwlUlvPx^ zdwZjN0n{S##DWySCQyIl@)8098!+k+-%B|rSIXM|OQ`82+Y!=^NOsjVp*BSgJuY{> z`NEyYqeUQy3c6e|>qjZO4w!C&TUUy~^6~ z5VLFT9UaK=I+9~UL(g2vsgIpL9m^&7w?{dkky(~YNZYr9E^P;8@x^IvNiE*Ri;!u0bRi%fz-aa=JlqT9L1awKHo}tT-`3W~aOxEC|0N^Ig`NT~ zq<4G2y&}-Hqaz=>FH^`}W#Cjmk3Rh=>S1kH1i|;RP;ZupxFI%ni~> zmK_}(EcSm)$msa;o=2_++uN;F|B=f7H%-0&i$y_4kBbDioWmbAvMMBg_u{P!4Hl=Q z1;RW(O;~PE9=h-}<~hrH5yHQJ@tij9ZlCSA5U@j!izY5&TUWtT-5I(QCPKS8*Xz?` zl|v_bEEfB5q{h4#@~Ke!(@F?%N*&eHww z2?YKZqw)V~yb}>SL_$jnT%P5)!iC34sO#(RH%-a-AdC;y`>(EsCk#%R>Z6l*|e2aq&Jaa!Z0`T0uJDsSJctX<#HOGbHtAuu_ z{`haPJeW5v`vlJieD3Z}_T}qW1>a1)#vQ-_C16eTPF-BTfIHMDjAaHfhDP!b8FOz1 zu3{(C4e}CXtQa zeBm`rHpQKt1(hpkbC0pHrB7|j^Llr(PvG^KeEr=T-W%56zG~)HupJRj*JyLNpj@zQ zlCF}`wzasp$!J3uE6*v~5_3`~CnruW*@V&7ZA+svfq1giAv{{;nw}oh^78Uo{}mA0 zMbO$L<>j|=bLU>1pGN7p2L?xmib>Khr8eLwY3I+A$Bw=I$m>2phKFG~K`ni;!?2j! zSSAMa7s*3Blx%OA(!G1f(1?u%WNqL`odW+6d*dE~$6k1Nh&>y4qad5ANpiOQj9q@6 zeA}vk+U~4reW{ianM);Zn!E*$GKFD8h;GJU>Om>Z7eqNq<9w2 zXGn;qXve?_z6D`V@O${9$93uNrf&bwB%u8-|M7o?^LJTZk9+0o`xsBoX-HP<8lDOX zp^19;j{d@hJwT=9#M3@C$1>sZDHrm6S5{pua^uDhqPgwueGC<#sG?#gA0MB!^#vIr zp|kf`18-%)Dd+b!$k%r>3D8O0j6?Hg;1#f(t z!OCC{Jx6rY6*q6)iYM`G0zBNdYnQ8RJoASz)#Oq&rxndf+rZ3hzr;l5G(~2R06^{A z5mx1wBbF)wa4_&Llhk!EE&v)2o;B0sITFbK+93CUPe8ynk{FC*pg?KvJ>~~A;6peB z8~VbXRyI*Gs*Z2v;4lq#`4wx}g3L2U`Gy43one8$F)*BD9o}n!3+JSn4eg1B2j< zZ9d6M=S^GvbE@HG-f$GVaqL(*I~vvUf{mLtfdIQ2WpD6yJX7BoriSB_ z%K`eXMc2y=o2uYRW(y&HGO{g3$r{jzehm%Tly>>Y-M>JomoOH03Byz`*oVi)&XIg6 zE-rq%jPa0`h0V;~o;~Y%A_v4E&;VsY!lC?-RH2Op&>{ zc`@;HvEeE1s`UAPsNDw`vZ?{uh?FF2^;AYUzVR0xad-vY7L!94V(Jf`Xg4t!>2+_o zAr{YUq5VBNdL}`r~vV_Au=o!wvvi&!KnBCsO}$pA~}*`vpX1+FKM(}yoOtIfvq40Q

i3E)x@`MK$hsqB-8|1i!!Hz;|j~wK2-viBN73#K4#vSK}?B zB4}bZ(3HfvFZVxiU<0#pnl6P3GBPp~dd*xSH)(?s0?Fv+R#d|F_?O6PdrrbsHYP}G z&Fk_rl~=AMZpR{;?9h>b=5Tz;#I>$n<;J&5Oq|f5zl0_YPC|rMdH)&RPNcqQ`wVmz zAMJdr^~;to7vsj~7%^=Wvr+c$kO#Os(x`^La4Ien#QzIKL>_i@4`~3d;Q+R((Y-4A~|$;N3eJz$irvP_Puc0 zZ4j{&aGa(Q+|Zv6mUgKao#g%l2Z+A?iz_Sr8Njd%Vq0eNWK@oI62F?0@f}r!o z!N-aYBrelPvpxi}n30>CTe3Oy=+RwR9+usFA~1NVc_v6hZ1A+Z>6+~EFK(-)< zi4!N%DwFqBml`+tdenanw@&lOZe1EAB6OnSw`!-Hif ztm$(nE9(XA*Z77br%w4(50uw?fClJBy}1L-NY5nkK{zSCkNAvysJgaxS5neezj&Z& z$<6m;VJ}S>rAe@284!R*x~^KeQfcbc0Z8FGDl2o?sIHER z4X_eW=RWyhA%i8JA%7YhGwve=PHWSqO*tmbe%Eeur}$y0jvc#qzjzL>UiF4G?tA6> zG{H+@S~Kd9PGT&Bh6Y^HlO?&3h$e$%*cB7ZJzZ?l=c5Uv{R=ee5gBUt^+CZTyEtUdyTE99zS_*>JOeqCk9 z=?V#Rjob*Al(aXbWBrI14}nR}1%(=`_#HghZ|c;k??JlmZI_dkRmHnXt>#uI85!~8 zja#+^u^&TqcHG{d|83&rHLq2fN$_KM%wm_Xukhs@zy3pl>561UCM1u)Z~xOe`Al%) z6SeeaWg@PFw~W5C_dcPn0H;t5l%F78a7F&D>q0qAF9Gn#|_=3z2ZY&oZFL%*~ZbJpvB=+aDzO%>~WAkTo&*n{a9W?FTn$Z3G=U+TlmJ>1;*9*!*o3FVsYJhpi z84;MnB$%lSiS_*UhN}4Wc@y1nJUq{nfl?SOUL;c^rj>nrihF?Nw@}D(JnN-$Xj<)l zKsO@_27dU6uvD-1tD_#MDi zUtV`yuyO0w5acqJcg-woSB(#C=R<~F4u3otwc&BNM7gPMS0#s0moB}eeVr%a$3P$% zs-S+YSNg~WwG(YKd`$Do5Y8<9_-Q@6Qc4##V$s(r#I=`b>BXYGYI(y5hENGixi=Bh z9olQwr_YDdCci()=p-ZHmTYpPAPmEESR?0w6HQ_0#M;E9Q6C$Ov0~{ktUltqE1ay9 zE=gXEI_}7F1UHCBhF~Y#o7`La;US3)_ zgb1@_{)L=8=egTkJ9L!|Yd#wiG8|U6m%RK)F_D-3j8BxTq;ValttB3UkB@|nb)3F3 zhX6Y3)2iBj_?n5?-*tF1f{_jrBSn(KF6n<0U!SPbDw~J@?RTxH_4&;omV9bwu792v z&AzbuTxUU-8wrD+1OB@bm@l4D6=Lflh~Hu1x7@y;6m7JWE|3n_fq>%3{WJ8Ma>B8r$q6)2@zhRzCE&&zsyliFYq_V` zq6hQkwpA}9 z{a;5fTT0*_qSIXMiyKNdo#NfXLhH#FRih8)9ZOoE^LH_Yssr^*EiC6F=var>mlmFT zv&Mk|Og4{ZQsN7-C2eiPX1({_wy@DUE^p{%BVcLa;^b;!5z+4W+q5UsvTauXsw(g& zX^gZw?~H)yAk_9qtMzoRy9oE5HTU|loz?(4>W^4l z@0))yi(Tf!A2XR{VM|Mkl9v`tqN}GTg>qzfuJ!hALRj|e4LMwM^~-jw=PXL*xzi6W zJCv)Cg|TeVm%(Dcd6JXyp})h8&%t;1#%dqofgtSHu@ft{%r0v<)UBLhw+H zl&y^fjH4r)i}%@|uI?V#OYp?k#NbrwD5GgZ3-{%nQbIw79>3#Ry{tk6Z@&b}4FM$ZV@Vf7^S9riv5+Ge4jm2--ow_G4K$FLczZIN-bNx6H0T;2k%Eq`hv?bz z{lgOX#L?Q?gA5{iQj|2}1Tm(>>3P_UP@98q>Gd;9X8fXamfW1YA+zAMqj$F*%!ltA z?>-fF;qv$RX<3#HpJX`Zb$$>+m)U zu&(XSMCj|NrO!&H$1F=4`|aDe4Ma9lxcP`f2M;E`Yxy;C@gQt+>8558O2=e_8+E44 zX(sW>PWMPK^}gt~b{_ms9H&$W)lWBh%?x(>A*4N`uLOM@i!c*XX+IT87Js;K8m zZLkE2YaBFg!KsBua%YSolRMVihcu-`pUZyr)G$zf50Hg2OdZca`P|USpMN|*@GgCa z4$%VjjpK`_r!-efu^b4_Lx_bVlHNX@r_vkw%G`P9o%<^)9zq!~%I}l~7Y4i6l?>O8 zN7=CW3o@9X^G}++QyDtV{)1!J;;%o7vvj|&*MhS`0Fbr6=MW>~39kq~o-IvZP`K?& zPuG!zi3D`CUNAZ3_V$D-s?}=d&MZ>hXpn>ouy9AK7rmY8E3JZgrsT3`yUOsd|=`sqar5Dg ztLvs&OjSoXGSR!a#!#ZQqmKyxkI(8T;(R!RM&4NE*wK|I7bKq2ayf`7uLODS%ysKV zgHQOPJ&M@T+JbYm29#d#%7%+DyaZ_^o)A6 zl`R-Xi;!LQdT;iSF=M)2yLN4uK^1e^XskL?eg3%hl&SxW%*-~dBB2hUM)4s<_5qok z$;S6?D(j!HHdltfmDO}IxwS$o#p|F}-JJQeY<5f#_8xC04`iRrO?$Wm|EN9=VW(aS z4jt>P48Z9iSQhNnLn1-p9wMm)~}W8+9lEgH^R`JTC>fwQ%Io zf;mHldK6_OM7pkf`k;h`!$%o?U!U82Diwv?IvT*cwNxT28O~3adt223d9~Ti0;$&A z`7W+&^Ck>;*}CXrOpMy?rq2V)T8Y@U4BPi}SaZto`V?4&B`?<87Vx|pI$TC zH{teK0D|BTW9O#4Gm-C^iNlHAiWQ4KZ+L_N?;l{qlC^{R7m_{D1x8o`sSpzMT&1*% zHgDg)9aqnLYqlr5Cfv`qcX5dci1k=|TKQbS8TS~$dS6*G3P@6#I;<<*(LvM(nFR#{ z=~c4eQx~SVtsknQat>giVR@17H#}nEK1`_oHMLC{J76FBsZOM?5v~c#A;h^;ywH0x z|2YWFdl#xoA|jpQ%{yF4-g2t+YNY)ZR@N4;Hs@Pky!EFf`f)+isEE@=ClyW>9ij@+ z-ViabaLkP)t2&q3!-Hp5m-wq$T3w94Xse!U5xhPZnU{Ug*PW3^{Iza{8 z@70W;`UvtBpqiSlsxn&M7$zM`Hg@P5gI&jYD5$ds=@&)kGrM85K2Zf$Wtz3m5hj zf?`6j9oo9cYa#FlG<+qD1l*?t3ky1nuh+Fo5`i|DTMwKWHkMeKn4ndVih8D16a|7n zIMbz(GX87lJwGM>!{jAtC)cFRm@$Jp0NgS}Wi{7$@JepJ;!7LoenD=~DvwU~9US45ORP@GxOhEDJfj!~U6HXRNEA#?D) ziuzgVW@ti1y$36@y_n`6A0spm)X>3m3;-#gUtTnP+xG4M>JpF7f04~4-g%ZQGhr02gBWIRJ1DNG^-1wi>MadlP;yKd-TfDs>ZrOExZ6EQgOB z^X1X;LT%X)Jc^ivsdnt!y9DOg)vr`Uzjy1Qz?Vh0RbgcD_R%4O2oc4%OtA`mcI)u4 z)AKJT)Bs;9QK9l`?xO0%ie*uN!sPhZ>z{0%wOUgiNoE+!566v0;zV#mRs4z+Fxi=s;NkLZ3z5 zcy+##w;Ee-5{r2(@1W!zvVHg?sR<>{7_C!nop_e61 z#}R=mvqaJAKVZdV=Nm>C1Pa1nv{iB+0#jvSWQv4DIX>}W=Xv=Mxt(ZKGPmc|S^ zKGw?ebHQ(#lp$0e%CqD`Yw9t_^R z<~-R)?L_;!MOsV?ZU)-}J1;gWmwP~$dznw`pW5z;qL&uK!Q zLv^W^cY4cQ78XH;HDOC0!8D)JyjVEfe%$4KqP;#QZzI=3sJ}^(T_uhxSj%k^8)6vm!FzK+C zKB}jJ!a8;N+|%H!xOy0ej&BbvY_xPdISTxXZx&KCEK>B~EKzY)^YNps;=w#Pce5X+ z-tA9luI}~qZM1*SAqGl_(fMt3XC3LSeQ7IY36Y|ARYxFy$*z(P{n&cJ^AAYfy>m+85Jq;$vcN6!?TCzi3Yku2vGSq7z z56l)izF*9C_q^x{f{vVZj5sl(j7$O$KZe*b#(MMGv!{T31#eHz?73HF-~q-f@)sF5 z^cJN4ieEo}@};q}LeiJ5);!rOCdB=eX!jj>o>0ekuv8yM!!%V!KkUsKG5unER(l}S zmo_xF=;uFj-qT$w+iGTJ?0IM9$omGlf9ly)UQ?QSboRV?_rwVKjmJerilW^YiZ^7n z9YFEHzqX;W9i5!%zHOV*q)9jL)}8&lsV+nF?N@k=_{PT_+qbvl^%Eg4jQ_L)+`am2 zYja1eB~15Jo^LIMEU*9b(74)R=h z{K()UNle3J9@ zsR8O21O`F=8f?dYmB7*&PoK&l+anM3b2VS#>`01mkomdjNa-uMLpQg*{{C|Vr-dwE zym&VR{%flCmtQ_4FC*;>1&9n)U|AQ;^6E;o0n;072`3C6$Fh}v+h%CXC@Wm&C^(wVjVN=o<3w7pba z!m1p2}Bh4HAIBSFsB zLJ3DS@6fsLHJnGhNpL`{jq`+W(~I*53Q~*8^${|PPb}19>=cK;{cSuKO3V}fE<-Za zN#YLL7#yFMo<~;_Nziy1x57{kpM&P#l`*E~=H_bxcC<8IeD~o4yl)_g;L^rMBu*oN z=MIy*_NS+(LzgmExSIg~h+=xEj!wGXpR1#pj_M^RXEDYVmItEfjcJJd#w!RH^%$1C zH_c1u@6X=ae}66#*n!{*O#%9y;|fT-Mrpm4y&YP*_!iyzUYem&j?B{oDFNSeVmXPE zL*KT-S{C~@(PTDuuh-M|CM1j`a<|xrc4PJeJfQxzR zZN6LI*tjdcf;PTK3Z{df0QuX42Uq|6d6!5Rh;1WiUn%^w`K?XEL>7a#T$)|lO-QSB zbtNQaK*SmlsTEYYjVvtDr-?C4vx`TX=)sChN~Gu`MezZ+JNs6Z^nk7%uG*vzBJXvE zdli@~IN%*b3N-5CnTUx(=vG0uM`~YDQof=n2fQQvc8GF=+n4o7s)oItAltPT zy3rp_rR#u+kH1>sR{Plx5|UVM06H${{Smn}LQMlhM&0v}zc9+TLW@*OFStJ?CFP!W zm(HDs;w;7XFB2=6gN5^NAW1li76u0OD54i?wqRIpf3KosFUHwX)68bvmkw|)ZV~Z? z(OYA|;exy_L`DYRoEh<|Vkjm>@0a4URK6wI^w4yS|G*~nX}E%h4T@5p_AJ-BdZ8k!ixubPBWtugfBTk=8 zNZ)>|Zq@gc?>zCB^za)uZ$1X@dEDIE#6$_;JD4gthgYh0=FAz0E~{&$;s`iG&TNEc zpYSFfcHD*oX>fKLO96N+Awe^ep%W_+9j2U~4^T9e!DEXH+z>w_eS&P1*19YT(FvRd!G#(H!e26 zQfHl*!w{1juS=Ox1(Hys99O%MXY1ycC=PtyM5)>Ha`UGr-+qkB9f2{{!1_7*GZyb? zO?mhJ{UbP^G1>JVX$kA|%qra8F!V!2ygqxEk8DBDW% zkUCXRge8@gqjXxEukYVa!K&74z<^Oy(?`IDR)0viOjBkG0vPmN3$IkkO7=o69_a`I-A< zARqHhQ|@>*m4!&@dbj@F8aWz=Z8F_G^rXr)Q>m#XhdH%Wk3%s>N>HuTzMA$F5qZeK_}LEoB@Qo(9@~zG%(_Q?}$7JvLTfbDy z!}7%2ON3lW%zk-#HV&1En#NIugQ}R{jineBXt{tfVLi}` z?1qLJjHXEv3n9)lgUliA->SD8OS*_yDtHcL3UgDfSCLYEf$qdn_Cc8c^x+K);Dz2b zxR&^?%diXxa=$;KpJ1`Imk1?}7)7PT+VsFH9Xa1{u6?dTQS=%2&vMi19&3Q-Ea_*+ z3wx35x|aOh#UNtg2(3_K^W-ff#C*-g1&Y9`z3DW}6W4{_^LV>oS`3MkUe3X}MqXpA zau4nk1l|!2jN8wiGbeMYjn^Kk)e;)N_7V~QX?XPchB&mHkvUtIFIS*maESZM7L9{B ze*E9$-lCtQDyimI1Vq`doA%2HYFUomkANd^w=&mNv=a@UYUltNg@G}N!>f7 z4dRi`#jW;I&Sy@HrD+w7$WLWPBQpfp$~HJh;qRawv7#`7_0RtG>Ei$=mF zt^RntO`XR4RcPmgunyyZogkD$V1MKTI|v@sVDz`bw5ehc>+8c)#NU~W;Zk#{qbrg8 zBU&gTBoc9k?X)6SrVSa%9Xx50Z2I5yUscntJR0@4ZKIZ5C+T*ob4UA!NlUGX8#{Ve zxNmsI!-bZW|0Vl&u~c|nF?wv@1rJt?`cN|CQgn){!}2dTpD$b(rF2PYV5QsZ%a-?) zrL%j@-u6hZsyXvegQHTLEy??*o!#76>Q(foFh8|0!F|&Yud4b=&w=1iAAQ~Y^mnl> zD$@*>e)pC`Y8Mn7EM8=~>mo_QuX`K<`Z_l^3qn9*vT)%}7KSa1U2J^3l*IV-l)j?T zV2kEfbRW*4*xP=V_oORMlgaJceyr3^pJ&@vuUT^d&_uK|9y@KUtOjC&yPva+ZeLAG z(h@_>fSmS?)gL{r7Mz)u39_NSrQAf40bh6+7@GWyGwXktWLW+%> zQGNxbHMW{m8%PiSv`Q%kFD#@r)SX=BECEj#7&*4CD8h~4M| zQVR;i+B7l~aJIR-*F4~1p{{;A>ezq(b={-46VsR$!-I*x^Ag89m6n%3?C886lzRgn8~N!w z&Mr_WDK5_F=&ZZ8_{-O?*@s?QE?Xv;EkTVyEcfiwr(KsWU4ViIG}lDVb8r}T=gytR zQZ90X!o{sse;bM(5R&Z@4i@g-rKKl5SZy?7KnCqP#N24zbGCUrSeGPW)yi*G2@ew& zMkcC-?oF_&93ATv6CdA8@^ifQAovbU&3PdA!15aFXm}=VD3%BgiyBfcdXe%GiPMuy zP3wN$O-`PIB(Cnk$l%T#t9#S6^0@W! zeO>>GQr)#W9M1SrXFDN=s^akMpUz}Zox9zJ^X&LP_y!~6Mz1ld*p#e*9W zuD{OloHCK<5Q5~KWeUA&oQKzs!s*CPN{U)rv~aj?nqGRwt7!d)hK&-EQb$Rex3jAM z)ytPh*qjXwl@nPAt&^BNhcsmk+$T+f=;h?TKclalNuq5kMGCU6==MMW>Z+eNwvXyu zC_Q+0YHUV^uF;p~0kEmQx=KP)O_k7X{*09YAHn2f6B7;2YJ1tPU*9djd zUDZB%qx)@58H3DbkFHX2?imc(ZAtSK1G<)$@(AXCW-qg}wC6zS6L{>u=(!xNp>#Vb zfI3DO4(qr$IzEUglk(v>0&4iJH`D)KxG$IrUXVJ6Xj0x6V_UJ0o1=9H0=+NOkaRZw zzO1iV)PxV#D~#vOi^C&hv}J-c&t4z$f(I+C{m}8`miL);d2wg+aOh*QD=J8%sAiMPj+g!VF;eK(sbDnfi zT-co+2TR3peVR`S96^#x5&y_*-}=t7D=tMv38cG6@7^n7=ZBj}1An6Fk`sqN57{7& zCaN5O?CQsc9e=b7+0D3+7n6hjC*82)u|m}_bm@At=($I{qn7lJLDX-$bLMms+F6or z38e{#7~LW0Z>%n!;d?5;UtbwO_@vHl_c|4XqxDaZv>JAr%d0#*h>0Pon6zK=j0_({ zXd&NqB2-Kdj4=q!lIT}gfB*U}6w$j%Z_2o`N=m5v=}y1>^7i6nU(qlUFMEK1ht=n{ z3S69X6OnkB>`NEuFIekMuRn<|JRHW$DGVGKM=8b8=e0kc<>w3Q6HIOlV7fsKQr| zS3G~v@I9|1CtUKOt|zyYVvyDLjJDSSXz?`HyDcp(o;ujBjFW9nXrN6HR3`dtLx!3t z-u9C6UqtpH+U^ud2pZ3!mdQ1*bK}(fazQ+>{&SQyx^L@ ztYPK=US|LM8x^2V16tCI7sXgFL8e22D98I0{!G(6eiAmjnXq7;oOCTLFE7>$HwjBj>?f2Hjn#uuH@^Bo`#Ff%+AHd z^<=ekDV#2_t-81k=M{7sFyPc!jc!*Kk34beR77NCdxD|DgOA|iVshB$?Tb^6y~9O| zu<4cpA)UYC5%g~-mAUy)XHQB}z!o39a37<{UHk{iil+QI=Mhw-p)+BE>74bhUJ_vS zneea}K0jZ*=}6*>V~IyDUyE0~9Uu}g=(&g?G!N`gUp`7NT$lVW_GWr3-v_e>X|W&I(oIg5z74W$`J z=6Etfe)!SKMmVmisc!)cYG^6iIGwhvFMlyaVX9kKMsDr_Rr=8NOPv1J;zs4m(_H}$ z)*)~D%7;AyWYwjFWe${CeXqc2qoDZkhZeEtZ2K3`iL3<#UbqMNn5f=Gn!LA7+YVcP zU&V&RUlgCv!6MMG;stLeuq}8!HtWYEO^K!92e{1u0KUabmkJC`mievB$3CakwX`09 z3snECSR}qaKwaQ)ER)kJ*PYgGW~xzNxhJq`Qyusph?)=*58$n<@)2o@=9^n8dN~)}4usWU{qx*_1w~&h8c%C$ z>)3<@;K$*tl#kHKgu{<$8k;qr;*4g83K8zg2MY3_!sK0$PFxhZgKRKk~k4EPyb-|El~oCH?EQO{_YP-SqzIe)@NyM7vqrTwm4g z*zBQJElj4iW{>$%0IhnTs_PDm_5_oJ114`8x(p>*G65v`g+uM3=Vv~%%jb&qkx)7w)!^?*X-#`l6^e>qwrDb8`IATpeKIvtOS+cLB;C5u-Es$L8pQ-=QaZ$KJP~ z0t<;j^8Uc!(4Z%ZfZSF;EV>Vv7y$d>Q;@bjDW_q6vSEhocMn4(dl^rTlmpc;XgvRV~0 z^!#va87q&xH{7maSw~Mv-87f{$S3D)>esIy74PSyxbgrKRn421H7agVA$;28e$ip{ zm@#$`Bnek_m3;jCI`PNBz5J%UezsA8sa4q(o>#oNCG9_D&!|3km6Bp6$wsF{pH1h) z)@Q-$rje%MadB}un^aU(sNR3P-Q3<%OK(SEw{$-{&D#IFib`}kci0<^D?mT!vLIW^ z*BkwZwurt@q(w0$DFCAgn>BD$8^QK$Qon!7U4*g-1a!Z?a>M|u)`|r??m)5y9PjI= zk9b%IJmbjW!#>(xza>uH@|V~ZRIpHfvbEJ_{*7T?otqL&bg-;L9x z9&i3E6rW`p@z znn-5m!b%q7Ih_Tf1*aLnknFVF@})}=!cciN>CTHZ&15=^=H*RgotiyCSDhKVh#bV$tq+ zXZ;{h=il}kAglu@?TSBtejDz+M~`G1(O;S!1Y&X`G<3%3esy1n9ep~tM?Y$O0G z&YvCriZT+1p@1Ti!Xgvp;v>|r8jGqr?_>~UHh^-oeZ{}r{{OPFUtm?!fAAM z*((f^D0@9R%@kgr^vFM@}L=^#(ObjETxfZzJ<;Wd?qX*VX<1zXP}YAOpg;B)26WS-Fz?ZtZBj{leG_Lfd1}+lNsRV;+vv zSL#t--|7Ad>tfjGFnSAImgq21%SfE87TiJ-CPq&=*0_gsJAf-8@DQ#531a}X$EmB& zpim|6ND@@cIKp!ZJ*_wwTgZB-0|396Sg0`BN9xrX_KJ~xx~I1(gXHB8Fosb3fUcks z?zQQ>woV9rijWs|h@ml(v$8QdqxipN=_)kHTMC|0J59AxGxWG7o)7hQhZr;)-Rbq3 zH*%^^;U#ld*9p+ahe(`*wYAT5pIud$CXiMJb$c@$E6{E@t|KQ;b|#vEvKHE&^yJu$ z#^1k%HWTeA)v6x3=J;p}G2|+o6p^XW4$Pz^L+)_)&)>u!4@is5wbGh2TKFbWa03g< zrB3905yp)^I&P0HJovx3FojrwnLtYB+V9t8jxK!w`xl2+!#7tPtQe!%)^bs;&eTmR zJ){L#$jh|{p>O7##y*;KuRLzg+9MPtKI(fJhT}+dIazi4-I=*Tk5&ryR(6Hzm@#`O ztxZgNfLv{$5D?NSYOHY@8kL%-BGi$xcsy9YNINi4`}QSq$k~tg({oJX`_-1HZX3ep zm>ttwLBaUF%Jy)e6?lWE#Wx(2{uA~ULjx3b&gkF`yjxU3noT6WEaKbykDk~+`n>%< zYD$+vA(G>T;r}^ANl9WKCkx*54x>Ku1@pD)aPuJ^rH#UIiGC~HRl}qGKSV}+Z87DU zI88CQzNV()ty73uRQm@TZ38NzLl=z{&91)nC#$OnGbsz|YnPJH!jkzvdu3P-s?!SUyB2!I|63eBSu z%iqir?K7Q^9Dpp?8^5Ko&1z^bg6Ik|=IOV+`}H%(%o}z0_3J*olh2<&>#Vg1bN%$4 z0#}HV$er}Xkb*VL&CBihlUJ0SYNK*YsNH0>J#Q9v@65xdgEHg@Ly=Al$QK&QT%`UaiiMfx7U|h-)~eFM&Wa2&DX!C IXS?hF0TGL}EdT%j literal 15177 zcmdVB2UL?=*YE4rt?aEJd#gwjuz*MtDG})^A|kz)P(%nV6d?qpTTw(nT7VEhDWQY} zArLy*%BDjQ0s#^NBApl_B}78I5AXYpao_j5_uO&6Gsbt$IE;~zXC+zBTx+hi=9=sG z{E0U=HQ+lga{SPtLwtt!?pPi=bQpT*&@cMG9RohWA1>Dg-hPGMHnjc?_>20@<26tg z4$-#{u?qAI33CgE9P$bX^oJ)Sl{Mul?bbIgK>0f0s`S0BNv)xGlW5I7{Z}uFyQ*ij1Mfn+} zUwF@7`2BaOUx)C9q4@D!CM0wowV#OBq{cg_dZ%D`=POaz^}|n}ME}720`%@r zUTlrO;qh}8Ig3>fgw`gzGwz7dpZ%wwYY@HqO{S!ueS*z!KJO6U{`=SSs$b$s+K+gE4hs$oIQ)tmonrC;3u31YR{%}wp1Jh-RS#d%X>Ik4j^k)XBM zNBRfq!!^XaSdb3;Pk%*!8ru zqORtlGp4m#QhQL(P+so)PL2mVP3Hs1_*=xiA*2|S>KeE-+#nM#-t$E=e9sVXCh0>m z)Fh^!;U5$xgZ=DLY?+$ zj0&3n*x3|nV9seBD&4-?&04R+v$Wn*Lcuy_qt$bEXRc@}kaT(7LMB`*yFrei<}jN= z$)$3s!F4RODx_Txq>W(S_+!Z3Rc3hhbHD)qH)}mZWGw6@Wzw4JwN*NN?XEWUHX&x+ z^0~{@&}AB8-t-4Uv&2L0Qv2)b)KBrX(YePOU2Q0@;GCYb^xgKTZq=FqFY~bGvG?4w zMrAsBt`5jDDHW?%+LQL%Ht@pYfyq`^cUaSL<*F~3DVKL@KQcPIT))**%7`=6C+wc= zEjq+{UTvNG{;s=ASZ3G31jy_9Y8yRPK*T~vb~7`sD7Fpvm?&^N+~M49P_c3&GmL;0 z&*bSyHxBH>!3RU~&zi7gsi&bLHsXts6xLQ^KT_$ke%c$d2W?1A0yC&?09jUBR~bH6 z!ycKl5y_dtSgP#7?CDqP#b5lQrU#!U8=N8 zj(3{09l=TLX}|Ao_tC}))`@65PjcmRpM@f`OB+2SLrFKuY&D0OBEUz!ceQr(st!oX~d;IhrA-x80PfcNwT9}6Du@u9*% z5QCIpgyLJ8Z%D*8h6E*fa5sYV%@c1eeJ>1_GF&PfN-Pm9Lw9GZ+UeMT0PJA5Y3#)j zclH-3Ji;G_|BMKs{+WGHk+R}FqL`1hFibKSNpv$&SYnRQ1BNTnp6*HC?Jjl~EdI zMqb63>Z?ULw&%=k^z7A_&;c44^V27I)CUW+P795QAjP9=6iemA%PF`cOpX09zv!i( zW5gG$3u=szy@Cmy(@HL&jR@ZEYa!bSd=a}dXSZ77R_fo|OQ}!Z8u?K)F06-pWtDbT zJCto7-b`|8jq-VF>+g#&&R=nMz|_^%l>^2_%@ob`FvUEsWoZx^(rI=yIa8o#L208* zf;pV(utc>JP4yAsc2{2%@QUP)x3)fH+PN7;Kwvqo(T*t|A`z>*XBG0YKV;9Kjx4^o zogedK>*7%b>B_>`dc6d6v5cRgu1?84-RMR~8s=GQQDy99*G>~7a*G_aHpni=Z@j#m zpYi>u1JJqeS=$-LRy>gp#3S~qIPr5B#;a_iW-|17|J1p@=f;(&?5RR6U^e}pJsQy6 zV27U%N#9=njI>V|X^=PclE|e6wY)<kEzMl)E(^y6t3+{oN2Iu;ktPv~x@^ZA3f!duJ2lYK zyVNro^X%xrPur2tkUG~*luejeylT~#Dfeb_M<~s_?Mrj_5|=5}Kt1hk_hz8bpi;(a ztQtPgYzpakcCP!{h)6X%VZF191qz|(%We#WfmN>$V2&iZ?aqHwlI>etESO>qcC3Yp zDJj_~^E&T)XGDp6T}B1pL0@%EE!o$^*b^V)9KHH_Ee&yI(qag#^kjLal%82si!mtC z&3R)xSOJ+=mal#Ki`U74p#?z?Muv$=BfH_r%HB}cWDbqkHM>mR8^nGnf49x*4xd#kGakAhC@|k<&8wEHz5o%&&Ecar*o@h zpgHKAskR7M;b@zsVO;|%cRHwp9+aa>EF-$wT_#@(p}}}U_626rGklk6$+o_8udo_h zAB}5kXpL*+qq1^8mZtZImvNa;__}&~Rh{A+`|Xj<^7n75Th_L7+dfZf7(rgjXcK1Y zg`Ky3$5O9qUT}r=TC%yF{f@!j=m2dkE7~Rx8*3n&%Md{at2A@F8`^!nYd2m_F)6Ef z)ZE}E=J%sIARhnRU0O@ZcEfbdZ7~%P_sHKj*E3HN%y?B|1tmt0rS1QDnrHRYgF0X; z1!Yb3G3p1PYIEnEls7CSk{#AA82F?1m*S z+c}g6@M3Cc_}uElB`UevI{FBy5*VV!8)=o*t*>LF7WdF|vt$I`I6vE3sSPF!Qv+n= zM&3FRG>_5DNFkV+%)2sW&2l;cp8ov(;*Aak3hA`Kn*CKd91jllwkX;bL zPDW?p);1ZhvJnuD>6nN3-mlVKlMlGvYBrKM-DMxp^!|=Y#|8>zU}8w#7&^1N=lY9w zX#Zowq2**`Ej-qo5t?p*vvZIbLi#H-vYS)yAz+tz_jgYXEsj3=l&-p^6?a@&`~Ir1 zulMhA(k%(pE2KE$Of`2;d4%8Y`=}vI8Q!U1-1z;)M3$&X zbz|t;r;hv!utkm37P+=!(&&9`jU-_1`i=Pd86jlR?r8j1)}gKVA)iCf2~6bSnQ0kEUH4 zo_U~rZWZTN(+*h6bD4S%OInGvm34sLupak8TSlXMpeUYf9-I(nTTuejIN!>Xig;hl z0-}`~%C?lIC_nN|cu-ec0E@_d=v zzED=vnr^eBTrVPKz1?zmb|N)v=J?CPfsDNlEz4t#W^eE8RY~sDU%$VlLe%U)J=yAS zPBL2o{O5}&d|&k2oz8f&`726H@yi)os~mwPktR$zKQOD)Ax4<%Aog$H3y*uXQT$+& zUe~hj&YEGM@J|I}tmh(WjKq*WE2ACR?#<;-c860wr;pb3zP&yf_tA0}9VM0o;!xIEcN=DTd$9$Gj zR35iTf%%F*j1-Wo3-NFoe#Dpg_}Z8={i;fjriCb2p|+ORb5CjHi?2^Vfg;X>A|bV! zS}&~ehP>y}Dm&&PZAQMZ->%HIDYNGhp1Wbt1j!N~Y$>09bVF0QPElWy7im9+Sl0Qt z&iWXHX6hN}52c`1^|pInlyl$t(R>PWs3?(9-fs@;H#zZl<`ff2I!a?ZuQWQdr{c4L zKs1;;ntLi2R+rptpTZJ!xG`QSmpQwLe6=YfM400UHF{eh&}YMILJgY8<)Ka->cH&M zcCwu;kS$IAZ9Vc*gW^t$h=k(v9Kn@@P_SCA&H1)aguzNLA)*3frSE)CG?cJt;3c@R z1ul7MWyz)>F6YJ+tv9atF9*88gDq+s_ z`4)epSO;f}Yo06WhsfO(q5&WF*65?@3uQc&`fq{5&;Hf^LAoMO$flSk1R1wBWU~9% z!ELnlOHfXMv7W1aRxyM)UR)qwiAFpg89JL(!7{5ny76i9I;13e= zkYW*E{}YCU;zsnls|_ViAQLy|h^VBF1!Q?e?ZXwjet!h1Ioiv^Now4M9_CQi+8UJ~ z6JnYIBx$J-qgR>5%fPm&9_YPt_3qv-^RyzrT5O&8=&|fr$b0_>Fs=KfK9)ruFnwP1 zLF;Dyem(pPIMq_O*{m#wl_0WV?zHtv*11*{ye^$c666mk^BeUa7+uaYdUMw!8X8u! zG?~F}WDTV_xN6ms>Mc|6*2J6wohu(Ke}EM{yo;|!t8zT*}k2m#BjRQ_&mI1dUk@@=lJ+Vtixn##XAqc?%$>g zsV7Aj2H$|-8j*XyMO~^ENL86n={s-wo#gfOl6I)?4aq`Ve#2@5_%a zt1F5#QR@`@EB?6 zWNy}I8yN&PG5a=fk2m##&)%<6^m+rQ1x~A^+sXVck;_5ak}_@5{qnX$NxI|fX2|G^ zhA>m=R&G&v*h3e7vm%ERbA73wS47SoyPU4lyT#9j9bHbbpOui9J?ADT z+kg7Kq>H%TeJ6?M`5Dw>^lkUEn)@X&zZ?7~JTQr$|2Ri)vYbRF;xv!Gb}`Psy#*5x zX|g8$?itUO`r}_nQ?g_ikOd#m7`0VfTieSKyujt_rz-#8JM-+v5qkmdH?)+A=E*;o zD=x8NJpb0LzVMJ%q&QIH(zAcXRB!(~3VmUd;RYC6?EQ@!2YBUWc3N86q7ra<{usWx z+A=kSkItA;S&ECCDVa3fLQ7MeJ z(xuOTW$C_N}9^dXhzjeY?%c{SVE@e18 zaP;4-s4v{Z&(k^OJ5Bj%#;zATpWj+z#h8AIuDrWvwtFh!U!htKY{i$deBDr-!^?YQ z0ZW_RdBtrw{4ayUUpGI&RxT9?wGmYRBZG~wt$>c$A&R6UuEaxL=EkN~#0*K-#2GsA z8#>|(owpeZxf*IxgL=Y9$pKxT&3MUvb)SW>)OmV*xV9FMjk#?KP0T=%>@jfvaRs}?$4i1=5j{sYL-S7L3gg!z48YUs1ZxtZ{nkPzon+w z%0Z6V0K&LZcKI`R9M7f(otJETXR4`XlBq~KxA}N<<+tTYI}!Ml)K>If=w>btkiGUU zEIAT(-W-OMv~Rt-@HT0i0cyC5iR0jW(=Qxr-p52)D~r*WWu0d}`8 zpN(MCPOKekT@%_eIyb7p4s#Oy=%TJdYfqSM z{pwS5kSxcBGX1H?6{fL7wDsM(T@qnrX6>)6)TPAn+N;PFa>0_be&j=w!k#1116*xl zm4ix0k+)?K(htm?cL!-Klp61P#>4@~&DO@|`|j&Bc+UjxYRtVE5H?B4K)BLvG3mq! zP@(E4vHHQPnK03rQp&ZSU3g~Ng?S=N+3&(9R1byM#Mkd@nVQv0lJz9qlZPe}k_DFQ)_2q(@;B>^Y9o*f1DyG1c z*i?E?rSoS^)8(RGiHe~bB_EN@QTU?J;BF7EoOXMmbfuVPmq6(0OR-5%`YYN%kpKIL zxxSqhPt}8T`B0nKRybV-AWqu&%?-_8oOjBc<{;`N@tLGenJ*7ZrM?dWgkyK3%(wU_ zuR?Y>zM4~_X<^ag3>?l%A#zxs#n-hSE zVc%|9Mk?%ntvIZpC8@A{NC6zNf%T+^2gTzr(6#ESY%;HsQjGI;aKDu+UCE8M#!1R( z+cd6jTpUqVPn+-IkxO4zse#8fiMK>lh(};eEBym($5Kr-FIoW=2dQFM?i?N@+r3)a zs8mrV6%);Rd$^ky(`v(|xE`U_sFh~B?{M40+m#cOo+Re|Quop?dYYu&IET8}FPvYX zIElAAxJ~>k&&8w#=tc+PYu!>?5+nLw)c06-QuIV(diyfFamt+8e2E`GE~+h(-+j%H7=}E71TMt6jV~MtF!&-1j|dAQUGv+Y2pMeVP7am?2@bk8V!=W!5cq5N(&c89kD1<^geLu?>xP z_9H3{`FBfPPL$~D?rtrkFN;~G77R7FvJ@W2nwgV3(T3=<7C7SaS#IRpFzQ%Ew4)v&CTLFx$SiT&RxkIq0unmr!{P6h zDyE#Y?!I&o2l$FM)qWCp=`#9$+y{mPHg#a8?}b%T*YKRt+D5FG07G!VhVeGsLAIN{ zRMb|@{u|7$_E?Os%q&AIRV2i?;`GTXY`oWDZr`*6cjuX3OLduHosQrax~m8|61FXv z?c_|;Z3L;>O!+caj5^C(+lPdS)Zzre1q1p8Uy(LPe2a-%o<+r^-D^8Ov;0gcLyb_! zxYnTiDFM;kCwY@!Dta?#`O4$k$_(5w#SiRdm#$~xRkTR;ngMC^YHSXp$3rA8Aado& z3h^PT+R`L?d+wNM{FQ19VDmY;NlR2_;fp1gtFx zb}>Yl{TSG*tm0a6XXrV#LZ~YrcnBgr3k`56ciFFZFgn>kjD6+hOWS+j%Dz_&b8fQ^ zh_lmGue0aFB3gem(KTWacr9uS=)+$n^Y5|v*e11L-Gei%JNjaV-HqZKY zaDdl+`0c#Q;qM5QbPY`HKXGO;?~+d^n?;KRDnVP&x8(A!VV z8kR^;1PA)QwCv+7c4dtKIGYg~NwBj08iK}mmULCmhLBRvn!t{$G<20LcbNC}?bS4z zj-nQ4o0YJa@|g&epkIZP*UKRjB{=^ zb4mWtVQW+?HX}>glx<(o@{WxNjhj>m92sjiHj&vh$ZI$0;*>QiofGc$7fko=$jVDw z2e1Xw1z9k)FZaF^C#-1bxNCtYtnmBNlh(S}vJ5>(JXXIh&+b%rFzLlK(e<{-uE^D_ z)XVvu=C{xWGCF#E3T|^n9MnYW##jzq1LlH&g<^6ZQ$AuGV0pL9WY>Hap5heM&w2B` zY8!9GNu1bP+l>_PJ}&1m9d1X0v-!mabVAkZ!0L<}Tl0<6}uYrUu8M5OswtB5xr1V&94)doJ?7}HUj*r7)_ zi5l~|6+bLHxclyn)hL0(F zrM}cUGxwb+Oi#~tG)d`?11nfMWf~HlBS?O9T>QAEoV?r*AFz!%f_Wwk;2j}s!!2K5 zJa;0Lxp`K*1$V2?*}VLKmelVomK1J_vThOz)xEn_QEE=10Hcf24rV|8$BToF2N-4d z0_b98plDpgIperXlc2YRdcYc)AsRv&j2^fz1?77stJ!Z|K42Z5xPE^9F9CP>{+zqb z=mh8H(5adYIuNbm$j9N^FN+K%(r*(IGj0HXlD zdZq)q1zQ0UhUQf{O3TO1$Q(c)xzx68Zbe{h3mG!18V4D-w8|--e=rqK&lUCkxxqaqrm3JE ze5|t01M@+(#xcY(G=vZkNpQr9@gHl#&>(lirF{YW0MihbL5nylV80CLvZiRx`+3T2S#tos`m@;=yt0 zV>$cg3pds^m4Kk6-MHRpRRCMrOc9h@+r4-~dK*5Mdd_+Mi_XzD!j8C}phC~k@Wo`{ ziBXv2ND{+kJ6kJyfAv0)IrZmcl{cDqM)-jP1N|vhP{s%)fPuyz_ZTCsbw!WcCd1+7 zr6+vaNN0B!zdzH7Wz;;nUp`44Dk3L`xQwx4KIpxP=FTsyU2L**tW)qYDqs;z-7RAYg!pWfa@Bu0Om`zMq>zh22(!{?V|(4D=d z-2+UY+;`zj1S;|7`3FXp%da>x@bzem=i+bV_RWhi2v6u<(csp%PYM~#M&_0y9cDQA zx+0L7be-N*?ZR=^P-Y~QP=xL7551K4merNgQN*`;JLf}3p zza0_Stq+5=NG6VPk6p}OmspB^sJbqZ2t4+vQmSl{q&2UoqEJn_4kwa-DE5L=1yU;p zx1Noa=?}ThXgOA)O}z#ro(A3_+z#&0N)`S)p1vjjsHoGF@0G{!@&&`q(ts)Tb;}Uu zrC7oZ6K7!6!$45?vvN14&uoS1WOuDTc=zn|fVFqu z!-f4}#p(Zq3c4aG=LpS_c0FMHAiU(zV8W8sl~~+u_skykPx2PTUkk+J_4ih9kyiC5 zS5tmH$cO=U@%8d#PfyQ(2vZ*Y6Sdr3pqTf%fo5v(zpo%n@TdI; z6H{rW?&$wxfHJ>9%z*|8{*T9L1_Ow>06|(;)2`%yvC7WZt(;<_P|DF3#ZE^Z%R0@m z;JO{U!pyzlDY^!3e{`SJEKVn;1wPz2-Mh5TdG4$@I0M_BdexMdA|#>s{P6#`TU_aY z-@ZPa?9Wk?{x_J}5%~`Z-T$f){y)G@bB4p$6j(ynMn*qKlD&t z2}o7@I(kHK-$6L8hUCspk_1B0xY2=w1U`Ay-6iwXBESrgygmP>@KZhfoJV zZ`RkR$dMny(rpkyUlq)Xe)~0MNgNRd8GiwQot)=PdXL28X{#_fvn!xLkld@GW0tbUC*g#TgZ>d$3#jC&1 za}I=vfE-Z>i`rI&inG2TYd2t5hdQkG4PZpBfMM#uf>mT>HJeMGiW$6X61}PABkDDL*KqG z2lH#ZoPXVIjH4tpAirH_8u{&SH>|tl@*U?0k^9bf3r>ED3&nD&IT~C1ot(%PB)>2F zlSExi+!%%W>5pzxo%ck$XU=R% zOx&9@q3m7)Iu$^=6z18_NTFke*SB}Ik$`6^A6ye^*y)k0yhCbD95$+?3%lsWjqf+g z?SYG(_o;<5F&buj=ZLbC6+=|lR?KOZj_q=c3Rqi_e|FDA4_(LF6-BEM2FK=OI^B77 z7gB^&;7g#q3&Rk}4(|xQrSlxRt>|zgZHoLkC&`?YE)QZt2bWnmo0Vbt;>;qrbI=v| zkC3;8VPeOw>7If}cHY_JQL}VE&)7?WMe%a>_^B0GSE5S{(G~ACuoB+VFM&~4G$?Ct zCp~jfjO0eOif2ZuK?(kt+FH~O%icZP z?)0-!jBCYbi@L7yF~hNztbMnHXM9YJHO8y*#n(~7;+K?hD-RP_YXh-^C`B{7$-Z+V z{?IR6&B@W!DY3#4HJr_6k?}~BIflT+u?>3VVNQ)6Lfo&(;F?(fTI2k(m%2Bce8gzl z=DF_ZtE-rVeZesA(Px)zBEF^Sz(3wbT;GeqpY*Fp=?wIxv|HIUe!Mb2hLm%x+00~> zVRcHIm(-Pcy&|9UF{zJ6FN1)kEGZVTq?#drWw=#s6SI+Zajw5BMG}$$S?q%HXZRa| zBq)5PCozFem~$)SZSjRpy#e?$H)x?Ig<~A+!OD?H~WmY{QdDp9-Y>EOu92!q3#Q z`ej=p2Rs$HgIXn#tolSB5dZ;K7qP>l-n+*>E($JzFj}_NFyXNnwUfgIQtITL; z`CD{ki^v)*@;X$^ePVMqz~PpWwCK&zCWMDdfkHw<4jSWuornd!-WzBbFoRidAjKHO~+0v)}(}diu-Ar)x#u z!jB?1##fI=S(kuxr%KZ#Zn&V`N?T-#%y#{rQblh^tX7Yz?(XNv7t$U#B>a{=Uw2Bt0jcS!si%C(BlbRwI%KDR z0zs8VX(l6aB4TWMo0OsZ*|RdlLeXS&kQv@Gc&8^9pRx8q*FC^FIEt<&gJmtT;3$fLj%D&-nT6x%B=OgP=LR{lU;2p+Lq%H%!oa<|% z-eRm=RP)(7_wi_zvQ|}2$`RI4RhB3z4Y%Xq%o>-|&~nW|1P+Q~WF={~;9M0?S~Vo> z%)zmr7m5M8MqNQ03iP_!>3zdY2 zH;Br2ozD;(Hzo%nD#+i1&{w2JSxze@XYZ!-{}H`07CMdy`QC5f`-u;BsaK@~jrmY` za^`I2a?y6ORky{<>SxnB4552slHDgAsuo;yA#Xq(zeHusJfTb(=r)7}TkTa*2En)2 zC7xuuL>S81;lATh6$Y1~`?(ySeJI9#FBfItZQ8k`@C}k%OF;yz3kWg}-Ja>budwPwBqK7fx>WE%#3GA1@!bmNx7;mxCkRG&d1YNGEChVgVKKdzN&jA^{G%zv22*{r57?(@#$1ggdW;pA7HQltGir_2~++~WFR;J4 zqHC?=2g;+qDghFS4A?SC1r}_mzq<)|c2_9Uwesz`uN{h4NA}nw2isigx=sS3+*YQ? zDE>&994-oAA7T{;9$`+0gT8mERt{C#ew(WlQ&q|{wIFT z#W&FTW`EA2cib!V?i=qNAvI*$C+m4=D?3E)ug5K4k3F;A+8zF4KE!dqp9)N%W?t%e zI=gj-l&NS{xre^4sBL;WGnXFp{L?G4Ge;%R_6FB3qOZZXk+RP{620+v=$3Lbl9Q-j z{Tbou9Q@j4n*v@#N2I$TFUR7{l#AOcvy0UNT@+&xW|OMP#`!XGem{Rs_wdOErmjDSe99N=@5UbTk- z<#hCJl`^AMfdJ_xe9vkCz*AV@;QW9*Gf1RvojWX*1YhN79~s@^#OI5ZBy#*HM4XdC zTe@d#c%yOtpD;f>QOkr85bH|Ir!%A0%4b^n)iK6_Q(RnVYmI@v!P&0NX`4t`Qor1lcIV%X|^t42*4vhl&_JzRmk|{RJA^N63pzRW{v5%Bj zwv-(#(!kxQ=bkRt^pC>U8gHA%=N`n0smPVT9SsWBD(#&?A5oEpW@dKI&d&MG-4QRU zO<6SFwDl{q=ZYOf14t?PIn-)MLMUqzOYD!jXAr1)VGU-L+iMp6<4^w^)Ar144In94 z#si{FK0I1qypK%Wa`w*xnY@Z+`Wgn-K)=DdYvNMyYKSmy}@2%F0+%ZQ8Cei+dok)QG4B$Eode+{ibl z`Uz=~J4z4@wo6l}|LpFe?wZ{EvBoP6gkFN5!-rWBn0$9VUpFF_!?`Zqo24nGPZ3jO z1&)2Z48VYiy>3;{MNiJv?_Vo2xpa0{%>F?KqCM12p;C#riTS!V| zl*u_YLm7Upis*dD(l0@scMvI@F651X)S7|tqwogSmxkYr)}K zd4G1R-^Wwdt(|9i?J?_tv&P`gwYtd&qv=QX`QqQQYYktfu_qDi9%Pm}iuu=XWmDs& zZ3**63W?eXtCF8@hClenCeQuTy#CtNN8)yljwrJyX`HS%vlHK=x}z6(cDlkn;My@7 zQ$^-BHpNsb6)*{Obwx6lp5FkgbHuAOkt#qM{F8iqBm+;)fr`0>1t@2Q9_?~WY%#l* zvA>|nU2798Bl((BeopV3tHt5m!qS?5A>xR$nnU@OX+>Sq_)@E#jnXy=6#w^T;9uu_ zeSAKgEWF~lVq;^YY?3x%a4Wevzr`?h7-*7=YIrhATwQS=iLaOm0z~Jg$*BFGHinkA zwg_J+)IidUWhQab`~D_Rk6B>Gz(U%{YF;e>gbMprO3JZ*sCD!}62HLk=xDEMe^)l4 zyj<4H#6NnVXlr+`tK`n}(;b`j6O^EaD;RrzryZ}|<@S6Z+BU5QIz_tR|65#%U*8=u z3oX}C*WkkA$|o+daTrXRo6xX~DR=k-Em9K<2A8#6iO-1i_s911_QIFTN>aCfTTbJ| zEEwN0GD>K6QCyFDvC-Wyt}a8~Qc0cFQs&N1d_(n>mp*DW@|O(0{(U+sLmMC|cQeA) z=llPIe}3)z2fbtO9<>F&sz3JM4)gwBIN{hnKUs{Ia{7`{4bX-|hWe&=aJL^k`@aD7 CQjcT+ diff --git a/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb index 5595d86a22..ce8f301ddc 100644 --- a/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb +++ b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb @@ -41,7 +41,7 @@ "\n", "

\n", "\"\"
\n", - "Fig. CUDA Graphs speedup.

\n", + "Fig. 4. CUDA Graphs speedup.

\n", "
\n", "\n", "\n", @@ -54,19 +54,19 @@ "\n", "
\n", "\"\"
\n", - "Fig. The weights calibration.

\n", + "Fig. 5. The weights calibration.

\n", "
\n", "\n", "##### 4. FP8 Model Weights.\n", "\n", - "The typical approach is to store weights in higher precision and then cast them to fp8 before operations. This is especially useful during training, as it allows us to store some values in high precision to avoid performance drops. However, for inference, this level of precision is not necessary.\n", + "The typical approach is to store weights in higher precision and then cast them to fp8 before operations. This is critical during training, as it allows us to store some values in high precision to avoid performance drops. However, for inference, this level of precision is not necessary.\n", "\n", - "The TransformerEngine offers a feature called `fp8_model_init`, which enables the creation of models that store only the fp8 copy of the weights. This helps reduce memory consumption, which can then be utilized to increase the batch size, leading to a speedup in generation.\n", + "The TransformerEngine includes a feature called `fp8_model_init`, which allows for the creation of models that store only the FP8 copy of the weights. This eliminates the need to cast from higher precision to BF16, saving time on this casting process. Additionally, it helps reduce memory consumption, which can be used to increase the batch size, resulting in even greater speedup.\n", "\n", "\n", "
\n", "\"\"
\n", - "Fig. Saving memory with fp8_model_init().

\n", + "Fig. 6. Saving memory with fp8_model_init().

\n", "
\n", "\n", "#### Benchmarking\n", @@ -123,7 +123,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "id": "7477e469", "metadata": {}, "outputs": [ @@ -131,23 +131,22 @@ "name": "stdout", "output_type": "stream", "text": [ - "Another string ... \n", - "\n", - "I have a new 2019 15\" MBP with 16GB RAM and 1TB SSD. I have a 2015 15\" MBP with 16GB RAM and 1TB SSD. I have a 2011 15\" MBP with 16GB RAM and 1TB SSD. I have a 2011 13\" MBP with 1\n", - "====================================================================================================\n", - "I love a good DIY project. I love the challenge of creating something from scratch, and I love the sense of accomplishment that comes with finishing a project.\n", + "============================== Generation example 1 ==============================\n", + "Tell me something about GPUs:\n", "\n", - "I also love the fact that I can make something that is unique and special to me.\n", + "1. What is the difference between a GPU and a CPU?\n", + "2. What is a GPU used for?\n", + "3. What is a GPU used for in a computer?\n", + "4. What is a GPU used for in a computer game\n", + "============================== Generation example 2 ==============================\n", + "Tell me something about NVIDIA:\n", "\n", - "There is something so satisfying about taking a blank canvas and turning it into something beautiful and functional.\n", + "NVIDIA is a global technology company that designs and develops graphics processing units (GPUs) for the gaming, professional visualization, and data center markets. The company was founded in 1993 and is headquartered in Santa Clara, California.\n", "\n", - "I also love the fact that I can save money by doing things myself.\n", "\n", - "When I make something myself, I know exactly\n", - "====================================================================================================\n", - "Benchmarking for batch_size=64 and total tokens = 1024\n", - "Benchmark with context_length=128 and max_new_tokens=896 took 42079.8125 ms.\n", - "Peak GPU memory usage: 65.96 GB\n" + "============================== Benchmarking ==============================\n", + "Benchmarking for batch_size = 64 and max total tokens = 1024\n", + "Time: 82.04 s.\n" ] } ], @@ -158,7 +157,7 @@ "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", "## !!! `model_name` attr must point to the location of the model weights !!!\n", "## Weights can be downloaded from: https://huggingface.co/google/gemma-7b\n", - "hyperparams.model_name = \"../gemma-weights\" # <== Add model weight location here e.g. \"/path/to/downloaded/gemma/weights\"\n", + "hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/gemma/weights\"\n", "hyperparams.mixed_precision = \"bf16\"\n", "\n", "model = init_baseline_model(hyperparams).cuda()\n", @@ -174,9 +173,9 @@ "source": [ "We put these times into the table for later comparison.\n", "\n", - "| Models | Time | Memory | \n", + "| Models | Time | Speedup | \n", "|-------------------------------------------------------------|---------------------------------------|--------------------------------------|\n", - "| HF (baseline) | 42,0 sec | - | " + "| HF (baseline) | 82,04 sec | 1 | " ] }, { @@ -209,10 +208,10 @@ "\n", "\n", "The class `transformer_engine.DotProductAttention` supports this format. One need to pass the following things as the arguments to the forward:\n", - "- `seq_offsets_q`, `seq_offsets_k`, `seq_offsets_v` - which represents the offsets of the beginnings of the next sequences,\n", - "- `cu_seqlens_q`, `cu_seqlens_kv` - cumulative sum of the lengths of the sequences of query and values,\n", - "- `max_seqlen_q` - maximum sequence length in query layer,\n", - "- `max_seqlen_kv` - maximum sequence length in key-value layer.\n", + "- `seq_offsets_q`, `seq_offsets_k`, `seq_offsets_v` – which represents the offsets of the beginnings of the next sequences,\n", + "- `cu_seqlens_q`, `cu_seqlens_kv` – cumulative sum of the lengths of the sequences of query and values,\n", + "- `max_seqlen_q` – maximum sequence length in query layer,\n", + "- `max_seqlen_kv` – maximum sequence length in key-value layer.\n", "\n", "
\n", "\n", @@ -225,7 +224,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "id": "4fc5e1cd", "metadata": {}, "outputs": [ @@ -240,27 +239,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "self.config.fp8 = False\n", - "Another string ... \n", + "============================== Generation example 1 ==============================\n", + "Tell me something about GPUs:\n", "\n", - "I have a 2007 1.9 TDI 105bhp and the engine management light came on.\n", + "1. What is the difference between a GPU and a CPU?\n", + "2. What is the difference between a GPU and a graphics card?\n", + "3. What is the difference between a graphics card and a video card?\n", + "4. What is the\n", + "============================== Generation example 2 ==============================\n", + "Tell me something about NVIDIA:\n", "\n", - "I have a code reader and it came up with the following:\n", + "NVIDIA is a global technology company that designs and develops graphics processing units (GPUs) for the gaming and professional markets.\n", "\n", - "16885 - P0341 - Camshaft Position Sensor (G40) - No Signal\n", + "What is the difference between a CPU and a GPU?\n", "\n", - "I have replaced the camshaft sensor and the light is still on.\n", - "\n", - "I have checked the wiring to the sensor and it is fine.\n", - "\n", - "I have checked the\n", - "====================================================================================================\n", - "I love the new Star Wars series The Mandalorian. I’ve been a fan of the franchise since I was a kid, and I’ve been a fan of The Mandalorian since it was first announced. I’ve been a fan of The Mandalorian since the first trailer was released. I’ve been a fan of The Mandalorian since the first episode of the first season was\n", - "====================================================================================================\n", - "Benchmarking for batch_size=64 and total tokens = 1024\n", - "self.config.fp8 = False\n", - "Benchmark with context_length=128 and max_new_tokens=896 took 27791.4375 ms.\n", - "Peak GPU memory usage: 65.96 GB\n" + "A CPU (Central Processing Unit) is a computer chip that is\n", + "============================== Benchmarking ==============================\n", + "Benchmarking for batch_size = 64 and max total tokens = 1024\n", + "Time: 28.19 s.\n" ] } ], @@ -272,10 +268,7 @@ "# Import necessary packages and methods\n", "from utils import *\n", "\n", - "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", - "## !!! `model_name` attr must point to the location of the model weights !!!\n", - "## Weights can be downloaded from: https://huggingface.co/google/gemma-7b\n", - "hyperparams.model_name = \"../gemma-weights\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", + "hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", "hyperparams.qkv_format = \"thd\"\n", "\n", "# Init the model and accelerator wrapper\n", @@ -294,8 +287,8 @@ "\n", "| Models | Time | Speedup | \n", "|-------------------------------------------------------------|---------------------------------------|--------------------------------------|\n", - "| HF (baseline) | 42,0 sec | 1 |\n", - "| THD attention with TE | 27,8 sec | 1.51 | " + "| HF (baseline) | 82.04 sec | 1 |\n", + "| THD attention with TE | 28.19 | 2.91 | " ] }, { @@ -357,27 +350,24 @@ "name": "stdout", "output_type": "stream", "text": [ - "self.config.fp8 = False\n", - "Another string ... \n", - "\n", - "I have a 2007 1.9 TDI 105bhp and the engine management light came on.\n", - "\n", - "I have a code reader and it came up with the following:\n", + "============================== Generation example 1 ==============================\n", + "Tell me something about GPUs:\n", "\n", - "16885 - P0341 - Camshaft Position Sensor (G40) - No Signal\n", + "1. What is the difference between a GPU and a CPU?\n", + "2. What is the difference between a GPU and a graphics card?\n", + "3. What is the difference between a graphics card and a video card?\n", + "4. What is the\n", + "============================== Generation example 2 ==============================\n", + "Tell me something about NVIDIA:\n", "\n", - "I have replaced the camshaft sensor and the light is still on.\n", + "NVIDIA is a global technology company that designs and develops graphics processing units (GPUs) for the gaming and professional markets.\n", "\n", - "I have checked the wiring to the sensor and it is fine.\n", + "What is the difference between a CPU and a GPU?\n", "\n", - "I have checked the\n", - "====================================================================================================\n", - "I love the new Star Wars series The Mandalorian. I’ve been a fan of the franchise since I was a kid, and I’ve been a fan of The Mandalorian since it was first announced. I’ve been a fan of The Mandalorian since the first trailer was released. I’ve been a fan of The Mandalorian since the first episode of the first season was\n", - "====================================================================================================\n", - "Benchmarking for batch_size=64 and total tokens = 1024\n", - "self.config.fp8 = False\n", - "Benchmark with context_length=128 and max_new_tokens=896 took 16560.943359375 ms.\n", - "Peak GPU memory usage: 63.81 GB\n" + "A CPU (Central Processing Unit) is a computer chip that is\n", + "============================== Benchmarking ==============================\n", + "Benchmarking for batch_size = 64 and max total tokens = 1024\n", + "Time: 16.81 s.\n" ] } ], @@ -388,11 +378,15 @@ "\n", "from utils import *\n", "\n", - "hyperparams.model_name = \"../gemma-weights\"\n", + "hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/gemma/weights\"\n", "hyperparams.qkv_format = \"thd\"\n", "\n", "hyperparams.generation_cuda_graphs = True\n", "\n", + "# It is necessary to preallocate a static buffer.\n", + "# CUDA graphs require static input tensors for every kernel.\n", + "# This approach may result in a slight increase in memory consumption;\n", + "# however, the substantial speedup achieved makes it worthwhile.\n", "hyperparams.cuda_graphs_static_batch_size = 64\n", "hyperparams.cuda_graphs_static_max_seq_len = 1024\n", "hyperparams.cuda_graphs_static_max_context_len = 128\n", @@ -407,13 +401,32 @@ "id": "53bb430f", "metadata": {}, "source": [ - "We obtained the **2.51x** speedup!\n", + "We obtained the **4.88x** speedup!\n", "\n", "| Models | Time | Speedup | \n", "|-------------------------------------------------------------|---------------------------------------|--------------------------------------|\n", - "| HF (baseline) | 42,0 sec | 1 |\n", - "| THD attention with TE | 27,8 sec | 1.51 | \n", - "| THD attention + Cuda Graphs with TE | 16,7 sec | 2.51 | " + "| HF (baseline) | 82.04 | 1 |\n", + "| THD attention with TE | 28.19 | 2.91 | \n", + "| THD attention + Cuda Graphs with TE | 16.81 | 4.88 | " + ] + }, + { + "cell_type": "markdown", + "id": "0a11b75c", + "metadata": {}, + "source": [ + "Let's look at the screenshots from *NVIDIA Nsight System* profiler to see where this speedup comes from:\n", + "

\n", + "\n", + "
\n", + " \n", + "\"\"
\n", + " Fig. 7. Without CUDA Graphs. We can see that GPU(blue) is idle for most of the time.\n", + "


\n", + "\"\"
\n", + " Fig. 8. With CUDA Graphs. We can see that GPU(orange) is utilized.\n", + "
\n", + "
" ] }, { @@ -442,7 +455,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "id": "aecee0e1", "metadata": {}, "outputs": [], @@ -450,8 +463,8 @@ "from utils import *\n", "import transformer_engine.pytorch as te\n", "\n", - "hyperparams.model_name = \"../gemma-weights\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", - "hyperparams.fuse_qkv_params = True\n", + "hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", + "hyperparams.fuse_qkv_params = True # This is needed by the last improvement.\n", "\n", "model = init_te_gemma_model(hyperparams).cuda()\n", "\n", @@ -469,7 +482,7 @@ "# Some parameters are in pointing to the same tensors, we do not want to double save them.\n", "dict_to_save = {k: v for k, v in model.state_dict().items() \\\n", " if (\"_context_phase\" not in k and \"_generation_phase\" not in k)}\n", - "torch.save(dict_to_save, '/root/model_calibrated_weights.pth') " + "torch.save(dict_to_save, '') " ] }, { @@ -492,22 +505,21 @@ "name": "stdout", "output_type": "stream", "text": [ - "self.config.fp8 = True\n", - "Another string ... \n", - "====================================================================================================\n", - "I love a good list.\n", - "\n", - "I love a good list of things to do, a good list of things to buy, a good list of things to read, a good list of things to watch.\n", + "============================== Generation example 1 ==============================\n", + "Tell me something about GPUs:\n", "\n", - "I love a good list of things to do in a city.\n", + "* What is a GPU?\n", + "* What is a GPU used for?\n", + "* What is a GPU used for in machine learning?\n", + "* What is a GPU used for in deep learning?\n", + "* What is a GPU used for in computer vision\n", + "============================== Generation example 2 ==============================\n", + "Tell me something about NVIDIA:\n", "\n", - "I love a good list of things to do in a city that I’ve never been to before.\n", - "\n", - "I love a good list of things to do in a city that I’ve never been to before that I\n", - "====================================================================================================\n", - "Benchmarking for batch_size=64 and total tokens = 1024\n", - "self.config.fp8 = True\n", - "Benchmark with context_length=128 and max_new_tokens=896 took 19161.548828125 ms.\n", + "NVIDIA Corporation is an American multinational technology company headquartered in Santa Clara, California, that designs graphics processing units (GPUs) for the gaming and professional markets, as well as system on a chip units (SoCs) for the mobile computing and automotive market\n", + "============================== Benchmarking ==============================\n", + "Benchmarking for batch_size = 64 and max total tokens = 1024\n", + "Time: 19.32 s.\n", "Peak GPU memory usage: 63.82 GB\n" ] } @@ -519,12 +531,13 @@ "\n", "from utils import *\n", "\n", - "hyperparams.model_name = \"../gemma-weights\"\n", + "hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/gemma/weights\"\n", "hyperparams.qkv_format = \"thd\"\n", + "hyperparams.fuse_qkv_params = True # This is needed by the last improvement.\n", "\n", - "hyperparams.fp8 = True\n", + "hyperparams.fp8 = True \n", "# We load calibrated fp8 weights directly from the file.\n", - "hyperparams.fp8_model_weights_filename = \"/root/model_calibrated_weights.pth\"\n", + "hyperparams.fp8_model_weights_filename = \"\"\n", "\n", "hyperparams.generation_cuda_graphs = True\n", "hyperparams.cuda_graphs_static_batch_size = 64\n", @@ -541,7 +554,11 @@ "id": "8cdbb56c", "metadata": {}, "source": [ - "We see that speedup is smaller than without fp8. It is because ... " + "We can observe that the outputs are coherent; however, the generation time has increased. Why is this the case? \n", + "\n", + "Running the model in FP8 does not imply that all weights are stored in FP8. By default, they are stored in higher precision and are cast to FP8, using saved scaling factors, before operations such as GEMMs.\n", + "\n", + "This approach is beneficial during training: we can perform one cast for both backward and forward passes, leading to speedups. However, performing a single cast for each forward pass introduces too much overhead to achieve a speedup. We will address this issue in the next section of the tutorial.\n" ] }, { @@ -557,8 +574,7 @@ "id": "2dd0cba9", "metadata": {}, "source": [ - "\n", - "As we have seen above, generation in FP8 precision results results in considerable speedup. Neverthless, memory usage is no different than without FP8. The reason of that is that TransformerEngine stores parameters in higher precision and only casts them to FP8. It is also true with the optimizer state. It is needed to maintain accucacy during training. However, we can get rid of high precision weights when doing inference. \n", + "TransformerEngine stores parameters in higher precision and only casts them to FP8. It is also true with the optimizer state. It is needed to maintain accucacy during training. However, we can get rid of high precision weights when doing inference. \n", "\n", "Transformer Engine supports maintaining only FP8 copy of weights with `fp8_model_init` decorator. Let's see an example\n", "```\n", @@ -579,22 +595,21 @@ "name": "stdout", "output_type": "stream", "text": [ - "self.config.fp8 = True\n", - "Another string ... \n", - "====================================================================================================\n", - "I love a good list.\n", + "============================== Generation example 1 ==============================\n", + "Tell me something about GPUs:\n", "\n", - "I love a good list of things to do, a good list of things to buy, a good list of things to read, a good list of things to watch.\n", + "* What is a GPU?\n", + "* What is a GPU used for?\n", + "* What is a GPU used for in machine learning?\n", + "* What is a GPU used for in deep learning?\n", + "* What is a GPU used for in computer vision\n", + "============================== Generation example 2 ==============================\n", + "Tell me something about NVIDIA:\n", "\n", - "I love a good list of things to do in a city.\n", - "\n", - "I love a good list of things to do in a city that I’ve never been to before.\n", - "\n", - "I love a good list of things to do in a city that I’ve never been to before that I\n", - "====================================================================================================\n", - "Benchmarking for batch_size=64 and total tokens = 1024\n", - "self.config.fp8 = True\n", - "Benchmark with context_length=128 and max_new_tokens=896 took 11993.3818359375 ms.\n", + "NVIDIA Corporation is an American multinational technology company headquartered in Santa Clara, California, that designs graphics processing units (GPUs) for the gaming and professional markets, as well as system on a chip units (SoCs) for the mobile computing and automotive market\n", + "============================== Benchmarking ==============================\n", + "Benchmarking for batch_size = 64 and max total tokens = 1024\n", + "Time: 12.18 s.\n", "Peak GPU memory usage: 56.60 GB\n" ] } @@ -607,7 +622,7 @@ "# Import necessary packages and methods\n", "from utils import *\n", "\n", - "hyperparams.model_name = \"../gemma-weights\"\n", + "hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/gemma/weights\"\n", "hyperparams.fuse_qkv_params = True # Needed for fp8_model_init().\n", "hyperparams.qkv_format = \"thd\"\n", "\n", @@ -622,7 +637,7 @@ "model = init_te_gemma_model(hyperparams).cuda()\n", "\n", "print_sample_of_generated_texts(model)\n", - "benchmark_generation(model)" + "benchmark_generation(model, measure_memory=True)" ] }, { @@ -630,16 +645,16 @@ "id": "3e30ca5a", "metadata": {}, "source": [ - "We finally obtained the **??%** speedup.\n", + "We finally obtained the **6.74x** speedup.\n", "\n", "| Models | Time | Speedup | \n", "|-------------------------------------------------------------|---------------------------------------|--------------------------------------|\n", - "| HF (baseline) | 42,0 sec | 1 |\n", - "| THD attention with TE | 27,8 sec | 1.51 | \n", - "| THD attention + Cuda Graphs with TE | 16,7 sec | 2.51 |\n", - "| THD attention + FP8 with TE + fp8_model_init() | 12,0 sec | 3.50 | \n", + "| HF (baseline) | 82.04 | 1 |\n", + "| THD attention with TE | 28.19 | 2.91 | \n", + "| THD attention + Cuda Graphs with TE | 16.81 | 4.88 | \n", + "| THD attention + FP8 with TE + fp8_model_init() | 12.18 | 6.74 | \n", "\n", - "Total memory usage dropped by the **a%**! We can use it to increase batch size to obtain even larger speedup." + "Moreover the memory usage dropped from *63.82 GB* to the *56.60 GB*. We can potentially use that to increase batch size to obtain even larger speedup." ] }, { From 27deface079ddb5d79c33ed4ce22c5220c6bdb16 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 22 May 2024 16:45:54 -0700 Subject: [PATCH 143/244] fix Signed-off-by: Pawel Gadzinski --- docs/examples/te_gemma/te_gemma.py | 1 - transformer_engine/pytorch/attention.py | 172 +++++++++++++++--------- 2 files changed, 112 insertions(+), 61 deletions(-) diff --git a/docs/examples/te_gemma/te_gemma.py b/docs/examples/te_gemma/te_gemma.py index 937cd98780..6264a448fb 100644 --- a/docs/examples/te_gemma/te_gemma.py +++ b/docs/examples/te_gemma/te_gemma.py @@ -296,7 +296,6 @@ class TEGemmaForCausalLMCudaGraphs(TEGemmaForCausalLM): """ def __init__(self, config : GemmaConfig): super().__init__(config) - # Przekonwertuj siebie na bf16 chatgpt... # Preparation of the static buffers. self.config = config self.hidden_states_buffer = torch.empty( diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 1a380b88b2..11e6c91c29 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -121,7 +121,7 @@ def __init__(self, max_batch_size, max_sequence_length, qkv_format="bshd"): self.max_batch_size = max_batch_size self.key_value_memory_dict = {} self.qkv_format = qkv_format - + if qkv_format == "thd": self.seq_len = torch.empty((max_batch_size,), device="cuda", dtype=torch.int32) self.incoming_seq_len = torch.empty((max_batch_size,), device="cuda", dtype=torch.int32) @@ -153,8 +153,8 @@ def swap_key_value_dict(self, batch_indices): new_inference_key_memory, new_inference_value_memory, ) - - + + def thd_setup_before_new_input(self, new_input, reset=False, pad_token_id=None): """ After every context/generation phase, the parameters representing @@ -167,11 +167,11 @@ def thd_setup_before_new_input(self, new_input, reset=False, pad_token_id=None): new_input: torch.Tensor Tensor with token_ids (not embeddings!) on which we want to do next forward pass. reset: int - If reset=True, all previous sequence lengths will be set to 0. - It is supposed to be used after last generation phase to + If reset=True, all previous sequence lengths will be set to 0. + It is supposed to be used after last generation phase to allow inference_params to be reused. pad_token_id: int - Value of padding token - used to compute sequence_lengths. If pad_token_id=None, + Value of padding token - used to compute sequence_lengths. If pad_token_id=None, we assume that all new_input sequence lengths are equal to the corresponding dimension of new_input. """ @@ -179,14 +179,18 @@ def thd_setup_before_new_input(self, new_input, reset=False, pad_token_id=None): self.seq_len.copy_(self.seq_len + self.incoming_seq_len) if pad_token_id is not None: - self.incoming_seq_len.copy_(torch.sum(new_input.ne(pad_token_id), dim=-1, dtype=torch.int32).squeeze()) + self.incoming_seq_len.copy_( + torch.sum(new_input.ne(pad_token_id), dim=-1, dtype=torch.int32).squeeze() + ) else: - self.incoming_seq_len.copy_(torch.ones(new_input.shape[0], device="cuda") * new_input.shape[1]) + self.incoming_seq_len.copy_( + torch.ones(new_input.shape[0], device="cuda") * new_input.shape[1] + ) self.max_incoming_seq_len = new_input.shape[1] if reset: self.seq_len.copy_(torch.zeros_like(self.seq_len)) - + def save_new_key_and_value_layer(self, layer_number, key_layer, value_layer): """ Saves key_layer and value_layer in the cache. @@ -197,26 +201,27 @@ def save_new_key_and_value_layer(self, layer_number, key_layer, value_layer): batch_size = key_layer.shape[0] channels = inference_key_memory.shape[2] * inference_key_memory.shape[3] # h * d tex.attention_copy( - inference_key_memory, - self.seq_len, + inference_key_memory, + self.seq_len, self.incoming_seq_len, - key_layer, + key_layer, self.max_incoming_seq_len, - self.max_sequence_length, + self.max_sequence_length, batch_size, channels) - + tex.attention_copy( - inference_value_memory, - self.seq_len, + inference_value_memory, + self.seq_len, self.incoming_seq_len, - value_layer, + value_layer, self.max_incoming_seq_len, - self.max_sequence_length, + self.max_sequence_length, batch_size, channels) else: - assert self.qkv_format in ["bshd", "sbhd"], "Attention format not supported by the inference." + assert self.qkv_format in ["bshd", "sbhd"], \ + "Attention format not supported by the inference." batch_start = self.batch_size_offset batch_end = batch_start + key_layer.size(1) assert batch_end <= inference_key_memory.size(1) @@ -232,8 +237,8 @@ def save_new_key_and_value_layer(self, layer_number, key_layer, value_layer): sequence_start:sequence_end, batch_start:batch_end, ...] = value_layer key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...] value_layer = inference_value_memory[:sequence_end, batch_start:batch_end, ...] - return key_layer, value_layer - + return key_layer, value_layer + @torch.no_grad() @@ -1522,11 +1527,12 @@ def apply_rotary_pos_emb( Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and dtype torch.int32. Only valid when `tensor_format` is 'thd'. begins: torch.Tensor, default = None. - We may not want begin all the sequences from the 0 embedding. This tensor argument allows that. + We may not want begin all the sequences from the 0 embedding. + This tensor argument allows that. """ assert not (begins is not None and not fused), \ """begins != None and fused=False is not supported""" - + if fused: assert ( tensor_format != "thd" or cu_seqlens is not None @@ -2441,11 +2447,13 @@ def backward(ctx, d_out): # if no_bias or alibi, return dqkv if ctx.attn_bias_type in ["no_bias", "alibi"]: - return (None, None, None, None, None, None,None, None, None, None, dqkv, None, None, None, + return (None, None, None, None, None, None,None, None, + None, None, dqkv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None) # else, return (dqkv, dbias) - return (None, None, None, None, None, None, None, None, None, None, dqkv, None, rest[0], None, + return (None, None, None, None, None, None, None, None, + None, None, dqkv, None, rest[0], None, None, None, None, None, None, None, None, None, None, None, None, None) @@ -2666,11 +2674,15 @@ def backward(ctx, d_out): # if no_bias or alibi, return dqkv if ctx.attn_bias_type in ["no_bias", "alibi"]: - return (None, None, None, None, None, None, None, None, None, None, None, None, dq, dkv, None, None, None, + return (None, None, None, None, None, None, + None, None, None, None, None, None, + dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None) # else, return (dqkv, dbias) - return (None, None, None, None, None, None, None, None, None, None, None, None, dq, dkv, None, rest[0], None, + return (None, None, None, None, None, None, + None, None, None, None, None, None, + dq, dkv, None, rest[0], None, None, None, None, None, None, None, None, None, None, None, None, None) @@ -2683,7 +2695,7 @@ def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seql q, k, v, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, rng_gen, fused_attention_backend, use_FAv2_bwd, fp8, fp8_meta): - + if fp8: if _NVTE_DEBUG: print('[DotProductAttention]: using FP8 forward') @@ -3463,11 +3475,12 @@ def __init__( self.unfused_attention = UnfusedDotProductAttention( norm_factor, **attn_kwargs, layer_number=layer_number) - + self._allocator = StaticBufferAllocator() def alloc(self, size, dtype, device): + """ Allocation of buffer, compatible with CUDA Graphs.""" return self._allocator(size, dtype, device) @@ -3711,7 +3724,7 @@ def forward( if qkv_format is None: qkv_format = self.qkv_format - + if inference_params is not None: assert self.layer_number is not None, "Layer number must be set!" @@ -3723,15 +3736,19 @@ def forward( ) = inference_params.key_value_memory_dict[self.layer_number] if qkv_format in ["bshd", "sbhd"]: - key_layer, value_layer = inference_params.save_new_key_and_value_layer(self.layer_number, key_layer, value_layer) + key_layer, value_layer = inference_params.save_new_key_and_value_layer( + self.layer_number, key_layer, value_layer + ) elif qkv_format == "thd": - inference_params.save_new_key_and_value_layer(self.layer_number, key_layer, value_layer) + inference_params.save_new_key_and_value_layer( + self.layer_number, key_layer, value_layer + ) """ We compute parameters needed by the THD attention with offsets. """ - batch_size = query_layer.shape[0] + batch_size = query_layer.shape[0] max_seqlen_q = inference_params.max_incoming_seq_len max_seqlen_kv = inference_params.max_sequence_length cu_seqlens_q = self.alloc(batch_size + 1, dtype=torch.int32, device="cuda") @@ -3742,17 +3759,35 @@ def forward( seq_offsets_o = self.alloc(batch_size + 1, dtype=torch.int32, device="cuda") cu_seqlens_q[1:].copy_(torch.cumsum(inference_params.incoming_seq_len, dim=0)) - cu_seqlens_kv[1:].copy_(torch.cumsum(inference_params.seq_len + inference_params.incoming_seq_len, dim=0)) + cu_seqlens_kv[1:].copy_( + torch.cumsum(inference_params.seq_len + inference_params.incoming_seq_len, + dim=0) + ) - seq_offsets_q.copy_(torch.arange(0, batch_size + 1, dtype=torch.int32, device="cuda") * self.channels * max_seqlen_q) + seq_offsets_q.copy_( + torch.arange(0, batch_size + 1, dtype=torch.int32, device="cuda") + * self.channels * max_seqlen_q + ) seq_offsets_o.copy_(seq_offsets_q) - seq_offsets_k.copy_(torch.arange(0, batch_size + 1, dtype=torch.int32, device="cuda") * self.channels * max_seqlen_kv) + seq_offsets_k.copy_( + torch.arange(0, batch_size + 1, dtype=torch.int32, device="cuda") + * self.channels * max_seqlen_kv + ) seq_offsets_v.copy_(seq_offsets_k) # qkv layers are reshaped to the format [t, h, d] - query_layer = query_layer.view(-1, query_layer.shape[2], query_layer.shape[3]).to(torch.bfloat16) - key_layer = inference_key_memory.view(-1, inference_key_memory.shape[2], inference_key_memory.shape[3]).to(torch.bfloat16) - value_layer = inference_value_memory.view(-1, inference_value_memory.shape[2], inference_value_memory.shape[3]).to(torch.bfloat16) + query_layer = query_layer.view( + -1, + query_layer.shape[2], + query_layer.shape[3]).to(torch.bfloat16) + key_layer = inference_key_memory.view( + -1, + inference_key_memory.shape[2], + inference_key_memory.shape[3]).to(torch.bfloat16) + value_layer = inference_value_memory.view( + -1, + inference_value_memory.shape[2], + inference_value_memory.shape[3]).to(torch.bfloat16) if qkv_format == "bshd": @@ -3760,7 +3795,7 @@ def forward( value_layer = value_layer.transpose(0, 1) key_layer = key_layer.contiguous() value_layer = value_layer.contiguous() - + assert (key_layer.shape[-2] == self.num_gqa_groups_per_partition and value_layer.shape[-2] == self.num_gqa_groups_per_partition ), f"Keys and values must have num_gqa_group = {self.num_gqa_groups} heads!" @@ -3877,7 +3912,7 @@ def forward( use_fused_attention = False if (not _flash_attn_2_3_plus) or context_parallel: use_flash_attention = False - + # Filter: Attention mask type. @@ -3998,7 +4033,7 @@ def forward( and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]): if self.device_compute_capability == (9, 0): use_flash_attention = False - + if self.qkv_format == "thd": use_flash_attention = False use_fused_attention = True @@ -4079,7 +4114,7 @@ def forward( if q_size > 1: out = out.view((batch_size, -1, out.shape[2])).contiguous() - + return out assert (not context_parallel), \ @@ -4313,7 +4348,8 @@ def __init__( self.num_attention_heads = num_attention_heads self.return_bias = return_bias - self.attention_hidden_size = attention_hidden_size if attention_hidden_size else (hidden_size // num_attention_heads) + self.attention_hidden_size = attention_hidden_size if attention_hidden_size \ + else (hidden_size // num_attention_heads) if init_method is None: init_method = get_default_init_method() @@ -4483,6 +4519,9 @@ def _allocate_memory( ) def alloc(self, size, dtype, device): + """ + Allocation of the buffer compatible with CUDA Graphs. + """ return self._allocator(size, dtype, device) def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None: @@ -4672,7 +4711,7 @@ def forward( ) num_queries_per_key_value = (self.num_attention_heads_per_partition // self.num_gqa_groups_per_partition) - + if self.qkv_weight_interleaved: # [sq, b, ng * (np/ng + 2) * hn] --> [sq, b, ng, (np/ng + 2), hn] new_tensor_shape = mixed_x_layer.size()[:-1] + ( @@ -4793,21 +4832,29 @@ def forward( rotary_pos_emb = ((rotary_pos_emb,) * 2) q_pos_emb, k_pos_emb = rotary_pos_emb - + if self.qkv_format == "thd" and inference_params is not None: # For thd attention incoming tokens can be on different positions, # so we need to copy different positional encoding freqency # for every sequence in a batch. # # For example if sequence lengths in context phase are: 2 and 5 (batch size=2), - # in first generation phase key_layer have shape [2, 1, d]. + # in first generation phase key_layer have shape [2, 1, d]. # key_layer[0, :] corresponds to the token with position 3 = 2 + 1, # and key_layer [1, :] corresponds to the token with position 6 = 5 + 1. key_layer = key_layer.contiguous() query_layer = query_layer.contiguous() - key_layer.copy_(apply_rotary_pos_emb(key_layer, k_pos_emb, "bshd", fused=True, begins=inference_params.seq_len)) - query_layer.copy_(apply_rotary_pos_emb(query_layer, q_pos_emb, "bshd", fused=True, begins=inference_params.seq_len)) + key_layer.copy_( + apply_rotary_pos_emb( + key_layer, k_pos_emb, "bshd", fused=True, begins=inference_params.seq_len + ) + ) + query_layer.copy_( + apply_rotary_pos_emb( + query_layer, q_pos_emb, "bshd", fused=True, begins=inference_params.seq_len + ) + ) else: # adjust key and value for inference if inference_params is not None: @@ -4818,12 +4865,16 @@ def forward( sequence_start = inference_params.sequence_len_offset sequence_end = sequence_start + sequence_length - + q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...] k_pos_emb = k_pos_emb[sequence_start:sequence_end, ...] - query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb, self.qkv_format, fused=True) - key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb, self.qkv_format, fused=True) + query_layer = apply_rotary_pos_emb( + query_layer, q_pos_emb, self.qkv_format, fused=True + ) + key_layer = apply_rotary_pos_emb( + key_layer, k_pos_emb, self.qkv_format, fused=True + ) query_layer = query_layer.contiguous() key_layer = key_layer.contiguous() @@ -4874,15 +4925,16 @@ def forward( class StaticBufferAllocator(torch.nn.Module): """ - This class is used when we use te.make_graphed_callable(). - CUDA Graphs require all tensors to be static. Neverthless, + This class is used when we use te.make_graphed_callable(). + CUDA Graphs require all tensors to be static. Neverthless, torch API make_graphed_callable() takes care of output of torch modules, and makes them static. Thus by wrapping allocation of memory into torch.nn.Module, we can greatly simplify our code. """ - def __init__(self): - super().__init__() - - def forward(self, size, dtype, device): - a = torch.zeros(size, dtype=dtype, device=device) - return a \ No newline at end of file + + @staticmethod + def forward(size, dtype, device): + """ + Allocate the buffers. + """ + return torch.zeros(size, dtype=dtype, device=device) From 600ff90f6143535ca6a6d5f9033b492d229d79c7 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 22 May 2024 17:12:01 -0700 Subject: [PATCH 144/244] Images Signed-off-by: Pawel Gadzinski --- docs/examples/te_gemma/media/graphs_2.png | Bin 40595 -> 15177 bytes docs/examples/te_gemma/media/speedups.png | Bin 0 -> 40595 bytes 2 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 docs/examples/te_gemma/media/speedups.png diff --git a/docs/examples/te_gemma/media/graphs_2.png b/docs/examples/te_gemma/media/graphs_2.png index 6f6e0b16732a4e07afbe8eb553f0a24faeb07f14..35c34ede5559bd0c26ce807789ee6d3fdb2bb062 100644 GIT binary patch literal 15177 zcmdVB2UL?=*YE4rt?aEJd#gwjuz*MtDG})^A|kz)P(%nV6d?qpTTw(nT7VEhDWQY} zArLy*%BDjQ0s#^NBApl_B}78I5AXYpao_j5_uO&6Gsbt$IE;~zXC+zBTx+hi=9=sG z{E0U=HQ+lga{SPtLwtt!?pPi=bQpT*&@cMG9RohWA1>Dg-hPGMHnjc?_>20@<26tg z4$-#{u?qAI33CgE9P$bX^oJ)Sl{Mul?bbIgK>0f0s`S0BNv)xGlW5I7{Z}uFyQ*ij1Mfn+} zUwF@7`2BaOUx)C9q4@D!CM0wowV#OBq{cg_dZ%D`=POaz^}|n}ME}720`%@r zUTlrO;qh}8Ig3>fgw`gzGwz7dpZ%wwYY@HqO{S!ueS*z!KJO6U{`=SSs$b$s+K+gE4hs$oIQ)tmonrC;3u31YR{%}wp1Jh-RS#d%X>Ik4j^k)XBM zNBRfq!!^XaSdb3;Pk%*!8ru zqORtlGp4m#QhQL(P+so)PL2mVP3Hs1_*=xiA*2|S>KeE-+#nM#-t$E=e9sVXCh0>m z)Fh^!;U5$xgZ=DLY?+$ zj0&3n*x3|nV9seBD&4-?&04R+v$Wn*Lcuy_qt$bEXRc@}kaT(7LMB`*yFrei<}jN= z$)$3s!F4RODx_Txq>W(S_+!Z3Rc3hhbHD)qH)}mZWGw6@Wzw4JwN*NN?XEWUHX&x+ z^0~{@&}AB8-t-4Uv&2L0Qv2)b)KBrX(YePOU2Q0@;GCYb^xgKTZq=FqFY~bGvG?4w zMrAsBt`5jDDHW?%+LQL%Ht@pYfyq`^cUaSL<*F~3DVKL@KQcPIT))**%7`=6C+wc= zEjq+{UTvNG{;s=ASZ3G31jy_9Y8yRPK*T~vb~7`sD7Fpvm?&^N+~M49P_c3&GmL;0 z&*bSyHxBH>!3RU~&zi7gsi&bLHsXts6xLQ^KT_$ke%c$d2W?1A0yC&?09jUBR~bH6 z!ycKl5y_dtSgP#7?CDqP#b5lQrU#!U8=N8 zj(3{09l=TLX}|Ao_tC}))`@65PjcmRpM@f`OB+2SLrFKuY&D0OBEUz!ceQr(st!oX~d;IhrA-x80PfcNwT9}6Du@u9*% z5QCIpgyLJ8Z%D*8h6E*fa5sYV%@c1eeJ>1_GF&PfN-Pm9Lw9GZ+UeMT0PJA5Y3#)j zclH-3Ji;G_|BMKs{+WGHk+R}FqL`1hFibKSNpv$&SYnRQ1BNTnp6*HC?Jjl~EdI zMqb63>Z?ULw&%=k^z7A_&;c44^V27I)CUW+P795QAjP9=6iemA%PF`cOpX09zv!i( zW5gG$3u=szy@Cmy(@HL&jR@ZEYa!bSd=a}dXSZ77R_fo|OQ}!Z8u?K)F06-pWtDbT zJCto7-b`|8jq-VF>+g#&&R=nMz|_^%l>^2_%@ob`FvUEsWoZx^(rI=yIa8o#L208* zf;pV(utc>JP4yAsc2{2%@QUP)x3)fH+PN7;Kwvqo(T*t|A`z>*XBG0YKV;9Kjx4^o zogedK>*7%b>B_>`dc6d6v5cRgu1?84-RMR~8s=GQQDy99*G>~7a*G_aHpni=Z@j#m zpYi>u1JJqeS=$-LRy>gp#3S~qIPr5B#;a_iW-|17|J1p@=f;(&?5RR6U^e}pJsQy6 zV27U%N#9=njI>V|X^=PclE|e6wY)<kEzMl)E(^y6t3+{oN2Iu;ktPv~x@^ZA3f!duJ2lYK zyVNro^X%xrPur2tkUG~*luejeylT~#Dfeb_M<~s_?Mrj_5|=5}Kt1hk_hz8bpi;(a ztQtPgYzpakcCP!{h)6X%VZF191qz|(%We#WfmN>$V2&iZ?aqHwlI>etESO>qcC3Yp zDJj_~^E&T)XGDp6T}B1pL0@%EE!o$^*b^V)9KHH_Ee&yI(qag#^kjLal%82si!mtC z&3R)xSOJ+=mal#Ki`U74p#?z?Muv$=BfH_r%HB}cWDbqkHM>mR8^nGnf49x*4xd#kGakAhC@|k<&8wEHz5o%&&Ecar*o@h zpgHKAskR7M;b@zsVO;|%cRHwp9+aa>EF-$wT_#@(p}}}U_626rGklk6$+o_8udo_h zAB}5kXpL*+qq1^8mZtZImvNa;__}&~Rh{A+`|Xj<^7n75Th_L7+dfZf7(rgjXcK1Y zg`Ky3$5O9qUT}r=TC%yF{f@!j=m2dkE7~Rx8*3n&%Md{at2A@F8`^!nYd2m_F)6Ef z)ZE}E=J%sIARhnRU0O@ZcEfbdZ7~%P_sHKj*E3HN%y?B|1tmt0rS1QDnrHRYgF0X; z1!Yb3G3p1PYIEnEls7CSk{#AA82F?1m*S z+c}g6@M3Cc_}uElB`UevI{FBy5*VV!8)=o*t*>LF7WdF|vt$I`I6vE3sSPF!Qv+n= zM&3FRG>_5DNFkV+%)2sW&2l;cp8ov(;*Aak3hA`Kn*CKd91jllwkX;bL zPDW?p);1ZhvJnuD>6nN3-mlVKlMlGvYBrKM-DMxp^!|=Y#|8>zU}8w#7&^1N=lY9w zX#Zowq2**`Ej-qo5t?p*vvZIbLi#H-vYS)yAz+tz_jgYXEsj3=l&-p^6?a@&`~Ir1 zulMhA(k%(pE2KE$Of`2;d4%8Y`=}vI8Q!U1-1z;)M3$&X zbz|t;r;hv!utkm37P+=!(&&9`jU-_1`i=Pd86jlR?r8j1)}gKVA)iCf2~6bSnQ0kEUH4 zo_U~rZWZTN(+*h6bD4S%OInGvm34sLupak8TSlXMpeUYf9-I(nTTuejIN!>Xig;hl z0-}`~%C?lIC_nN|cu-ec0E@_d=v zzED=vnr^eBTrVPKz1?zmb|N)v=J?CPfsDNlEz4t#W^eE8RY~sDU%$VlLe%U)J=yAS zPBL2o{O5}&d|&k2oz8f&`726H@yi)os~mwPktR$zKQOD)Ax4<%Aog$H3y*uXQT$+& zUe~hj&YEGM@J|I}tmh(WjKq*WE2ACR?#<;-c860wr;pb3zP&yf_tA0}9VM0o;!xIEcN=DTd$9$Gj zR35iTf%%F*j1-Wo3-NFoe#Dpg_}Z8={i;fjriCb2p|+ORb5CjHi?2^Vfg;X>A|bV! zS}&~ehP>y}Dm&&PZAQMZ->%HIDYNGhp1Wbt1j!N~Y$>09bVF0QPElWy7im9+Sl0Qt z&iWXHX6hN}52c`1^|pInlyl$t(R>PWs3?(9-fs@;H#zZl<`ff2I!a?ZuQWQdr{c4L zKs1;;ntLi2R+rptpTZJ!xG`QSmpQwLe6=YfM400UHF{eh&}YMILJgY8<)Ka->cH&M zcCwu;kS$IAZ9Vc*gW^t$h=k(v9Kn@@P_SCA&H1)aguzNLA)*3frSE)CG?cJt;3c@R z1ul7MWyz)>F6YJ+tv9atF9*88gDq+s_ z`4)epSO;f}Yo06WhsfO(q5&WF*65?@3uQc&`fq{5&;Hf^LAoMO$flSk1R1wBWU~9% z!ELnlOHfXMv7W1aRxyM)UR)qwiAFpg89JL(!7{5ny76i9I;13e= zkYW*E{}YCU;zsnls|_ViAQLy|h^VBF1!Q?e?ZXwjet!h1Ioiv^Now4M9_CQi+8UJ~ z6JnYIBx$J-qgR>5%fPm&9_YPt_3qv-^RyzrT5O&8=&|fr$b0_>Fs=KfK9)ruFnwP1 zLF;Dyem(pPIMq_O*{m#wl_0WV?zHtv*11*{ye^$c666mk^BeUa7+uaYdUMw!8X8u! zG?~F}WDTV_xN6ms>Mc|6*2J6wohu(Ke}EM{yo;|!t8zT*}k2m#BjRQ_&mI1dUk@@=lJ+Vtixn##XAqc?%$>g zsV7Aj2H$|-8j*XyMO~^ENL86n={s-wo#gfOl6I)?4aq`Ve#2@5_%a zt1F5#QR@`@EB?6 zWNy}I8yN&PG5a=fk2m##&)%<6^m+rQ1x~A^+sXVck;_5ak}_@5{qnX$NxI|fX2|G^ zhA>m=R&G&v*h3e7vm%ERbA73wS47SoyPU4lyT#9j9bHbbpOui9J?ADT z+kg7Kq>H%TeJ6?M`5Dw>^lkUEn)@X&zZ?7~JTQr$|2Ri)vYbRF;xv!Gb}`Psy#*5x zX|g8$?itUO`r}_nQ?g_ikOd#m7`0VfTieSKyujt_rz-#8JM-+v5qkmdH?)+A=E*;o zD=x8NJpb0LzVMJ%q&QIH(zAcXRB!(~3VmUd;RYC6?EQ@!2YBUWc3N86q7ra<{usWx z+A=kSkItA;S&ECCDVa3fLQ7MeJ z(xuOTW$C_N}9^dXhzjeY?%c{SVE@e18 zaP;4-s4v{Z&(k^OJ5Bj%#;zATpWj+z#h8AIuDrWvwtFh!U!htKY{i$deBDr-!^?YQ z0ZW_RdBtrw{4ayUUpGI&RxT9?wGmYRBZG~wt$>c$A&R6UuEaxL=EkN~#0*K-#2GsA z8#>|(owpeZxf*IxgL=Y9$pKxT&3MUvb)SW>)OmV*xV9FMjk#?KP0T=%>@jfvaRs}?$4i1=5j{sYL-S7L3gg!z48YUs1ZxtZ{nkPzon+w z%0Z6V0K&LZcKI`R9M7f(otJETXR4`XlBq~KxA}N<<+tTYI}!Ml)K>If=w>btkiGUU zEIAT(-W-OMv~Rt-@HT0i0cyC5iR0jW(=Qxr-p52)D~r*WWu0d}`8 zpN(MCPOKekT@%_eIyb7p4s#Oy=%TJdYfqSM z{pwS5kSxcBGX1H?6{fL7wDsM(T@qnrX6>)6)TPAn+N;PFa>0_be&j=w!k#1116*xl zm4ix0k+)?K(htm?cL!-Klp61P#>4@~&DO@|`|j&Bc+UjxYRtVE5H?B4K)BLvG3mq! zP@(E4vHHQPnK03rQp&ZSU3g~Ng?S=N+3&(9R1byM#Mkd@nVQv0lJz9qlZPe}k_DFQ)_2q(@;B>^Y9o*f1DyG1c z*i?E?rSoS^)8(RGiHe~bB_EN@QTU?J;BF7EoOXMmbfuVPmq6(0OR-5%`YYN%kpKIL zxxSqhPt}8T`B0nKRybV-AWqu&%?-_8oOjBc<{;`N@tLGenJ*7ZrM?dWgkyK3%(wU_ zuR?Y>zM4~_X<^ag3>?l%A#zxs#n-hSE zVc%|9Mk?%ntvIZpC8@A{NC6zNf%T+^2gTzr(6#ESY%;HsQjGI;aKDu+UCE8M#!1R( z+cd6jTpUqVPn+-IkxO4zse#8fiMK>lh(};eEBym($5Kr-FIoW=2dQFM?i?N@+r3)a zs8mrV6%);Rd$^ky(`v(|xE`U_sFh~B?{M40+m#cOo+Re|Quop?dYYu&IET8}FPvYX zIElAAxJ~>k&&8w#=tc+PYu!>?5+nLw)c06-QuIV(diyfFamt+8e2E`GE~+h(-+j%H7=}E71TMt6jV~MtF!&-1j|dAQUGv+Y2pMeVP7am?2@bk8V!=W!5cq5N(&c89kD1<^geLu?>xP z_9H3{`FBfPPL$~D?rtrkFN;~G77R7FvJ@W2nwgV3(T3=<7C7SaS#IRpFzQ%Ew4)v&CTLFx$SiT&RxkIq0unmr!{P6h zDyE#Y?!I&o2l$FM)qWCp=`#9$+y{mPHg#a8?}b%T*YKRt+D5FG07G!VhVeGsLAIN{ zRMb|@{u|7$_E?Os%q&AIRV2i?;`GTXY`oWDZr`*6cjuX3OLduHosQrax~m8|61FXv z?c_|;Z3L;>O!+caj5^C(+lPdS)Zzre1q1p8Uy(LPe2a-%o<+r^-D^8Ov;0gcLyb_! zxYnTiDFM;kCwY@!Dta?#`O4$k$_(5w#SiRdm#$~xRkTR;ngMC^YHSXp$3rA8Aado& z3h^PT+R`L?d+wNM{FQ19VDmY;NlR2_;fp1gtFx zb}>Yl{TSG*tm0a6XXrV#LZ~YrcnBgr3k`56ciFFZFgn>kjD6+hOWS+j%Dz_&b8fQ^ zh_lmGue0aFB3gem(KTWacr9uS=)+$n^Y5|v*e11L-Gei%JNjaV-HqZKY zaDdl+`0c#Q;qM5QbPY`HKXGO;?~+d^n?;KRDnVP&x8(A!VV z8kR^;1PA)QwCv+7c4dtKIGYg~NwBj08iK}mmULCmhLBRvn!t{$G<20LcbNC}?bS4z zj-nQ4o0YJa@|g&epkIZP*UKRjB{=^ zb4mWtVQW+?HX}>glx<(o@{WxNjhj>m92sjiHj&vh$ZI$0;*>QiofGc$7fko=$jVDw z2e1Xw1z9k)FZaF^C#-1bxNCtYtnmBNlh(S}vJ5>(JXXIh&+b%rFzLlK(e<{-uE^D_ z)XVvu=C{xWGCF#E3T|^n9MnYW##jzq1LlH&g<^6ZQ$AuGV0pL9WY>Hap5heM&w2B` zY8!9GNu1bP+l>_PJ}&1m9d1X0v-!mabVAkZ!0L<}Tl0<6}uYrUu8M5OswtB5xr1V&94)doJ?7}HUj*r7)_ zi5l~|6+bLHxclyn)hL0(F zrM}cUGxwb+Oi#~tG)d`?11nfMWf~HlBS?O9T>QAEoV?r*AFz!%f_Wwk;2j}s!!2K5 zJa;0Lxp`K*1$V2?*}VLKmelVomK1J_vThOz)xEn_QEE=10Hcf24rV|8$BToF2N-4d z0_b98plDpgIperXlc2YRdcYc)AsRv&j2^fz1?77stJ!Z|K42Z5xPE^9F9CP>{+zqb z=mh8H(5adYIuNbm$j9N^FN+K%(r*(IGj0HXlD zdZq)q1zQ0UhUQf{O3TO1$Q(c)xzx68Zbe{h3mG!18V4D-w8|--e=rqK&lUCkxxqaqrm3JE ze5|t01M@+(#xcY(G=vZkNpQr9@gHl#&>(lirF{YW0MihbL5nylV80CLvZiRx`+3T2S#tos`m@;=yt0 zV>$cg3pds^m4Kk6-MHRpRRCMrOc9h@+r4-~dK*5Mdd_+Mi_XzD!j8C}phC~k@Wo`{ ziBXv2ND{+kJ6kJyfAv0)IrZmcl{cDqM)-jP1N|vhP{s%)fPuyz_ZTCsbw!WcCd1+7 zr6+vaNN0B!zdzH7Wz;;nUp`44Dk3L`xQwx4KIpxP=FTsyU2L**tW)qYDqs;z-7RAYg!pWfa@Bu0Om`zMq>zh22(!{?V|(4D=d z-2+UY+;`zj1S;|7`3FXp%da>x@bzem=i+bV_RWhi2v6u<(csp%PYM~#M&_0y9cDQA zx+0L7be-N*?ZR=^P-Y~QP=xL7551K4merNgQN*`;JLf}3p zza0_Stq+5=NG6VPk6p}OmspB^sJbqZ2t4+vQmSl{q&2UoqEJn_4kwa-DE5L=1yU;p zx1Noa=?}ThXgOA)O}z#ro(A3_+z#&0N)`S)p1vjjsHoGF@0G{!@&&`q(ts)Tb;}Uu zrC7oZ6K7!6!$45?vvN14&uoS1WOuDTc=zn|fVFqu z!-f4}#p(Zq3c4aG=LpS_c0FMHAiU(zV8W8sl~~+u_skykPx2PTUkk+J_4ih9kyiC5 zS5tmH$cO=U@%8d#PfyQ(2vZ*Y6Sdr3pqTf%fo5v(zpo%n@TdI; z6H{rW?&$wxfHJ>9%z*|8{*T9L1_Ow>06|(;)2`%yvC7WZt(;<_P|DF3#ZE^Z%R0@m z;JO{U!pyzlDY^!3e{`SJEKVn;1wPz2-Mh5TdG4$@I0M_BdexMdA|#>s{P6#`TU_aY z-@ZPa?9Wk?{x_J}5%~`Z-T$f){y)G@bB4p$6j(ynMn*qKlD&t z2}o7@I(kHK-$6L8hUCspk_1B0xY2=w1U`Ay-6iwXBESrgygmP>@KZhfoJV zZ`RkR$dMny(rpkyUlq)Xe)~0MNgNRd8GiwQot)=PdXL28X{#_fvn!xLkld@GW0tbUC*g#TgZ>d$3#jC&1 za}I=vfE-Z>i`rI&inG2TYd2t5hdQkG4PZpBfMM#uf>mT>HJeMGiW$6X61}PABkDDL*KqG z2lH#ZoPXVIjH4tpAirH_8u{&SH>|tl@*U?0k^9bf3r>ED3&nD&IT~C1ot(%PB)>2F zlSExi+!%%W>5pzxo%ck$XU=R% zOx&9@q3m7)Iu$^=6z18_NTFke*SB}Ik$`6^A6ye^*y)k0yhCbD95$+?3%lsWjqf+g z?SYG(_o;<5F&buj=ZLbC6+=|lR?KOZj_q=c3Rqi_e|FDA4_(LF6-BEM2FK=OI^B77 z7gB^&;7g#q3&Rk}4(|xQrSlxRt>|zgZHoLkC&`?YE)QZt2bWnmo0Vbt;>;qrbI=v| zkC3;8VPeOw>7If}cHY_JQL}VE&)7?WMe%a>_^B0GSE5S{(G~ACuoB+VFM&~4G$?Ct zCp~jfjO0eOif2ZuK?(kt+FH~O%icZP z?)0-!jBCYbi@L7yF~hNztbMnHXM9YJHO8y*#n(~7;+K?hD-RP_YXh-^C`B{7$-Z+V z{?IR6&B@W!DY3#4HJr_6k?}~BIflT+u?>3VVNQ)6Lfo&(;F?(fTI2k(m%2Bce8gzl z=DF_ZtE-rVeZesA(Px)zBEF^Sz(3wbT;GeqpY*Fp=?wIxv|HIUe!Mb2hLm%x+00~> zVRcHIm(-Pcy&|9UF{zJ6FN1)kEGZVTq?#drWw=#s6SI+Zajw5BMG}$$S?q%HXZRa| zBq)5PCozFem~$)SZSjRpy#e?$H)x?Ig<~A+!OD?H~WmY{QdDp9-Y>EOu92!q3#Q z`ej=p2Rs$HgIXn#tolSB5dZ;K7qP>l-n+*>E($JzFj}_NFyXNnwUfgIQtITL; z`CD{ki^v)*@;X$^ePVMqz~PpWwCK&zCWMDdfkHw<4jSWuornd!-WzBbFoRidAjKHO~+0v)}(}diu-Ar)x#u z!jB?1##fI=S(kuxr%KZ#Zn&V`N?T-#%y#{rQblh^tX7Yz?(XNv7t$U#B>a{=Uw2Bt0jcS!si%C(BlbRwI%KDR z0zs8VX(l6aB4TWMo0OsZ*|RdlLeXS&kQv@Gc&8^9pRx8q*FC^FIEt<&gJmtT;3$fLj%D&-nT6x%B=OgP=LR{lU;2p+Lq%H%!oa<|% z-eRm=RP)(7_wi_zvQ|}2$`RI4RhB3z4Y%Xq%o>-|&~nW|1P+Q~WF={~;9M0?S~Vo> z%)zmr7m5M8MqNQ03iP_!>3zdY2 zH;Br2ozD;(Hzo%nD#+i1&{w2JSxze@XYZ!-{}H`07CMdy`QC5f`-u;BsaK@~jrmY` za^`I2a?y6ORky{<>SxnB4552slHDgAsuo;yA#Xq(zeHusJfTb(=r)7}TkTa*2En)2 zC7xuuL>S81;lATh6$Y1~`?(ySeJI9#FBfItZQ8k`@C}k%OF;yz3kWg}-Ja>budwPwBqK7fx>WE%#3GA1@!bmNx7;mxCkRG&d1YNGEChVgVKKdzN&jA^{G%zv22*{r57?(@#$1ggdW;pA7HQltGir_2~++~WFR;J4 zqHC?=2g;+qDghFS4A?SC1r}_mzq<)|c2_9Uwesz`uN{h4NA}nw2isigx=sS3+*YQ? zDE>&994-oAA7T{;9$`+0gT8mERt{C#ew(WlQ&q|{wIFT z#W&FTW`EA2cib!V?i=qNAvI*$C+m4=D?3E)ug5K4k3F;A+8zF4KE!dqp9)N%W?t%e zI=gj-l&NS{xre^4sBL;WGnXFp{L?G4Ge;%R_6FB3qOZZXk+RP{620+v=$3Lbl9Q-j z{Tbou9Q@j4n*v@#N2I$TFUR7{l#AOcvy0UNT@+&xW|OMP#`!XGem{Rs_wdOErmjDSe99N=@5UbTk- z<#hCJl`^AMfdJ_xe9vkCz*AV@;QW9*Gf1RvojWX*1YhN79~s@^#OI5ZBy#*HM4XdC zTe@d#c%yOtpD;f>QOkr85bH|Ir!%A0%4b^n)iK6_Q(RnVYmI@v!P&0NX`4t`Qor1lcIV%X|^t42*4vhl&_JzRmk|{RJA^N63pzRW{v5%Bj zwv-(#(!kxQ=bkRt^pC>U8gHA%=N`n0smPVT9SsWBD(#&?A5oEpW@dKI&d&MG-4QRU zO<6SFwDl{q=ZYOf14t?PIn-)MLMUqzOYD!jXAr1)VGU-L+iMp6<4^w^)Ar144In94 z#si{FK0I1qypK%Wa`w*xnY@Z+`Wgn-K)=DdYvNMyYKSmy}@2%F0+%ZQ8Cei+dok)QG4B$Eode+{ibl z`Uz=~J4z4@wo6l}|LpFe?wZ{EvBoP6gkFN5!-rWBn0$9VUpFF_!?`Zqo24nGPZ3jO z1&)2Z48VYiy>3;{MNiJv?_Vo2xpa0{%>F?KqCM12p;C#riTS!V| zl*u_YLm7Upis*dD(l0@scMvI@F651X)S7|tqwogSmxkYr)}K zd4G1R-^Wwdt(|9i?J?_tv&P`gwYtd&qv=QX`QqQQYYktfu_qDi9%Pm}iuu=XWmDs& zZ3**63W?eXtCF8@hClenCeQuTy#CtNN8)yljwrJyX`HS%vlHK=x}z6(cDlkn;My@7 zQ$^-BHpNsb6)*{Obwx6lp5FkgbHuAOkt#qM{F8iqBm+;)fr`0>1t@2Q9_?~WY%#l* zvA>|nU2798Bl((BeopV3tHt5m!qS?5A>xR$nnU@OX+>Sq_)@E#jnXy=6#w^T;9uu_ zeSAKgEWF~lVq;^YY?3x%a4Wevzr`?h7-*7=YIrhATwQS=iLaOm0z~Jg$*BFGHinkA zwg_J+)IidUWhQab`~D_Rk6B>Gz(U%{YF;e>gbMprO3JZ*sCD!}62HLk=xDEMe^)l4 zyj<4H#6NnVXlr+`tK`n}(;b`j6O^EaD;RrzryZ}|<@S6Z+BU5QIz_tR|65#%U*8=u z3oX}C*WkkA$|o+daTrXRo6xX~DR=k-Em9K<2A8#6iO-1i_s911_QIFTN>aCfTTbJ| zEEwN0GD>K6QCyFDvC-Wyt}a8~Qc0cFQs&N1d_(n>mp*DW@|O(0{(U+sLmMC|cQeA) z=llPIe}3)z2fbtO9<>F&sz3JM4)gwBIN{hnKUs{Ia{7`{4bX-|hWe&=aJL^k`@aD7 CQjcT+ literal 40595 zcmeFZbySyY)GdmI1u7;8f(RChibzO`AczWxbfZXjON%HNgi_L=pmZqRC<;>2AT81$ zUD9X1xcB#+Z;U(cIRD)H-E+tBkIfc;@xITqp0(zhbFTNeD{Mc-MRaAN#{@8E3{}g|xvbZL1A!V#(VSU$3lSJFd*ie(% zT*FLL)5u)M*kXKBAs-3JQ4&!h0qF-JzdLM9_78-vO;%S97LbT;Gkq#})Ra#6;Rg?9 z`!6RxkWaRX|N1fc%ADhwlqh>}YtQSd?DSl}`y8isQUKbJ zuH^6E#cg=PYQ+ziP=;4)tFEagTl-46d@z!J#F5#S7{2mbdw6j0C1YQ2uMWO6^3{H_ zt72jv-$O1({{C$s7Rr@gx|G@Fn9HbJ?t7XcH0`dGf5MEyKB7p?X?#c)x7iHv{EM<44>Y(aU;!Ux?f;? zt!UHkL04*{Z5q#AjXpo#-apH4Gu108PF|8DxxLeqjz8)B`xDoMh0FWQVk081%(t6+ z-Me>BnBg*I#BToO+e>q!QU(SFH9to+=#~N)%d}EGEb2awwP%~Re$a6H9?pM0y!s^8 zC?n#Acm0sGxp`__e7wx<+YysJr6Zko=eG-S2P%#kkByJk(;hl`!>S4WL8@E zv6OM?yJQWnkUZ7o1!?sZzULB~&#UCz?>xgd_rdElGqc#<1;?0;;djVOdl=&5;%?A-oxFd1$o4eD=fupOx@MLm;;dSg+b(-1Wn{=27c9JI)b0NK zMDfGDYc-=KUUd8hL)7ug=Uy*Zo0(k{XuBZ)@qXFE@V8s^EocqdgFVbBR@i@Ce)sX? zokQo8UotYQ6(}Plxa|L=>^i~~6>OCA?6aZ!v~7O>LT=cBLx=Q7ni9>jl|QzP<}8in zTGoqkoMvK*6NwL~9`Gscsa;)JO4qKY%~n3QGMnrqV{4mpk(E_s=nzR{bWVvjn$(i2JwIa3#Ka_xEgl}$!@^Ycf5xHl z)v@t3i=jKBb@PR54x(~OUuf+o1g2T%M%$u^W822QWy=;W$K|Zyk&!;LY~PnJBLZ2q zhntj)M=gcFr)pOZe|yD{W77R7Cnv`j3yVb=c^_i3T~d3iiIPs3{~g=-uS zQ%z>)^H&!~`p516bi1Wn4As3%ytm6u@YSnVIV*E*e)w#T&If|{ji#n%KB+NQiuh(6 zEaU#-$B#Ex1+WNWx5~eL+m_SSUllB)s2Dd?A3d^D?a!m2prADCaaBKm|G4~o6;o5w zAU=m2GjsDQBQE60)?!fsneF8Cn4vHYhiam`nCbLsNn~JzGxEYE(_cbFw)*<~z;jQ} zu@7Acs@Zn)YEkd72fwy!YH7tqMG4v1*u=%gzCt>%ut+1~s4-qnAIXv4sZ_mUo9JTq zZl_b>)zNLng(}(E*{zwz;hwa-asB;Tef|A+ZkxQbfaVXbHebUZ1LG!~PgD=;fcWrr^ z>HPWT%0SlH1}TXl{A^9t_CnCH8Ma>H#qT;Y=x!`b^w6!zIYwfB1b9>G&iq~#Vr#g&RtNDOiF+ECr zMh1IhbMpx5Cz0OxmRhQr&Yd&5Qu-cSc)sAzTYSHP)r)8D?$UO4xj1PSW9^rZa9La* zX-*!my;jt165vy4VQD$kD5oZ4YWe|1Y}wH%bIfCObhI^1*XNva*3snVwoyu)qIi5H za-3qC&MUszSc$+VZzZpgI{J4QadhoI%j1!@*@`NAE+2`|u;t@5?e8Hh^|}^<3Gwmu zu~N}9U(Y+T3>I~8blFQS&5v``8>hDPT5wq|OldkGq1g7m5@~2^dhhA!DX^t(VPSy) z1B@Le#pH?^1`36WVhqQQrRg;?i3PKZVTb8mcwbk%<6jw)KnlO;BzXURRJ&28qM{=5 z((wcP_vhv;+PUxzjSLRnR#FPa3ePNct*KU3RrQ-*n$i#zeO55@?!2fsyHR^g zNr|vIb#T>tAP6L(Ve{~%kdGff3NLd<#rt(1;kH6(SyD^rxF6@S9t)Y8o}T^2;ANrz zkn*WX?`yyzT!s_!_NC+vZpClr=B(H*U5ZIdlg0sRVDRFpi~OC^Hrg?@vNT?pOr&Kc zs{6_V$d9clcLi&bv;|KyM8#`(A_POkqt{b+r|nIg6UV|rrQ`ME^CPGwL6BMja};(NWmaH!7m{LP zt`nD{;l=wtDk{ooX={H6Zq%?nOY-%3HTKek8*&v(FAklL!|T)*%H#BNr zH`LX=1#$sO^lI&;cOF3fRwP0ZTZDp46>->AkuR&m2=I`|_ox@PL+@TUATy3@>>XCT zQ2r8^CnGC6T=PcIcBWROij<=#WJ_51?G3`XZB(YaR@JYgVq#8R4i#r(XOBY$*}99S z221@G<;|)s6Ng^EA_Dt12z1pyJ3hbEh$|v;FamhK%c)Soh}5Va(bn3kh>IMV4qvmc zS?2Jyt-CTM%=sj;(j;5CHP0rsyrQD2$2haC5ii984Fl!c{(5#YN9@7Dv*g~kmBNRU zE*hSiXK`m{IGlInvYC)krV{F*R^MkpiIXQq2tXws&X?HV-><6FGXJ|p;~*E2x~n+4 z5`2C4;7_qPZ{J?U8Lf4vV!dQNrYtJA_4DBjqmH<^I3@s{MEMU_c@ zaLax3OESo9)}1yYNNYqnQ!hQla?KuPc3y^Bpu6r;1BaP|bIM$yh+ zTV0AvPL>lAdQ5Bgn+4sE%(>S$=^X#OA$Ef3WfC(xhbiPezE~?~nFu&7RHzYaYuZzy z+mT}##G>H}TtVcpw{MT-Pk-muGcsyAlbrX~*lDS4lMw}(h~s5s4~3zd5)#e4^X-v{ z)gX?3sc7*b?3Rp*N+=Nl$xbUHH?Cj5Fu!}b5}$(@zD1y>va(w2L}P93oBC)8{pKVU zq_ev3q1*%xGm=PmRTSR*J~A>=5P0t}ZqT1uNsDimm9jYkr7A%o)uQx>^(izW2u?i% zgN8SPZk5pi#(9&jv`)*W=3S+9PRe}Y)%L$$g@z_dI4#zDFYX~YEI+Cp{3z_e~#h9T1HYet8qwl4PS*OT{qaj%O= zA!kvM8XBZK%&U2WcC` zy~D9xWYN`@KqdD+1vig0b{D%HS2@x^^W&Gdk56Pru64zz)fs+%H5`2Hnra;*BX%VV z0QYDSS{oq*(#IBq8{*>ZN*B5w^pGPB4+Det|K=5J$xKg|$+a5wZX0bPh>=(cha0F6 z{a+7JQ^%tOssWhSBYV8V3L~G)+LZ|OJTm2^SzI3APX91tHYRiczth~PgNI()2h8#7>`0V-%;(_Y$ z*Y!ict;Em@(OfqBVj$))@oB4{pPy!iY-wrf88$WjrUXf0VOK}Qb=-#!lH-MI1!MX1oe|*#3ql-!gZ&pg zwQt;be)#ZVi$4>0v937SaZ2-LZ~Q6hy2_sssR`{bU`h1;ef#Po+g{ZcVTa$9(%Rjc zoSf9k%>MJ|f#Hu70W@_)w*ulQhb>D*$;OJ4W)l-bl5+N2ho!l)o=Kg>nW4q@0lw(L zK^;~H;A_Q9qfp%X+;~?qO4)&f2la3q`n5Ort9X@0CfoO&A(CETCer^*V70w4KR|Hr ztgR(J#&N4L+8YH01tOJU!z!aMkYaD9rp#Il*gGEdc(i7lMy*JCnu7q z=|_~NAtT{ki@&w`^*w9Za%AxG{N<^NDcgKpdS%RJ|cP1bNkqIn^o_292n!m%a(p*_N2r5H#PpW#_?fLl$}!{>)~%9R#!0lHZ!ef1HtV&?kI3D$HR{0eQjC*Vm#Vw>#$K zEQj&qw%ZkC8dST{$0B(VfexB10ikwJLn9q+(z~cAo%;7;^e1hpzt4zv=X>3q@sOXe zTAa~MJ=2=*P~qe4eFf)F{=A z+)_T4rKa<>?hq74N)6We#< z)?jT;ct}VB2)ChMUtgeSpL$pq5E^Po+G1&G*&iS=q<6cir6u;+zLR?3Ue^y_MTMQc z>!ii^EYHtn1JgtM!v6?mnx4n01(!5EQ2n|}%cpCcZeo1bp&n%m3kw|}*`SM>H)`~c zhxEDN8cIP4ye$1M#4#l=OVaB5H6J5JNW5rD4*$DMDqZ_tx_ti~|2;8!Y6};U5#vPDQ&KjuZEM+J0lHG z&t5j~pW)JVp@=!GA>!h90OxXh%h#+-d%EuzkAWX<7L`095oTaTz3ccy9+~Az@|BWV zY>Z!M=n=Ers3=A-TcWt|xTK`9$<381>YpA(BvU6gI`4CyX%4w;-fO~%QXnHKSMPAL4vvkj_Ti5t* z8^@>#|J`L}Wj7=wb}yL4F%Ht%88sCdgKpDw9VN*Ep2MQqc1 zxfFNb*3x=^Spa#)u;yC39l{ITFqtjSrix*AclQV%RZ8-_Po6TIBccv{9#@Pq?XKu= zn@b;6FLy7Z$`Y(<`)R_NX3+AHev9#c#LV72#B5pKb-4t2inOU$etz$^?>iBo!Sra0 zvGC1mGzoS`-1aBbP?Z-5pAo)!^M&I~B;AvCxkaUG?~%XMBoc89DmU7LBYx`no_kIJCLi!Xl>R}FQ z6|G0lyu9Q923q1S?{V&GmfMn%mX=n1(ms5#A-XcJy}jMwU>qX30g6_!>O5Q6emAb> zU2Y8b!)J=@tw$g4axJ={5i`vb(skG3$z@btEIs@GusSlmz+bq^q(*kZq86& z9hNH`o)t}yJUy|rqNk&?&9;6gjgxI+Q4iGvUe{drLBoqjhK6k>7wnmsnPrG`J~GnS z($Zq+?KU?z7twWD+PTldxJL&Ojk+r>B^BxExy#bRVn)Z{LyJLv60;V~xbu5{c6M2m zo4AAoX&aktkcYz4ypbFArC%bllUTLC$Gm;}nXD58#)8lf%HaLt&Jsl~&z?RVl~d=9|#t;c}Kw*{bMP&w#kHTHL@PH+&5->$JmB}V0is}eCm@8-*t{}abad%v`eNM zmYVzY{(L()hXRW2oxAAx48BjXYm~UtdbN=QuYj?T!u5mmYmdCbsu;eN)nn-P9Gl{d zyF&M3G15N%RR3h25S5UVlPgpb6im08XDjFhdDxm^7~I;{mRmLTz_8V=O|34x!cG($w=l%gfN0yDuZ5+Z+SuxFtNPbL6uf;w;zCsb!ha-?u(HG%VGt zF7dP)cb$~K@6?33Y?5-U&A2EvIZqj!3t^4-M~;vqi_f4$u0 z@^0I|UI2o|z-Db)+faO;@o>mSLM{RG@Q<)~M@kBu1v6A+z{b>L#+yxrIm(I(>+9-l z9roJgkBkISk4?Sd2bwq`78E5y=OFFj;Q;`C2jK~}kf1V^O%-H-$bkO>?qK(d2o=5_ZDN!sdX^+4j;47T6OoycREJmZ-}P?cYBb zW;Zq@yW!chXOi050}+awTt6BEZlS}e{OVAAdg2$-#@)5$E68}P)nBnc36Oe%Sw$7@ z-quqeQBIkCvvEttCGy88@iRm9pdH_X!Nm_(Wwo=%jvdq4sTs5Je9E=&BeT0kvS~-4 zALhzb5@bMdssd^MS7&$4ELqN}n?JKHpn98q7yjtZY4tlGTW6d(v;8f0E0DEhR=XCu zGc)T;?m`PSYJf!Og;0g$AkY~VENFopBm{z!;vqEvVvVh>am{K@Dn2JA#;oPaiMt%66)l8|cQ?eY7Q21UZ%L=JARL zpNGt{7efYnZRJ!{46RP*PnJ*%^1Y)iw*g=VnU>2Z!X2 zbDK)v%(K7OyF{+dtDCShulW)i+Y_3TT9tn|>qqtb9NjtNaX6qYr(^ZF_tB zXq1qNIw{>e>8l-PBMh2CVrE2bjqB{vVF^_{=UA0Xx*u6R1zMh>(yzx#mLy$GUef3MQjMj ze#N{8A8Z_4nU-9(pH+gu+cuH`(CnSWCorUM=2Lr-2-$NcR?VGbYz`fQt`+iGUi5jc zZ2#)|Xz`E_Taq1!wyT(&V3b|N|J=TF2jZF@2sMk~SGuQ` zh{{N2HTSz^pVl_)7lqv11^ ztkpAiA75N4NpL2XbbcIkI{m(9PGNv7LRP zw){CEN>yT`b_&_-{*$2Mh^FFeUUHIGGkBaph~zdJ2$)MJqp}DhBEz8NE%Y&6H2<4! zXsZcFy1#tH-#D>gzX*=DZd7rbm8!ehy$H=OQ+*XxW{U`roeItmK%-U7C#jIwMNxLp zLQs{Fmp=uD1RdI3Ky3B4eu6YvQb1FXoT^L8`a?tS8iOR?%sYrKg@3-DCm5}ml#~m- zRscNx*7NK)d(t;*$Mcx=K7|_h$jBN5{WN7ebiBK2ipP{)Q^PJbZ4nQ+mr!Pt2+s2q zKfe>BkLEwQXp1q{mSBdW;;C{oW4n_z&;A;nUdT~tM}yRFD<>l(^N%b;C^ArUiS{YE zU@;D=5S4;WL?YR6PQ6*QluF-##f4O@%I8y(%QGc{SXsWQWoAM_6gNgVhLCgf&P5NCD|hL2 z6*%+5q;`1zF@C>QYPe7I_qnEaW#o4}Pm5ID7E&LBV0 ztxjx(i+41I)Q{iC2lem&fhfj;5yCOy$l6N#pr}15D=SNc;j!s=!cc{d zsi~>S(K7uOKpQILv7zAw$38zfLbpx5*Z`Rkl1IRcDJ@M+!iltLHMIMsuWuAub<_>k zh+Q<}S^?-jfu^dCN*q>xUG|Aq>dfJ1Q*X|r|CPIU?=5Jg*#Z7Fh4F%=r0k|3ECWgz z`r=uFdeyIeqR7Y-@bu}n4!iMuFxYy=#?8~!{K@F0QzbtSS*%NrkjVO`7ku%x++$e` z28HL_U30_s-;k2>!7^-}QTr+AY_ct6@*|FbQl=4G?lARid{r=qKC(*jMg#-oyogD+K6tRyd`6joUU;6_dObPX$3wVJrxPaQkisL?!c;JHi zT3R~tN`!Xk;Due5T`&&CAqokdlW?k_NCnns96m4(%Lv4j;aU;8Be$K#p+yZ%meMu9 zIA0Kh_I9)*R}Q6p_!T?YJ>p(^A5}S&RaG@2M>Vv!C&1uygx|3}2OF1zrUKaGGWq!sgtNXUZvgZy@MFY2@h<&N#oFZ%MheH3MS*wvM~$-k4ifeq z#3`i7#ZhCY_?Va&qn*#4`dPzk5w=K#$oel1GT$*X`)D=R9t$J@(EvoZ1T?Gu`VH_a z!hn|y+O!#OJ=I}W!A!W|{6a$7BklAE{RJxfh?moHB-Jub#5XA0;hJZ;xQsh_c#*A} z`wV;l0SPNoy4}1H2mxK_48UA0)sYY)o*y`4F%I82N4ZU%ry(26voVCkGH>&(8HsTCg z6`Dx!^D8%iBDBJ07VH%Z9rNLKj6wnhCqa0^7RT3CjiGWAZ}(;-^jzrfvT|~-A^4p| z%MH*!NZs{EIJpa2vnA?QJa!tc50>QJk)L}eA5<>05H8Z?iQ>pP#?M-*r_Y?ZNl-O= z4j4Fhfn&+9HbJ4x+sCF(?#yLAc)2@SG_FSxTCw%$Kimv24ZNvm1N(gI2_n^+ExX>1F>n>~UY zGCgU6=DRRDgZ;;-6`?Idq#z_rz*fKD;4?#d;p4~^S5Ofe;B+7?DiFuPhaEq8QXd_Y zV!ECm-fz3vU9g1%feIsG{pTk;36>L52;u!1m_AKUFW_@RtPZpz;mzDo-K&w>;`{2= zn`DPkeJas_@_C1$o3<#1P+H!?c>|iyK7Cq~rdvPrr^jmqw?cGAaCi|^^uf-oEY5$7 zhW)=15!D_3s#*S$6#Rs?6vm_5FNmP|hxtOSmy(rjvg96YLtucI?J4z=Q&kNk9EmW@ z5qc7M2gL&WY+&01R8;F?Bj_mf94{;S4%f^c@Lc}2&N5v@7!9H!C0G;`z=g+_|5(Wg z9uSEC?)D6kIQI*WT6XjOxk7N1z)9dk_U_xqhBHX)DP#gE98^MG2h0&ekdK+F!p7HU z!6gHsAJy+`Ny#pqg9U^#Mx+7=d8i4Eot@`Ud~wr6hgtjf8qu!DROh`0w?&u(Mgmm1 zJF8v>2F8Le#I`_@Ak+K&@C}VrJXI_bKl+^}{^er+K^SMb?B)`PMYrHW2H$_=NMdjA zeQ+;KFk_%m-&8PnWS1Lk3h%J}r;vJ3stBJE41Tw6xf8C2Z1Vv0;1Zv@1C&DhbYtkE40L79)xbbY=~PN$o$Su3y2$QYPvwq=#QCxkSqao=ysnRJz(J|v>N-@DdUy4DC? zo;Wi6t3CX#JhUsCJQTTBALM)fP?ee-eW7EE@8#IsFWB?IHRltU_isIuGq#3X`ZFIt zejK4YJj&?>2WDx}H*M$UJJ7DP4b>fek~$wR=8B7{bGZhoxy!B<@$4C1+}_YNl3LhC z{@(rNs_Y37A)y(YW%p=m=Y{3g>iGqGtwR@whUIIEE?HYz28Q4LZS1@CP5IL+&V4iK znubnEyG8S5_blbA80V=PXMXT~ljKsQ|J5kdErj<&a!*|w%joz5TZTC&&0SrZ9XzB! z+YaZFcwYAKHB8Yfm-e>mi!Um*jGJ&?ZK4J=s49DDyOdb6f`S)d(;;7&kZF4@tHOU( z`8Db=yzKe?;h)x(bROF3m$~0a9UZgE$$Dt-hKPvv8+$%~?hQXh;p#}?4R52Iq4!1| z)tizd0|Ek!N>7DowyJ~FS8`>3u*vmB+o&wMq;Fe{cxk-|29`jOi1@l5FMhtyr@zQw zb=*KgmI(qRTRG`vhY~2BXtkALIq#%7ql+c;8#l4BM>B7{`U6TW;T6Q`AVf}qXt8sr zPF+ySR=7D7>9M?#gi;*6do|CrJJJRLSCSoP;|0RcjkRS;)OD3qY;?I6Eh}FR5}Z&C zk<$q9$pzB^xRG9XqJRK!zVy&V!hsI20aeX!wXD^bY9u7Zu*d5{J=HTbl)*fJs{Q0O zVDCVc+4E`z3Cv0v0s0Bd3Lo!Rp3B&!AJN8m7RxypYTa22KyxOzO3U*rt5zyPy+do@ z@UZuW6!xi|Bv@@Wr77hbeR`begn&%(9evhx*Rz<+I6R{q-R#Z=;!y@E`uALo=_p>?71b!Yw4_9@%=gHxS(ts^Hd8 zRg#C6KMdGaxL;EydyMIO=j79ljb7cJu{K)}+A#lN`@GdAu^tgkZqKd>;{g-S#bx*P z&s0gZrRZUD2Xsjp%atJ6nO*S5(d^a{cf@AzYA-+ z{BtY+pHDg`@11Rd2ZCtV&_a#CfEJISAZ#&&4uoS5I_M6V5roDJ4Dv+HfIv?u_676# z(v%wK9dWp2W)JnY<)rxFRMGx+Lt6}n3QfOs3pPW8`wiJvJX|DNmr^Yi~#>+$eutvJb^@@#*eBe{Rx z4u^i7S`y)A8$2H@cEck$p87pcoO_f^WGD?nXvDQ@JX+EHLv;FxH#=>v#O(JwmbO76 z+?cKkcxsB5N1inPk@RJ7rdAv=RAlLFu~KPKNqh~ux}$qZjdHx<_QZbGJC+9 zBf4dv=yaow9NM;$FJH)%4imXWS#4V1 zkcc)3n3d)0x37P^z&{^<*<35L83uRa8wlP5p3c&$=HTF9zp$`Q4LiJ+gmT-X%nLRk zgQVrytWo;-ekl#-+M5{E`(X0!i6g1b=eI~OGhU)EIp1Z zBWlDD8QkK8bs9Q!ESQ^NV#vu6e&JzgZsXJ3z+ci*Qmo4ScW=RtOEb!NlnjgnGz-4g zGJHOzISda8a9eB`6`WufiWU#OgI*HYPR-J8_=G4LLfhrZQb7Ec(R{d(Mh+GJdD3Y= zUG+-*9ErMnYKsBjTpqits!zk{tcmfV8!0Qv^P^D(USe69)H;Xxnc7U3g^hBZk3OJ& zi*W=cWOEElN_fqqr$@N6-=71#FU zV><|Go%p0Atgs$bVf2sn(43)8y*GMcLnQwz#@jkwi!e3;`W4-oX3e~IRnWWewdDa~ zA_vEy3ofIw@%c005!Pna)*202FhEU0GDwbVCWazV1$Qro6?fM-%y&c+BXMZYF(rE& zH4vT;V(Q0y(rEo1JMj(_e1e&Q|L1QJ(}gpU!Iz2n_hyEJL&Q`z$%jdkG0id{MpH{+p@C$3*+k((evYSm=K29&lk3TDAgdb?nZgz zZlMe}Gdi>F&jVo}}%R*H_x+2W)Abo~lq+0Ju zp!#{<|45V!`G%M+#*h#(!j4HxOnfxx9M`CUE=&x!I;<^qjZdmVdSxWMf7<*@foOQ& ztSu(5)q$Y|HIyW$YIALEsc zB^kl*>rp=l11=^KZi3>Hg>M*%7`@(c_(Fq%4uevI@OB%cp^acsiNQ-xDZ=hNE>Wh zf+_{}3gOGWC)a%|=(6Ax-<|D%0=r}9&UE80J|e+`;%)+W?8f@cr4Bv2D8a5{n)n}j zfv_WB*spoF~$=lU_~<;0%VN`B!3LOJPI}Ak39)8o`~It_F?3lE;6~eP_sreN8|d zbFV6Lu^20ubYRJNlsOLKyCa7TCbw_kERG+)R@XS@2IazxE0mbCg8wGz1k(dAGJ{jWO_+?eWK#u3z0w|0eU{(KpEy2ku@wwqb+c7Rf`m_bBW!G2EGObYuCk zr`N&Z#qD=~TDMsE?w>C1y9gB?>}eXnJl;ZExV7$y1%m0VT^a4 z#zWY=i*HR%hGqxNSl-@_c^VoT+MJA=U>~Jv4~VrtWD@T*5^IN_{sO$+Lf5XnQ7d$; z8JP{>cjB`e@ASZ#d+g?>G1XhPjhs9^XJzF}Q`14X?T+{Bm|I#t2@j_O0NX}VK5dH; zh4ZhThH+c(!g$5m=1`7A*_}>KYvtQjk1JayU9lAt5tuJ;Y@B_gpNlUSl6luETrOW= zf0+;$c&4qE|Cm+=vF&trb_Th-724E?dtaXq?-&SM2HOz8=YVI9aNd;+d-cj0CX~~n zO2O=gKgY&Ak%<&6LNTmKv${Mj0atmXV$TVfd5#^Mjmxm7Aj;pz?EL-G*#>~l*C z6{@~SzU{2~#lRHEB6-k}8V}T|)Q`*(R|A{1mt>^yr|Rn6;G%wHo1esp%p`KhMv_O5 z9$io`A~9&sqQJLayMLb%VfO(|Y*tRruk3-a_-+RJ^>0p($><%$MDtTmPuIvu26tfw zVkLpEPj@iF?(c<%T*%0A&PPEBrsQ|zfj0aUZ1m@^UlCkIz_1DC_!AIHWQs+u5=$whCuvv~?BnBe(P9;0Z?d{ z3SmAzQopD^F8Sr>^Qt;7XbK4n|3cROR#o-H#bpz`B{hjk91Te-N(vkoFYd*%lzwvA zYC77=#>vTfl!u1~i^Rkd%K?~1K_RId$80n5;{yX$bPL5NC}uwg2BIzAizF|U6Q`VJ+# zfxL|C_4U<@+9A_wXa@ID0IwZ8{O_bLpKcG^W}Yx2b=d$ttQ3y6#~8Vg)qJr2a-lfM zch;T1e*H>;|C5H_kp>1AozA?=SWhg;RyeVAE*+*MBkdg+=)uz^vg{WwIxdZJ(C|Fi z1EVni`DvWs$%=%gzmO)&8AJ3QUk% zRFmEE^BpYR{)Twe;YR0JdYb4s|6d(acVeg95M62fbD5XOOXyJ$FX|7X0VA2$Yk za{HuS1a9&I>dXsG@aC?_atF`K|JT>S^ZqI2ZgC9)>ylo>w|fc?9&nP9AAU?sk>-5< zn+Wds>D}x^lAxfZlpi)bbO;{C{nB!r|Af&U!!jj#c}hs+&(KTntw76>0?WtaCr^+a z{^K=*+_3)A&^gRdboSu+ z9p`^T@eOPqg+<_@Z2S{sB3Cl1M`$+x%ZsNcCqF@*OhJ_oLGI=-YJUii9S;vrO_r&K z#YhwR*w|R(fnDoM`YKL(#Qg^?E$st58R<>)2>r3+{uovG^y$+ZF2TPWeCG*n8BO>k zCSB<-U#5EUe?o@XN<{IFaL2{dmkyJMgofVL*LR4xdBXVb3y#Yu zO|k#2M)W66kmCE^{rgYScd$huJ0D;oxHd^8cR0@bZ}d7#8fF#(;7%YX_$?+-uU{{P zUsCH>p2rdHgMbb~s;aabH*Wk1*kGBD8S3F3aCPGk4iB@9xs_E3Wc(D&i9+%!Lh2v3 zg##ld>E-%sa8z5pGhrL}jOr)BJG)<5S(#IP`}&tiZjn}+vcjY`QNA{jl0D4L<^Ct! z-|j#qsA_AYPE^WtjgHo~lo4IO!0;Ch`S67aNgVL*|Ch^dD{$b(R+&uxx~2@axM4OR@1Gl&HY{sHd|bZ% z_abXm;_7;>?RV6g~cUcJh>#p63-(_e)NAo zEcTAy31i65E}zt|-~~jgPcr<^Z+chO79rq*A=a%VMUZOZKQgU{!rQHicpD@qLHKbu zkl=8}|4(1ZwcLXg?IG^NV;;gbXF|=m83f zS6nhSQ&vBVHlPst*{whykh^>7Z_o@&_m#C_u=UJLJN3WI?;Xcu^Wxbv0WUnOCFk&e zT-8Zq3>FpRn*Z~ched)J*Z=*FGm{v6CcNE0f8O?!`CEs0_anU$a!L)%TmN~vx8GIg zt>8fuN#z~?l%Gp~D*p!Fs~s0?@E8mHxwGNl&}Ay>{hMj8Z~1raVoCn*^5Vm?|Hcn{ z`2PK-qsMIb9H5>yiM&8XDkUlAcJ)BS2g`?U8b{NtlGu*y04Uf_BylDMamkC~Dk^Th z?ctIJHH{XW(i;>#c5duYjLhyL$>QGJWI?2}gT!6Sd<>V8Qy<6{-E$qEXWe}$;i16u z8n*lQ@HZi$`_9$tuZa|W5ttPlaK+@3uBn)t0HEqdEY#Z@{t6}^-U~+9eVNCx925D? zZnK-nb3`tOWLBedX<}|p7*BQT9vX^GNQF7*{{8#%9R-k-P`HwlOrFIAJ)KAjZk^<1 zJoa{p!Y4D=0@ICWEkT z21drm|Llt2PP|?H0_6l|hr2NlP04L_77~n0d%N;kwE`Nc_qPP#ldFAq^AJf9>?JiH zHE-Su zLb)D;ag-;Ftv?kPZvrwqsCn!Wp1E)}vNl4a)RW{P@IB7bo2)EK6f^;0VNx>EP0+;< z>8zVKZ*CC!vwpgwuHNcoa(?>r#|I6#!$@1&vfX7@%%US3qN5qFM=6wD+-}Krk4TbNK9B`N9v<@=jG?u zW*VzKc6B{==FCp)*H=i{x8r4*UugH=(6|@~5G38`W~RnMNyc_$qs&H^PY;2{gLb(L zCZ&w-rI3^{Gck#o=#ggsTxKw)Xllx0HQM?F?iCUguRnNVmIr)hn%T8FLRwlUlvPx^ zdwZjN0n{S##DWySCQyIl@)8098!+k+-%B|rSIXM|OQ`82+Y!=^NOsjVp*BSgJuY{> z`NEyYqeUQy3c6e|>qjZO4w!C&TUUy~^6~ z5VLFT9UaK=I+9~UL(g2vsgIpL9m^&7w?{dkky(~YNZYr9E^P;8@x^IvNiE*Ri;!u0bRi%fz-aa=JlqT9L1awKHo}tT-`3W~aOxEC|0N^Ig`NT~ zq<4G2y&}-Hqaz=>FH^`}W#Cjmk3Rh=>S1kH1i|;RP;ZupxFI%ni~> zmK_}(EcSm)$msa;o=2_++uN;F|B=f7H%-0&i$y_4kBbDioWmbAvMMBg_u{P!4Hl=Q z1;RW(O;~PE9=h-}<~hrH5yHQJ@tij9ZlCSA5U@j!izY5&TUWtT-5I(QCPKS8*Xz?` zl|v_bEEfB5q{h4#@~Ke!(@F?%N*&eHww z2?YKZqw)V~yb}>SL_$jnT%P5)!iC34sO#(RH%-a-AdC;y`>(EsCk#%R>Z6l*|e2aq&Jaa!Z0`T0uJDsSJctX<#HOGbHtAuu_ z{`haPJeW5v`vlJieD3Z}_T}qW1>a1)#vQ-_C16eTPF-BTfIHMDjAaHfhDP!b8FOz1 zu3{(C4e}CXtQa zeBm`rHpQKt1(hpkbC0pHrB7|j^Llr(PvG^KeEr=T-W%56zG~)HupJRj*JyLNpj@zQ zlCF}`wzasp$!J3uE6*v~5_3`~CnruW*@V&7ZA+svfq1giAv{{;nw}oh^78Uo{}mA0 zMbO$L<>j|=bLU>1pGN7p2L?xmib>Khr8eLwY3I+A$Bw=I$m>2phKFG~K`ni;!?2j! zSSAMa7s*3Blx%OA(!G1f(1?u%WNqL`odW+6d*dE~$6k1Nh&>y4qad5ANpiOQj9q@6 zeA}vk+U~4reW{ianM);Zn!E*$GKFD8h;GJU>Om>Z7eqNq<9w2 zXGn;qXve?_z6D`V@O${9$93uNrf&bwB%u8-|M7o?^LJTZk9+0o`xsBoX-HP<8lDOX zp^19;j{d@hJwT=9#M3@C$1>sZDHrm6S5{pua^uDhqPgwueGC<#sG?#gA0MB!^#vIr zp|kf`18-%)Dd+b!$k%r>3D8O0j6?Hg;1#f(t z!OCC{Jx6rY6*q6)iYM`G0zBNdYnQ8RJoASz)#Oq&rxndf+rZ3hzr;l5G(~2R06^{A z5mx1wBbF)wa4_&Llhk!EE&v)2o;B0sITFbK+93CUPe8ynk{FC*pg?KvJ>~~A;6peB z8~VbXRyI*Gs*Z2v;4lq#`4wx}g3L2U`Gy43one8$F)*BD9o}n!3+JSn4eg1B2j< zZ9d6M=S^GvbE@HG-f$GVaqL(*I~vvUf{mLtfdIQ2WpD6yJX7BoriSB_ z%K`eXMc2y=o2uYRW(y&HGO{g3$r{jzehm%Tly>>Y-M>JomoOH03Byz`*oVi)&XIg6 zE-rq%jPa0`h0V;~o;~Y%A_v4E&;VsY!lC?-RH2Op&>{ zc`@;HvEeE1s`UAPsNDw`vZ?{uh?FF2^;AYUzVR0xad-vY7L!94V(Jf`Xg4t!>2+_o zAr{YUq5VBNdL}`r~vV_Au=o!wvvi&!KnBCsO}$pA~}*`vpX1+FKM(}yoOtIfvq40Q

i3E)x@`MK$hsqB-8|1i!!Hz;|j~wK2-viBN73#K4#vSK}?B zB4}bZ(3HfvFZVxiU<0#pnl6P3GBPp~dd*xSH)(?s0?Fv+R#d|F_?O6PdrrbsHYP}G z&Fk_rl~=AMZpR{;?9h>b=5Tz;#I>$n<;J&5Oq|f5zl0_YPC|rMdH)&RPNcqQ`wVmz zAMJdr^~;to7vsj~7%^=Wvr+c$kO#Os(x`^La4Ien#QzIKL>_i@4`~3d;Q+R((Y-4A~|$;N3eJz$irvP_Puc0 zZ4j{&aGa(Q+|Zv6mUgKao#g%l2Z+A?iz_Sr8Njd%Vq0eNWK@oI62F?0@f}r!o z!N-aYBrelPvpxi}n30>CTe3Oy=+RwR9+usFA~1NVc_v6hZ1A+Z>6+~EFK(-)< zi4!N%DwFqBml`+tdenanw@&lOZe1EAB6OnSw`!-Hif ztm$(nE9(XA*Z77br%w4(50uw?fClJBy}1L-NY5nkK{zSCkNAvysJgaxS5neezj&Z& z$<6m;VJ}S>rAe@284!R*x~^KeQfcbc0Z8FGDl2o?sIHER z4X_eW=RWyhA%i8JA%7YhGwve=PHWSqO*tmbe%Eeur}$y0jvc#qzjzL>UiF4G?tA6> zG{H+@S~Kd9PGT&Bh6Y^HlO?&3h$e$%*cB7ZJzZ?l=c5Uv{R=ee5gBUt^+CZTyEtUdyTE99zS_*>JOeqCk9 z=?V#Rjob*Al(aXbWBrI14}nR}1%(=`_#HghZ|c;k??JlmZI_dkRmHnXt>#uI85!~8 zja#+^u^&TqcHG{d|83&rHLq2fN$_KM%wm_Xukhs@zy3pl>561UCM1u)Z~xOe`Al%) z6SeeaWg@PFw~W5C_dcPn0H;t5l%F78a7F&D>q0qAF9Gn#|_=3z2ZY&oZFL%*~ZbJpvB=+aDzO%>~WAkTo&*n{a9W?FTn$Z3G=U+TlmJ>1;*9*!*o3FVsYJhpi z84;MnB$%lSiS_*UhN}4Wc@y1nJUq{nfl?SOUL;c^rj>nrihF?Nw@}D(JnN-$Xj<)l zKsO@_27dU6uvD-1tD_#MDi zUtV`yuyO0w5acqJcg-woSB(#C=R<~F4u3otwc&BNM7gPMS0#s0moB}eeVr%a$3P$% zs-S+YSNg~WwG(YKd`$Do5Y8<9_-Q@6Qc4##V$s(r#I=`b>BXYGYI(y5hENGixi=Bh z9olQwr_YDdCci()=p-ZHmTYpPAPmEESR?0w6HQ_0#M;E9Q6C$Ov0~{ktUltqE1ay9 zE=gXEI_}7F1UHCBhF~Y#o7`La;US3)_ zgb1@_{)L=8=egTkJ9L!|Yd#wiG8|U6m%RK)F_D-3j8BxTq;ValttB3UkB@|nb)3F3 zhX6Y3)2iBj_?n5?-*tF1f{_jrBSn(KF6n<0U!SPbDw~J@?RTxH_4&;omV9bwu792v z&AzbuTxUU-8wrD+1OB@bm@l4D6=Lflh~Hu1x7@y;6m7JWE|3n_fq>%3{WJ8Ma>B8r$q6)2@zhRzCE&&zsyliFYq_V` zq6hQkwpA}9 z{a;5fTT0*_qSIXMiyKNdo#NfXLhH#FRih8)9ZOoE^LH_Yssr^*EiC6F=var>mlmFT zv&Mk|Og4{ZQsN7-C2eiPX1({_wy@DUE^p{%BVcLa;^b;!5z+4W+q5UsvTauXsw(g& zX^gZw?~H)yAk_9qtMzoRy9oE5HTU|loz?(4>W^4l z@0))yi(Tf!A2XR{VM|Mkl9v`tqN}GTg>qzfuJ!hALRj|e4LMwM^~-jw=PXL*xzi6W zJCv)Cg|TeVm%(Dcd6JXyp})h8&%t;1#%dqofgtSHu@ft{%r0v<)UBLhw+H zl&y^fjH4r)i}%@|uI?V#OYp?k#NbrwD5GgZ3-{%nQbIw79>3#Ry{tk6Z@&b}4FM$ZV@Vf7^S9riv5+Ge4jm2--ow_G4K$FLczZIN-bNx6H0T;2k%Eq`hv?bz z{lgOX#L?Q?gA5{iQj|2}1Tm(>>3P_UP@98q>Gd;9X8fXamfW1YA+zAMqj$F*%!ltA z?>-fF;qv$RX<3#HpJX`Zb$$>+m)U zu&(XSMCj|NrO!&H$1F=4`|aDe4Ma9lxcP`f2M;E`Yxy;C@gQt+>8558O2=e_8+E44 zX(sW>PWMPK^}gt~b{_ms9H&$W)lWBh%?x(>A*4N`uLOM@i!c*XX+IT87Js;K8m zZLkE2YaBFg!KsBua%YSolRMVihcu-`pUZyr)G$zf50Hg2OdZca`P|USpMN|*@GgCa z4$%VjjpK`_r!-efu^b4_Lx_bVlHNX@r_vkw%G`P9o%<^)9zq!~%I}l~7Y4i6l?>O8 zN7=CW3o@9X^G}++QyDtV{)1!J;;%o7vvj|&*MhS`0Fbr6=MW>~39kq~o-IvZP`K?& zPuG!zi3D`CUNAZ3_V$D-s?}=d&MZ>hXpn>ouy9AK7rmY8E3JZgrsT3`yUOsd|=`sqar5Dg ztLvs&OjSoXGSR!a#!#ZQqmKyxkI(8T;(R!RM&4NE*wK|I7bKq2ayf`7uLODS%ysKV zgHQOPJ&M@T+JbYm29#d#%7%+DyaZ_^o)A6 zl`R-Xi;!LQdT;iSF=M)2yLN4uK^1e^XskL?eg3%hl&SxW%*-~dBB2hUM)4s<_5qok z$;S6?D(j!HHdltfmDO}IxwS$o#p|F}-JJQeY<5f#_8xC04`iRrO?$Wm|EN9=VW(aS z4jt>P48Z9iSQhNnLn1-p9wMm)~}W8+9lEgH^R`JTC>fwQ%Io zf;mHldK6_OM7pkf`k;h`!$%o?U!U82Diwv?IvT*cwNxT28O~3adt223d9~Ti0;$&A z`7W+&^Ck>;*}CXrOpMy?rq2V)T8Y@U4BPi}SaZto`V?4&B`?<87Vx|pI$TC zH{teK0D|BTW9O#4Gm-C^iNlHAiWQ4KZ+L_N?;l{qlC^{R7m_{D1x8o`sSpzMT&1*% zHgDg)9aqnLYqlr5Cfv`qcX5dci1k=|TKQbS8TS~$dS6*G3P@6#I;<<*(LvM(nFR#{ z=~c4eQx~SVtsknQat>giVR@17H#}nEK1`_oHMLC{J76FBsZOM?5v~c#A;h^;ywH0x z|2YWFdl#xoA|jpQ%{yF4-g2t+YNY)ZR@N4;Hs@Pky!EFf`f)+isEE@=ClyW>9ij@+ z-ViabaLkP)t2&q3!-Hp5m-wq$T3w94Xse!U5xhPZnU{Ug*PW3^{Iza{8 z@70W;`UvtBpqiSlsxn&M7$zM`Hg@P5gI&jYD5$ds=@&)kGrM85K2Zf$Wtz3m5hj zf?`6j9oo9cYa#FlG<+qD1l*?t3ky1nuh+Fo5`i|DTMwKWHkMeKn4ndVih8D16a|7n zIMbz(GX87lJwGM>!{jAtC)cFRm@$Jp0NgS}Wi{7$@JepJ;!7LoenD=~DvwU~9US45ORP@GxOhEDJfj!~U6HXRNEA#?D) ziuzgVW@ti1y$36@y_n`6A0spm)X>3m3;-#gUtTnP+xG4M>JpF7f04~4-g%ZQGhr02gBWIRJ1DNG^-1wi>MadlP;yKd-TfDs>ZrOExZ6EQgOB z^X1X;LT%X)Jc^ivsdnt!y9DOg)vr`Uzjy1Qz?Vh0RbgcD_R%4O2oc4%OtA`mcI)u4 z)AKJT)Bs;9QK9l`?xO0%ie*uN!sPhZ>z{0%wOUgiNoE+!566v0;zV#mRs4z+Fxi=s;NkLZ3z5 zcy+##w;Ee-5{r2(@1W!zvVHg?sR<>{7_C!nop_e61 z#}R=mvqaJAKVZdV=Nm>C1Pa1nv{iB+0#jvSWQv4DIX>}W=Xv=Mxt(ZKGPmc|S^ zKGw?ebHQ(#lp$0e%CqD`Yw9t_^R z<~-R)?L_;!MOsV?ZU)-}J1;gWmwP~$dznw`pW5z;qL&uK!Q zLv^W^cY4cQ78XH;HDOC0!8D)JyjVEfe%$4KqP;#QZzI=3sJ}^(T_uhxSj%k^8)6vm!FzK+C zKB}jJ!a8;N+|%H!xOy0ej&BbvY_xPdISTxXZx&KCEK>B~EKzY)^YNps;=w#Pce5X+ z-tA9luI}~qZM1*SAqGl_(fMt3XC3LSeQ7IY36Y|ARYxFy$*z(P{n&cJ^AAYfy>m+85Jq;$vcN6!?TCzi3Yku2vGSq7z z56l)izF*9C_q^x{f{vVZj5sl(j7$O$KZe*b#(MMGv!{T31#eHz?73HF-~q-f@)sF5 z^cJN4ieEo}@};q}LeiJ5);!rOCdB=eX!jj>o>0ekuv8yM!!%V!KkUsKG5unER(l}S zmo_xF=;uFj-qT$w+iGTJ?0IM9$omGlf9ly)UQ?QSboRV?_rwVKjmJerilW^YiZ^7n z9YFEHzqX;W9i5!%zHOV*q)9jL)}8&lsV+nF?N@k=_{PT_+qbvl^%Eg4jQ_L)+`am2 zYja1eB~15Jo^LIMEU*9b(74)R=h z{K()UNle3J9@ zsR8O21O`F=8f?dYmB7*&PoK&l+anM3b2VS#>`01mkomdjNa-uMLpQg*{{C|Vr-dwE zym&VR{%flCmtQ_4FC*;>1&9n)U|AQ;^6E;o0n;072`3C6$Fh}v+h%CXC@Wm&C^(wVjVN=o<3w7pba z!m1p2}Bh4HAIBSFsB zLJ3DS@6fsLHJnGhNpL`{jq`+W(~I*53Q~*8^${|PPb}19>=cK;{cSuKO3V}fE<-Za zN#YLL7#yFMo<~;_Nziy1x57{kpM&P#l`*E~=H_bxcC<8IeD~o4yl)_g;L^rMBu*oN z=MIy*_NS+(LzgmExSIg~h+=xEj!wGXpR1#pj_M^RXEDYVmItEfjcJJd#w!RH^%$1C zH_c1u@6X=ae}66#*n!{*O#%9y;|fT-Mrpm4y&YP*_!iyzUYem&j?B{oDFNSeVmXPE zL*KT-S{C~@(PTDuuh-M|CM1j`a<|xrc4PJeJfQxzR zZN6LI*tjdcf;PTK3Z{df0QuX42Uq|6d6!5Rh;1WiUn%^w`K?XEL>7a#T$)|lO-QSB zbtNQaK*SmlsTEYYjVvtDr-?C4vx`TX=)sChN~Gu`MezZ+JNs6Z^nk7%uG*vzBJXvE zdli@~IN%*b3N-5CnTUx(=vG0uM`~YDQof=n2fQQvc8GF=+n4o7s)oItAltPT zy3rp_rR#u+kH1>sR{Plx5|UVM06H${{Smn}LQMlhM&0v}zc9+TLW@*OFStJ?CFP!W zm(HDs;w;7XFB2=6gN5^NAW1li76u0OD54i?wqRIpf3KosFUHwX)68bvmkw|)ZV~Z? z(OYA|;exy_L`DYRoEh<|Vkjm>@0a4URK6wI^w4yS|G*~nX}E%h4T@5p_AJ-BdZ8k!ixubPBWtugfBTk=8 zNZ)>|Zq@gc?>zCB^za)uZ$1X@dEDIE#6$_;JD4gthgYh0=FAz0E~{&$;s`iG&TNEc zpYSFfcHD*oX>fKLO96N+Awe^ep%W_+9j2U~4^T9e!DEXH+z>w_eS&P1*19YT(FvRd!G#(H!e26 zQfHl*!w{1juS=Ox1(Hys99O%MXY1ycC=PtyM5)>Ha`UGr-+qkB9f2{{!1_7*GZyb? zO?mhJ{UbP^G1>JVX$kA|%qra8F!V!2ygqxEk8DBDW% zkUCXRge8@gqjXxEukYVa!K&74z<^Oy(?`IDR)0viOjBkG0vPmN3$IkkO7=o69_a`I-A< zARqHhQ|@>*m4!&@dbj@F8aWz=Z8F_G^rXr)Q>m#XhdH%Wk3%s>N>HuTzMA$F5qZeK_}LEoB@Qo(9@~zG%(_Q?}$7JvLTfbDy z!}7%2ON3lW%zk-#HV&1En#NIugQ}R{jineBXt{tfVLi}` z?1qLJjHXEv3n9)lgUliA->SD8OS*_yDtHcL3UgDfSCLYEf$qdn_Cc8c^x+K);Dz2b zxR&^?%diXxa=$;KpJ1`Imk1?}7)7PT+VsFH9Xa1{u6?dTQS=%2&vMi19&3Q-Ea_*+ z3wx35x|aOh#UNtg2(3_K^W-ff#C*-g1&Y9`z3DW}6W4{_^LV>oS`3MkUe3X}MqXpA zau4nk1l|!2jN8wiGbeMYjn^Kk)e;)N_7V~QX?XPchB&mHkvUtIFIS*maESZM7L9{B ze*E9$-lCtQDyimI1Vq`doA%2HYFUomkANd^w=&mNv=a@UYUltNg@G}N!>f7 z4dRi`#jW;I&Sy@HrD+w7$WLWPBQpfp$~HJh;qRawv7#`7_0RtG>Ei$=mF zt^RntO`XR4RcPmgunyyZogkD$V1MKTI|v@sVDz`bw5ehc>+8c)#NU~W;Zk#{qbrg8 zBU&gTBoc9k?X)6SrVSa%9Xx50Z2I5yUscntJR0@4ZKIZ5C+T*ob4UA!NlUGX8#{Ve zxNmsI!-bZW|0Vl&u~c|nF?wv@1rJt?`cN|CQgn){!}2dTpD$b(rF2PYV5QsZ%a-?) zrL%j@-u6hZsyXvegQHTLEy??*o!#76>Q(foFh8|0!F|&Yud4b=&w=1iAAQ~Y^mnl> zD$@*>e)pC`Y8Mn7EM8=~>mo_QuX`K<`Z_l^3qn9*vT)%}7KSa1U2J^3l*IV-l)j?T zV2kEfbRW*4*xP=V_oORMlgaJceyr3^pJ&@vuUT^d&_uK|9y@KUtOjC&yPva+ZeLAG z(h@_>fSmS?)gL{r7Mz)u39_NSrQAf40bh6+7@GWyGwXktWLW+%> zQGNxbHMW{m8%PiSv`Q%kFD#@r)SX=BECEj#7&*4CD8h~4M| zQVR;i+B7l~aJIR-*F4~1p{{;A>ezq(b={-46VsR$!-I*x^Ag89m6n%3?C886lzRgn8~N!w z&Mr_WDK5_F=&ZZ8_{-O?*@s?QE?Xv;EkTVyEcfiwr(KsWU4ViIG}lDVb8r}T=gytR zQZ90X!o{sse;bM(5R&Z@4i@g-rKKl5SZy?7KnCqP#N24zbGCUrSeGPW)yi*G2@ew& zMkcC-?oF_&93ATv6CdA8@^ifQAovbU&3PdA!15aFXm}=VD3%BgiyBfcdXe%GiPMuy zP3wN$O-`PIB(Cnk$l%T#t9#S6^0@W! zeO>>GQr)#W9M1SrXFDN=s^akMpUz}Zox9zJ^X&LP_y!~6Mz1ld*p#e*9W zuD{OloHCK<5Q5~KWeUA&oQKzs!s*CPN{U)rv~aj?nqGRwt7!d)hK&-EQb$Rex3jAM z)ytPh*qjXwl@nPAt&^BNhcsmk+$T+f=;h?TKclalNuq5kMGCU6==MMW>Z+eNwvXyu zC_Q+0YHUV^uF;p~0kEmQx=KP)O_k7X{*09YAHn2f6B7;2YJ1tPU*9djd zUDZB%qx)@58H3DbkFHX2?imc(ZAtSK1G<)$@(AXCW-qg}wC6zS6L{>u=(!xNp>#Vb zfI3DO4(qr$IzEUglk(v>0&4iJH`D)KxG$IrUXVJ6Xj0x6V_UJ0o1=9H0=+NOkaRZw zzO1iV)PxV#D~#vOi^C&hv}J-c&t4z$f(I+C{m}8`miL);d2wg+aOh*QD=J8%sAiMPj+g!VF;eK(sbDnfi zT-co+2TR3peVR`S96^#x5&y_*-}=t7D=tMv38cG6@7^n7=ZBj}1An6Fk`sqN57{7& zCaN5O?CQsc9e=b7+0D3+7n6hjC*82)u|m}_bm@At=($I{qn7lJLDX-$bLMms+F6or z38e{#7~LW0Z>%n!;d?5;UtbwO_@vHl_c|4XqxDaZv>JAr%d0#*h>0Pon6zK=j0_({ zXd&NqB2-Kdj4=q!lIT}gfB*U}6w$j%Z_2o`N=m5v=}y1>^7i6nU(qlUFMEK1ht=n{ z3S69X6OnkB>`NEuFIekMuRn<|JRHW$DGVGKM=8b8=e0kc<>w3Q6HIOlV7fsKQr| zS3G~v@I9|1CtUKOt|zyYVvyDLjJDSSXz?`HyDcp(o;ujBjFW9nXrN6HR3`dtLx!3t z-u9C6UqtpH+U^ud2pZ3!mdQ1*bK}(fazQ+>{&SQyx^L@ ztYPK=US|LM8x^2V16tCI7sXgFL8e22D98I0{!G(6eiAmjnXq7;oOCTLFE7>$HwjBj>?f2Hjn#uuH@^Bo`#Ff%+AHd z^<=ekDV#2_t-81k=M{7sFyPc!jc!*Kk34beR77NCdxD|DgOA|iVshB$?Tb^6y~9O| zu<4cpA)UYC5%g~-mAUy)XHQB}z!o39a37<{UHk{iil+QI=Mhw-p)+BE>74bhUJ_vS zneea}K0jZ*=}6*>V~IyDUyE0~9Uu}g=(&g?G!N`gUp`7NT$lVW_GWr3-v_e>X|W&I(oIg5z74W$`J z=6Etfe)!SKMmVmisc!)cYG^6iIGwhvFMlyaVX9kKMsDr_Rr=8NOPv1J;zs4m(_H}$ z)*)~D%7;AyWYwjFWe${CeXqc2qoDZkhZeEtZ2K3`iL3<#UbqMNn5f=Gn!LA7+YVcP zU&V&RUlgCv!6MMG;stLeuq}8!HtWYEO^K!92e{1u0KUabmkJC`mievB$3CakwX`09 z3snECSR}qaKwaQ)ER)kJ*PYgGW~xzNxhJq`Qyusph?)=*58$n<@)2o@=9^n8dN~)}4usWU{qx*_1w~&h8c%C$ z>)3<@;K$*tl#kHKgu{<$8k;qr;*4g83K8zg2MY3_!sK0$PFxhZgKRKk~k4EPyb-|El~oCH?EQO{_YP-SqzIe)@NyM7vqrTwm4g z*zBQJElj4iW{>$%0IhnTs_PDm_5_oJ114`8x(p>*G65v`g+uM3=Vv~%%jb&qkx)7w)!^?*X-#`l6^e>qwrDb8`IATpeKIvtOS+cLB;C5u-Es$L8pQ-=QaZ$KJP~ z0t<;j^8Uc!(4Z%ZfZSF;EV>Vv7y$d>Q;@bjDW_q6vSEhocMn4(dl^rTlmpc;XgvRV~0 z^!#va87q&xH{7maSw~Mv-87f{$S3D)>esIy74PSyxbgrKRn421H7agVA$;28e$ip{ zm@#$`Bnek_m3;jCI`PNBz5J%UezsA8sa4q(o>#oNCG9_D&!|3km6Bp6$wsF{pH1h) z)@Q-$rje%MadB}un^aU(sNR3P-Q3<%OK(SEw{$-{&D#IFib`}kci0<^D?mT!vLIW^ z*BkwZwurt@q(w0$DFCAgn>BD$8^QK$Qon!7U4*g-1a!Z?a>M|u)`|r??m)5y9PjI= zk9b%IJmbjW!#>(xza>uH@|V~ZRIpHfvbEJ_{*7T?otqL&bg-;L9x z9&i3E6rW`p@z znn-5m!b%q7Ih_Tf1*aLnknFVF@})}=!cciN>CTHZ&15=^=H*RgotiyCSDhKVh#bV$tq+ zXZ;{h=il}kAglu@?TSBtejDz+M~`G1(O;S!1Y&X`G<3%3esy1n9ep~tM?Y$O0G z&YvCriZT+1p@1Ti!Xgvp;v>|r8jGqr?_>~UHh^-oeZ{}r{{OPFUtm?!fAAM z*((f^D0@9R%@kgr^vFM@}L=^#(ObjETxfZzJ<;Wd?qX*VX<1zXP}YAOpg;B)26WS-Fz?ZtZBj{leG_Lfd1}+lNsRV;+vv zSL#t--|7Ad>tfjGFnSAImgq21%SfE87TiJ-CPq&=*0_gsJAf-8@DQ#531a}X$EmB& zpim|6ND@@cIKp!ZJ*_wwTgZB-0|396Sg0`BN9xrX_KJ~xx~I1(gXHB8Fosb3fUcks z?zQQ>woV9rijWs|h@ml(v$8QdqxipN=_)kHTMC|0J59AxGxWG7o)7hQhZr;)-Rbq3 zH*%^^;U#ld*9p+ahe(`*wYAT5pIud$CXiMJb$c@$E6{E@t|KQ;b|#vEvKHE&^yJu$ z#^1k%HWTeA)v6x3=J;p}G2|+o6p^XW4$Pz^L+)_)&)>u!4@is5wbGh2TKFbWa03g< zrB3905yp)^I&P0HJovx3FojrwnLtYB+V9t8jxK!w`xl2+!#7tPtQe!%)^bs;&eTmR zJ){L#$jh|{p>O7##y*;KuRLzg+9MPtKI(fJhT}+dIazi4-I=*Tk5&ryR(6Hzm@#`O ztxZgNfLv{$5D?NSYOHY@8kL%-BGi$xcsy9YNINi4`}QSq$k~tg({oJX`_-1HZX3ep zm>ttwLBaUF%Jy)e6?lWE#Wx(2{uA~ULjx3b&gkF`yjxU3noT6WEaKbykDk~+`n>%< zYD$+vA(G>T;r}^ANl9WKCkx*54x>Ku1@pD)aPuJ^rH#UIiGC~HRl}qGKSV}+Z87DU zI88CQzNV()ty73uRQm@TZ38NzLl=z{&91)nC#$OnGbsz|YnPJH!jkzvdu3P-s?!SUyB2!I|63eBSu z%iqir?K7Q^9Dpp?8^5Ko&1z^bg6Ik|=IOV+`}H%(%o}z0_3J*olh2<&>#Vg1bN%$4 z0#}HV$er}Xkb*VL&CBihlUJ0SYNK*YsNH0>J#Q9v@65xdgEHg@Ly=Al$QK&QT%`UaiiMfx7U|h-)~eFM&Wa2&DX!C IXS?hF0TGL}EdT%j diff --git a/docs/examples/te_gemma/media/speedups.png b/docs/examples/te_gemma/media/speedups.png new file mode 100644 index 0000000000000000000000000000000000000000..6f6e0b16732a4e07afbe8eb553f0a24faeb07f14 GIT binary patch literal 40595 zcmeFZbySyY)GdmI1u7;8f(RChibzO`AczWxbfZXjON%HNgi_L=pmZqRC<;>2AT81$ zUD9X1xcB#+Z;U(cIRD)H-E+tBkIfc;@xITqp0(zhbFTNeD{Mc-MRaAN#{@8E3{}g|xvbZL1A!V#(VSU$3lSJFd*ie(% zT*FLL)5u)M*kXKBAs-3JQ4&!h0qF-JzdLM9_78-vO;%S97LbT;Gkq#})Ra#6;Rg?9 z`!6RxkWaRX|N1fc%ADhwlqh>}YtQSd?DSl}`y8isQUKbJ zuH^6E#cg=PYQ+ziP=;4)tFEagTl-46d@z!J#F5#S7{2mbdw6j0C1YQ2uMWO6^3{H_ zt72jv-$O1({{C$s7Rr@gx|G@Fn9HbJ?t7XcH0`dGf5MEyKB7p?X?#c)x7iHv{EM<44>Y(aU;!Ux?f;? zt!UHkL04*{Z5q#AjXpo#-apH4Gu108PF|8DxxLeqjz8)B`xDoMh0FWQVk081%(t6+ z-Me>BnBg*I#BToO+e>q!QU(SFH9to+=#~N)%d}EGEb2awwP%~Re$a6H9?pM0y!s^8 zC?n#Acm0sGxp`__e7wx<+YysJr6Zko=eG-S2P%#kkByJk(;hl`!>S4WL8@E zv6OM?yJQWnkUZ7o1!?sZzULB~&#UCz?>xgd_rdElGqc#<1;?0;;djVOdl=&5;%?A-oxFd1$o4eD=fupOx@MLm;;dSg+b(-1Wn{=27c9JI)b0NK zMDfGDYc-=KUUd8hL)7ug=Uy*Zo0(k{XuBZ)@qXFE@V8s^EocqdgFVbBR@i@Ce)sX? zokQo8UotYQ6(}Plxa|L=>^i~~6>OCA?6aZ!v~7O>LT=cBLx=Q7ni9>jl|QzP<}8in zTGoqkoMvK*6NwL~9`Gscsa;)JO4qKY%~n3QGMnrqV{4mpk(E_s=nzR{bWVvjn$(i2JwIa3#Ka_xEgl}$!@^Ycf5xHl z)v@t3i=jKBb@PR54x(~OUuf+o1g2T%M%$u^W822QWy=;W$K|Zyk&!;LY~PnJBLZ2q zhntj)M=gcFr)pOZe|yD{W77R7Cnv`j3yVb=c^_i3T~d3iiIPs3{~g=-uS zQ%z>)^H&!~`p516bi1Wn4As3%ytm6u@YSnVIV*E*e)w#T&If|{ji#n%KB+NQiuh(6 zEaU#-$B#Ex1+WNWx5~eL+m_SSUllB)s2Dd?A3d^D?a!m2prADCaaBKm|G4~o6;o5w zAU=m2GjsDQBQE60)?!fsneF8Cn4vHYhiam`nCbLsNn~JzGxEYE(_cbFw)*<~z;jQ} zu@7Acs@Zn)YEkd72fwy!YH7tqMG4v1*u=%gzCt>%ut+1~s4-qnAIXv4sZ_mUo9JTq zZl_b>)zNLng(}(E*{zwz;hwa-asB;Tef|A+ZkxQbfaVXbHebUZ1LG!~PgD=;fcWrr^ z>HPWT%0SlH1}TXl{A^9t_CnCH8Ma>H#qT;Y=x!`b^w6!zIYwfB1b9>G&iq~#Vr#g&RtNDOiF+ECr zMh1IhbMpx5Cz0OxmRhQr&Yd&5Qu-cSc)sAzTYSHP)r)8D?$UO4xj1PSW9^rZa9La* zX-*!my;jt165vy4VQD$kD5oZ4YWe|1Y}wH%bIfCObhI^1*XNva*3snVwoyu)qIi5H za-3qC&MUszSc$+VZzZpgI{J4QadhoI%j1!@*@`NAE+2`|u;t@5?e8Hh^|}^<3Gwmu zu~N}9U(Y+T3>I~8blFQS&5v``8>hDPT5wq|OldkGq1g7m5@~2^dhhA!DX^t(VPSy) z1B@Le#pH?^1`36WVhqQQrRg;?i3PKZVTb8mcwbk%<6jw)KnlO;BzXURRJ&28qM{=5 z((wcP_vhv;+PUxzjSLRnR#FPa3ePNct*KU3RrQ-*n$i#zeO55@?!2fsyHR^g zNr|vIb#T>tAP6L(Ve{~%kdGff3NLd<#rt(1;kH6(SyD^rxF6@S9t)Y8o}T^2;ANrz zkn*WX?`yyzT!s_!_NC+vZpClr=B(H*U5ZIdlg0sRVDRFpi~OC^Hrg?@vNT?pOr&Kc zs{6_V$d9clcLi&bv;|KyM8#`(A_POkqt{b+r|nIg6UV|rrQ`ME^CPGwL6BMja};(NWmaH!7m{LP zt`nD{;l=wtDk{ooX={H6Zq%?nOY-%3HTKek8*&v(FAklL!|T)*%H#BNr zH`LX=1#$sO^lI&;cOF3fRwP0ZTZDp46>->AkuR&m2=I`|_ox@PL+@TUATy3@>>XCT zQ2r8^CnGC6T=PcIcBWROij<=#WJ_51?G3`XZB(YaR@JYgVq#8R4i#r(XOBY$*}99S z221@G<;|)s6Ng^EA_Dt12z1pyJ3hbEh$|v;FamhK%c)Soh}5Va(bn3kh>IMV4qvmc zS?2Jyt-CTM%=sj;(j;5CHP0rsyrQD2$2haC5ii984Fl!c{(5#YN9@7Dv*g~kmBNRU zE*hSiXK`m{IGlInvYC)krV{F*R^MkpiIXQq2tXws&X?HV-><6FGXJ|p;~*E2x~n+4 z5`2C4;7_qPZ{J?U8Lf4vV!dQNrYtJA_4DBjqmH<^I3@s{MEMU_c@ zaLax3OESo9)}1yYNNYqnQ!hQla?KuPc3y^Bpu6r;1BaP|bIM$yh+ zTV0AvPL>lAdQ5Bgn+4sE%(>S$=^X#OA$Ef3WfC(xhbiPezE~?~nFu&7RHzYaYuZzy z+mT}##G>H}TtVcpw{MT-Pk-muGcsyAlbrX~*lDS4lMw}(h~s5s4~3zd5)#e4^X-v{ z)gX?3sc7*b?3Rp*N+=Nl$xbUHH?Cj5Fu!}b5}$(@zD1y>va(w2L}P93oBC)8{pKVU zq_ev3q1*%xGm=PmRTSR*J~A>=5P0t}ZqT1uNsDimm9jYkr7A%o)uQx>^(izW2u?i% zgN8SPZk5pi#(9&jv`)*W=3S+9PRe}Y)%L$$g@z_dI4#zDFYX~YEI+Cp{3z_e~#h9T1HYet8qwl4PS*OT{qaj%O= zA!kvM8XBZK%&U2WcC` zy~D9xWYN`@KqdD+1vig0b{D%HS2@x^^W&Gdk56Pru64zz)fs+%H5`2Hnra;*BX%VV z0QYDSS{oq*(#IBq8{*>ZN*B5w^pGPB4+Det|K=5J$xKg|$+a5wZX0bPh>=(cha0F6 z{a+7JQ^%tOssWhSBYV8V3L~G)+LZ|OJTm2^SzI3APX91tHYRiczth~PgNI()2h8#7>`0V-%;(_Y$ z*Y!ict;Em@(OfqBVj$))@oB4{pPy!iY-wrf88$WjrUXf0VOK}Qb=-#!lH-MI1!MX1oe|*#3ql-!gZ&pg zwQt;be)#ZVi$4>0v937SaZ2-LZ~Q6hy2_sssR`{bU`h1;ef#Po+g{ZcVTa$9(%Rjc zoSf9k%>MJ|f#Hu70W@_)w*ulQhb>D*$;OJ4W)l-bl5+N2ho!l)o=Kg>nW4q@0lw(L zK^;~H;A_Q9qfp%X+;~?qO4)&f2la3q`n5Ort9X@0CfoO&A(CETCer^*V70w4KR|Hr ztgR(J#&N4L+8YH01tOJU!z!aMkYaD9rp#Il*gGEdc(i7lMy*JCnu7q z=|_~NAtT{ki@&w`^*w9Za%AxG{N<^NDcgKpdS%RJ|cP1bNkqIn^o_292n!m%a(p*_N2r5H#PpW#_?fLl$}!{>)~%9R#!0lHZ!ef1HtV&?kI3D$HR{0eQjC*Vm#Vw>#$K zEQj&qw%ZkC8dST{$0B(VfexB10ikwJLn9q+(z~cAo%;7;^e1hpzt4zv=X>3q@sOXe zTAa~MJ=2=*P~qe4eFf)F{=A z+)_T4rKa<>?hq74N)6We#< z)?jT;ct}VB2)ChMUtgeSpL$pq5E^Po+G1&G*&iS=q<6cir6u;+zLR?3Ue^y_MTMQc z>!ii^EYHtn1JgtM!v6?mnx4n01(!5EQ2n|}%cpCcZeo1bp&n%m3kw|}*`SM>H)`~c zhxEDN8cIP4ye$1M#4#l=OVaB5H6J5JNW5rD4*$DMDqZ_tx_ti~|2;8!Y6};U5#vPDQ&KjuZEM+J0lHG z&t5j~pW)JVp@=!GA>!h90OxXh%h#+-d%EuzkAWX<7L`095oTaTz3ccy9+~Az@|BWV zY>Z!M=n=Ers3=A-TcWt|xTK`9$<381>YpA(BvU6gI`4CyX%4w;-fO~%QXnHKSMPAL4vvkj_Ti5t* z8^@>#|J`L}Wj7=wb}yL4F%Ht%88sCdgKpDw9VN*Ep2MQqc1 zxfFNb*3x=^Spa#)u;yC39l{ITFqtjSrix*AclQV%RZ8-_Po6TIBccv{9#@Pq?XKu= zn@b;6FLy7Z$`Y(<`)R_NX3+AHev9#c#LV72#B5pKb-4t2inOU$etz$^?>iBo!Sra0 zvGC1mGzoS`-1aBbP?Z-5pAo)!^M&I~B;AvCxkaUG?~%XMBoc89DmU7LBYx`no_kIJCLi!Xl>R}FQ z6|G0lyu9Q923q1S?{V&GmfMn%mX=n1(ms5#A-XcJy}jMwU>qX30g6_!>O5Q6emAb> zU2Y8b!)J=@tw$g4axJ={5i`vb(skG3$z@btEIs@GusSlmz+bq^q(*kZq86& z9hNH`o)t}yJUy|rqNk&?&9;6gjgxI+Q4iGvUe{drLBoqjhK6k>7wnmsnPrG`J~GnS z($Zq+?KU?z7twWD+PTldxJL&Ojk+r>B^BxExy#bRVn)Z{LyJLv60;V~xbu5{c6M2m zo4AAoX&aktkcYz4ypbFArC%bllUTLC$Gm;}nXD58#)8lf%HaLt&Jsl~&z?RVl~d=9|#t;c}Kw*{bMP&w#kHTHL@PH+&5->$JmB}V0is}eCm@8-*t{}abad%v`eNM zmYVzY{(L()hXRW2oxAAx48BjXYm~UtdbN=QuYj?T!u5mmYmdCbsu;eN)nn-P9Gl{d zyF&M3G15N%RR3h25S5UVlPgpb6im08XDjFhdDxm^7~I;{mRmLTz_8V=O|34x!cG($w=l%gfN0yDuZ5+Z+SuxFtNPbL6uf;w;zCsb!ha-?u(HG%VGt zF7dP)cb$~K@6?33Y?5-U&A2EvIZqj!3t^4-M~;vqi_f4$u0 z@^0I|UI2o|z-Db)+faO;@o>mSLM{RG@Q<)~M@kBu1v6A+z{b>L#+yxrIm(I(>+9-l z9roJgkBkISk4?Sd2bwq`78E5y=OFFj;Q;`C2jK~}kf1V^O%-H-$bkO>?qK(d2o=5_ZDN!sdX^+4j;47T6OoycREJmZ-}P?cYBb zW;Zq@yW!chXOi050}+awTt6BEZlS}e{OVAAdg2$-#@)5$E68}P)nBnc36Oe%Sw$7@ z-quqeQBIkCvvEttCGy88@iRm9pdH_X!Nm_(Wwo=%jvdq4sTs5Je9E=&BeT0kvS~-4 zALhzb5@bMdssd^MS7&$4ELqN}n?JKHpn98q7yjtZY4tlGTW6d(v;8f0E0DEhR=XCu zGc)T;?m`PSYJf!Og;0g$AkY~VENFopBm{z!;vqEvVvVh>am{K@Dn2JA#;oPaiMt%66)l8|cQ?eY7Q21UZ%L=JARL zpNGt{7efYnZRJ!{46RP*PnJ*%^1Y)iw*g=VnU>2Z!X2 zbDK)v%(K7OyF{+dtDCShulW)i+Y_3TT9tn|>qqtb9NjtNaX6qYr(^ZF_tB zXq1qNIw{>e>8l-PBMh2CVrE2bjqB{vVF^_{=UA0Xx*u6R1zMh>(yzx#mLy$GUef3MQjMj ze#N{8A8Z_4nU-9(pH+gu+cuH`(CnSWCorUM=2Lr-2-$NcR?VGbYz`fQt`+iGUi5jc zZ2#)|Xz`E_Taq1!wyT(&V3b|N|J=TF2jZF@2sMk~SGuQ` zh{{N2HTSz^pVl_)7lqv11^ ztkpAiA75N4NpL2XbbcIkI{m(9PGNv7LRP zw){CEN>yT`b_&_-{*$2Mh^FFeUUHIGGkBaph~zdJ2$)MJqp}DhBEz8NE%Y&6H2<4! zXsZcFy1#tH-#D>gzX*=DZd7rbm8!ehy$H=OQ+*XxW{U`roeItmK%-U7C#jIwMNxLp zLQs{Fmp=uD1RdI3Ky3B4eu6YvQb1FXoT^L8`a?tS8iOR?%sYrKg@3-DCm5}ml#~m- zRscNx*7NK)d(t;*$Mcx=K7|_h$jBN5{WN7ebiBK2ipP{)Q^PJbZ4nQ+mr!Pt2+s2q zKfe>BkLEwQXp1q{mSBdW;;C{oW4n_z&;A;nUdT~tM}yRFD<>l(^N%b;C^ArUiS{YE zU@;D=5S4;WL?YR6PQ6*QluF-##f4O@%I8y(%QGc{SXsWQWoAM_6gNgVhLCgf&P5NCD|hL2 z6*%+5q;`1zF@C>QYPe7I_qnEaW#o4}Pm5ID7E&LBV0 ztxjx(i+41I)Q{iC2lem&fhfj;5yCOy$l6N#pr}15D=SNc;j!s=!cc{d zsi~>S(K7uOKpQILv7zAw$38zfLbpx5*Z`Rkl1IRcDJ@M+!iltLHMIMsuWuAub<_>k zh+Q<}S^?-jfu^dCN*q>xUG|Aq>dfJ1Q*X|r|CPIU?=5Jg*#Z7Fh4F%=r0k|3ECWgz z`r=uFdeyIeqR7Y-@bu}n4!iMuFxYy=#?8~!{K@F0QzbtSS*%NrkjVO`7ku%x++$e` z28HL_U30_s-;k2>!7^-}QTr+AY_ct6@*|FbQl=4G?lARid{r=qKC(*jMg#-oyogD+K6tRyd`6joUU;6_dObPX$3wVJrxPaQkisL?!c;JHi zT3R~tN`!Xk;Due5T`&&CAqokdlW?k_NCnns96m4(%Lv4j;aU;8Be$K#p+yZ%meMu9 zIA0Kh_I9)*R}Q6p_!T?YJ>p(^A5}S&RaG@2M>Vv!C&1uygx|3}2OF1zrUKaGGWq!sgtNXUZvgZy@MFY2@h<&N#oFZ%MheH3MS*wvM~$-k4ifeq z#3`i7#ZhCY_?Va&qn*#4`dPzk5w=K#$oel1GT$*X`)D=R9t$J@(EvoZ1T?Gu`VH_a z!hn|y+O!#OJ=I}W!A!W|{6a$7BklAE{RJxfh?moHB-Jub#5XA0;hJZ;xQsh_c#*A} z`wV;l0SPNoy4}1H2mxK_48UA0)sYY)o*y`4F%I82N4ZU%ry(26voVCkGH>&(8HsTCg z6`Dx!^D8%iBDBJ07VH%Z9rNLKj6wnhCqa0^7RT3CjiGWAZ}(;-^jzrfvT|~-A^4p| z%MH*!NZs{EIJpa2vnA?QJa!tc50>QJk)L}eA5<>05H8Z?iQ>pP#?M-*r_Y?ZNl-O= z4j4Fhfn&+9HbJ4x+sCF(?#yLAc)2@SG_FSxTCw%$Kimv24ZNvm1N(gI2_n^+ExX>1F>n>~UY zGCgU6=DRRDgZ;;-6`?Idq#z_rz*fKD;4?#d;p4~^S5Ofe;B+7?DiFuPhaEq8QXd_Y zV!ECm-fz3vU9g1%feIsG{pTk;36>L52;u!1m_AKUFW_@RtPZpz;mzDo-K&w>;`{2= zn`DPkeJas_@_C1$o3<#1P+H!?c>|iyK7Cq~rdvPrr^jmqw?cGAaCi|^^uf-oEY5$7 zhW)=15!D_3s#*S$6#Rs?6vm_5FNmP|hxtOSmy(rjvg96YLtucI?J4z=Q&kNk9EmW@ z5qc7M2gL&WY+&01R8;F?Bj_mf94{;S4%f^c@Lc}2&N5v@7!9H!C0G;`z=g+_|5(Wg z9uSEC?)D6kIQI*WT6XjOxk7N1z)9dk_U_xqhBHX)DP#gE98^MG2h0&ekdK+F!p7HU z!6gHsAJy+`Ny#pqg9U^#Mx+7=d8i4Eot@`Ud~wr6hgtjf8qu!DROh`0w?&u(Mgmm1 zJF8v>2F8Le#I`_@Ak+K&@C}VrJXI_bKl+^}{^er+K^SMb?B)`PMYrHW2H$_=NMdjA zeQ+;KFk_%m-&8PnWS1Lk3h%J}r;vJ3stBJE41Tw6xf8C2Z1Vv0;1Zv@1C&DhbYtkE40L79)xbbY=~PN$o$Su3y2$QYPvwq=#QCxkSqao=ysnRJz(J|v>N-@DdUy4DC? zo;Wi6t3CX#JhUsCJQTTBALM)fP?ee-eW7EE@8#IsFWB?IHRltU_isIuGq#3X`ZFIt zejK4YJj&?>2WDx}H*M$UJJ7DP4b>fek~$wR=8B7{bGZhoxy!B<@$4C1+}_YNl3LhC z{@(rNs_Y37A)y(YW%p=m=Y{3g>iGqGtwR@whUIIEE?HYz28Q4LZS1@CP5IL+&V4iK znubnEyG8S5_blbA80V=PXMXT~ljKsQ|J5kdErj<&a!*|w%joz5TZTC&&0SrZ9XzB! z+YaZFcwYAKHB8Yfm-e>mi!Um*jGJ&?ZK4J=s49DDyOdb6f`S)d(;;7&kZF4@tHOU( z`8Db=yzKe?;h)x(bROF3m$~0a9UZgE$$Dt-hKPvv8+$%~?hQXh;p#}?4R52Iq4!1| z)tizd0|Ek!N>7DowyJ~FS8`>3u*vmB+o&wMq;Fe{cxk-|29`jOi1@l5FMhtyr@zQw zb=*KgmI(qRTRG`vhY~2BXtkALIq#%7ql+c;8#l4BM>B7{`U6TW;T6Q`AVf}qXt8sr zPF+ySR=7D7>9M?#gi;*6do|CrJJJRLSCSoP;|0RcjkRS;)OD3qY;?I6Eh}FR5}Z&C zk<$q9$pzB^xRG9XqJRK!zVy&V!hsI20aeX!wXD^bY9u7Zu*d5{J=HTbl)*fJs{Q0O zVDCVc+4E`z3Cv0v0s0Bd3Lo!Rp3B&!AJN8m7RxypYTa22KyxOzO3U*rt5zyPy+do@ z@UZuW6!xi|Bv@@Wr77hbeR`begn&%(9evhx*Rz<+I6R{q-R#Z=;!y@E`uALo=_p>?71b!Yw4_9@%=gHxS(ts^Hd8 zRg#C6KMdGaxL;EydyMIO=j79ljb7cJu{K)}+A#lN`@GdAu^tgkZqKd>;{g-S#bx*P z&s0gZrRZUD2Xsjp%atJ6nO*S5(d^a{cf@AzYA-+ z{BtY+pHDg`@11Rd2ZCtV&_a#CfEJISAZ#&&4uoS5I_M6V5roDJ4Dv+HfIv?u_676# z(v%wK9dWp2W)JnY<)rxFRMGx+Lt6}n3QfOs3pPW8`wiJvJX|DNmr^Yi~#>+$eutvJb^@@#*eBe{Rx z4u^i7S`y)A8$2H@cEck$p87pcoO_f^WGD?nXvDQ@JX+EHLv;FxH#=>v#O(JwmbO76 z+?cKkcxsB5N1inPk@RJ7rdAv=RAlLFu~KPKNqh~ux}$qZjdHx<_QZbGJC+9 zBf4dv=yaow9NM;$FJH)%4imXWS#4V1 zkcc)3n3d)0x37P^z&{^<*<35L83uRa8wlP5p3c&$=HTF9zp$`Q4LiJ+gmT-X%nLRk zgQVrytWo;-ekl#-+M5{E`(X0!i6g1b=eI~OGhU)EIp1Z zBWlDD8QkK8bs9Q!ESQ^NV#vu6e&JzgZsXJ3z+ci*Qmo4ScW=RtOEb!NlnjgnGz-4g zGJHOzISda8a9eB`6`WufiWU#OgI*HYPR-J8_=G4LLfhrZQb7Ec(R{d(Mh+GJdD3Y= zUG+-*9ErMnYKsBjTpqits!zk{tcmfV8!0Qv^P^D(USe69)H;Xxnc7U3g^hBZk3OJ& zi*W=cWOEElN_fqqr$@N6-=71#FU zV><|Go%p0Atgs$bVf2sn(43)8y*GMcLnQwz#@jkwi!e3;`W4-oX3e~IRnWWewdDa~ zA_vEy3ofIw@%c005!Pna)*202FhEU0GDwbVCWazV1$Qro6?fM-%y&c+BXMZYF(rE& zH4vT;V(Q0y(rEo1JMj(_e1e&Q|L1QJ(}gpU!Iz2n_hyEJL&Q`z$%jdkG0id{MpH{+p@C$3*+k((evYSm=K29&lk3TDAgdb?nZgz zZlMe}Gdi>F&jVo}}%R*H_x+2W)Abo~lq+0Ju zp!#{<|45V!`G%M+#*h#(!j4HxOnfxx9M`CUE=&x!I;<^qjZdmVdSxWMf7<*@foOQ& ztSu(5)q$Y|HIyW$YIALEsc zB^kl*>rp=l11=^KZi3>Hg>M*%7`@(c_(Fq%4uevI@OB%cp^acsiNQ-xDZ=hNE>Wh zf+_{}3gOGWC)a%|=(6Ax-<|D%0=r}9&UE80J|e+`;%)+W?8f@cr4Bv2D8a5{n)n}j zfv_WB*spoF~$=lU_~<;0%VN`B!3LOJPI}Ak39)8o`~It_F?3lE;6~eP_sreN8|d zbFV6Lu^20ubYRJNlsOLKyCa7TCbw_kERG+)R@XS@2IazxE0mbCg8wGz1k(dAGJ{jWO_+?eWK#u3z0w|0eU{(KpEy2ku@wwqb+c7Rf`m_bBW!G2EGObYuCk zr`N&Z#qD=~TDMsE?w>C1y9gB?>}eXnJl;ZExV7$y1%m0VT^a4 z#zWY=i*HR%hGqxNSl-@_c^VoT+MJA=U>~Jv4~VrtWD@T*5^IN_{sO$+Lf5XnQ7d$; z8JP{>cjB`e@ASZ#d+g?>G1XhPjhs9^XJzF}Q`14X?T+{Bm|I#t2@j_O0NX}VK5dH; zh4ZhThH+c(!g$5m=1`7A*_}>KYvtQjk1JayU9lAt5tuJ;Y@B_gpNlUSl6luETrOW= zf0+;$c&4qE|Cm+=vF&trb_Th-724E?dtaXq?-&SM2HOz8=YVI9aNd;+d-cj0CX~~n zO2O=gKgY&Ak%<&6LNTmKv${Mj0atmXV$TVfd5#^Mjmxm7Aj;pz?EL-G*#>~l*C z6{@~SzU{2~#lRHEB6-k}8V}T|)Q`*(R|A{1mt>^yr|Rn6;G%wHo1esp%p`KhMv_O5 z9$io`A~9&sqQJLayMLb%VfO(|Y*tRruk3-a_-+RJ^>0p($><%$MDtTmPuIvu26tfw zVkLpEPj@iF?(c<%T*%0A&PPEBrsQ|zfj0aUZ1m@^UlCkIz_1DC_!AIHWQs+u5=$whCuvv~?BnBe(P9;0Z?d{ z3SmAzQopD^F8Sr>^Qt;7XbK4n|3cROR#o-H#bpz`B{hjk91Te-N(vkoFYd*%lzwvA zYC77=#>vTfl!u1~i^Rkd%K?~1K_RId$80n5;{yX$bPL5NC}uwg2BIzAizF|U6Q`VJ+# zfxL|C_4U<@+9A_wXa@ID0IwZ8{O_bLpKcG^W}Yx2b=d$ttQ3y6#~8Vg)qJr2a-lfM zch;T1e*H>;|C5H_kp>1AozA?=SWhg;RyeVAE*+*MBkdg+=)uz^vg{WwIxdZJ(C|Fi z1EVni`DvWs$%=%gzmO)&8AJ3QUk% zRFmEE^BpYR{)Twe;YR0JdYb4s|6d(acVeg95M62fbD5XOOXyJ$FX|7X0VA2$Yk za{HuS1a9&I>dXsG@aC?_atF`K|JT>S^ZqI2ZgC9)>ylo>w|fc?9&nP9AAU?sk>-5< zn+Wds>D}x^lAxfZlpi)bbO;{C{nB!r|Af&U!!jj#c}hs+&(KTntw76>0?WtaCr^+a z{^K=*+_3)A&^gRdboSu+ z9p`^T@eOPqg+<_@Z2S{sB3Cl1M`$+x%ZsNcCqF@*OhJ_oLGI=-YJUii9S;vrO_r&K z#YhwR*w|R(fnDoM`YKL(#Qg^?E$st58R<>)2>r3+{uovG^y$+ZF2TPWeCG*n8BO>k zCSB<-U#5EUe?o@XN<{IFaL2{dmkyJMgofVL*LR4xdBXVb3y#Yu zO|k#2M)W66kmCE^{rgYScd$huJ0D;oxHd^8cR0@bZ}d7#8fF#(;7%YX_$?+-uU{{P zUsCH>p2rdHgMbb~s;aabH*Wk1*kGBD8S3F3aCPGk4iB@9xs_E3Wc(D&i9+%!Lh2v3 zg##ld>E-%sa8z5pGhrL}jOr)BJG)<5S(#IP`}&tiZjn}+vcjY`QNA{jl0D4L<^Ct! z-|j#qsA_AYPE^WtjgHo~lo4IO!0;Ch`S67aNgVL*|Ch^dD{$b(R+&uxx~2@axM4OR@1Gl&HY{sHd|bZ% z_abXm;_7;>?RV6g~cUcJh>#p63-(_e)NAo zEcTAy31i65E}zt|-~~jgPcr<^Z+chO79rq*A=a%VMUZOZKQgU{!rQHicpD@qLHKbu zkl=8}|4(1ZwcLXg?IG^NV;;gbXF|=m83f zS6nhSQ&vBVHlPst*{whykh^>7Z_o@&_m#C_u=UJLJN3WI?;Xcu^Wxbv0WUnOCFk&e zT-8Zq3>FpRn*Z~ched)J*Z=*FGm{v6CcNE0f8O?!`CEs0_anU$a!L)%TmN~vx8GIg zt>8fuN#z~?l%Gp~D*p!Fs~s0?@E8mHxwGNl&}Ay>{hMj8Z~1raVoCn*^5Vm?|Hcn{ z`2PK-qsMIb9H5>yiM&8XDkUlAcJ)BS2g`?U8b{NtlGu*y04Uf_BylDMamkC~Dk^Th z?ctIJHH{XW(i;>#c5duYjLhyL$>QGJWI?2}gT!6Sd<>V8Qy<6{-E$qEXWe}$;i16u z8n*lQ@HZi$`_9$tuZa|W5ttPlaK+@3uBn)t0HEqdEY#Z@{t6}^-U~+9eVNCx925D? zZnK-nb3`tOWLBedX<}|p7*BQT9vX^GNQF7*{{8#%9R-k-P`HwlOrFIAJ)KAjZk^<1 zJoa{p!Y4D=0@ICWEkT z21drm|Llt2PP|?H0_6l|hr2NlP04L_77~n0d%N;kwE`Nc_qPP#ldFAq^AJf9>?JiH zHE-Su zLb)D;ag-;Ftv?kPZvrwqsCn!Wp1E)}vNl4a)RW{P@IB7bo2)EK6f^;0VNx>EP0+;< z>8zVKZ*CC!vwpgwuHNcoa(?>r#|I6#!$@1&vfX7@%%US3qN5qFM=6wD+-}Krk4TbNK9B`N9v<@=jG?u zW*VzKc6B{==FCp)*H=i{x8r4*UugH=(6|@~5G38`W~RnMNyc_$qs&H^PY;2{gLb(L zCZ&w-rI3^{Gck#o=#ggsTxKw)Xllx0HQM?F?iCUguRnNVmIr)hn%T8FLRwlUlvPx^ zdwZjN0n{S##DWySCQyIl@)8098!+k+-%B|rSIXM|OQ`82+Y!=^NOsjVp*BSgJuY{> z`NEyYqeUQy3c6e|>qjZO4w!C&TUUy~^6~ z5VLFT9UaK=I+9~UL(g2vsgIpL9m^&7w?{dkky(~YNZYr9E^P;8@x^IvNiE*Ri;!u0bRi%fz-aa=JlqT9L1awKHo}tT-`3W~aOxEC|0N^Ig`NT~ zq<4G2y&}-Hqaz=>FH^`}W#Cjmk3Rh=>S1kH1i|;RP;ZupxFI%ni~> zmK_}(EcSm)$msa;o=2_++uN;F|B=f7H%-0&i$y_4kBbDioWmbAvMMBg_u{P!4Hl=Q z1;RW(O;~PE9=h-}<~hrH5yHQJ@tij9ZlCSA5U@j!izY5&TUWtT-5I(QCPKS8*Xz?` zl|v_bEEfB5q{h4#@~Ke!(@F?%N*&eHww z2?YKZqw)V~yb}>SL_$jnT%P5)!iC34sO#(RH%-a-AdC;y`>(EsCk#%R>Z6l*|e2aq&Jaa!Z0`T0uJDsSJctX<#HOGbHtAuu_ z{`haPJeW5v`vlJieD3Z}_T}qW1>a1)#vQ-_C16eTPF-BTfIHMDjAaHfhDP!b8FOz1 zu3{(C4e}CXtQa zeBm`rHpQKt1(hpkbC0pHrB7|j^Llr(PvG^KeEr=T-W%56zG~)HupJRj*JyLNpj@zQ zlCF}`wzasp$!J3uE6*v~5_3`~CnruW*@V&7ZA+svfq1giAv{{;nw}oh^78Uo{}mA0 zMbO$L<>j|=bLU>1pGN7p2L?xmib>Khr8eLwY3I+A$Bw=I$m>2phKFG~K`ni;!?2j! zSSAMa7s*3Blx%OA(!G1f(1?u%WNqL`odW+6d*dE~$6k1Nh&>y4qad5ANpiOQj9q@6 zeA}vk+U~4reW{ianM);Zn!E*$GKFD8h;GJU>Om>Z7eqNq<9w2 zXGn;qXve?_z6D`V@O${9$93uNrf&bwB%u8-|M7o?^LJTZk9+0o`xsBoX-HP<8lDOX zp^19;j{d@hJwT=9#M3@C$1>sZDHrm6S5{pua^uDhqPgwueGC<#sG?#gA0MB!^#vIr zp|kf`18-%)Dd+b!$k%r>3D8O0j6?Hg;1#f(t z!OCC{Jx6rY6*q6)iYM`G0zBNdYnQ8RJoASz)#Oq&rxndf+rZ3hzr;l5G(~2R06^{A z5mx1wBbF)wa4_&Llhk!EE&v)2o;B0sITFbK+93CUPe8ynk{FC*pg?KvJ>~~A;6peB z8~VbXRyI*Gs*Z2v;4lq#`4wx}g3L2U`Gy43one8$F)*BD9o}n!3+JSn4eg1B2j< zZ9d6M=S^GvbE@HG-f$GVaqL(*I~vvUf{mLtfdIQ2WpD6yJX7BoriSB_ z%K`eXMc2y=o2uYRW(y&HGO{g3$r{jzehm%Tly>>Y-M>JomoOH03Byz`*oVi)&XIg6 zE-rq%jPa0`h0V;~o;~Y%A_v4E&;VsY!lC?-RH2Op&>{ zc`@;HvEeE1s`UAPsNDw`vZ?{uh?FF2^;AYUzVR0xad-vY7L!94V(Jf`Xg4t!>2+_o zAr{YUq5VBNdL}`r~vV_Au=o!wvvi&!KnBCsO}$pA~}*`vpX1+FKM(}yoOtIfvq40Q

i3E)x@`MK$hsqB-8|1i!!Hz;|j~wK2-viBN73#K4#vSK}?B zB4}bZ(3HfvFZVxiU<0#pnl6P3GBPp~dd*xSH)(?s0?Fv+R#d|F_?O6PdrrbsHYP}G z&Fk_rl~=AMZpR{;?9h>b=5Tz;#I>$n<;J&5Oq|f5zl0_YPC|rMdH)&RPNcqQ`wVmz zAMJdr^~;to7vsj~7%^=Wvr+c$kO#Os(x`^La4Ien#QzIKL>_i@4`~3d;Q+R((Y-4A~|$;N3eJz$irvP_Puc0 zZ4j{&aGa(Q+|Zv6mUgKao#g%l2Z+A?iz_Sr8Njd%Vq0eNWK@oI62F?0@f}r!o z!N-aYBrelPvpxi}n30>CTe3Oy=+RwR9+usFA~1NVc_v6hZ1A+Z>6+~EFK(-)< zi4!N%DwFqBml`+tdenanw@&lOZe1EAB6OnSw`!-Hif ztm$(nE9(XA*Z77br%w4(50uw?fClJBy}1L-NY5nkK{zSCkNAvysJgaxS5neezj&Z& z$<6m;VJ}S>rAe@284!R*x~^KeQfcbc0Z8FGDl2o?sIHER z4X_eW=RWyhA%i8JA%7YhGwve=PHWSqO*tmbe%Eeur}$y0jvc#qzjzL>UiF4G?tA6> zG{H+@S~Kd9PGT&Bh6Y^HlO?&3h$e$%*cB7ZJzZ?l=c5Uv{R=ee5gBUt^+CZTyEtUdyTE99zS_*>JOeqCk9 z=?V#Rjob*Al(aXbWBrI14}nR}1%(=`_#HghZ|c;k??JlmZI_dkRmHnXt>#uI85!~8 zja#+^u^&TqcHG{d|83&rHLq2fN$_KM%wm_Xukhs@zy3pl>561UCM1u)Z~xOe`Al%) z6SeeaWg@PFw~W5C_dcPn0H;t5l%F78a7F&D>q0qAF9Gn#|_=3z2ZY&oZFL%*~ZbJpvB=+aDzO%>~WAkTo&*n{a9W?FTn$Z3G=U+TlmJ>1;*9*!*o3FVsYJhpi z84;MnB$%lSiS_*UhN}4Wc@y1nJUq{nfl?SOUL;c^rj>nrihF?Nw@}D(JnN-$Xj<)l zKsO@_27dU6uvD-1tD_#MDi zUtV`yuyO0w5acqJcg-woSB(#C=R<~F4u3otwc&BNM7gPMS0#s0moB}eeVr%a$3P$% zs-S+YSNg~WwG(YKd`$Do5Y8<9_-Q@6Qc4##V$s(r#I=`b>BXYGYI(y5hENGixi=Bh z9olQwr_YDdCci()=p-ZHmTYpPAPmEESR?0w6HQ_0#M;E9Q6C$Ov0~{ktUltqE1ay9 zE=gXEI_}7F1UHCBhF~Y#o7`La;US3)_ zgb1@_{)L=8=egTkJ9L!|Yd#wiG8|U6m%RK)F_D-3j8BxTq;ValttB3UkB@|nb)3F3 zhX6Y3)2iBj_?n5?-*tF1f{_jrBSn(KF6n<0U!SPbDw~J@?RTxH_4&;omV9bwu792v z&AzbuTxUU-8wrD+1OB@bm@l4D6=Lflh~Hu1x7@y;6m7JWE|3n_fq>%3{WJ8Ma>B8r$q6)2@zhRzCE&&zsyliFYq_V` zq6hQkwpA}9 z{a;5fTT0*_qSIXMiyKNdo#NfXLhH#FRih8)9ZOoE^LH_Yssr^*EiC6F=var>mlmFT zv&Mk|Og4{ZQsN7-C2eiPX1({_wy@DUE^p{%BVcLa;^b;!5z+4W+q5UsvTauXsw(g& zX^gZw?~H)yAk_9qtMzoRy9oE5HTU|loz?(4>W^4l z@0))yi(Tf!A2XR{VM|Mkl9v`tqN}GTg>qzfuJ!hALRj|e4LMwM^~-jw=PXL*xzi6W zJCv)Cg|TeVm%(Dcd6JXyp})h8&%t;1#%dqofgtSHu@ft{%r0v<)UBLhw+H zl&y^fjH4r)i}%@|uI?V#OYp?k#NbrwD5GgZ3-{%nQbIw79>3#Ry{tk6Z@&b}4FM$ZV@Vf7^S9riv5+Ge4jm2--ow_G4K$FLczZIN-bNx6H0T;2k%Eq`hv?bz z{lgOX#L?Q?gA5{iQj|2}1Tm(>>3P_UP@98q>Gd;9X8fXamfW1YA+zAMqj$F*%!ltA z?>-fF;qv$RX<3#HpJX`Zb$$>+m)U zu&(XSMCj|NrO!&H$1F=4`|aDe4Ma9lxcP`f2M;E`Yxy;C@gQt+>8558O2=e_8+E44 zX(sW>PWMPK^}gt~b{_ms9H&$W)lWBh%?x(>A*4N`uLOM@i!c*XX+IT87Js;K8m zZLkE2YaBFg!KsBua%YSolRMVihcu-`pUZyr)G$zf50Hg2OdZca`P|USpMN|*@GgCa z4$%VjjpK`_r!-efu^b4_Lx_bVlHNX@r_vkw%G`P9o%<^)9zq!~%I}l~7Y4i6l?>O8 zN7=CW3o@9X^G}++QyDtV{)1!J;;%o7vvj|&*MhS`0Fbr6=MW>~39kq~o-IvZP`K?& zPuG!zi3D`CUNAZ3_V$D-s?}=d&MZ>hXpn>ouy9AK7rmY8E3JZgrsT3`yUOsd|=`sqar5Dg ztLvs&OjSoXGSR!a#!#ZQqmKyxkI(8T;(R!RM&4NE*wK|I7bKq2ayf`7uLODS%ysKV zgHQOPJ&M@T+JbYm29#d#%7%+DyaZ_^o)A6 zl`R-Xi;!LQdT;iSF=M)2yLN4uK^1e^XskL?eg3%hl&SxW%*-~dBB2hUM)4s<_5qok z$;S6?D(j!HHdltfmDO}IxwS$o#p|F}-JJQeY<5f#_8xC04`iRrO?$Wm|EN9=VW(aS z4jt>P48Z9iSQhNnLn1-p9wMm)~}W8+9lEgH^R`JTC>fwQ%Io zf;mHldK6_OM7pkf`k;h`!$%o?U!U82Diwv?IvT*cwNxT28O~3adt223d9~Ti0;$&A z`7W+&^Ck>;*}CXrOpMy?rq2V)T8Y@U4BPi}SaZto`V?4&B`?<87Vx|pI$TC zH{teK0D|BTW9O#4Gm-C^iNlHAiWQ4KZ+L_N?;l{qlC^{R7m_{D1x8o`sSpzMT&1*% zHgDg)9aqnLYqlr5Cfv`qcX5dci1k=|TKQbS8TS~$dS6*G3P@6#I;<<*(LvM(nFR#{ z=~c4eQx~SVtsknQat>giVR@17H#}nEK1`_oHMLC{J76FBsZOM?5v~c#A;h^;ywH0x z|2YWFdl#xoA|jpQ%{yF4-g2t+YNY)ZR@N4;Hs@Pky!EFf`f)+isEE@=ClyW>9ij@+ z-ViabaLkP)t2&q3!-Hp5m-wq$T3w94Xse!U5xhPZnU{Ug*PW3^{Iza{8 z@70W;`UvtBpqiSlsxn&M7$zM`Hg@P5gI&jYD5$ds=@&)kGrM85K2Zf$Wtz3m5hj zf?`6j9oo9cYa#FlG<+qD1l*?t3ky1nuh+Fo5`i|DTMwKWHkMeKn4ndVih8D16a|7n zIMbz(GX87lJwGM>!{jAtC)cFRm@$Jp0NgS}Wi{7$@JepJ;!7LoenD=~DvwU~9US45ORP@GxOhEDJfj!~U6HXRNEA#?D) ziuzgVW@ti1y$36@y_n`6A0spm)X>3m3;-#gUtTnP+xG4M>JpF7f04~4-g%ZQGhr02gBWIRJ1DNG^-1wi>MadlP;yKd-TfDs>ZrOExZ6EQgOB z^X1X;LT%X)Jc^ivsdnt!y9DOg)vr`Uzjy1Qz?Vh0RbgcD_R%4O2oc4%OtA`mcI)u4 z)AKJT)Bs;9QK9l`?xO0%ie*uN!sPhZ>z{0%wOUgiNoE+!566v0;zV#mRs4z+Fxi=s;NkLZ3z5 zcy+##w;Ee-5{r2(@1W!zvVHg?sR<>{7_C!nop_e61 z#}R=mvqaJAKVZdV=Nm>C1Pa1nv{iB+0#jvSWQv4DIX>}W=Xv=Mxt(ZKGPmc|S^ zKGw?ebHQ(#lp$0e%CqD`Yw9t_^R z<~-R)?L_;!MOsV?ZU)-}J1;gWmwP~$dznw`pW5z;qL&uK!Q zLv^W^cY4cQ78XH;HDOC0!8D)JyjVEfe%$4KqP;#QZzI=3sJ}^(T_uhxSj%k^8)6vm!FzK+C zKB}jJ!a8;N+|%H!xOy0ej&BbvY_xPdISTxXZx&KCEK>B~EKzY)^YNps;=w#Pce5X+ z-tA9luI}~qZM1*SAqGl_(fMt3XC3LSeQ7IY36Y|ARYxFy$*z(P{n&cJ^AAYfy>m+85Jq;$vcN6!?TCzi3Yku2vGSq7z z56l)izF*9C_q^x{f{vVZj5sl(j7$O$KZe*b#(MMGv!{T31#eHz?73HF-~q-f@)sF5 z^cJN4ieEo}@};q}LeiJ5);!rOCdB=eX!jj>o>0ekuv8yM!!%V!KkUsKG5unER(l}S zmo_xF=;uFj-qT$w+iGTJ?0IM9$omGlf9ly)UQ?QSboRV?_rwVKjmJerilW^YiZ^7n z9YFEHzqX;W9i5!%zHOV*q)9jL)}8&lsV+nF?N@k=_{PT_+qbvl^%Eg4jQ_L)+`am2 zYja1eB~15Jo^LIMEU*9b(74)R=h z{K()UNle3J9@ zsR8O21O`F=8f?dYmB7*&PoK&l+anM3b2VS#>`01mkomdjNa-uMLpQg*{{C|Vr-dwE zym&VR{%flCmtQ_4FC*;>1&9n)U|AQ;^6E;o0n;072`3C6$Fh}v+h%CXC@Wm&C^(wVjVN=o<3w7pba z!m1p2}Bh4HAIBSFsB zLJ3DS@6fsLHJnGhNpL`{jq`+W(~I*53Q~*8^${|PPb}19>=cK;{cSuKO3V}fE<-Za zN#YLL7#yFMo<~;_Nziy1x57{kpM&P#l`*E~=H_bxcC<8IeD~o4yl)_g;L^rMBu*oN z=MIy*_NS+(LzgmExSIg~h+=xEj!wGXpR1#pj_M^RXEDYVmItEfjcJJd#w!RH^%$1C zH_c1u@6X=ae}66#*n!{*O#%9y;|fT-Mrpm4y&YP*_!iyzUYem&j?B{oDFNSeVmXPE zL*KT-S{C~@(PTDuuh-M|CM1j`a<|xrc4PJeJfQxzR zZN6LI*tjdcf;PTK3Z{df0QuX42Uq|6d6!5Rh;1WiUn%^w`K?XEL>7a#T$)|lO-QSB zbtNQaK*SmlsTEYYjVvtDr-?C4vx`TX=)sChN~Gu`MezZ+JNs6Z^nk7%uG*vzBJXvE zdli@~IN%*b3N-5CnTUx(=vG0uM`~YDQof=n2fQQvc8GF=+n4o7s)oItAltPT zy3rp_rR#u+kH1>sR{Plx5|UVM06H${{Smn}LQMlhM&0v}zc9+TLW@*OFStJ?CFP!W zm(HDs;w;7XFB2=6gN5^NAW1li76u0OD54i?wqRIpf3KosFUHwX)68bvmkw|)ZV~Z? z(OYA|;exy_L`DYRoEh<|Vkjm>@0a4URK6wI^w4yS|G*~nX}E%h4T@5p_AJ-BdZ8k!ixubPBWtugfBTk=8 zNZ)>|Zq@gc?>zCB^za)uZ$1X@dEDIE#6$_;JD4gthgYh0=FAz0E~{&$;s`iG&TNEc zpYSFfcHD*oX>fKLO96N+Awe^ep%W_+9j2U~4^T9e!DEXH+z>w_eS&P1*19YT(FvRd!G#(H!e26 zQfHl*!w{1juS=Ox1(Hys99O%MXY1ycC=PtyM5)>Ha`UGr-+qkB9f2{{!1_7*GZyb? zO?mhJ{UbP^G1>JVX$kA|%qra8F!V!2ygqxEk8DBDW% zkUCXRge8@gqjXxEukYVa!K&74z<^Oy(?`IDR)0viOjBkG0vPmN3$IkkO7=o69_a`I-A< zARqHhQ|@>*m4!&@dbj@F8aWz=Z8F_G^rXr)Q>m#XhdH%Wk3%s>N>HuTzMA$F5qZeK_}LEoB@Qo(9@~zG%(_Q?}$7JvLTfbDy z!}7%2ON3lW%zk-#HV&1En#NIugQ}R{jineBXt{tfVLi}` z?1qLJjHXEv3n9)lgUliA->SD8OS*_yDtHcL3UgDfSCLYEf$qdn_Cc8c^x+K);Dz2b zxR&^?%diXxa=$;KpJ1`Imk1?}7)7PT+VsFH9Xa1{u6?dTQS=%2&vMi19&3Q-Ea_*+ z3wx35x|aOh#UNtg2(3_K^W-ff#C*-g1&Y9`z3DW}6W4{_^LV>oS`3MkUe3X}MqXpA zau4nk1l|!2jN8wiGbeMYjn^Kk)e;)N_7V~QX?XPchB&mHkvUtIFIS*maESZM7L9{B ze*E9$-lCtQDyimI1Vq`doA%2HYFUomkANd^w=&mNv=a@UYUltNg@G}N!>f7 z4dRi`#jW;I&Sy@HrD+w7$WLWPBQpfp$~HJh;qRawv7#`7_0RtG>Ei$=mF zt^RntO`XR4RcPmgunyyZogkD$V1MKTI|v@sVDz`bw5ehc>+8c)#NU~W;Zk#{qbrg8 zBU&gTBoc9k?X)6SrVSa%9Xx50Z2I5yUscntJR0@4ZKIZ5C+T*ob4UA!NlUGX8#{Ve zxNmsI!-bZW|0Vl&u~c|nF?wv@1rJt?`cN|CQgn){!}2dTpD$b(rF2PYV5QsZ%a-?) zrL%j@-u6hZsyXvegQHTLEy??*o!#76>Q(foFh8|0!F|&Yud4b=&w=1iAAQ~Y^mnl> zD$@*>e)pC`Y8Mn7EM8=~>mo_QuX`K<`Z_l^3qn9*vT)%}7KSa1U2J^3l*IV-l)j?T zV2kEfbRW*4*xP=V_oORMlgaJceyr3^pJ&@vuUT^d&_uK|9y@KUtOjC&yPva+ZeLAG z(h@_>fSmS?)gL{r7Mz)u39_NSrQAf40bh6+7@GWyGwXktWLW+%> zQGNxbHMW{m8%PiSv`Q%kFD#@r)SX=BECEj#7&*4CD8h~4M| zQVR;i+B7l~aJIR-*F4~1p{{;A>ezq(b={-46VsR$!-I*x^Ag89m6n%3?C886lzRgn8~N!w z&Mr_WDK5_F=&ZZ8_{-O?*@s?QE?Xv;EkTVyEcfiwr(KsWU4ViIG}lDVb8r}T=gytR zQZ90X!o{sse;bM(5R&Z@4i@g-rKKl5SZy?7KnCqP#N24zbGCUrSeGPW)yi*G2@ew& zMkcC-?oF_&93ATv6CdA8@^ifQAovbU&3PdA!15aFXm}=VD3%BgiyBfcdXe%GiPMuy zP3wN$O-`PIB(Cnk$l%T#t9#S6^0@W! zeO>>GQr)#W9M1SrXFDN=s^akMpUz}Zox9zJ^X&LP_y!~6Mz1ld*p#e*9W zuD{OloHCK<5Q5~KWeUA&oQKzs!s*CPN{U)rv~aj?nqGRwt7!d)hK&-EQb$Rex3jAM z)ytPh*qjXwl@nPAt&^BNhcsmk+$T+f=;h?TKclalNuq5kMGCU6==MMW>Z+eNwvXyu zC_Q+0YHUV^uF;p~0kEmQx=KP)O_k7X{*09YAHn2f6B7;2YJ1tPU*9djd zUDZB%qx)@58H3DbkFHX2?imc(ZAtSK1G<)$@(AXCW-qg}wC6zS6L{>u=(!xNp>#Vb zfI3DO4(qr$IzEUglk(v>0&4iJH`D)KxG$IrUXVJ6Xj0x6V_UJ0o1=9H0=+NOkaRZw zzO1iV)PxV#D~#vOi^C&hv}J-c&t4z$f(I+C{m}8`miL);d2wg+aOh*QD=J8%sAiMPj+g!VF;eK(sbDnfi zT-co+2TR3peVR`S96^#x5&y_*-}=t7D=tMv38cG6@7^n7=ZBj}1An6Fk`sqN57{7& zCaN5O?CQsc9e=b7+0D3+7n6hjC*82)u|m}_bm@At=($I{qn7lJLDX-$bLMms+F6or z38e{#7~LW0Z>%n!;d?5;UtbwO_@vHl_c|4XqxDaZv>JAr%d0#*h>0Pon6zK=j0_({ zXd&NqB2-Kdj4=q!lIT}gfB*U}6w$j%Z_2o`N=m5v=}y1>^7i6nU(qlUFMEK1ht=n{ z3S69X6OnkB>`NEuFIekMuRn<|JRHW$DGVGKM=8b8=e0kc<>w3Q6HIOlV7fsKQr| zS3G~v@I9|1CtUKOt|zyYVvyDLjJDSSXz?`HyDcp(o;ujBjFW9nXrN6HR3`dtLx!3t z-u9C6UqtpH+U^ud2pZ3!mdQ1*bK}(fazQ+>{&SQyx^L@ ztYPK=US|LM8x^2V16tCI7sXgFL8e22D98I0{!G(6eiAmjnXq7;oOCTLFE7>$HwjBj>?f2Hjn#uuH@^Bo`#Ff%+AHd z^<=ekDV#2_t-81k=M{7sFyPc!jc!*Kk34beR77NCdxD|DgOA|iVshB$?Tb^6y~9O| zu<4cpA)UYC5%g~-mAUy)XHQB}z!o39a37<{UHk{iil+QI=Mhw-p)+BE>74bhUJ_vS zneea}K0jZ*=}6*>V~IyDUyE0~9Uu}g=(&g?G!N`gUp`7NT$lVW_GWr3-v_e>X|W&I(oIg5z74W$`J z=6Etfe)!SKMmVmisc!)cYG^6iIGwhvFMlyaVX9kKMsDr_Rr=8NOPv1J;zs4m(_H}$ z)*)~D%7;AyWYwjFWe${CeXqc2qoDZkhZeEtZ2K3`iL3<#UbqMNn5f=Gn!LA7+YVcP zU&V&RUlgCv!6MMG;stLeuq}8!HtWYEO^K!92e{1u0KUabmkJC`mievB$3CakwX`09 z3snECSR}qaKwaQ)ER)kJ*PYgGW~xzNxhJq`Qyusph?)=*58$n<@)2o@=9^n8dN~)}4usWU{qx*_1w~&h8c%C$ z>)3<@;K$*tl#kHKgu{<$8k;qr;*4g83K8zg2MY3_!sK0$PFxhZgKRKk~k4EPyb-|El~oCH?EQO{_YP-SqzIe)@NyM7vqrTwm4g z*zBQJElj4iW{>$%0IhnTs_PDm_5_oJ114`8x(p>*G65v`g+uM3=Vv~%%jb&qkx)7w)!^?*X-#`l6^e>qwrDb8`IATpeKIvtOS+cLB;C5u-Es$L8pQ-=QaZ$KJP~ z0t<;j^8Uc!(4Z%ZfZSF;EV>Vv7y$d>Q;@bjDW_q6vSEhocMn4(dl^rTlmpc;XgvRV~0 z^!#va87q&xH{7maSw~Mv-87f{$S3D)>esIy74PSyxbgrKRn421H7agVA$;28e$ip{ zm@#$`Bnek_m3;jCI`PNBz5J%UezsA8sa4q(o>#oNCG9_D&!|3km6Bp6$wsF{pH1h) z)@Q-$rje%MadB}un^aU(sNR3P-Q3<%OK(SEw{$-{&D#IFib`}kci0<^D?mT!vLIW^ z*BkwZwurt@q(w0$DFCAgn>BD$8^QK$Qon!7U4*g-1a!Z?a>M|u)`|r??m)5y9PjI= zk9b%IJmbjW!#>(xza>uH@|V~ZRIpHfvbEJ_{*7T?otqL&bg-;L9x z9&i3E6rW`p@z znn-5m!b%q7Ih_Tf1*aLnknFVF@})}=!cciN>CTHZ&15=^=H*RgotiyCSDhKVh#bV$tq+ zXZ;{h=il}kAglu@?TSBtejDz+M~`G1(O;S!1Y&X`G<3%3esy1n9ep~tM?Y$O0G z&YvCriZT+1p@1Ti!Xgvp;v>|r8jGqr?_>~UHh^-oeZ{}r{{OPFUtm?!fAAM z*((f^D0@9R%@kgr^vFM@}L=^#(ObjETxfZzJ<;Wd?qX*VX<1zXP}YAOpg;B)26WS-Fz?ZtZBj{leG_Lfd1}+lNsRV;+vv zSL#t--|7Ad>tfjGFnSAImgq21%SfE87TiJ-CPq&=*0_gsJAf-8@DQ#531a}X$EmB& zpim|6ND@@cIKp!ZJ*_wwTgZB-0|396Sg0`BN9xrX_KJ~xx~I1(gXHB8Fosb3fUcks z?zQQ>woV9rijWs|h@ml(v$8QdqxipN=_)kHTMC|0J59AxGxWG7o)7hQhZr;)-Rbq3 zH*%^^;U#ld*9p+ahe(`*wYAT5pIud$CXiMJb$c@$E6{E@t|KQ;b|#vEvKHE&^yJu$ z#^1k%HWTeA)v6x3=J;p}G2|+o6p^XW4$Pz^L+)_)&)>u!4@is5wbGh2TKFbWa03g< zrB3905yp)^I&P0HJovx3FojrwnLtYB+V9t8jxK!w`xl2+!#7tPtQe!%)^bs;&eTmR zJ){L#$jh|{p>O7##y*;KuRLzg+9MPtKI(fJhT}+dIazi4-I=*Tk5&ryR(6Hzm@#`O ztxZgNfLv{$5D?NSYOHY@8kL%-BGi$xcsy9YNINi4`}QSq$k~tg({oJX`_-1HZX3ep zm>ttwLBaUF%Jy)e6?lWE#Wx(2{uA~ULjx3b&gkF`yjxU3noT6WEaKbykDk~+`n>%< zYD$+vA(G>T;r}^ANl9WKCkx*54x>Ku1@pD)aPuJ^rH#UIiGC~HRl}qGKSV}+Z87DU zI88CQzNV()ty73uRQm@TZ38NzLl=z{&91)nC#$OnGbsz|YnPJH!jkzvdu3P-s?!SUyB2!I|63eBSu z%iqir?K7Q^9Dpp?8^5Ko&1z^bg6Ik|=IOV+`}H%(%o}z0_3J*olh2<&>#Vg1bN%$4 z0#}HV$er}Xkb*VL&CBihlUJ0SYNK*YsNH0>J#Q9v@65xdgEHg@Ly=Al$QK&QT%`UaiiMfx7U|h-)~eFM&Wa2&DX!C IXS?hF0TGL}EdT%j literal 0 HcmV?d00001 From b5ba6d6275ac4fe784a045fbdcf1736582133902 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 22 May 2024 17:46:40 -0700 Subject: [PATCH 145/244] fix Signed-off-by: Pawel Gadzinski --- .../common/transpose/cast_transpose_fusion.cu | 34 ------------------- 1 file changed, 34 deletions(-) diff --git a/transformer_engine/common/transpose/cast_transpose_fusion.cu b/transformer_engine/common/transpose/cast_transpose_fusion.cu index 40618e7171..07fd3760b0 100644 --- a/transformer_engine/common/transpose/cast_transpose_fusion.cu +++ b/transformer_engine/common/transpose/cast_transpose_fusion.cu @@ -1768,37 +1768,3 @@ void nvte_dqgeglu_cast_transpose(const NVTETensor input, reinterpret_cast(transposed_output), stream); } - -void nvte_cast_transpose_dbias_dswish(const NVTETensor input, - const NVTETensor swish_input, - NVTETensor cast_output, - NVTETensor transposed_output, - NVTETensor dbias, - NVTETensor workspace, - cudaStream_t stream) { - NVTE_API_CALL(nvte_cast_transpose_dbias_dswish); - using namespace transformer_engine; - cast_transpose_dbias_dgelu>( - *reinterpret_cast(input), - *reinterpret_cast(swish_input), - reinterpret_cast(cast_output), - reinterpret_cast(transposed_output), - reinterpret_cast(dbias), - reinterpret_cast(workspace), - stream); -} - -void nvte_dswiglu_cast_transpose(const NVTETensor input, - const NVTETensor swiglu_input, - NVTETensor cast_output, - NVTETensor transposed_output, - cudaStream_t stream) { - NVTE_API_CALL(nvte_dswiglu_cast_transpose); - using namespace transformer_engine; - dgeglu_cast_transpose, swish>( - *reinterpret_cast(input), - *reinterpret_cast(swiglu_input), - reinterpret_cast(cast_output), - reinterpret_cast(transposed_output), - stream); -} From fcfda2c38739c6c699b8b132783fc8416e2bfa02 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 22 May 2024 17:48:15 -0700 Subject: [PATCH 146/244] fix Signed-off-by: Pawel Gadzinski --- .../include/transformer_engine/transpose.h | 47 ------------------- 1 file changed, 47 deletions(-) diff --git a/transformer_engine/common/include/transformer_engine/transpose.h b/transformer_engine/common/include/transformer_engine/transpose.h index 2f77738466..0d55be5d40 100644 --- a/transformer_engine/common/include/transformer_engine/transpose.h +++ b/transformer_engine/common/include/transformer_engine/transpose.h @@ -222,53 +222,6 @@ void nvte_dsreglu_cast_transpose(const NVTETensor input, NVTETensor transposed_output, cudaStream_t stream); -/*! \brief Compute backward of SiLU operation on the input, then cast and transpose. Additionally, - * reduce the result of the SiLU backward along the first dimension. - * - * This function produces 3 results: - * - `cast_output` is equal to `cast(dSiLU(input))` - * - `transposed_output` is equal to `transpose(cast(dSiLU(input)))` - * - `dbias` is equal to `reduce(dSiLU(input), axis=0)` - * - * Calling this function with workspace being an empty tensor will not perform the operation, - * but instead set the shape and type of the workspace tensor to the required values. - * - * \param[in] input Input tensor of shape [N, H]. - * \param[in] swish_input Tensor used as input to the forward of SiLU operation. - * Shape [N, H]. - * \param[in,out] cast_output Result of the cast. Shape: [N, H]. - * \param[in,out] transposed_output Result of the cast and transpose. Shape: [H, N]. - * \param[out] dbias Result of the reduction of the dSiLU(input) along the - * first dimension. Shape: [H]. - * \param[out] workspace Workspace tensor. - * \param[in] stream CUDA stream used for the operation. - */ -void nvte_cast_transpose_dbias_dswish(const NVTETensor input, - const NVTETensor swish_input, - NVTETensor cast_output, - NVTETensor transposed_output, - NVTETensor dbias, - NVTETensor workspace, - cudaStream_t stream); - -/*! \brief Compute dswiglu of the input, additionally does cast and transpose the dswiglu output. - * - * This function produces 2 results: - * - `cast_output` is the result of the cast - * - `transposed_output` is the transposed result of the cast. - * - * \param[in] input Input tensor of shape [N, H]. - * \param[in] swiglu_input Tensor used as input to the forward of SwiGLU operation. - * Shape [N, H * 2]. - * \param[in,out] cast_output Result of the cast. Shape: [N, H * 2]. - * \param[in,out] transposed_output Result of the cast and transpose. Shape: [H * 2, N]. - * \param[in] stream CUDA stream used for the operation. - */ -void nvte_dswiglu_cast_transpose(const NVTETensor input, - const NVTETensor swiglu_input, - NVTETensor cast_output, - NVTETensor transposed_output, - cudaStream_t stream); #ifdef __cplusplus } // extern "C" #endif From 1d7c997e3c5c5273f020565d382bae465549ff7f Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 22 May 2024 17:49:10 -0700 Subject: [PATCH 147/244] fix Signed-off-by: Pawel Gadzinski --- transformer_engine/jax/mlp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/jax/mlp.py b/transformer_engine/jax/mlp.py index 6f17d9258a..946a9e289a 100644 --- a/transformer_engine/jax/mlp.py +++ b/transformer_engine/jax/mlp.py @@ -15,7 +15,7 @@ from .cpp_extensions import dact_lu_dbias_cast_transpose, dgated_act_lu_cast_transpose from .cpp_extensions import rmsnorm_fwd_fp8, rmsnorm_bwd from .cpp_extensions import layernorm_fwd_fp8, layernorm_bwd -from .dot import fp8_dot_impl, get_precision_of_fp8_dot, quantize, dequantize +from .dot import fp8_dot_impl, get_precision_of_fp8_dot, quantize from .layernorm import canonicalize_layernorm_type from .fp8 import FP8Helper, FP8MetaPackage from .sharding import with_sharding_constraint_by_logical_axes From 117f2f94adccc6960ebce63f6045bc40a0b47015 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 22 May 2024 17:55:28 -0700 Subject: [PATCH 148/244] fix Signed-off-by: Pawel Gadzinski --- .../pytorch/csrc/extensions/attention.cu | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index 6ef10e6b67..76a849d642 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -2293,3 +2293,35 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, return output; } + +// Kernel used to update KV chache when attention layout is "thd". +extern "C" +__global__ void attention_copy_kernel( + __nv_bfloat16* cache_tensor, + int* seq_len, + int* incoming_seq_len, + __nv_bfloat16* hidden_tensor, + int max_incoming_seq_len, + int max_seq_len, + int b, + int s + ) { + for(int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { + int to_copy = s * incoming_seq_len[batch_idx]; + int offset = seq_len[batch_idx]; + + __nv_bfloat16* begin_cache_copy = cache_tensor + max_seq_len * s * batch_idx + s * offset; + __nv_bfloat16* begin_hidden_copy = hidden_tensor + s * batch_idx * max_incoming_seq_len; + + for(int i = threadIdx.x; i < to_copy; i += blockDim.x) { + *(begin_cache_copy + i) = *(begin_hidden_copy + i); + } + } +} + +void attention_copy(torch::Tensor A, torch::Tensor seq_len, torch::Tensor incoming_seq_len, torch::Tensor B, int max_incoming_seq_len, int max_seq_len, int b, int s) { + attention_copy_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>(reinterpret_cast<__nv_bfloat16*>(A.data_ptr()), + seq_len.data_ptr(), + incoming_seq_len.data_ptr(), + reinterpret_cast<__nv_bfloat16*>(B.data_ptr()), max_incoming_seq_len, max_seq_len, b, s); +} From c65eee73cce1ae12731bd393140b6a828f62b786 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 22 May 2024 18:47:52 -0700 Subject: [PATCH 149/244] git fix Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/attention.py | 495 +++++++----------- .../pytorch/cpp_extensions/fused_attn.py | 6 + 2 files changed, 206 insertions(+), 295 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 231a40337f..7308f1812c 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -5,14 +5,14 @@ """Attention.""" import collections from contextlib import nullcontext -from importlib.metadata import version as get_pkg_version +from importlib.metadata import version import math import os from typing import Any, Callable, Dict, List, Optional, Tuple, Union import warnings import numpy as np -from packaging.version import Version as PkgVersion +from pkg_resources import packaging import torch import torch.nn.functional as F @@ -67,13 +67,13 @@ from transformer_engine.pytorch.graph import is_graph_capturing -_flash_attn_version = PkgVersion(get_pkg_version("flash-attn")) -_flash_attn_version_required = PkgVersion("2.0.6") -_flash_attn_max_version = PkgVersion("2.5.8") -_flash_attn_2_1_plus = _flash_attn_version >= PkgVersion("2.1") -_flash_attn_2_3_plus = _flash_attn_version >= PkgVersion("2.3") -_flash_attn_2_4_plus = _flash_attn_version >= PkgVersion("2.4") -_flash_attn_2_4_1_plus = _flash_attn_version >= PkgVersion("2.4.1") +_flash_attn_version = packaging.version.Version(version("flash-attn")) +_flash_attn_version_required = packaging.version.Version("2.0.6") +_flash_attn_max_version = packaging.version.Version("2.5.8") +_flash_attn_2_1_plus = _flash_attn_version >= packaging.version.Version("2.1") +_flash_attn_2_3_plus = _flash_attn_version >= packaging.version.Version("2.3") +_flash_attn_2_4_plus = _flash_attn_version >= packaging.version.Version("2.4") +_flash_attn_2_4_1_plus = _flash_attn_version >= packaging.version.Version("2.4.1") if _flash_attn_version >= _flash_attn_version_required: from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_forward_func # pylint: disable=no-name-in-module @@ -121,7 +121,7 @@ def __init__(self, max_batch_size, max_sequence_length, qkv_format="bshd"): self.max_batch_size = max_batch_size self.key_value_memory_dict = {} self.qkv_format = qkv_format - + if qkv_format == "thd": self.seq_len = torch.empty((max_batch_size,), device="cuda", dtype=torch.int32) self.incoming_seq_len = torch.empty((max_batch_size,), device="cuda", dtype=torch.int32) @@ -153,8 +153,8 @@ def swap_key_value_dict(self, batch_indices): new_inference_key_memory, new_inference_value_memory, ) - - + + def thd_setup_before_new_input(self, new_input, reset=False, pad_token_id=None): """ After every context/generation phase, the parameters representing @@ -167,11 +167,11 @@ def thd_setup_before_new_input(self, new_input, reset=False, pad_token_id=None): new_input: torch.Tensor Tensor with token_ids (not embeddings!) on which we want to do next forward pass. reset: int - If reset=True, all previous sequence lengths will be set to 0. - It is supposed to be used after last generation phase to + If reset=True, all previous sequence lengths will be set to 0. + It is supposed to be used after last generation phase to allow inference_params to be reused. pad_token_id: int - Value of padding token - used to compute sequence_lengths. If pad_token_id=None, + Value of padding token - used to compute sequence_lengths. If pad_token_id=None, we assume that all new_input sequence lengths are equal to the corresponding dimension of new_input. """ @@ -179,18 +179,14 @@ def thd_setup_before_new_input(self, new_input, reset=False, pad_token_id=None): self.seq_len.copy_(self.seq_len + self.incoming_seq_len) if pad_token_id is not None: - self.incoming_seq_len.copy_( - torch.sum(new_input.ne(pad_token_id), dim=-1, dtype=torch.int32).squeeze() - ) + self.incoming_seq_len.copy_(torch.sum(new_input.ne(pad_token_id), dim=-1, dtype=torch.int32).squeeze()) else: - self.incoming_seq_len.copy_( - torch.ones(new_input.shape[0], device="cuda") * new_input.shape[1] - ) + self.incoming_seq_len.copy_(torch.ones(new_input.shape[0], device="cuda") * new_input.shape[1]) self.max_incoming_seq_len = new_input.shape[1] if reset: self.seq_len.copy_(torch.zeros_like(self.seq_len)) - + def save_new_key_and_value_layer(self, layer_number, key_layer, value_layer): """ Saves key_layer and value_layer in the cache. @@ -201,27 +197,26 @@ def save_new_key_and_value_layer(self, layer_number, key_layer, value_layer): batch_size = key_layer.shape[0] channels = inference_key_memory.shape[2] * inference_key_memory.shape[3] # h * d tex.attention_copy( - inference_key_memory, - self.seq_len, + inference_key_memory, + self.seq_len, self.incoming_seq_len, - key_layer, + key_layer, self.max_incoming_seq_len, - self.max_sequence_length, + self.max_sequence_length, batch_size, channels) - + tex.attention_copy( - inference_value_memory, - self.seq_len, + inference_value_memory, + self.seq_len, self.incoming_seq_len, - value_layer, + value_layer, self.max_incoming_seq_len, - self.max_sequence_length, + self.max_sequence_length, batch_size, channels) else: - assert self.qkv_format in ["bshd", "sbhd"], \ - "Attention format not supported by the inference." + assert self.qkv_format in ["bshd", "sbhd"], "Attention format not supported by the inference." batch_start = self.batch_size_offset batch_end = batch_start + key_layer.size(1) assert batch_end <= inference_key_memory.size(1) @@ -237,8 +232,8 @@ def save_new_key_and_value_layer(self, layer_number, key_layer, value_layer): sequence_start:sequence_end, batch_start:batch_end, ...] = value_layer key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...] value_layer = inference_value_memory[:sequence_end, batch_start:batch_end, ...] - return key_layer, value_layer - + return key_layer, value_layer + @torch.no_grad() @@ -318,7 +313,7 @@ def get_cu_seqlens(mask: torch.Tensor) -> torch.Tensor: the samples in a batch. """ mask = mask.squeeze(1).squeeze(1) - reduced_mask = mask.logical_not().sum(dim=1) + reduced_mask = mask.sum(dim=1) cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32) zero = torch.zeros(1, dtype=torch.int32, device="cuda") cu_seqlens = torch.cat((zero, cu_seqlens)) @@ -336,13 +331,13 @@ def get_cu_seqlens_and_indices(mask: torch.Tensor) -> Tuple[torch.Tensor, torch. mask = mask.squeeze(1).squeeze(1) bs, seqlen = mask.shape - reduced_mask = mask.logical_not().sum(dim=1) + reduced_mask = mask.sum(dim=1) cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32) zero = torch.zeros(1, dtype=torch.int32, device="cuda") cu_seqlens = torch.cat((zero, cu_seqlens)) mask = mask.reshape(-1) - indices = mask.logical_not().nonzero() + indices = mask.nonzero() indices = indices.unsqueeze(-1) num_nonzeros = indices.shape[0] @@ -502,7 +497,7 @@ def forward( *tensors: Tuple[torch.Tensor, ...] ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: assert 1 <= len(tensors) <= 3, f"Packing {len(tensors)} tensors not supported." - ctx.save_for_backward(indices) + ctx.indices = indices ctx.dim0 = tensors[0].shape[0] if len(tensors) == 1: return pack_tensor(indices, *tensors) @@ -512,12 +507,11 @@ def forward( @staticmethod def backward(ctx, *grad_outputs: Tuple[torch.Tensor, ...]): - (indices,) = ctx.saved_tensors if len(grad_outputs) == 1: - return None, unpack_tensor(indices, ctx.dim0, *grad_outputs) + return None, unpack_tensor(ctx.indices, ctx.dim0, *grad_outputs) if len(grad_outputs) == 2: - return None, *unpack_2_tensors(indices, ctx.dim0, *grad_outputs) - return None, *unpack_3_tensors(indices, ctx.dim0, *grad_outputs) + return None, *unpack_2_tensors(ctx.indices, ctx.dim0, *grad_outputs) + return None, *unpack_3_tensors(ctx.indices, ctx.dim0, *grad_outputs) class UnpackTensor(torch.autograd.Function): @@ -531,13 +525,12 @@ def forward( dim0: int, tensor: torch.Tensor, ) -> torch.Tensor: - ctx.save_for_backward(indices) + ctx.indices = indices return unpack_tensor(indices, dim0, tensor) @staticmethod def backward(ctx, grad_output): - (indices,) = ctx.saved_tensors - return None, None, pack_tensor(indices, grad_output) + return None, None, pack_tensor(ctx.indices, grad_output) def flash_attn_p2p_communicate(rank, send_tensor, send_dst, @@ -772,13 +765,8 @@ def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, else: # [b, 2, sq//2, np, hn] -> [b*sq, np, hn] q_inputs[i%2] = q.view(-1, *q.shape[-2:]) - if qkv_format == "thd": - # [2, t, np, hn] -> [2, t/2, np, hn] - kv_inputs[i%2] = tex.thd_read_half_tensor( - kv_inputs[i%2], cu_seqlens_k, 0) - else: - # [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn] - kv_inputs[i%2] = kv_inputs[i%2][:, :, 0, ...].contiguous() + # [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn] + kv_inputs[i%2] = kv_inputs[i%2][:, :, 0, ...].contiguous() # [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn] kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:]) if _flash_attn_2_3_plus: @@ -824,13 +812,8 @@ def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, if len(rest) > 0: attn_biases[i] = rest[0] else: - if qkv_format == "thd": - # [t, np, hn] -> [t/2, np, hn] - q_inputs[i%2] = tex.thd_read_half_tensor(q, cu_seqlens_q, 1) - else: - # [b, 2, sq//2, np, hn]->[b, sq//2, np, hn]->[b*sq//2, np, hn] - q_inputs[i%2] = \ - q[:, 1, ...].contiguous().view(-1, *q.shape[-2:]) + # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn] + q_inputs[i%2] = q[:, 1, ...].contiguous().view(-1, *q.shape[-2:]) # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn] kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:]) if _flash_attn_2_3_plus: @@ -888,7 +871,7 @@ def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, if i == 1: out = torch.empty_like(q).zero_() softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double) - if causal and qkv_format != "thd": + if causal: # [b, np, sq] -> [b, np, 2, sq//2] softmax_lse_ = softmax_lse.view( *softmax_lse.shape[:-1], 2, softmax_lse.shape[-1]//2 @@ -897,14 +880,8 @@ def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, flash_attn_fwd_softmax_lse_correction(softmax_lse, softmax_lse_per_step[i-1]) else: - if qkv_format == "thd": - tex.thd_second_half_lse_correction(softmax_lse, - softmax_lse_per_step[i-1], - cu_seqlens_q, - q.size(0)) - else: - flash_attn_fwd_softmax_lse_correction(softmax_lse_[..., 1, :], - softmax_lse_per_step[i-1]) + flash_attn_fwd_softmax_lse_correction(softmax_lse_[..., 1, :], + softmax_lse_per_step[i-1]) if i < cp_size: flash_attn_streams[(i-1)%2].record_event(fwd_results_correction_done) @@ -912,8 +889,7 @@ def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, torch.cuda.current_stream().wait_stream(flash_attn_streams[1]) softmax_lse = softmax_lse.to(torch.float) - if qkv_format in ["bshd", "sbhd"]: - seq_dim = qkv_format.index("s") + seq_dim = qkv_format.index("s") for i in range(cp_size): if qkv_format == "bshd": out_per_step[i] = out_per_step[i].view(out.shape[0], -1, *out.shape[-2:]) @@ -921,39 +897,18 @@ def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, elif qkv_format == "sbhd": out_per_step[i] = out_per_step[i].view(-1, *out.shape[-3:]) out_ = out[1] - if i <= rank or not causal: - if qkv_format in ["bshd", "sbhd"]: - flash_attn_fwd_out_correction(out.view(*out_per_step[i].shape), - out_per_step[i], - seq_dim, - softmax_lse, - softmax_lse_per_step[i]) - elif qkv_format == "thd": - tex.thd_out_correction(out, - out_per_step[i], - softmax_lse, - softmax_lse_per_step[i], - cu_seqlens_q, - False) - else: - assert False, f"{qkv_format} is an unsupported qkv_format!" + flash_attn_fwd_out_correction(out.view(*out_per_step[i].shape), + out_per_step[i], + seq_dim, + softmax_lse, + softmax_lse_per_step[i]) else: - if qkv_format in ["bshd", "sbhd"]: - flash_attn_fwd_out_correction(out_, - out_per_step[i], - seq_dim, - softmax_lse_[..., 1, :], - softmax_lse_per_step[i]) - elif qkv_format == "thd": - tex.thd_out_correction(out, - out_per_step[i], - softmax_lse, - softmax_lse_per_step[i], - cu_seqlens_q, - True) - else: - assert False, f"{qkv_format} is an unsupported qkv_format!" + flash_attn_fwd_out_correction(out_, + out_per_step[i], + seq_dim, + softmax_lse_[..., 1, :], + softmax_lse_per_step[i]) kv = p2p_comm_buffers[-1] if use_fused_attention: @@ -963,9 +918,8 @@ def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, out = out.view(-1, *out.shape[-3:]) else: out = out.view(-1, *out.shape[-2:]) - - ctx.save_for_backward(q, kv, out, softmax_lse, - cu_seqlens_q, cu_seqlens_k, *rng_states, *attn_biases) + ctx.save_for_backward(q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k) + ctx.rng_states = rng_states ctx.cp_group = cp_group ctx.cp_global_ranks = cp_global_ranks ctx.dropout_p = dropout_p @@ -976,17 +930,16 @@ def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, ctx.qkv_format = qkv_format ctx.attn_bias_type = attn_bias_type ctx.attn_bias_shape = None if attn_bias is None else attn_bias.shape + ctx.attn_biases = attn_biases ctx.deterministic = deterministic ctx.use_fused_attention = use_fused_attention return out @staticmethod def backward(ctx, dout): - (q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k) = ctx.saved_tensors[:6] - cp_size = get_distributed_world_size(ctx.cp_group) - rng_states = ctx.saved_tensors[6:6+cp_size] - attn_biases = ctx.saved_tensors[6+cp_size:6+cp_size*2] + q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors + cp_size = get_distributed_world_size(ctx.cp_group) rank = get_distributed_rank(ctx.cp_group) send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size] recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size] @@ -994,12 +947,12 @@ def backward(ctx, dout): qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format - if attn_biases[0] is not None: + if ctx.attn_biases[0] is not None: # [b, np, sq, 2*cp, sk//(2*cp)] attn_dbias = torch.zeros( *ctx.attn_bias_shape, - dtype=attn_biases[0].dtype, - device=attn_biases[0].device + dtype=ctx.attn_biases[0].dtype, + device=ctx.attn_biases[0].device ) # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, 2, sq//2, 2*cp, sk//(2*cp)] attn_dbias_ = attn_dbias.view( @@ -1009,17 +962,12 @@ def backward(ctx, dout): attn_dbias = None if ctx.causal: - if ctx.qkv_format == "thd": - softmax_lse_ = tex.thd_read_second_half_lse(softmax_lse, cu_seqlens_q, q.size(0)) - else: - # [b, np, sq] -> [b, np, 2, sq//2] - softmax_lse_ = \ - softmax_lse.view(*softmax_lse.shape[:-1], 2, softmax_lse.shape[-1]//2) - softmax_lse_ = softmax_lse_[..., 1, :].contiguous() - if ctx.use_fused_attention: - # [b, np, sq//2] -> [b, np, sq//2, 1] - softmax_lse_.unsqueeze_(-1) - + # [b, np, sq] -> [b, np, 2, sq//2] + softmax_lse_ = softmax_lse.view(*softmax_lse.shape[:-1], 2, softmax_lse.shape[-1]//2) + softmax_lse_ = softmax_lse_[..., 1, :].contiguous() + if ctx.use_fused_attention: + # [b, np, sq//2] -> [b, np, sq//2, 1] + softmax_lse_.unsqueeze_(-1) if ctx.use_fused_attention: # [b, np, sq] -> [b, np, sq, 1] softmax_lse.unsqueeze_(-1) @@ -1082,9 +1030,9 @@ def backward(ctx, dout): # [2, sq//2, b, np, hn] -> [sq, b, np, hn] out_ = out.view(-1, *out.shape[-3:]) dout_ = dout.view(-1, *dout.shape[-3:]) - aux_ctx_tensors = [softmax_lse, rng_states[cp_size-i-1]] + aux_ctx_tensors = [softmax_lse, ctx.rng_states[cp_size-i-1]] if attn_dbias is not None: - aux_ctx_tensors += [attn_biases[cp_size-i-1]] + aux_ctx_tensors += [ctx.attn_biases[cp_size-i-1]] dq_, dk_, dv_, dbias_ = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_k, cu_seqlens_q, cu_seqlens_k, @@ -1114,7 +1062,7 @@ def backward(ctx, dout): dq_, dkv_[0], dkv_[1], cu_seqlens_q, cu_seqlens_k, ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, True, - rng_state=rng_states[cp_size-i-1], + rng_state=ctx.rng_states[cp_size-i-1], **fa_optional_backward_kwargs ) elif i >= (cp_size-rank-1): @@ -1135,9 +1083,9 @@ def backward(ctx, dout): # [2, sq//2, b, np, hn] -> [sq, b, np, hn] out_ = out.view(-1, *out.shape[-3:]) dout_ = dout.view(-1, *dout.shape[-3:]) - aux_ctx_tensors = [softmax_lse, rng_states[cp_size-i-1]] + aux_ctx_tensors = [softmax_lse, ctx.rng_states[cp_size-i-1]] if attn_dbias is not None: - aux_ctx_tensors += [attn_biases[cp_size-i-1]] + aux_ctx_tensors += [ctx.attn_biases[cp_size-i-1]] dq_, dk_, dv_, dbias_ = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_k//2, cu_seqlens_q, cu_seqlens_k//2, @@ -1154,12 +1102,8 @@ def backward(ctx, dout): # [b, 2, sq//2, np, hn] -> [b*sq, np, hn] q_ = q.view(-1, *q.shape[-2:]) dq_ = torch.empty_like(q_) - if ctx.qkv_format == "thd": - # [2, t, np, hn] -> [2, t/2, np, hn] - kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_k, 0) - else: - # [2, b, 2, sk//2, np, hn]->[2, b, sk//2, np, hn]->[2, b*sk//2, np, hn] - kv_ = kv[:, :, 0, ...].contiguous().view(2, -1, *kv.shape[-2:]) + # [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn] + kv_ = kv[:, :, 0, ...].contiguous().view(2, -1, *kv.shape[-2:]) dkv_ = torch.empty_like(kv_) # [b, 2, sq//2, np, hn] -> [b*sq, np, hn] out_ = out.view(-1, *out.shape[-2:]) @@ -1171,7 +1115,7 @@ def backward(ctx, dout): dq_, dkv_[0], dkv_[1], cu_seqlens_q, cu_seqlens_k//2, ctx.max_seqlen_q, ctx.max_seqlen_k//2, ctx.dropout_p, ctx.softmax_scale, False, - rng_state=rng_states[cp_size-i-1], + rng_state=ctx.rng_states[cp_size-i-1], **fa_optional_backward_kwargs ) else: @@ -1192,9 +1136,9 @@ def backward(ctx, dout): # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] out_ = out[1].contiguous() dout_ = dout[1].contiguous() - aux_ctx_tensors = [softmax_lse_, rng_states[cp_size-i-1]] + aux_ctx_tensors = [softmax_lse_, ctx.rng_states[cp_size-i-1]] if attn_dbias is not None: - aux_ctx_tensors += [attn_biases[cp_size-i-1]] + aux_ctx_tensors += [ctx.attn_biases[cp_size-i-1]] dq_, dk_, dv_, dbias_ = fused_attn_bwd( ctx.max_seqlen_q//2, ctx.max_seqlen_k, cu_seqlens_q//2, cu_seqlens_k, @@ -1208,23 +1152,15 @@ def backward(ctx, dout): attn_bias_type=ctx.attn_bias_type, ) else: - if ctx.qkv_format == "thd": - # [t, np, hn] -> [t/2, np, hn] - q_ = tex.thd_read_half_tensor(q, cu_seqlens_q, 1) - else: - # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn] - q_ = q[:, 1, ...].contiguous().view(-1, *q.shape[-2:]) + # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn] + q_ = q[:, 1, ...].contiguous().view(-1, *q.shape[-2:]) dq_ = torch.empty_like(q_) # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn] kv_ = kv.view(2, -1, *kv.shape[-2:]) dkv_ = torch.empty_like(kv_) - if ctx.qkv_format == "thd": - out_ = tex.thd_read_half_tensor(out, cu_seqlens_q, 1) - dout_ = tex.thd_read_half_tensor(dout, cu_seqlens_q, 1) - else: - # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn] - out_ = out[:, 1, ...].contiguous().view(-1, *out.shape[-2:]) - dout_ = dout[:, 1, ...].contiguous().view(-1, *dout.shape[-2:]) + # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn] + out_ = out[:, 1, ...].contiguous().view(-1, *out.shape[-2:]) + dout_ = dout[:, 1, ...].contiguous().view(-1, *dout.shape[-2:]) if _flash_attn_2_3_plus: fa_optional_backward_kwargs["window_size"] = [-1, -1] _flash_attn_backward( @@ -1232,14 +1168,14 @@ def backward(ctx, dout): dq_, dkv_[0], dkv_[1], cu_seqlens_q//2, cu_seqlens_k, ctx.max_seqlen_q//2, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, False, - rng_state=rng_states[cp_size-i-1], + rng_state=ctx.rng_states[cp_size-i-1], **fa_optional_backward_kwargs ) else: if ctx.use_fused_attention: - aux_ctx_tensors = [softmax_lse, rng_states[cp_size-i-1]] + aux_ctx_tensors = [softmax_lse, ctx.rng_states[cp_size-i-1]] if attn_dbias is not None: - aux_ctx_tensors += [attn_biases[cp_size-i-1]] + aux_ctx_tensors += [ctx.attn_biases[cp_size-i-1]] dq_, dk_, dv_, dbias_ = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_k, cu_seqlens_q, cu_seqlens_k, @@ -1297,22 +1233,16 @@ def backward(ctx, dout): elif ctx.qkv_format == "sbhd": dq[0].copy_(dq_[0]) dq[1].add_(dq_[1]) - elif ctx.qkv_format == "thd": - tex.thd_grad_correction(dq, dq_, cu_seqlens_q, "copy", "add") elif i > 0: if ctx.qkv_format == "bshd": dq[:, 1, ...].add_(dq_) elif ctx.qkv_format == "sbhd": dq[1].add_(dq_) - elif ctx.qkv_format == "thd": - tex.thd_grad_correction(dq, dq_, cu_seqlens_q, "none", "add") else: if ctx.qkv_format == "bshd": dq[:, 1, ...].copy_(dq_) elif ctx.qkv_format == "sbhd": dq[1].copy_(dq_) - elif ctx.qkv_format == "thd": - tex.thd_grad_correction(dq, dq_, cu_seqlens_q, "none", "copy") else: if i == 0: dq.copy_(dq_) @@ -1363,8 +1293,6 @@ def backward(ctx, dout): elif ctx.qkv_format == "sbhd": dkv[:, 0, ...].add_(dkv_[:, 0, ...]) dkv[:, 1, ...].copy_(dkv_[:, 1, ...]) - elif ctx.qkv_format == "thd": - tex.thd_grad_correction(dkv, dkv_, cu_seqlens_k, "add", "copy") else: dkv.add_(dkv_) elif i >= (cp_size-rank-1): @@ -1373,15 +1301,11 @@ def backward(ctx, dout): dkv[:, :, 0, ...].copy_(dkv_) elif ctx.qkv_format == "sbhd": dkv[:, 0, ...].copy_(dkv_) - elif ctx.qkv_format == "thd": - tex.thd_grad_correction(dkv, dkv_, cu_seqlens_k, "copy", "none") else: if ctx.qkv_format == "bshd": dkv[:, :, 0, ...].add_(dkv_) elif ctx.qkv_format == "sbhd": dkv[:, 0, ...].add_(dkv_) - elif ctx.qkv_format == "thd": - tex.thd_grad_correction(dkv, dkv_, cu_seqlens_k, "add", "none") elif i > 0: dkv.add_(dkv_) else: @@ -1419,12 +1343,10 @@ def attn_forward_func_with_cp( use_fused_attention=False ) -> torch.Tensor: """Attention implementation with context parallelism""" - assert(qkv_format in ["bshd", "sbhd", "thd"] + assert(qkv_format in ["bshd", "sbhd"] ), f"QKV format of {qkv_format} is not supported with context parallelism!" assert(qkv_format != "sbhd" or use_fused_attention ), "FlashAttention does not support sbhd format!" - assert(not(qkv_format == "thd" and use_fused_attention) - ), "FusedAttention does not support thd format!" assert (attn_mask_type in ["causal", "no_mask"] ), f"Mask type of {attn_mask_type} is not supported with context parallelism!" assert (attn_bias is None or use_fused_attention @@ -1600,12 +1522,11 @@ def apply_rotary_pos_emb( Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and dtype torch.int32. Only valid when `tensor_format` is 'thd'. begins: torch.Tensor, default = None. - We may not want begin all the sequences from the 0 embedding. - This tensor argument allows that. + We may not want begin all the sequences from the 0 embedding. This tensor argument allows that. """ assert not (begins is not None and not fused), \ """begins != None and fused=False is not supported""" - + if fused: assert ( tensor_format != "thd" or cu_seqlens is not None @@ -2230,14 +2151,12 @@ def forward( key_layer.device, ) elif qkv_format == 'thd': - assert (cu_seqlens_q is not None and cu_seqlens_kv is not None - ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!" - if max_seqlen_q is None: - seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] - max_seqlen_q = seqlens_q.max().item() - if max_seqlen_kv is None: - seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] - max_seqlen_kv = seqlens_kv.max().item() + assert not context_parallel, "thd format not supported with context parallelism!" + assert (max_seqlen_q is not None + and max_seqlen_kv is not None + and cu_seqlens_q is not None + and cu_seqlens_kv is not None + ), "max_seqlen_q/kv and cu_seqlens_q/kv can not be None when qkv_format is thd!" if context_parallel: assert ( @@ -2360,6 +2279,7 @@ def forward(ctx, is_training, max_seqlen, cu_seqlens, out_fp8, aux_ctx_tensors = fused_attn_fwd_qkvpacked( is_training, max_seqlen, cu_seqlens, qkv_fp8, fp8_dtype_forward, fused_attention_backend, attn_bias, + seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, fp8_meta["scaling_fwd"].scale_inv[META_QKV], fp8_meta["scaling_fwd"].scale_inv[META_S], fp8_meta["scaling_fwd"].scale[META_S], @@ -2400,6 +2320,7 @@ def forward(ctx, is_training, max_seqlen, cu_seqlens, out_ret, aux_ctx_tensors = fused_attn_fwd_qkvpacked( is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype, fused_attention_backend, attn_bias, + seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, None, None, None, None, None, None, attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, rng_gen) @@ -2408,8 +2329,11 @@ def forward(ctx, is_training, max_seqlen, cu_seqlens, ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) qkvo_tensors = (qkv, out_save) if not ctx.fp8 else (None, None) - ctx.save_for_backward(*qkvo_tensors, cu_seqlens, *fp8_tensors, *aux_ctx_tensors) + ctx.save_for_backward(*qkvo_tensors, cu_seqlens, + seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, + *fp8_tensors, *aux_ctx_tensors) ctx.fp8_meta = fp8_meta + ctx.aux_ctx_tensors = aux_ctx_tensors ctx.max_seqlen = max_seqlen ctx.qkv_dtype = qkv_dtype ctx.attn_scale = attn_scale @@ -2433,12 +2357,14 @@ def backward(ctx, d_out): d_out = d_out._data d_out = d_out.contiguous() - (qkv, out, cu_seqlens, qkv_fp8, out_fp8, - fwd_scales, fwd_scale_invs, *aux_ctx_tensors) = ctx.saved_tensors - if not aux_ctx_tensors[0].is_contiguous(): - aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous() + (qkv, out, cu_seqlens, + seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, + qkv_fp8, out_fp8, + fwd_scales, fwd_scale_invs) = ctx.saved_tensors + if not ctx.aux_ctx_tensors[0].is_contiguous(): + ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous() if ctx.use_FAv2_bwd: - softmax_lse, rng_state = aux_ctx_tensors + softmax_lse, rng_state = ctx.aux_ctx_tensors dqkv = torch.empty_like(qkv) maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x d_out, q, k, v, out = [maybe_contiguous(x) @@ -2470,8 +2396,9 @@ def backward(ctx, d_out): dqkv_fp8, *rest = fused_attn_bwd_qkvpacked( ctx.max_seqlen, cu_seqlens, qkv_fp8, out_fp8, d_out_fp8, - fp8_dtype_forward, fp8_dtype_backward, aux_ctx_tensors, + fp8_dtype_forward, fp8_dtype_backward, ctx.aux_ctx_tensors, ctx.fused_attention_backend, + seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, fwd_scale_invs[META_QKV], # d_scale_qkv, fwd_scale_invs[META_S], # d_scale_s, fwd_scale_invs[META_O], # d_scale_o, @@ -2505,21 +2432,20 @@ def backward(ctx, d_out): d_out = d_out_f8tensor.from_float8(qkv.dtype) dqkv, *rest = fused_attn_bwd_qkvpacked( ctx.max_seqlen, cu_seqlens, qkv, out, d_out, - ctx.qkv_dtype, ctx.qkv_dtype, aux_ctx_tensors, + ctx.qkv_dtype, ctx.qkv_dtype, ctx.aux_ctx_tensors, ctx.fused_attention_backend, + seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, None, None, None, None, None, None, None, None, None, None, ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type) # if no_bias or alibi, return dqkv if ctx.attn_bias_type in ["no_bias", "alibi"]: - return (None, None, None, None, None, None,None, None, - None, None, dqkv, None, None, None, + return (None, None, None, None, None, None,None, None, None, None, dqkv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None) # else, return (dqkv, dbias) - return (None, None, None, None, None, None, None, None, - None, None, dqkv, None, rest[0], None, + return (None, None, None, None, None, None, None, None, None, None, dqkv, None, rest[0], None, None, None, None, None, None, None, None, None, None, None, None, None) @@ -2560,6 +2486,7 @@ def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seql out_fp8, aux_ctx_tensors = fused_attn_fwd_kvpacked( is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, q_fp8, kv_fp8, fp8_dtype_forward, fused_attention_backend, attn_bias, + seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, fp8_meta["scaling_fwd"].scale_inv[META_QKV], fp8_meta["scaling_fwd"].scale_inv[META_S], fp8_meta["scaling_fwd"].scale[META_S], @@ -2603,6 +2530,7 @@ def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seql out_ret, aux_ctx_tensors = fused_attn_fwd_kvpacked( is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, q, kv, qkv_dtype, fused_attention_backend, attn_bias, + seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, None, None, None, None, None, None, attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, rng_gen) @@ -2612,8 +2540,10 @@ def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seql ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) qkvo_tensors = (q, kv, out_save) if not ctx.fp8 else (None, None, None) ctx.save_for_backward(*qkvo_tensors, cu_seqlens_q, cu_seqlens_kv, + seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, *fp8_tensors, *aux_ctx_tensors) ctx.fp8_meta = fp8_meta + ctx.aux_ctx_tensors = aux_ctx_tensors ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_kv = max_seqlen_kv ctx.qkv_dtype = qkv_dtype @@ -2638,12 +2568,14 @@ def backward(ctx, d_out): d_out = d_out._data d_out = d_out.contiguous() - (q, kv, out, cu_seqlens_q, cu_seqlens_kv, q_fp8, kv_fp8, out_fp8, + (q, kv, out, cu_seqlens_q, cu_seqlens_kv, + seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, + q_fp8, kv_fp8, out_fp8, fwd_scales, fwd_scale_invs, *aux_ctx_tensors) = ctx.saved_tensors - if not aux_ctx_tensors[0].is_contiguous(): - aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous() + if not ctx.aux_ctx_tensors[0].is_contiguous(): + ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous() if ctx.use_FAv2_bwd: - softmax_lse, rng_state = aux_ctx_tensors + softmax_lse, rng_state = ctx.aux_ctx_tensors dq = torch.empty_like(q) dkv = torch.empty_like(kv) maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x @@ -2677,8 +2609,9 @@ def backward(ctx, d_out): dq_fp8, dkv_fp8, *rest = fused_attn_bwd_kvpacked( ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, q_fp8, kv_fp8, out_fp8, d_out_fp8, - fp8_dtype_forward, fp8_dtype_backward, aux_ctx_tensors, + fp8_dtype_forward, fp8_dtype_backward, ctx.aux_ctx_tensors, ctx.fused_attention_backend, + seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, fwd_scale_invs[META_QKV], # d_scale_qkv, fwd_scale_invs[META_S], # d_scale_s, fwd_scale_invs[META_O], # d_scale_o, @@ -2724,23 +2657,20 @@ def backward(ctx, d_out): dq, dkv, *rest = fused_attn_bwd_kvpacked( ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, q, kv, out, d_out, - ctx.qkv_dtype, ctx.qkv_dtype, aux_ctx_tensors, + ctx.qkv_dtype, ctx.qkv_dtype, ctx.aux_ctx_tensors, ctx.fused_attention_backend, + seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, None, None, None, None, None, None, None, None, None, None, ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type) # if no_bias or alibi, return dqkv if ctx.attn_bias_type in ["no_bias", "alibi"]: - return (None, None, None, None, None, None, - None, None, None, None, None, None, - dq, dkv, None, None, None, + return (None, None, None, None, None, None, None, None, None, None, None, None, dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None) # else, return (dqkv, dbias) - return (None, None, None, None, None, None, - None, None, None, None, None, None, - dq, dkv, None, rest[0], None, + return (None, None, None, None, None, None, None, None, None, None, None, None, dq, dkv, None, rest[0], None, None, None, None, None, None, None, None, None, None, None, None, None) @@ -2753,6 +2683,7 @@ def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seql q, k, v, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, rng_gen, fused_attention_backend, use_FAv2_bwd, fp8, fp8_meta): + if fp8: if _NVTE_DEBUG: print('[DotProductAttention]: using FP8 forward') @@ -2801,6 +2732,7 @@ def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seql out_fp8, aux_ctx_tensors = fused_attn_fwd( is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, q_fp8, k_fp8, v_fp8, fp8_dtype_forward, fused_attention_backend, attn_bias, + seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, fp8_meta["scaling_fwd"].scale_inv[META_QKV], fp8_meta["scaling_fwd"].scale_inv[META_S], fp8_meta["scaling_fwd"].scale[META_S], @@ -2872,6 +2804,7 @@ def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seql out_ret, aux_ctx_tensors = fused_attn_fwd( is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, q, k, v, qkv_dtype, fused_attention_backend, attn_bias, + seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, None, None, None, None, None, None, attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, rng_gen) @@ -2889,8 +2822,10 @@ def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seql ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None) ctx.save_for_backward(*qkvo_tensors, cu_seqlens_q, cu_seqlens_kv, + seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, *fp8_tensors, *aux_ctx_tensors) ctx.fp8_meta = fp8_meta + ctx.aux_ctx_tensors = aux_ctx_tensors ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_kv = max_seqlen_kv ctx.qkv_dtype = qkv_dtype @@ -2915,12 +2850,14 @@ def backward(ctx, d_out): d_out = d_out._data d_out = d_out.contiguous() - (q, k, v, out, cu_seqlens_q, cu_seqlens_kv, q_fp8, k_fp8, v_fp8, out_fp8, + (q, k, v, out, cu_seqlens_q, cu_seqlens_kv, + seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, + q_fp8, k_fp8, v_fp8, out_fp8, fwd_scales, fwd_scale_invs, *aux_ctx_tensors) = ctx.saved_tensors - if not aux_ctx_tensors[0].is_contiguous(): - aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous() + if not ctx.aux_ctx_tensors[0].is_contiguous(): + ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous() if ctx.use_FAv2_bwd: - softmax_lse, rng_state = aux_ctx_tensors + softmax_lse, rng_state = ctx.aux_ctx_tensors dq = torch.empty_like(q) dk = torch.empty_like(k) dv = torch.empty_like(v) @@ -2955,8 +2892,9 @@ def backward(ctx, d_out): dq_fp8, dk_fp8, dv_fp8, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, q_fp8, k_fp8, v_fp8, out_fp8, d_out_fp8, - fp8_dtype_forward, fp8_dtype_backward, aux_ctx_tensors, + fp8_dtype_forward, fp8_dtype_backward, ctx.aux_ctx_tensors, ctx.fused_attention_backend, + seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, fwd_scale_invs[META_QKV], # d_scale_qkv, fwd_scale_invs[META_S], # d_scale_s, fwd_scale_invs[META_O], # d_scale_o, @@ -3038,8 +2976,9 @@ def backward(ctx, d_out): dq, dk, dv, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, q, k, v, out, d_out, - ctx.qkv_dtype, ctx.qkv_dtype, aux_ctx_tensors, + ctx.qkv_dtype, ctx.qkv_dtype, ctx.aux_ctx_tensors, ctx.fused_attention_backend, + seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, None, None, None, None, None, None, None, None, None, None, ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type) @@ -3294,10 +3233,12 @@ def forward( f"""fp8_recipe.fp8_dpa={self.fp8_meta["recipe"].fp8_dpa}""" f"""{forced_fp8_dpa} and """ f"""NVTE_FP8_DPA_BWD={int(os.getenv("NVTE_FP8_DPA_BWD", "1"))}""") + output = FusedAttnFunc.apply( self.training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, + seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, query_layer, key_layer, value_layer, qkv_dtype, core_attention_bias, @@ -3314,6 +3255,8 @@ def forward( self.fp8_meta, ) + + # ...hd -> ...(hd) return output.view(*output.shape[:-2], -1) @@ -3518,12 +3461,11 @@ def __init__( self.unfused_attention = UnfusedDotProductAttention( norm_factor, **attn_kwargs, layer_number=layer_number) - + self._allocator = StaticBufferAllocator() def alloc(self, size, dtype, device): - """ Allocation of buffer, compatible with CUDA Graphs.""" return self._allocator(size, dtype, device) @@ -3654,9 +3596,7 @@ def forward( a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv] for cross-attention. For the 'arbitrary' mask type, it should be in a shape that is - broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. A `True` value - means the corresponding position is masked out and a `False` means that position is - allowed to participate in attention. + broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. qkv_format: str, default = `None` If provided, overrides :attr:`qkv_format` from initialization. cu_seqlens_q: Optional[torch.Tensor], default = `None` @@ -3763,21 +3703,13 @@ def forward( graph_safe_rng_available() ), "Upgrade PyTorch version to get RNG manipulation support for cuda graph capture." - if self.rng_states_tracker is not None and is_graph_capturing(): - assert ( - isinstance(self.rng_states_tracker, CudaRNGStatesTracker) - ), "Unsupported RNG states tracker." - assert ( - graph_safe_rng_available() - ), "Upgrade PyTorch version to get RNG manipulation support for cuda graph capture." - if window_size is None: window_size = self.window_size if qkv_format is None: qkv_format = self.qkv_format - + if inference_params is not None: assert self.layer_number is not None, "Layer number must be set!" @@ -3789,19 +3721,15 @@ def forward( ) = inference_params.key_value_memory_dict[self.layer_number] if qkv_format in ["bshd", "sbhd"]: - key_layer, value_layer = inference_params.save_new_key_and_value_layer( - self.layer_number, key_layer, value_layer - ) + key_layer, value_layer = inference_params.save_new_key_and_value_layer(self.layer_number, key_layer, value_layer) elif qkv_format == "thd": - inference_params.save_new_key_and_value_layer( - self.layer_number, key_layer, value_layer - ) + inference_params.save_new_key_and_value_layer(self.layer_number, key_layer, value_layer) """ We compute parameters needed by the THD attention with offsets. """ - batch_size = query_layer.shape[0] + batch_size = query_layer.shape[0] max_seqlen_q = inference_params.max_incoming_seq_len max_seqlen_kv = inference_params.max_sequence_length cu_seqlens_q = self.alloc(batch_size + 1, dtype=torch.int32, device="cuda") @@ -3812,35 +3740,17 @@ def forward( seq_offsets_o = self.alloc(batch_size + 1, dtype=torch.int32, device="cuda") cu_seqlens_q[1:].copy_(torch.cumsum(inference_params.incoming_seq_len, dim=0)) - cu_seqlens_kv[1:].copy_( - torch.cumsum(inference_params.seq_len + inference_params.incoming_seq_len, - dim=0) - ) + cu_seqlens_kv[1:].copy_(torch.cumsum(inference_params.seq_len + inference_params.incoming_seq_len, dim=0)) - seq_offsets_q.copy_( - torch.arange(0, batch_size + 1, dtype=torch.int32, device="cuda") - * self.channels * max_seqlen_q - ) + seq_offsets_q.copy_(torch.arange(0, batch_size + 1, dtype=torch.int32, device="cuda") * self.channels * max_seqlen_q) seq_offsets_o.copy_(seq_offsets_q) - seq_offsets_k.copy_( - torch.arange(0, batch_size + 1, dtype=torch.int32, device="cuda") - * self.channels * max_seqlen_kv - ) + seq_offsets_k.copy_(torch.arange(0, batch_size + 1, dtype=torch.int32, device="cuda") * self.channels * max_seqlen_kv) seq_offsets_v.copy_(seq_offsets_k) # qkv layers are reshaped to the format [t, h, d] - query_layer = query_layer.view( - -1, - query_layer.shape[2], - query_layer.shape[3]).to(torch.bfloat16) - key_layer = inference_key_memory.view( - -1, - inference_key_memory.shape[2], - inference_key_memory.shape[3]).to(torch.bfloat16) - value_layer = inference_value_memory.view( - -1, - inference_value_memory.shape[2], - inference_value_memory.shape[3]).to(torch.bfloat16) + query_layer = query_layer.view(-1, query_layer.shape[2], query_layer.shape[3]).to(torch.bfloat16) + key_layer = inference_key_memory.view(-1, inference_key_memory.shape[2], inference_key_memory.shape[3]).to(torch.bfloat16) + value_layer = inference_value_memory.view(-1, inference_value_memory.shape[2], inference_value_memory.shape[3]).to(torch.bfloat16) if qkv_format == "bshd": @@ -3848,7 +3758,7 @@ def forward( value_layer = value_layer.transpose(0, 1) key_layer = key_layer.contiguous() value_layer = value_layer.contiguous() - + assert (key_layer.shape[-2] == self.num_gqa_groups_per_partition and value_layer.shape[-2] == self.num_gqa_groups_per_partition ), f"Keys and values must have num_gqa_group = {self.num_gqa_groups} heads!" @@ -3965,7 +3875,7 @@ def forward( use_fused_attention = False if (not _flash_attn_2_3_plus) or context_parallel: use_flash_attention = False - + # Filter: Attention mask type. @@ -4086,7 +3996,7 @@ def forward( and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]): if self.device_compute_capability == (9, 0): use_flash_attention = False - + if self.qkv_format == "thd": use_flash_attention = False use_fused_attention = True @@ -4129,6 +4039,10 @@ def forward( cu_seqlens_kv=cu_seqlens_kv, max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv, + seq_offsets_q=seq_offsets_q, + seq_offsets_k=seq_offsets_k, + seq_offsets_v=seq_offsets_v, + seq_offsets_o=seq_offsets_o, attn_mask_type=attn_mask_type, attention_mask=attention_mask, fused_attention_backend=fused_attention_backend, @@ -4139,7 +4053,7 @@ def forward( cp_global_ranks=self.cp_global_ranks, cp_stream=self.cp_stream, is_first_microbatch=is_first_microbatch) - return self.fused_attention( + out = self.fused_attention( query_layer, key_layer, value_layer, @@ -4148,6 +4062,10 @@ def forward( cu_seqlens_kv=cu_seqlens_kv, max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv, + seq_offsets_q=seq_offsets_q, + seq_offsets_k=seq_offsets_k, + seq_offsets_v=seq_offsets_v, + seq_offsets_o=seq_offsets_o, attn_mask_type=attn_mask_type, attention_mask=attention_mask, fused_attention_backend=fused_attention_backend, @@ -4158,6 +4076,13 @@ def forward( cp_global_ranks=self.cp_global_ranks, cp_stream=self.cp_stream, is_first_microbatch=is_first_microbatch) + if qkv_format == "thd": + out = out.unsqueeze(1) + if q_size > 1: + out = out.view((batch_size, -1, out.shape[2])).contiguous() + + + return out assert (not context_parallel), \ "Context parallelism is only implemented with Flash Attention and Fused Attention!" @@ -4390,8 +4315,7 @@ def __init__( self.num_attention_heads = num_attention_heads self.return_bias = return_bias - self.attention_hidden_size = attention_hidden_size if attention_hidden_size \ - else (hidden_size // num_attention_heads) + self.attention_hidden_size = attention_hidden_size if attention_hidden_size else (hidden_size // num_attention_heads) if init_method is None: init_method = get_default_init_method() @@ -4561,9 +4485,6 @@ def _allocate_memory( ) def alloc(self, size, dtype, device): - """ - Allocation of the buffer compatible with CUDA Graphs. - """ return self._allocator(size, dtype, device) def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None: @@ -4638,9 +4559,7 @@ def forward( a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv] for cross-attention. For the 'arbitrary' mask type, it should be in a shape that is - broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. A `True` value - means the corresponding position is masked out and a `False` means that position is - allowed to participate in attention. + broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'arbitrary'}, default = `None` type of attention mask passed into softmax operation. @@ -4755,7 +4674,7 @@ def forward( ) num_queries_per_key_value = (self.num_attention_heads_per_partition // self.num_gqa_groups_per_partition) - + if self.qkv_weight_interleaved: # [sq, b, ng * (np/ng + 2) * hn] --> [sq, b, ng, (np/ng + 2), hn] new_tensor_shape = mixed_x_layer.size()[:-1] + ( @@ -4876,29 +4795,21 @@ def forward( rotary_pos_emb = ((rotary_pos_emb,) * 2) q_pos_emb, k_pos_emb = rotary_pos_emb - + if self.qkv_format == "thd" and inference_params is not None: # For thd attention incoming tokens can be on different positions, # so we need to copy different positional encoding freqency # for every sequence in a batch. # # For example if sequence lengths in context phase are: 2 and 5 (batch size=2), - # in first generation phase key_layer have shape [2, 1, d]. + # in first generation phase key_layer have shape [2, 1, d]. # key_layer[0, :] corresponds to the token with position 3 = 2 + 1, # and key_layer [1, :] corresponds to the token with position 6 = 5 + 1. key_layer = key_layer.contiguous() query_layer = query_layer.contiguous() - key_layer.copy_( - apply_rotary_pos_emb( - key_layer, k_pos_emb, "bshd", fused=True, begins=inference_params.seq_len - ) - ) - query_layer.copy_( - apply_rotary_pos_emb( - query_layer, q_pos_emb, "bshd", fused=True, begins=inference_params.seq_len - ) - ) + key_layer.copy_(apply_rotary_pos_emb(key_layer, k_pos_emb, "bshd", fused=True, begins=inference_params.seq_len)) + query_layer.copy_(apply_rotary_pos_emb(query_layer, q_pos_emb, "bshd", fused=True, begins=inference_params.seq_len)) else: # adjust key and value for inference if inference_params is not None: @@ -4909,16 +4820,12 @@ def forward( sequence_start = inference_params.sequence_len_offset sequence_end = sequence_start + sequence_length - + q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...] k_pos_emb = k_pos_emb[sequence_start:sequence_end, ...] - query_layer = apply_rotary_pos_emb( - query_layer, q_pos_emb, self.qkv_format, fused=True - ) - key_layer = apply_rotary_pos_emb( - key_layer, k_pos_emb, self.qkv_format, fused=True - ) + query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb, self.qkv_format, fused=True) + key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb, self.qkv_format, fused=True) query_layer = query_layer.contiguous() key_layer = key_layer.contiguous() @@ -4969,16 +4876,14 @@ def forward( class StaticBufferAllocator(torch.nn.Module): """ - This class is used when we use te.make_graphed_callable(). - CUDA Graphs require all tensors to be static. Neverthless, + This class is used when we use te.make_graphed_callable(). + CUDA Graphs require all tensors to be static. Neverthless, torch API make_graphed_callable() takes care of output of torch modules, and makes them static. Thus by wrapping allocation of memory into torch.nn.Module, we can greatly simplify our code. """ - - @staticmethod - def forward(size, dtype, device): - """ - Allocate the buffers. - """ - return torch.zeros(size, dtype=dtype, device=device) + def __init__(self): + super().__init__() + + def forward(self, size, dtype, device): + return torch.zeros(size, dtype=dtype, device=device) \ No newline at end of file diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index b16b4d8355..8b7299670d 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -237,6 +237,7 @@ def fused_attn_fwd_qkvpacked( max_seqlen, is_training, attn_scale, dropout, fast_zero_fill, QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], cu_seqlens, qkv, qkv_dtype, + seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, d_scale_qkv, d_scale_s, q_scale_s, q_scale_o, amax_s, amax_o, attn_bias, rng_gen, rng_elts_per_thread, ) @@ -385,6 +386,7 @@ def fused_attn_bwd_qkvpacked( max_seqlen, attn_scale, dropout, fast_zero_fill, QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], cu_seqlens, qkv, o, d_o, qkv_dtype, dqkv_dtype, aux_ctx_tensors, + seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, d_scale_dp, q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv, ) @@ -565,6 +567,7 @@ def fused_attn_fwd_kvpacked( max_seqlen_q, max_seqlen_kv, is_training, attn_scale, dropout, fast_zero_fill, QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], cu_seqlens_q, cu_seqlens_kv, q, kv, qkv_dtype, + seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, d_scale_qkv, d_scale_s, q_scale_s, q_scale_o, amax_s, amax_o, attn_bias, rng_gen, rng_elts_per_thread, ) @@ -727,6 +730,7 @@ def fused_attn_bwd_kvpacked( max_seqlen_q, max_seqlen_kv, attn_scale, dropout, fast_zero_fill, QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], cu_seqlens_q, cu_seqlens_kv, q, kv, o, d_o, qkv_dtype, dqkv_dtype, aux_ctx_tensors, + seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, d_scale_dp, q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv, ) @@ -914,6 +918,7 @@ def fused_attn_fwd( max_seqlen_q, max_seqlen_kv, is_training, attn_scale, dropout, fast_zero_fill, QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], cu_seqlens_q, cu_seqlens_kv, q, k, v, qkv_dtype, + seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, d_scale_qkv, d_scale_s, q_scale_s, q_scale_o, amax_s, amax_o, attn_bias, rng_gen, rng_elts_per_thread, ) @@ -1084,6 +1089,7 @@ def fused_attn_bwd( max_seqlen_q, max_seqlen_kv, attn_scale, dropout, fast_zero_fill, QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], cu_seqlens_q, cu_seqlens_kv, q, k, v, o, d_o, qkv_dtype, dqkv_dtype, aux_ctx_tensors, + seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o, d_scale_qkv, d_scale_s, d_scale_o, d_scale_do, d_scale_dp, q_scale_s, q_scale_dp, q_scale_dqkv, amax_dp, amax_dqkv, ) From 6ec8926b046b1751af2a9a767ed0c54951d9ffe8 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 22 May 2024 18:53:28 -0700 Subject: [PATCH 150/244] git fix Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/attention.py | 33 +++++++++++++------------ 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 7308f1812c..ff76532f99 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -5,14 +5,14 @@ """Attention.""" import collections from contextlib import nullcontext -from importlib.metadata import version +from importlib.metadata import version as get_pkg_version import math import os from typing import Any, Callable, Dict, List, Optional, Tuple, Union import warnings import numpy as np -from pkg_resources import packaging +from packaging.version import Version as PkgVersion import torch import torch.nn.functional as F @@ -67,13 +67,13 @@ from transformer_engine.pytorch.graph import is_graph_capturing -_flash_attn_version = packaging.version.Version(version("flash-attn")) -_flash_attn_version_required = packaging.version.Version("2.0.6") -_flash_attn_max_version = packaging.version.Version("2.5.8") -_flash_attn_2_1_plus = _flash_attn_version >= packaging.version.Version("2.1") -_flash_attn_2_3_plus = _flash_attn_version >= packaging.version.Version("2.3") -_flash_attn_2_4_plus = _flash_attn_version >= packaging.version.Version("2.4") -_flash_attn_2_4_1_plus = _flash_attn_version >= packaging.version.Version("2.4.1") +_flash_attn_version = PkgVersion(get_pkg_version("flash-attn")) +_flash_attn_version_required = PkgVersion("2.0.6") +_flash_attn_max_version = PkgVersion("2.5.8") +_flash_attn_2_1_plus = _flash_attn_version >= PkgVersion("2.1") +_flash_attn_2_3_plus = _flash_attn_version >= PkgVersion("2.3") +_flash_attn_2_4_plus = _flash_attn_version >= PkgVersion("2.4") +_flash_attn_2_4_1_plus = _flash_attn_version >= PkgVersion("2.4.1") if _flash_attn_version >= _flash_attn_version_required: from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_forward_func # pylint: disable=no-name-in-module @@ -313,7 +313,7 @@ def get_cu_seqlens(mask: torch.Tensor) -> torch.Tensor: the samples in a batch. """ mask = mask.squeeze(1).squeeze(1) - reduced_mask = mask.sum(dim=1) + reduced_mask = mask.logical_not().sum(dim=1) cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32) zero = torch.zeros(1, dtype=torch.int32, device="cuda") cu_seqlens = torch.cat((zero, cu_seqlens)) @@ -331,13 +331,13 @@ def get_cu_seqlens_and_indices(mask: torch.Tensor) -> Tuple[torch.Tensor, torch. mask = mask.squeeze(1).squeeze(1) bs, seqlen = mask.shape - reduced_mask = mask.sum(dim=1) + reduced_mask = mask.logical_not().sum(dim=1) cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32) zero = torch.zeros(1, dtype=torch.int32, device="cuda") cu_seqlens = torch.cat((zero, cu_seqlens)) mask = mask.reshape(-1) - indices = mask.nonzero() + indices = mask.logical_not().nonzero() indices = indices.unsqueeze(-1) num_nonzeros = indices.shape[0] @@ -497,7 +497,7 @@ def forward( *tensors: Tuple[torch.Tensor, ...] ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: assert 1 <= len(tensors) <= 3, f"Packing {len(tensors)} tensors not supported." - ctx.indices = indices + ctx.save_for_backward(indices) ctx.dim0 = tensors[0].shape[0] if len(tensors) == 1: return pack_tensor(indices, *tensors) @@ -507,11 +507,12 @@ def forward( @staticmethod def backward(ctx, *grad_outputs: Tuple[torch.Tensor, ...]): + (indices,) = ctx.saved_tensors if len(grad_outputs) == 1: - return None, unpack_tensor(ctx.indices, ctx.dim0, *grad_outputs) + return None, unpack_tensor(indices, ctx.dim0, *grad_outputs) if len(grad_outputs) == 2: - return None, *unpack_2_tensors(ctx.indices, ctx.dim0, *grad_outputs) - return None, *unpack_3_tensors(ctx.indices, ctx.dim0, *grad_outputs) + return None, *unpack_2_tensors(indices, ctx.dim0, *grad_outputs) + return None, *unpack_3_tensors(indices, ctx.dim0, *grad_outputs) class UnpackTensor(torch.autograd.Function): From 4da9feee0ec134bd9df31ebb28530770d0f45b3f Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 22 May 2024 18:57:24 -0700 Subject: [PATCH 151/244] git fix Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/attention.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index ff76532f99..37a24f1a69 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -526,12 +526,13 @@ def forward( dim0: int, tensor: torch.Tensor, ) -> torch.Tensor: - ctx.indices = indices + ctx.save_for_backward(indices) return unpack_tensor(indices, dim0, tensor) @staticmethod def backward(ctx, grad_output): - return None, None, pack_tensor(ctx.indices, grad_output) + (indices,) = ctx.saved_tensors + return None, None, pack_tensor(indices, grad_output) def flash_attn_p2p_communicate(rank, send_tensor, send_dst, From f16868bf9fac044c44587a37e1d105cc7f32f8d0 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 22 May 2024 19:00:06 -0700 Subject: [PATCH 152/244] git fix Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/attention.py | 65 ++++++++++++++++++------- 1 file changed, 48 insertions(+), 17 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 37a24f1a69..c5453da5ea 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -102,7 +102,6 @@ __all__ = ["DotProductAttention", "InferenceParams", "MultiheadAttention"] - class InferenceParams: # pylint: disable=too-few-public-methods """ Inference parameters that are passed to the main model in order @@ -814,8 +813,13 @@ def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, if len(rest) > 0: attn_biases[i] = rest[0] else: - # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn] - q_inputs[i%2] = q[:, 1, ...].contiguous().view(-1, *q.shape[-2:]) + if qkv_format == "thd": + # [t, np, hn] -> [t/2, np, hn] + q_inputs[i%2] = tex.thd_read_half_tensor(q, cu_seqlens_q, 1) + else: + # [b, 2, sq//2, np, hn]->[b, sq//2, np, hn]->[b*sq//2, np, hn] + q_inputs[i%2] = \ + q[:, 1, ...].contiguous().view(-1, *q.shape[-2:]) # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn] kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:]) if _flash_attn_2_3_plus: @@ -873,7 +877,7 @@ def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, if i == 1: out = torch.empty_like(q).zero_() softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double) - if causal: + if causal and qkv_format != "thd": # [b, np, sq] -> [b, np, 2, sq//2] softmax_lse_ = softmax_lse.view( *softmax_lse.shape[:-1], 2, softmax_lse.shape[-1]//2 @@ -882,8 +886,14 @@ def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, flash_attn_fwd_softmax_lse_correction(softmax_lse, softmax_lse_per_step[i-1]) else: - flash_attn_fwd_softmax_lse_correction(softmax_lse_[..., 1, :], - softmax_lse_per_step[i-1]) + if qkv_format == "thd": + tex.thd_second_half_lse_correction(softmax_lse, + softmax_lse_per_step[i-1], + cu_seqlens_q, + q.size(0)) + else: + flash_attn_fwd_softmax_lse_correction(softmax_lse_[..., 1, :], + softmax_lse_per_step[i-1]) if i < cp_size: flash_attn_streams[(i-1)%2].record_event(fwd_results_correction_done) @@ -891,7 +901,8 @@ def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, torch.cuda.current_stream().wait_stream(flash_attn_streams[1]) softmax_lse = softmax_lse.to(torch.float) - seq_dim = qkv_format.index("s") + if qkv_format in ["bshd", "sbhd"]: + seq_dim = qkv_format.index("s") for i in range(cp_size): if qkv_format == "bshd": out_per_step[i] = out_per_step[i].view(out.shape[0], -1, *out.shape[-2:]) @@ -900,17 +911,37 @@ def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, out_per_step[i] = out_per_step[i].view(-1, *out.shape[-3:]) out_ = out[1] if i <= rank or not causal: - flash_attn_fwd_out_correction(out.view(*out_per_step[i].shape), - out_per_step[i], - seq_dim, - softmax_lse, - softmax_lse_per_step[i]) + if qkv_format in ["bshd", "sbhd"]: + flash_attn_fwd_out_correction(out.view(*out_per_step[i].shape), + out_per_step[i], + seq_dim, + softmax_lse, + softmax_lse_per_step[i]) + elif qkv_format == "thd": + tex.thd_out_correction(out, + out_per_step[i], + softmax_lse, + softmax_lse_per_step[i], + cu_seqlens_q, + False) + else: + assert False, f"{qkv_format} is an unsupported qkv_format!" else: - flash_attn_fwd_out_correction(out_, - out_per_step[i], - seq_dim, - softmax_lse_[..., 1, :], - softmax_lse_per_step[i]) + if qkv_format in ["bshd", "sbhd"]: + flash_attn_fwd_out_correction(out_, + out_per_step[i], + seq_dim, + softmax_lse_[..., 1, :], + softmax_lse_per_step[i]) + elif qkv_format == "thd": + tex.thd_out_correction(out, + out_per_step[i], + softmax_lse, + softmax_lse_per_step[i], + cu_seqlens_q, + True) + else: + assert False, f"{qkv_format} is an unsupported qkv_format!" kv = p2p_comm_buffers[-1] if use_fused_attention: From c439a7684f9d089ca0c2ee0eb47ae429ad76ecb8 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 22 May 2024 19:12:30 -0700 Subject: [PATCH 153/244] git fix Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/attention.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index c5453da5ea..ab2a083726 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -766,8 +766,13 @@ def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, else: # [b, 2, sq//2, np, hn] -> [b*sq, np, hn] q_inputs[i%2] = q.view(-1, *q.shape[-2:]) - # [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn] - kv_inputs[i%2] = kv_inputs[i%2][:, :, 0, ...].contiguous() + if qkv_format == "thd": + # [2, t, np, hn] -> [2, t/2, np, hn] + kv_inputs[i%2] = tex.thd_read_half_tensor( + kv_inputs[i%2], cu_seqlens_k, 0) + else: + # [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn] + kv_inputs[i%2] = kv_inputs[i%2][:, :, 0, ...].contiguous() # [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn] kv_inputs[i%2] = kv_inputs[i%2].view(2, -1, *k.shape[-2:]) if _flash_attn_2_3_plus: @@ -911,7 +916,7 @@ def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, out_per_step[i] = out_per_step[i].view(-1, *out.shape[-3:]) out_ = out[1] if i <= rank or not causal: - if qkv_format in ["bshd", "sbhd"]: + if qkv_format in ["bshd", "sbhd"]: flash_attn_fwd_out_correction(out.view(*out_per_step[i].shape), out_per_step[i], seq_dim, From 448df78585ac9cb2e2591d0e5b527520bc82cd97 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 22 May 2024 19:16:08 -0700 Subject: [PATCH 154/244] git fix Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/attention.py | 43 ++++++++++++++----------- 1 file changed, 25 insertions(+), 18 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index ab2a083726..e9d4f159da 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -915,6 +915,7 @@ def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, elif qkv_format == "sbhd": out_per_step[i] = out_per_step[i].view(-1, *out.shape[-3:]) out_ = out[1] + if i <= rank or not causal: if qkv_format in ["bshd", "sbhd"]: flash_attn_fwd_out_correction(out.view(*out_per_step[i].shape), @@ -956,7 +957,8 @@ def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, out = out.view(-1, *out.shape[-3:]) else: out = out.view(-1, *out.shape[-2:]) - ctx.save_for_backward(q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k) + ctx.save_for_backward(q, kv, out, softmax_lse, + cu_seqlens_q, cu_seqlens_k, *rng_states, *attn_biases) ctx.rng_states = rng_states ctx.cp_group = cp_group ctx.cp_global_ranks = cp_global_ranks @@ -968,16 +970,17 @@ def forward(ctx, is_training, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, ctx.qkv_format = qkv_format ctx.attn_bias_type = attn_bias_type ctx.attn_bias_shape = None if attn_bias is None else attn_bias.shape - ctx.attn_biases = attn_biases ctx.deterministic = deterministic ctx.use_fused_attention = use_fused_attention return out @staticmethod def backward(ctx, dout): - q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors - + (q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k) = ctx.saved_tensors[:6] cp_size = get_distributed_world_size(ctx.cp_group) + rng_states = ctx.saved_tensors[6:6+cp_size] + attn_biases = ctx.saved_tensors[6+cp_size:6+cp_size*2] + rank = get_distributed_rank(ctx.cp_group) send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size] recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size] @@ -985,12 +988,12 @@ def backward(ctx, dout): qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format - if ctx.attn_biases[0] is not None: + if attn_biases[0] is not None: # [b, np, sq, 2*cp, sk//(2*cp)] attn_dbias = torch.zeros( *ctx.attn_bias_shape, - dtype=ctx.attn_biases[0].dtype, - device=ctx.attn_biases[0].device + dtype=attn_biases[0].dtype, + device=attn_biases[0].device ) # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, 2, sq//2, 2*cp, sk//(2*cp)] attn_dbias_ = attn_dbias.view( @@ -1000,12 +1003,16 @@ def backward(ctx, dout): attn_dbias = None if ctx.causal: - # [b, np, sq] -> [b, np, 2, sq//2] - softmax_lse_ = softmax_lse.view(*softmax_lse.shape[:-1], 2, softmax_lse.shape[-1]//2) - softmax_lse_ = softmax_lse_[..., 1, :].contiguous() - if ctx.use_fused_attention: - # [b, np, sq//2] -> [b, np, sq//2, 1] - softmax_lse_.unsqueeze_(-1) + if ctx.qkv_format == "thd": + softmax_lse_ = tex.thd_read_second_half_lse(softmax_lse, cu_seqlens_q, q.size(0)) + else: + # [b, np, sq] -> [b, np, 2, sq//2] + softmax_lse_ = \ + softmax_lse.view(*softmax_lse.shape[:-1], 2, softmax_lse.shape[-1]//2) + softmax_lse_ = softmax_lse_[..., 1, :].contiguous() + if ctx.use_fused_attention: + # [b, np, sq//2] -> [b, np, sq//2, 1] + softmax_lse_.unsqueeze_(-1) if ctx.use_fused_attention: # [b, np, sq] -> [b, np, sq, 1] softmax_lse.unsqueeze_(-1) @@ -1068,9 +1075,9 @@ def backward(ctx, dout): # [2, sq//2, b, np, hn] -> [sq, b, np, hn] out_ = out.view(-1, *out.shape[-3:]) dout_ = dout.view(-1, *dout.shape[-3:]) - aux_ctx_tensors = [softmax_lse, ctx.rng_states[cp_size-i-1]] + aux_ctx_tensors = [softmax_lse, rng_states[cp_size-i-1]] if attn_dbias is not None: - aux_ctx_tensors += [ctx.attn_biases[cp_size-i-1]] + aux_ctx_tensors += [attn_biases[cp_size-i-1]] dq_, dk_, dv_, dbias_ = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_k, cu_seqlens_q, cu_seqlens_k, @@ -1100,7 +1107,7 @@ def backward(ctx, dout): dq_, dkv_[0], dkv_[1], cu_seqlens_q, cu_seqlens_k, ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, True, - rng_state=ctx.rng_states[cp_size-i-1], + rng_state=rng_states[cp_size-i-1], **fa_optional_backward_kwargs ) elif i >= (cp_size-rank-1): @@ -1121,9 +1128,9 @@ def backward(ctx, dout): # [2, sq//2, b, np, hn] -> [sq, b, np, hn] out_ = out.view(-1, *out.shape[-3:]) dout_ = dout.view(-1, *dout.shape[-3:]) - aux_ctx_tensors = [softmax_lse, ctx.rng_states[cp_size-i-1]] + aux_ctx_tensors = [softmax_lse, rng_states[cp_size-i-1]] if attn_dbias is not None: - aux_ctx_tensors += [ctx.attn_biases[cp_size-i-1]] + aux_ctx_tensors += [attn_biases[cp_size-i-1]] dq_, dk_, dv_, dbias_ = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_k//2, cu_seqlens_q, cu_seqlens_k//2, From 63a98b717e98184cfb2d789aef4ce9f29fecacee Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 22 May 2024 19:19:05 -0700 Subject: [PATCH 155/244] git fix Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/attention.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index e9d4f159da..9caf6088d1 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -4604,7 +4604,9 @@ def forward( a single tensor of [batch_size, 1, 1, seqlen_q] for self-attention, and a tuple of two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv] for cross-attention. For the 'arbitrary' mask type, it should be in a shape that is - broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. + broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv]. A `True` value + means the corresponding position is masked out and a `False` means that position is + allowed to participate in attention. attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'arbitrary'}, default = `None` type of attention mask passed into softmax operation. From f64acd3306ae025bdeae50bad7e7872616c2fb6d Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 24 May 2024 11:21:50 -0700 Subject: [PATCH 156/244] Attention.py refactoring Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/attention.py | 201 +++++++++++++----------- 1 file changed, 107 insertions(+), 94 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 9caf6088d1..4e5279787f 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -109,21 +109,25 @@ class InferenceParams: # pylint: disable=too-few-public-methods Parameters ---------- - max_batch_size : int + max_batch_size: int maximum batch size during inference. - max_sequence_length : int - maximum sequence length during inference. + max_sequence_length: int + maximum sequence length during inference. + qkv_format: str + {'bshd', 'sbhd', 'thd'} """ def __init__(self, max_batch_size, max_sequence_length, qkv_format="bshd"): + assert qkv_format in ["bsdh", "sbhd", "thd"] + self.max_sequence_length = max_sequence_length self.max_batch_size = max_batch_size self.key_value_memory_dict = {} self.qkv_format = qkv_format if qkv_format == "thd": - self.seq_len = torch.empty((max_batch_size,), device="cuda", dtype=torch.int32) - self.incoming_seq_len = torch.empty((max_batch_size,), device="cuda", dtype=torch.int32) + self.cached_sequence_lengths = torch.empty((max_batch_size,), device="cuda", dtype=torch.int32) + self.input_sequence_lengths = torch.empty((max_batch_size,), device="cuda", dtype=torch.int32) else: self.sequence_len_offset = 0 self.batch_size_offset = 0 @@ -176,44 +180,44 @@ def thd_setup_before_new_input(self, new_input, reset=False, pad_token_id=None): """ assert self.qkv_format == "thd" - self.seq_len.copy_(self.seq_len + self.incoming_seq_len) + self.cached_sequence_lengths.copy_(self.cached_sequence_lengths + self.input_sequence_lengths) if pad_token_id is not None: - self.incoming_seq_len.copy_(torch.sum(new_input.ne(pad_token_id), dim=-1, dtype=torch.int32).squeeze()) + self.input_sequence_lengths.copy_(torch.sum(new_input.ne(pad_token_id), dim=-1, dtype=torch.int32).squeeze()) else: - self.incoming_seq_len.copy_(torch.ones(new_input.shape[0], device="cuda") * new_input.shape[1]) + self.input_sequence_lengths.copy_(torch.ones(new_input.shape[0], device="cuda") * new_input.shape[1]) self.max_incoming_seq_len = new_input.shape[1] if reset: - self.seq_len.copy_(torch.zeros_like(self.seq_len)) + self.cached_sequence_lengths.copy_(torch.zeros_like(self.cached_sequence_lengths)) - def save_new_key_and_value_layer(self, layer_number, key_layer, value_layer): + def save_to_kv_cache(self, layer_number, key_layer, value_layer): """ Saves key_layer and value_layer in the cache. """ (inference_key_memory, inference_value_memory, ) = self.key_value_memory_dict[layer_number] if self.qkv_format == "thd": - batch_size = key_layer.shape[0] - channels = inference_key_memory.shape[2] * inference_key_memory.shape[3] # h * d + channels = inference_key_memory.shape[1] * inference_key_memory.shape[2] # h * d tex.attention_copy( inference_key_memory, - self.seq_len, - self.incoming_seq_len, + self.cached_sequence_lengths, + self.input_sequence_lengths, key_layer, self.max_incoming_seq_len, self.max_sequence_length, - batch_size, + self.max_batch_size, channels) tex.attention_copy( inference_value_memory, - self.seq_len, - self.incoming_seq_len, + self.cached_sequence_lengths, + self.input_sequence_lengths, value_layer, self.max_incoming_seq_len, self.max_sequence_length, - batch_size, + self.max_batch_size, channels) + key_layer, value_layer = inference_key_memory, inference_value_memory else: assert self.qkv_format in ["bshd", "sbhd"], "Attention format not supported by the inference." batch_start = self.batch_size_offset @@ -231,8 +235,63 @@ def save_new_key_and_value_layer(self, layer_number, key_layer, value_layer): sequence_start:sequence_end, batch_start:batch_end, ...] = value_layer key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...] value_layer = inference_value_memory[:sequence_end, batch_start:batch_end, ...] - return key_layer, value_layer + return key_layer, value_layer + + def allocate_memory_for_kv_cache_if_empty( + self, + layer_number, + num_gqa_groups_per_partition, + hidden_size_per_attention_head, + dtype): + + if layer_number in self.key_value_memory_dict: + return # Already allocated + + s = self.max_sequence_length + b = self.max_batch_size + + def _allocate_memory(dims): + return torch.empty( + *dims, + num_gqa_groups_per_partition, + hidden_size_per_attention_head, + dtype=dtype, + device=torch.cuda.current_device(), + ) + + if self.qkv_format == "thd": + inference_key_memory = _allocate_memory((b * s,)) + inference_value_memory = _allocate_memory((b * s,)) + else: + inference_key_memory = _allocate_memory((s, b)) + inference_value_memory = _allocate_memory((s, b)) + self.key_value_memory_dict[layer_number] = ( + inference_key_memory, + inference_value_memory, + ) + def set_params_to_thd_attention(self, buffers, channels): + max_seqlen_q, max_seqlen_kv = self.max_incoming_seq_len, self.max_sequence_length + + # Allocation of buffers, works with CUDA Graphs. + cu_seqlens_q, cu_seqlens_kv, seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o = \ + buffers + + cu_seqlens_q[1:].copy_(torch.cumsum(self.input_sequence_lengths, dim=0)) + cu_seqlens_kv[1:].copy_( + torch.cumsum( + self.cached_sequence_lengths + self.input_sequence_lengths, dim=0 + ) + ) + + # If layer has shape [b * s_layer, h, d] + # offsets are of the form [k * s_layer * h * d for k = 0, ..., batch_size] + seq_offsets_q.copy_(torch.arange(0, self.max_batch_size + 1, device="cuda") * channels * max_seqlen_q) + seq_offsets_k.copy_(torch.arange(0, self.max_batch_size + 1, device="cuda") * channels * max_seqlen_kv) + seq_offsets_v.copy_(seq_offsets_k) + seq_offsets_o.copy_(seq_offsets_q) + + return max_seqlen_q, max_seqlen_kv, buffers @torch.no_grad() @@ -3762,45 +3821,25 @@ def forward( key_layer = key_layer.transpose(0, 1) value_layer = value_layer.transpose(0, 1) - (inference_key_memory, inference_value_memory, - ) = inference_params.key_value_memory_dict[self.layer_number] - - if qkv_format in ["bshd", "sbhd"]: - key_layer, value_layer = inference_params.save_new_key_and_value_layer(self.layer_number, key_layer, value_layer) - elif qkv_format == "thd": - - inference_params.save_new_key_and_value_layer(self.layer_number, key_layer, value_layer) - - """ - We compute parameters needed by the THD attention with offsets. - """ - batch_size = query_layer.shape[0] - max_seqlen_q = inference_params.max_incoming_seq_len - max_seqlen_kv = inference_params.max_sequence_length - cu_seqlens_q = self.alloc(batch_size + 1, dtype=torch.int32, device="cuda") - cu_seqlens_kv = self.alloc(batch_size + 1, dtype=torch.int32, device="cuda") - seq_offsets_q = self.alloc(batch_size + 1, dtype=torch.int32, device="cuda") - seq_offsets_k = self.alloc(batch_size + 1, dtype=torch.int32, device="cuda") - seq_offsets_v = self.alloc(batch_size + 1, dtype=torch.int32, device="cuda") - seq_offsets_o = self.alloc(batch_size + 1, dtype=torch.int32, device="cuda") - - cu_seqlens_q[1:].copy_(torch.cumsum(inference_params.incoming_seq_len, dim=0)) - cu_seqlens_kv[1:].copy_(torch.cumsum(inference_params.seq_len + inference_params.incoming_seq_len, dim=0)) + key_layer, value_layer = inference_params.save_to_kv_cache( + self.layer_number, key_layer, value_layer + ) - seq_offsets_q.copy_(torch.arange(0, batch_size + 1, dtype=torch.int32, device="cuda") * self.channels * max_seqlen_q) - seq_offsets_o.copy_(seq_offsets_q) - seq_offsets_k.copy_(torch.arange(0, batch_size + 1, dtype=torch.int32, device="cuda") * self.channels * max_seqlen_kv) - seq_offsets_v.copy_(seq_offsets_k) + if qkv_format == "thd": + # Allocation of buffers, works with CUDA Graphs. + buffers = [self.alloc(batch_size + 1, dtype=torch.int32, device="cuda") for _ in range(6)] - # qkv layers are reshaped to the format [t, h, d] - query_layer = query_layer.view(-1, query_layer.shape[2], query_layer.shape[3]).to(torch.bfloat16) - key_layer = inference_key_memory.view(-1, inference_key_memory.shape[2], inference_key_memory.shape[3]).to(torch.bfloat16) - value_layer = inference_value_memory.view(-1, inference_value_memory.shape[2], inference_value_memory.shape[3]).to(torch.bfloat16) + max_seqlen_q, max_seqlen_kv, buffers = inference_params.set_params_to_thd_attention(buffers, self.channels) + cu_seqlens_q, cu_seqlens_kv, seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o = \ + buffers + # query_layer is reshaped to the format [t, h, d] + query_layer = query_layer.view(-1, *query_layer.shape[2:]) if qkv_format == "bshd": key_layer = key_layer.transpose(0, 1) value_layer = value_layer.transpose(0, 1) + key_layer = key_layer.contiguous() value_layer = value_layer.contiguous() @@ -4515,20 +4554,6 @@ def __init__( self._allocator = StaticBufferAllocator() - - - def _allocate_memory( - self, inference_max_sequence_len: int, batch_size: int, dtype: torch.dtype - ) -> torch.Tensor: - return torch.empty( - inference_max_sequence_len, - batch_size, - self.num_gqa_groups_per_partition, - self.hidden_size_per_attention_head, - dtype=dtype, - device=torch.cuda.current_device(), - ) - def alloc(self, size, dtype, device): return self._allocator(size, dtype, device) @@ -4670,33 +4695,13 @@ def forward( # Pre-allocate memory for key-values for inference # ================================================= - if inference_params and self.layer_number is not None: - if self.layer_number not in inference_params.key_value_memory_dict: - inf_max_seq_len = inference_params.max_sequence_length - inf_max_batch_size = inference_params.max_batch_size - if self.qkv_format == "thd": - inference_key_memory = self._allocate_memory( - inf_max_batch_size, inf_max_seq_len, hidden_states.dtype - ) - inference_value_memory = self._allocate_memory( - inf_max_batch_size, inf_max_seq_len, hidden_states.dtype - ) - else: - inference_key_memory = self._allocate_memory( - inf_max_seq_len, inf_max_batch_size, hidden_states.dtype - ) - inference_value_memory = self._allocate_memory( - inf_max_seq_len, inf_max_batch_size, hidden_states.dtype - ) - inference_params.key_value_memory_dict[self.layer_number] = ( - inference_key_memory, - inference_value_memory, - ) - else: - ( - inference_key_memory, - inference_value_memory, - ) = inference_params.key_value_memory_dict[self.layer_number] + + inference_params.allocate_memory_for_kv_cache_if_empty( + self.layer_number, + self.num_gqa_groups_per_partition, + self.hidden_size_per_attention_head, + hidden_states.dtype + ) # ====================== # Query, Key, and Value @@ -4855,8 +4860,16 @@ def forward( key_layer = key_layer.contiguous() query_layer = query_layer.contiguous() - key_layer.copy_(apply_rotary_pos_emb(key_layer, k_pos_emb, "bshd", fused=True, begins=inference_params.seq_len)) - query_layer.copy_(apply_rotary_pos_emb(query_layer, q_pos_emb, "bshd", fused=True, begins=inference_params.seq_len)) + key_layer.copy_( + apply_rotary_pos_emb( + key_layer, k_pos_emb, "bshd", fused=True, begins=inference_params.cached_sequence_lengths + ) + ) + query_layer.copy_( + apply_rotary_pos_emb( + query_layer, q_pos_emb, "bshd", fused=True, begins=inference_params.cached_sequence_lengths + ) + ) else: # adjust key and value for inference if inference_params is not None: From c8e4510e577ee3829d1649a6d62d628a84684255 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 24 May 2024 13:18:34 -0700 Subject: [PATCH 157/244] Attention.py refactoring Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/attention.py | 79 +++++++++++++++++++++---- 1 file changed, 67 insertions(+), 12 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 4e5279787f..e31c0c716d 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -114,18 +114,29 @@ class InferenceParams: # pylint: disable=too-few-public-methods max_sequence_length: int maximum sequence length during inference. qkv_format: str - {'bshd', 'sbhd', 'thd'} + Dimension format for `q`, `k` and `v`, {`sbhd`, `bshd`, `thd`}. `s` stands for + the sequence length dimension, `b` batch size, `h` the number of attention heads, + `d` head size, and `t` the total number of sequences in a batch, i.e. + `t = sum(s_i) for i = 0...b-1`. """ def __init__(self, max_batch_size, max_sequence_length, qkv_format="bshd"): assert qkv_format in ["bsdh", "sbhd", "thd"] - + self.max_sequence_length = max_sequence_length self.max_batch_size = max_batch_size - self.key_value_memory_dict = {} + + # self.key_value_memory_dict[layer number] = (key_cache, value_cache) + # if qkv_format in ["bshd", "sbhd"]: (key/value)_cache.shape = [b/s, s/b, h, d] + # # if qkv_format = "thd": (key/value)_cache.shape = [t, h, d] + self.key_value_memory_dict = {} self.qkv_format = qkv_format if qkv_format == "thd": + # In thd attention layout input sequences can have different lenghts. + # self.input_sequence_lengths stores tensor of shape [b] with lengths of input sequences + # and self.cached_sequence_lengths is the sum of all previous input lengths tensors - + # equivalently it contains total lengths of cached sequences. self.cached_sequence_lengths = torch.empty((max_batch_size,), device="cuda", dtype=torch.int32) self.input_sequence_lengths = torch.empty((max_batch_size,), device="cuda", dtype=torch.int32) else: @@ -160,10 +171,8 @@ def swap_key_value_dict(self, batch_indices): def thd_setup_before_new_input(self, new_input, reset=False, pad_token_id=None): """ - After every context/generation phase, the parameters representing - for example sequence lengths and incmoing sequence lengths, - need to be updated. This function does exactly that. - + Updates parameters representing incoming sequence lengths and lengths + of sequence in the cache. Should be called before every forward pass in inference. Parameters ---------- @@ -174,7 +183,7 @@ def thd_setup_before_new_input(self, new_input, reset=False, pad_token_id=None): It is supposed to be used after last generation phase to allow inference_params to be reused. pad_token_id: int - Value of padding token - used to compute sequence_lengths. If pad_token_id=None, + Value of padding token - used to compute sequence lengths. If pad_token_id=None, we assume that all new_input sequence lengths are equal to the corresponding dimension of new_input. """ @@ -193,11 +202,23 @@ def thd_setup_before_new_input(self, new_input, reset=False, pad_token_id=None): def save_to_kv_cache(self, layer_number, key_layer, value_layer): """ Saves key_layer and value_layer in the cache. + + Parameters + ---------- + layer_number: input + layer number of the current `TransformerLayer` when multiple such modules are + concatenated to form a transformer block. + key_layer: torch.Tensor + Tensor of format corresponding to self.qkv_format with current key_layer. + value_layer: int + Tensor of format corresponding to self.qkv_format with current value_layer. """ (inference_key_memory, inference_value_memory, ) = self.key_value_memory_dict[layer_number] if self.qkv_format == "thd": channels = inference_key_memory.shape[1] * inference_key_memory.shape[2] # h * d + # This kernels copies kernels from input layers into cache, + # taking into account the thd format and sequence lengths. tex.attention_copy( inference_key_memory, self.cached_sequence_lengths, @@ -243,12 +264,24 @@ def allocate_memory_for_kv_cache_if_empty( num_gqa_groups_per_partition, hidden_size_per_attention_head, dtype): + """ + Allocates memory for kv_cache for given layer, if it hasn't been alocated before. + + Parameters + ---------- + layer_number: input + layer number of the current `TransformerLayer` when multiple such modules are + concatenated to form a transformer block. + num_gqa_groups_per_partition: torch.Tensor + This will be third dimension of cache tensor. + hidden_size_per_attention_head: int + This will be fourth dimension of cache tensor. + """ if layer_number in self.key_value_memory_dict: return # Already allocated - s = self.max_sequence_length - b = self.max_batch_size + b, s = self.max_batch_size, self.max_sequence_length def _allocate_memory(dims): return torch.empty( @@ -271,9 +304,31 @@ def _allocate_memory(dims): ) def set_params_to_thd_attention(self, buffers, channels): + """ + Fused attention with q/k/v of thd layout needs some parameters which give information + about sequence lengths. This method computes them and saves them into fiven buffers. + + Parameters + ---------- + buffers: List[torch.Tensor] + buffers of size [batch_size + 1] for the parameters: + cu_seqlens_q, cu_seqlens_kv, seq_offsets_q, + seq_offsets_k, seq_offsets_v, seq_offsets_o + respectively. + channels: int + value of num_heads * hidden_dim_for_each_head. + + Returns + ---------- + max_seqlen_q: int + Maximal value of query sequence length. + max_seqlen_kv: int + Maximal value of key/value sequence length. + buffers: torch.Tensor + Tensor with filled buffers. + """ max_seqlen_q, max_seqlen_kv = self.max_incoming_seq_len, self.max_sequence_length - # Allocation of buffers, works with CUDA Graphs. cu_seqlens_q, cu_seqlens_kv, seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o = \ buffers @@ -3826,7 +3881,7 @@ def forward( ) if qkv_format == "thd": - # Allocation of buffers, works with CUDA Graphs. + # Allocation of buffers, it works correctly with CUDA Graphs. buffers = [self.alloc(batch_size + 1, dtype=torch.int32, device="cuda") for _ in range(6)] max_seqlen_q, max_seqlen_kv, buffers = inference_params.set_params_to_thd_attention(buffers, self.channels) From 954257d83ad6331d6156d454b483ca10e0a7f0d6 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 24 May 2024 13:48:18 -0700 Subject: [PATCH 158/244] te_gemma.py refactoring Signed-off-by: Pawel Gadzinski --- docs/examples/te_gemma/te_gemma.py | 54 +++++++++++++++++++----------- 1 file changed, 34 insertions(+), 20 deletions(-) diff --git a/docs/examples/te_gemma/te_gemma.py b/docs/examples/te_gemma/te_gemma.py index 6264a448fb..86fd7cafed 100644 --- a/docs/examples/te_gemma/te_gemma.py +++ b/docs/examples/te_gemma/te_gemma.py @@ -56,7 +56,8 @@ def forward(self, *args, **kwargs): # We need to pass positional encoding. keys_to_remove = ["position_ids", "past_key_value", "output_attentions", "use_cache", "cache_position"] for key in keys_to_remove: kwargs.pop(key, None) - return (super().forward(*args, rotary_pos_emb=self.te_rope_emb, **kwargs),) # We need to return tuple to be compatible with HF. + # We need to return tuple to be compatible with HF. + return (super().forward(*args, rotary_pos_emb=self.te_rope_emb, **kwargs),) class StaticGemmaModel(torch.nn.Module): """ @@ -119,14 +120,14 @@ def forward(self, hidden_states : torch.Tensor): logits = logits[:, -1, :] next_tokens = torch.argmax(logits, dim=1) - hidden_states.copy_(self.model.embed_tokens(next_tokens).unsqueeze(1)) + hidden_states.copy_(self.model.embed_tokens(next_tokens).unsqueeze(1)) # static copy for CUDA graphs # self.inference_params contains for example kv_cache # This needs to be called before every pass, # to update the information of sequence lengths. # Here we increase sequence offsets by one, # because we generated one token for every sequence. - self.inference_params.thd_setup_before_new_input(next_tokens.unsqueeze(1)) + self.inference_params.thd_setup_before_new_input(next_tokens.unsqueeze(1)) # static copy for CUDA graphs return next_tokens @@ -161,9 +162,9 @@ def __init__(self, config: GemmaConfig): self._model_generation_phase = GemmaGenerator( lm_head=self.lm_head, model=self.model, - dtype=torch.float32, + dtype=torch.bfloat16, ) - self._model_context_phase = StaticGemmaModel(self.model, torch.float32, 'padding_causal', self.lm_head) + self._model_context_phase = StaticGemmaModel(self.model, torch.bfloat16, 'padding_causal', self.lm_head) if self.config.fp8: self.fp8_recipe = DelayedScaling(fp8_format=Format.HYBRID, amax_history_len=16, amax_compute_algo="max") @@ -200,6 +201,10 @@ def _create_hidden_states_buffer(self, input_ids : torch.Tensor): # This function is overriden in TeGEmmaForCausalLMCudaGraphs. def _create_inference_params(self, max_batch_size : int, max_sequence_length : int): return InferenceParams(max_batch_size, max_sequence_length, qkv_format="thd") + + # This function is overriden in TeGEmmaForCausalLMCudaGraphs. + def _get_max_input_seq_len(self, input_ids): + return input_ids.shape[1] # The buffer for generation is some part (beginning) of hidden states buffer. # This function returns pointer to it and also copies there data if provided. @@ -214,7 +219,6 @@ def _get_generation_buffer(self, hidden_states_buffer, data_to_copy=None): generation_buffer = output.view((hidden_states_buffer.shape[0], 1, hidden_states_buffer.shape[2])) return generation_buffer - def _generate_context_phase( self, input_ids : torch.Tensor, @@ -231,7 +235,7 @@ def _generate_context_phase( # We choose logits coresponding with last token in each sequence, # which have various lengths - they are stored in (inference_params.incoming_seq_len - 1) Tensor. - logits = logits[torch.arange(logits.size(0)), inference_params.incoming_seq_len - 1, :] + logits = logits[torch.arange(logits.size(0)), inference_params.input_sequence_lengths - 1, :] next_tokens = torch.argmax(logits, dim=1) # self.hidden_states have shape [b, s, hd]. @@ -239,9 +243,6 @@ def _generate_context_phase( hidden_states = self._get_generation_buffer(hidden_states, self.model.embed_tokens(next_tokens)) return hidden_states, next_tokens - def _get_max_input_seq_len(self, input_ids): - return input_ids.shape[1] - @torch.no_grad() def generate( self, @@ -252,8 +253,13 @@ def generate( ): self.eval() assert self.config.qkv_format == "thd", "Generation using other qkv_layouts than thd is not provided in this tutorial" + + # We need both autocasts: FP8 for operations that can run in lower precision + # and BF16 for those that cannot. with autocast(dtype=torch.bfloat16, cache_enabled=False), \ - te.pytorch.fp8_autocast(enabled=self.config.fp8, fp8_recipe=self.fp8_recipe if self.config.fp8 else None): + te.pytorch.fp8_autocast( + enabled=self.config.fp8, fp8_recipe=self.fp8_recipe if self.config.fp8 else None): + batch_size, max_input_sequence_len = input_ids.shape[0], self._get_max_input_seq_len(input_ids) lengths = torch.sum(input_ids.ne(pad_token_id), dim=-1).squeeze() # [s] input_ids = F.pad(input_ids, (max_input_sequence_len - input_ids.shape[1], 0), "constant", 0) @@ -280,10 +286,10 @@ def generate( inference_params.thd_setup_before_new_input(next_tokens.unsqueeze(1)) output_tokens = [next_tokens] - for _ in range(max_new_tokens): next_tokens = self._model_generation_phase(hidden_states) - output_tokens.append(next_tokens.clone()) + # next_tokens is static output tensor, so we need to clone it - it gets changed every iteration. + output_tokens.append(next_tokens.clone()) result = torch.cat((input_ids, torch.stack(output_tokens).permute([1, 0])), dim=1) return result @@ -300,29 +306,34 @@ def __init__(self, config : GemmaConfig): self.config = config self.hidden_states_buffer = torch.empty( (config.cuda_graphs_static_batch_size, config.cuda_graphs_static_max_context_len, config.hidden_size)).cuda() - self.generation_buffer = self._get_generation_buffer(self.hidden_states_buffer) # in fact part of the buffer for hidden_states + # This is in fact part of the buffer for hidden_states. + self.generation_buffer = self._get_generation_buffer(self.hidden_states_buffer) self.inference_params = InferenceParams( - max_batch_size=config.cuda_graphs_static_batch_size, max_sequence_length=config.cuda_graphs_static_max_seq_len, qkv_format="thd") + max_batch_size=config.cuda_graphs_static_batch_size, + max_sequence_length=config.cuda_graphs_static_max_seq_len, + qkv_format="thd" + ) self._model_generation_phase.set_inference_params(self.inference_params) self._model_context_phase.set_inference_params(self.inference_params) def record(self): - self.eval() + self.eval() # We want to record model in training=False, because it will be used in generation. + # Here "the trick" happens. We override methods from TEGemmaForCausalLM # with their recorded version. After invocation of each of them, # captured graph will be replayed with minimal usage of CPU, # what will lead to huge speedup. - - input_shape = (self.config.cuda_graphs_static_batch_size, self.config.cuda_graphs_static_max_context_len) self.inference_params.thd_setup_before_new_input(torch.randn(input_shape), reset=True) - self._model_context_phase = self.record_graph(self._model_context_phase, self.hidden_states_buffer) # CUDA Graphs recording + self._model_context_phase = self.record_graph( + self._model_context_phase, self.hidden_states_buffer) # CUDA Graphs recording input_shape = torch.randn((self.config.cuda_graphs_static_batch_size, 1)) self.inference_params.thd_setup_before_new_input(input_shape, reset=True) - self._model_generation_phase = self.record_graph(self._model_generation_phase, self.generation_buffer) # CUDA Graphs recording + self._model_generation_phase = self.record_graph( + self._model_generation_phase, self.generation_buffer) # CUDA Graphs recording """ Functions _create_hidden_states_buffer and _create_inference_params from base class are overriden @@ -344,6 +355,9 @@ def record_graph(self, function, input_tensor): # record_graph() returns captured function, which can be run later with minimal use of th CPU. fp8_format = Format.HYBRID fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=1024, amax_compute_algo="max") + + # We need both autocasts: FP8 for operations that can run in lower precision + # and BF16 for those that cannot. with autocast(dtype=torch.bfloat16, cache_enabled=False): graphed_function = te.pytorch.make_graphed_callables( function, From 6e35fcb6841e7afe74e68413047b961f52633b30 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 28 May 2024 11:28:54 -0700 Subject: [PATCH 159/244] Not THD attention generation Signed-off-by: Pawel Gadzinski --- docs/examples/te_gemma/te_gemma.py | 69 +++++++++++++++++-------- transformer_engine/pytorch/attention.py | 64 +++++++++++++---------- 2 files changed, 86 insertions(+), 47 deletions(-) diff --git a/docs/examples/te_gemma/te_gemma.py b/docs/examples/te_gemma/te_gemma.py index 86fd7cafed..baa037dd28 100644 --- a/docs/examples/te_gemma/te_gemma.py +++ b/docs/examples/te_gemma/te_gemma.py @@ -80,13 +80,13 @@ def __init__( def set_inference_params(self, inference_params): self.inference_params = inference_params - def forward(self, hidden_states : torch.Tensor): + def forward(self, hidden_states : torch.Tensor, attention_mask : torch.Tensor = None): with torch.no_grad(): hidden_states.data[:] = hidden_states.data[:] * self.normalizer # static operation - for CUDA graphs for decoder_layer in self.model.layers: hidden_states.data[:] = decoder_layer( hidden_states, - attention_mask=None, + attention_mask=attention_mask, self_attn_mask_type=self.mask, inference_params=self.inference_params )[0] # static copy - for CUDA graphs @@ -102,17 +102,18 @@ class GemmaGenerator(torch.nn.Module): GemmaGenerator gets one layer of embeddins, makes forward pass and returns next tokens. """ - def __init__(self, model : GemmaModel, lm_head: torch.nn.Module, dtype : torch.dtype): + def __init__(self, model : GemmaModel, lm_head: torch.nn.Module, dtype : torch.dtype, qkv_format : str): super().__init__() self.model = model self.gemma_layers = StaticGemmaModel(model, dtype, 'padding', lm_head) + self.qkv_format = qkv_format def set_inference_params(self, inference_params): self.inference_params = inference_params self.gemma_layers.set_inference_params(inference_params) - def forward(self, hidden_states : torch.Tensor): - logits = self.gemma_layers(hidden_states) + def forward(self, hidden_states : torch.Tensor, mask : torch.Tensor = None): + logits = self.gemma_layers(hidden_states, attention_mask=mask) assert logits.shape[0] == hidden_states.shape[0] # b assert logits.shape[1] == hidden_states.shape[1] # seq_len @@ -127,7 +128,7 @@ def forward(self, hidden_states : torch.Tensor): # to update the information of sequence lengths. # Here we increase sequence offsets by one, # because we generated one token for every sequence. - self.inference_params.thd_setup_before_new_input(next_tokens.unsqueeze(1)) # static copy for CUDA graphs + self.inference_params.setup_before_new_input(next_tokens.unsqueeze(1)) return next_tokens @@ -163,6 +164,7 @@ def __init__(self, config: GemmaConfig): lm_head=self.lm_head, model=self.model, dtype=torch.bfloat16, + qkv_format=config.qkv_format ) self._model_context_phase = StaticGemmaModel(self.model, torch.bfloat16, 'padding_causal', self.lm_head) @@ -200,7 +202,7 @@ def _create_hidden_states_buffer(self, input_ids : torch.Tensor): # This function is overriden in TeGEmmaForCausalLMCudaGraphs. def _create_inference_params(self, max_batch_size : int, max_sequence_length : int): - return InferenceParams(max_batch_size, max_sequence_length, qkv_format="thd") + return InferenceParams(max_batch_size, max_sequence_length, qkv_format=self.config.qkv_format) # This function is overriden in TeGEmmaForCausalLMCudaGraphs. def _get_max_input_seq_len(self, input_ids): @@ -226,16 +228,24 @@ def _generate_context_phase( ): hidden_states = self._create_hidden_states_buffer(input_ids) hidden_states.data[:] = self.model.embed_tokens(input_ids) - - # We need to update offsets before every forward pass to make cache work properly. - inference_params.thd_setup_before_new_input(input_ids, pad_token_id=0, reset=True) + + # We need to update offsets before every forward pass to make cache work properly. + inference_params.setup_before_new_input(input_ids, pad_token_id=0, reset=True) + hidden_states.data[:] = self.model.embed_tokens(input_ids) - logits = self._model_context_phase(hidden_states) + logits = self._model_context_phase( + hidden_states, + attention_mask=((input_ids == 0) if self.config.qkv_format != "thd" else None) + ) # We choose logits coresponding with last token in each sequence, - # which have various lengths - they are stored in (inference_params.incoming_seq_len - 1) Tensor. - logits = logits[torch.arange(logits.size(0)), inference_params.input_sequence_lengths - 1, :] + # which have various lengths - they are stored in (inference_params.incoming_seq_len - 1) Tensor + # when qkv_format == "thd" and they are the last token in the sequence when qkv_format != "thd". + if self.config.qkv_format == "thd": + logits = logits[torch.arange(logits.size(0)), inference_params.input_sequence_lengths - 1, :] + else: + logits = logits[:, -1, :] next_tokens = torch.argmax(logits, dim=1) # self.hidden_states have shape [b, s, hd]. @@ -243,6 +253,12 @@ def _generate_context_phase( hidden_states = self._get_generation_buffer(hidden_states, self.model.embed_tokens(next_tokens)) return hidden_states, next_tokens + def _make_mask_one_token_longer(self, mask): + return torch.cat( + [mask, torch.zeros(mask.size(0), 1, 1, 1, dtype=torch.bool, device=mask.device)], + dim=-1 + ) + @torch.no_grad() def generate( self, @@ -252,7 +268,6 @@ def generate( *args, **kwargs ): self.eval() - assert self.config.qkv_format == "thd", "Generation using other qkv_layouts than thd is not provided in this tutorial" # We need both autocasts: FP8 for operations that can run in lower precision # and BF16 for those that cannot. @@ -274,8 +289,9 @@ def generate( self._model_context_phase.set_inference_params(inference_params) self._model_generation_phase.set_inference_params(inference_params) - # Context phase - TEGemmaForCausalLM._padding_to_end(input_ids, lengths) + if self.config.qkv_format == "thd": + # For thd layout padding is at the end, otherwise at the beginning. + TEGemmaForCausalLM._padding_to_end(input_ids, lengths) hidden_states, next_tokens = self._generate_context_phase( input_ids, @@ -283,13 +299,22 @@ def generate( ) # Generation phase. - inference_params.thd_setup_before_new_input(next_tokens.unsqueeze(1)) + + inference_params.setup_before_new_input(next_tokens.unsqueeze(1)) + output_tokens = [next_tokens] + if self.config.qkv_format != "thd": + mask = (input_ids == 0).unsqueeze(1).unsqueeze(1) + for _ in range(max_new_tokens): - next_tokens = self._model_generation_phase(hidden_states) + if self.config.qkv_format != "thd": + # It will not work with cuda graphs, but it is not used for thd qkv_format. + mask = self._make_mask_one_token_longer(mask) + + next_tokens = self._model_generation_phase(hidden_states, mask) # next_tokens is static output tensor, so we need to clone it - it gets changed every iteration. - output_tokens.append(next_tokens.clone()) + output_tokens.append(next_tokens.clone()) result = torch.cat((input_ids, torch.stack(output_tokens).permute([1, 0])), dim=1) return result @@ -302,6 +327,8 @@ class TEGemmaForCausalLMCudaGraphs(TEGemmaForCausalLM): """ def __init__(self, config : GemmaConfig): super().__init__(config) + assert config.qkv_format == "thd", "Generation with CUDA Graphs are implemented only for thd format." + # Preparation of the static buffers. self.config = config self.hidden_states_buffer = torch.empty( @@ -326,12 +353,12 @@ def record(self): # captured graph will be replayed with minimal usage of CPU, # what will lead to huge speedup. input_shape = (self.config.cuda_graphs_static_batch_size, self.config.cuda_graphs_static_max_context_len) - self.inference_params.thd_setup_before_new_input(torch.randn(input_shape), reset=True) + self.inference_params.setup_before_new_input(torch.randn(input_shape), reset=True) self._model_context_phase = self.record_graph( self._model_context_phase, self.hidden_states_buffer) # CUDA Graphs recording input_shape = torch.randn((self.config.cuda_graphs_static_batch_size, 1)) - self.inference_params.thd_setup_before_new_input(input_shape, reset=True) + self.inference_params.setup_before_new_input(input_shape, reset=True) self._model_generation_phase = self.record_graph( self._model_generation_phase, self.generation_buffer) # CUDA Graphs recording diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index e31c0c716d..6dc332b801 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -121,7 +121,7 @@ class InferenceParams: # pylint: disable=too-few-public-methods """ def __init__(self, max_batch_size, max_sequence_length, qkv_format="bshd"): - assert qkv_format in ["bsdh", "sbhd", "thd"] + assert qkv_format in ["bshd", "sbhd", "thd"] self.max_sequence_length = max_sequence_length self.max_batch_size = max_batch_size @@ -142,6 +142,7 @@ def __init__(self, max_batch_size, max_sequence_length, qkv_format="bshd"): else: self.sequence_len_offset = 0 self.batch_size_offset = 0 + self.input_sequence_length = None def swap_key_value_dict(self, batch_indices): """ @@ -169,7 +170,7 @@ def swap_key_value_dict(self, batch_indices): ) - def thd_setup_before_new_input(self, new_input, reset=False, pad_token_id=None): + def setup_before_new_input(self, new_input, reset=False, pad_token_id=None): """ Updates parameters representing incoming sequence lengths and lengths of sequence in the cache. Should be called before every forward pass in inference. @@ -187,17 +188,21 @@ def thd_setup_before_new_input(self, new_input, reset=False, pad_token_id=None): we assume that all new_input sequence lengths are equal to the corresponding dimension of new_input. """ - assert self.qkv_format == "thd" + if self.qkv_format == "thd": + self.cached_sequence_lengths.copy_(self.cached_sequence_lengths + self.input_sequence_lengths) + if pad_token_id is not None: + self.input_sequence_lengths.copy_(torch.sum(new_input.ne(pad_token_id), dim=-1, dtype=torch.int32).squeeze()) + else: + self.input_sequence_lengths.copy_(torch.ones(new_input.shape[0], device="cuda") * new_input.shape[1]) + self.max_incoming_seq_len = new_input.shape[1] - self.cached_sequence_lengths.copy_(self.cached_sequence_lengths + self.input_sequence_lengths) - if pad_token_id is not None: - self.input_sequence_lengths.copy_(torch.sum(new_input.ne(pad_token_id), dim=-1, dtype=torch.int32).squeeze()) + if reset: + self.cached_sequence_lengths.copy_(torch.zeros_like(self.cached_sequence_lengths)) else: - self.input_sequence_lengths.copy_(torch.ones(new_input.shape[0], device="cuda") * new_input.shape[1]) - self.max_incoming_seq_len = new_input.shape[1] + if self.input_sequence_length is not None: + self.sequence_len_offset += self.input_sequence_length + self.input_sequence_length = new_input.shape[1] - if reset: - self.cached_sequence_lengths.copy_(torch.zeros_like(self.cached_sequence_lengths)) def save_to_kv_cache(self, layer_number, key_layer, value_layer): """ @@ -1606,21 +1611,24 @@ def forward( freqs: torch.Tensor, tensor_format: str = "sbhd", cu_seqlens: Union[torch.Tensor, None] = None, - begins: Union[torch.Tensor, None] = None, + beginning_offsets: Union[torch.Tensor, None] = None, ) -> torch.Tensor: - if begins is None: - begins = torch.Tensor() + if beginning_offsets is None: + # Each sequence will start from positional encoding corresponding to 0. + # Otherwise sequence i will start from positional encoding + # corresponding to beginning_offsets[i]. + beginning_offsets = torch.Tensor() if tensor_format == "sbhd": - output = tex.fused_rope_forward(t, freqs, begins, False) + output = tex.fused_rope_forward(t, freqs, beginning_offsets, False) elif tensor_format == "bshd": output = tex.fused_rope_forward( - t.transpose(0, 1), freqs, begins, True + t.transpose(0, 1), freqs, beginning_offsets, True ).transpose(0, 1) elif tensor_format == "thd": - output = tex.fused_rope_thd_forward(t, cu_seqlens, freqs, begins) + output = tex.fused_rope_thd_forward(t, cu_seqlens, freqs, beginning_offsets) else: raise ValueError(f"Unsupported tensor_format: {tensor_format}.") - ctx.save_for_backward(freqs, cu_seqlens, begins) + ctx.save_for_backward(freqs, cu_seqlens, beginning_offsets) ctx.tensor_format = tensor_format return output @@ -3884,7 +3892,8 @@ def forward( # Allocation of buffers, it works correctly with CUDA Graphs. buffers = [self.alloc(batch_size + 1, dtype=torch.int32, device="cuda") for _ in range(6)] - max_seqlen_q, max_seqlen_kv, buffers = inference_params.set_params_to_thd_attention(buffers, self.channels) + max_seqlen_q, max_seqlen_kv, buffers = \ + inference_params.set_params_to_thd_attention(buffers, self.channels) cu_seqlens_q, cu_seqlens_kv, seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o = \ buffers @@ -4139,7 +4148,10 @@ def forward( if self.qkv_format == "thd": use_flash_attention = False use_fused_attention = True - fused_attention_backend = FusedAttnBackend["F16_arbitrary_seqlen"] + + if self.qkv_format == "bshd" and query_layer.shape[1] != value_layer.shape[1]: + use_flash_attention = False # Flash attention does not support max_seqlen_q != max_seqlen_kv + if use_flash_attention: if _NVTE_DEBUG: @@ -4750,13 +4762,13 @@ def forward( # Pre-allocate memory for key-values for inference # ================================================= - - inference_params.allocate_memory_for_kv_cache_if_empty( - self.layer_number, - self.num_gqa_groups_per_partition, - self.hidden_size_per_attention_head, - hidden_states.dtype - ) + if inference_params is not None: + inference_params.allocate_memory_for_kv_cache_if_empty( + self.layer_number, + self.num_gqa_groups_per_partition, + self.hidden_size_per_attention_head, + hidden_states.dtype + ) # ====================== # Query, Key, and Value From 4a2a936a88ff333beabfe2a42748df963906b3e5 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 29 May 2024 10:30:33 -0700 Subject: [PATCH 160/244] Tutorial fixes Signed-off-by: Pawel Gadzinski --- ...celerate_hf_gemma_finetuning_with_te.ipynb | 299 ++++++++++++++++++ .../tutorial_generation_gemma_with_te.ipynb | 210 +++++++----- 2 files changed, 434 insertions(+), 75 deletions(-) create mode 100644 docs/examples/te_gemma/tutorial_accelerate_hf_gemma_finetuning_with_te.ipynb diff --git a/docs/examples/te_gemma/tutorial_accelerate_hf_gemma_finetuning_with_te.ipynb b/docs/examples/te_gemma/tutorial_accelerate_hf_gemma_finetuning_with_te.ipynb new file mode 100644 index 0000000000..dcdd28c30a --- /dev/null +++ b/docs/examples/te_gemma/tutorial_accelerate_hf_gemma_finetuning_with_te.ipynb @@ -0,0 +1,299 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Accelerating a Hugging Face Gemma model with Transformer Engine" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In the previous [tutorial](../te_llama/tutorial_accelerate_hf_llama_with_te.ipynb), we demonstrated how to accelerate HF Llama models using the Transformer Engine library. We replaced `GemmaDecoderLayer` with `TransformerLayer` from the Transformer Engine, achieving a **25%** speedup. Furthermore, we conducted the finetuning in FP8 precision, which yielded an additional **39%** speedup from the baseline.\n", + "\n", + "Now, we will undertake a similar enhancement for the Google's [Gemma](https://blog.google/technology/developers/gemma-open-models/) model." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Dependencies for this tutorial\n", + "\n", + "Following files and media are necessary to effectively run this tutorial:\n", + "\n", + "1. `te_gemma.py`\n", + " - This file contains the code to load a Hugging Face Gemma checkpoint in Transformer Engine's `TransformerLayer` instead of Hugging Face's `GemmaDecoderLayer`. This is used in the following two sections of the tutorial - \"Improvement 1\" and \"Improvement 2\".\n", + "2. `utils.py`\n", + " - This file contains the code related to dataloading, hyperparameters, setting up model/optimizers/accelerator, model training and other miscellaneous tasks like restarting the jupyter notebook from within the cell. \n", + "3. `media/`\n", + " - This directory contains the images used in the following tutorial." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Differences between Llama and Gemma" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Thr Llama and the Gemma are very similar models - both are based on Transformer Decoder architecture. The most important architectural differences between them are the following:\n", + "\n", + "\n", + "| Feature | Llama | Gemma |\n", + "|----------------------------------------------|------------------------------------|--------------------------------------------|\n", + "| **Norm Layer** | Standard RMSNorm
$ y = \\frac{x - \\mathrm{E}[x]}{ \\sqrt{\\mathrm{Var}[x] + \\varepsilon}} * \\gamma + \\beta $ | RMSNorm with zero centered gamma parameter
$ y = \\frac{x - \\mathrm{E}[x]}{ \\sqrt{\\mathrm{Var}[x] + \\varepsilon}} * (\\textcolor{red}{1 +} \\gamma) + \\beta $ |\n", + "| **Embedding Dimension/Head Dimension** | 4096/4096 | 3072/4096 |\n", + "| **Activation Function** | SwiGlu | GeGlu |\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# [Baseline] Running HF `GemmaModel` (Precision: `BF16`)\n", + "\n", + "Similarly to the Llama tutorial, we begin the experiments by running baseline Hugging Face Gemma model finetuning in BF16 precision.\n", + "\n", + "

\n", + "\n", + "Note\n", + " \n", + "This tutorial loads and trains a Gemma 7B model which takes up most of the GPU memory and therefore, we need to restart the jupyter notebook each time before running the following sections. A small utility method `restart_jupyter_notebook` is defined in the accompanying `utils.py` file. This function restarts the jupyter notebook so that the GPU memory is flushed before the model is loaded again from the checkpoint in order to avoid running into OOM (Out Of Memory) errors.\n", + "\n", + "If the utility doesn't work, comment this line `restart_jupyter_notebook()` in the following cell and manually restart the jupyter notebook before running the cell. Repeat the same for other sections in this tutorial.\n", + "\n", + "
\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "10 finetuning steps complete!\n", + "\n", + "Average time taken per step: \n", + "298 \n", + "milliseconds\n" + ] + } + ], + "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "restart_jupyter_notebook()\n", + "\n", + "\n", + "# Import necessary packages and methods\n", + "from utils import *\n", + "\n", + "\n", + "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", + "## !!! `model_name` attr must point to the location of the model weights !!!\n", + "## Weights can be downloaded from: https://huggingface.co/google/gemma-7b\n", + "hyperparams.model_name = \"../../../../gemma-7b\" # <== Add model weight location here e.g. \"/path/to/downloaded/gemma/weights\"\n", + "hyperparams.mixed_precision = \"bf16\"\n", + "\n", + "\n", + "# Init the model and accelerator wrapper\n", + "model = init_baseline_model(hyperparams).cuda()\n", + "accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)\n", + "\n", + "\n", + "# Finetune the model\n", + "finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's add this information in a table and keep comparing it with a few possible improvements in future sections:\n", + "\n", + "| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n", + "|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n", + "| HF (baseline) | BF16 | 298 | 1 |" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# [Improvement 1] Replace HF's `GemmaDecoderLayer` with TE's `TransformerLayer` (Precision: `BF16`)\n", + "\n", + "We replace *GemmaDecoderLayer* with the highly tuned *TransformerLayer*, similarly to our approach in the [Llama tutorial](../te_llama/tutorial_accelerate_hf_llama_with_te.ipynb). Let's observe the impact this change has on the model's speed." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "10 finetuning steps complete!\n", + "\n", + "Average time taken per step: \n", + "257 \n", + "milliseconds\n" + ] + } + ], + "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "restart_jupyter_notebook()\n", + "\n", + "\n", + "# Import necessary packages and methods\n", + "from utils import *\n", + "\n", + "\n", + "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", + "## !!! `model_name` attr must point to the location of the model weights !!!\n", + "## Weights can be downloaded from: https://huggingface.co/google/gemma-7b\n", + "hyperparams.model_name = \"../../../../gemma-7b\" # <== Add model weight location here e.g. \"/path/to/downloaded/gemma/weights\"\n", + "hyperparams.mixed_precision = \"bf16\"\n", + "\n", + "\n", + "# Init the model and accelerator wrapper\n", + "model = init_te_gemma_model(hyperparams).cuda()\n", + "accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)\n", + "\n", + "\n", + "# Finetune the model\n", + "finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Compared to the \"baseline\" implementation, we see that using Transformer Engine's `TransformerLayer` in place of Huggging Face's `GemmaDecoderLayer` gives a speedup of **16%** even when using only BF16 precision!\n", + "\n", + "| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n", + "|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n", + "| HF (baseline) | BF16 | 298 | 1 |\n", + "| TE (replace `GemmaDecoderLayer` with `TE.TransformerLayer`) | BF16 | 257 | 1.16 |" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# [Improvement 2] Replace HF's `GemmaDecoderLayer` with TE's `TransformerLayer` (Precision: `FP8`)\n", + "\n", + "The last improvement is about enabling FP8 precision. Let's see how it works." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "10 finetuning steps complete!\n", + "\n", + "Average time taken per step: \n", + "214 \n", + "milliseconds\n" + ] + } + ], + "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "#restart_jupyter_notebook()\n", + "\n", + "\n", + "# Import necessary packages and methods\n", + "from utils import *\n", + "\n", + "\n", + "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", + "## !!! `model_name` attr must point to the location of the model weights !!!\n", + "## Weights can be downloaded from: https://huggingface.co/google/gemma-7b\n", + "hyperparams.model_name = \"../../../../gemma-7b\" # <== Add model weight location here e.g. \"/path/to/downloaded/gemma/weights\"\n", + "hyperparams.mixed_precision = \"fp8\"\n", + "\n", + "\n", + "# Init the model and accelerator wrapper\n", + "model = init_te_gemma_model(hyperparams).cuda()\n", + "accelerator, model, optimizer, train_dataloader, lr_scheduler = wrap_with_accelerator(model, hyperparams)\n", + "\n", + "\n", + "# Finetune the model\n", + "finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n", + "|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n", + "| HF (baseline) | BF16 | 298 | 1 |\n", + "| TE (replace `GemmaDecoderLayer` with `TE.TransformerLayer`) | BF16 | 257 | 1.16 |\n", + "| TE (replace `GemmaDecoderLayer` with `TE.TransformerLayer`) | FP8 | 214 | 1.39 |\n", + "\n", + "\n", + "After turning on FP8 precision, we get even more speedup of almost **39%**!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Conclusion\n", + "\n", + "As shown in the [Llama tutorial](../te_llama/tutorial_accelerate_hf_llama_with_te.ipynb), using the `TransformerLayer` module from Transformer Engine to replace Hugging Face's `GemmaDecoderLayer` results in a speedup compared to Hugging Face's native Gemma implementation." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## See more\n", + "\n", + "We also prepared [tutorial](./tutorial_generation_gemma_with_te.ipynb) in which we will show how to speedup the Gemma model generation using Transformer Engine." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb index ce8f301ddc..35afbd2447 100644 --- a/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb +++ b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb @@ -2,81 +2,79 @@ "cells": [ { "cell_type": "markdown", - "id": "8581f0e4", "metadata": {}, "source": [ - "# Accelerating Generation of the Hugging Face Gemma Model with Transformer Engine\n", + "# Accelerating token generation of the Hugging Face Gemma Model with Transformer Engine\n", "\n", "Generative AI has made remarkable strides in recent years, with Large Language Models (LLMs) like ChatGPT at the forefront. These models have revolutionized how we interact with machine-generated content, providing capabilities that range from writing assistance to complex decision support. The core functionality of these models is the generation process, which involves predicting the next token in a sequence based on the preceding text. This task is critical for applications such as automated content creation, translation, and more, emphasizing the importance of efficient implementation.\n", "\n", + "\n", + "\n", + "
\n", + "\"\"
\n", + "Animation 1. Hugging Face Gemma model token generation.\n", + "
\n", + "\n", "For those seeking a deeper understanding of text generation mechanisms in Transformers, it is recommended to check out the [HuggingFace generation tutorial](https://huggingface.co/docs/transformers/llm_tutorial).\n", "\n", - "In our previous tutorials on [Llama](../te_llama/tutorial_accelerate_hf_llama_with_te.ipynb) and [Gemma](./tutorial_accelerate_hf_gemma_with_te.ipynb), we demonstrated how finetuning can be accelerated using the Transformer Engine's `TransformerLayer`. Building on this foundation, our current objective is to enhance the generation speed of the Gemma model.\n", + "In the previous tutorials on [Llama](../te_llama/tutorial_accelerate_hf_llama_with_te.ipynb) and [Gemma](./tutorial_accelerate_hf_gemma_with_te.ipynb), it was demonstrated how finetuning can be accelerated using the Transformer Engine's `TransformerLayer`. Building on this foundation, the current objective is to enhance the generation speed of the Gemma model.\n", "\n", "This tutorial will introduce and explain several advanced features of the Transformer Engine that contribute to this goal:\n", "\n", "##### 1. THD Attention Layout.\n", "\n", - "Addressing the challenge of computing attention for sequences with varying lengths, a common method is to pad these sequences and apply an attention mask. The Transformer Engine, however, offers a more optimized approach—by specifying the lengths and offsets of the sequences, attention can be computed directly. Instead of passing the matrix and mask with the shape `[b, s, h, d]`, one can pass a matrix of the shape `[t, h, d]` along with tensors detailing sequence lengths and offsets to run the attention optimized for this case. This specific attention layout is referred to as the **THD layout**.\n", + "Addressing the challenge of computing attention for sequences with varying lengths, a common method is to pad these sequences and apply an attention mask. The Transformer Engine, however, offers a more optimized approach—by specifying the lengths and offsets of the sequences, attention can be computed directly. Instead of passing the matrix and mask with the shape `[b, s, h, d]`, one can pass a matrix of the shape `[t, h, d]` along with tensors detailing cumulative sequence lengths and offsets to run the attention optimized for this case. This specific attention layout is referred to as the **THD layout**.\n", "\n", "
\n", - "\"\"
\n", - "Fig. 1. The sequences and the mask for standard attention layout - padding from the end.

\n", - "\"\"
\n", - "Fig. 2. The sequences and the mask for standard attention layout - padding from the beginning.

\n", - "\"\"
\n", - "Fig. 3. An attention with thd layer.

\n", + "\"\"
\n", + "Fig. 1. The difference between BSDH (default) and THD attention layouts is as follows: with BSDH, we need to provide the attention mask, while with THD, we need to provide cumulative sequence lengths and sequence offsets.

\n", "
\n", "\n", "##### 2. CUDA Graphs API.\n", "\n", - "The speed of GPUs is increasing at a rapid pace. It turns out that sometimes the runtime of kernels is shorter than the time it takes for the CPU to submit them, which can lead to significant overhead. CUDA Graphs were developed to address this issue. When certain kernels are executed repeatedly, this tool allows us to record and replay them without CPU involvement. This becomes particularly useful in applications like text generation, where a `TransformerLayer` is run for every token that needs to be generated.\n", + "The speed of GPUs is increasing at a rapid pace. It turns out that sometimes the runtime of kernels is shorter than the time it takes for the CPU to submit them, which can lead to significant overhead. CUDA Graphs can address this issue. When certain kernels are executed repeatedly, it allows us to record and replay them without less CPU involvement. This becomes particularly useful in applications like token generation, where a `TransformerLayer` is run for every token that needs to be generated.\n", "\n", "We recommend reading further about CUDA Graphs [here](https://developer.nvidia.com/blog/cuda-graphs/).\n", "\n", - "PyTorch exposes graphs via a raw `torch.cuda.CUDAGraphclass` and two convenience wrappers, `torch.cuda.graph` and `torch.cuda.make_graphed_callables`. More information about the cuda graphs in Pytorch can be found [here](https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/).\n", - "\n", - "Transformer Engine supports cuda graphs from version 1.5.\n", - "\n", + "PyTorch exposes graphs via a raw `torch.cuda.CUDAGraph` class and two convenience wrappers, `torch.cuda.graph` and `torch.cuda.make_graphed_callables`. More information about the cuda graphs in Pytorch can be found [here](https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/).\n", "\n", "
\n", "\"\"
\n", - "Fig. 4. CUDA Graphs speedup.

\n", + "Fig. 2. CUDA Graphs allow us to reduce the overhead generated by the long time it takes to launch a single kernel. They enable the recording and replaying of subsequent launches, thus reducing the total time used by the CPU.

\n", "
\n", "\n", "\n", "##### 3. FP8 Weights Calibration.\n", "\n", - "Assuming that we have a model trained in FP32/BF16 precision and we wish to execute it in FP8 precision, the process isn't straightforward due to the absence of appropriate FP8 scaling factors. In this scenario, FP8 calibration becomes essential. By conducting several forward passes on sample data, we can compute the FP8 saling parameters. This calibration allows the model to operate correctly in FP8 precision.\n", - "\n", - "We highly recommend familiarizing yourself with the [tutorial](../../examples/fp8_primer.ipynb) on FP8 precision to understand the importance of proper scaling factors.\n", + "Assuming that the model is trained in FP32/BF16 precision and the goal is to execute it in FP8 precision, the process isn't straightforward due to the absence of appropriate FP8 scaling factors. In this scenario, FP8 calibration becomes essential. By conducting several forward passes on sample data, the FP8 scaling parameters can be computed. This calibration allows the model to operate correctly in FP8 precision.\n", "\n", + "It is highly recommended to familiarize oneself with the [tutorial](../../examples/fp8_primer.ipynb) on FP8 precision to understand the importance of proper scaling factors.\n", "\n", "
\n", "\"\"
\n", - "Fig. 5. The weights calibration.

\n", + "Fig. 3. \n", + "If the model is trained in BF16/FP32, it does not include the computed FP8 scaling factors. When it is run under fp8_autocast(), the value of these scaling factors will default to their initial values, which can cause numerical errors. Weight calibration involves calculating FP8 scaling factors from higher precision forward passes. Once these factors are computed, the numerical errors should be resolved.

\n", "
\n", "\n", "##### 4. FP8 Model Weights.\n", "\n", - "The typical approach is to store weights in higher precision and then cast them to fp8 before operations. This is critical during training, as it allows us to store some values in high precision to avoid performance drops. However, for inference, this level of precision is not necessary.\n", - "\n", - "The TransformerEngine includes a feature called `fp8_model_init`, which allows for the creation of models that store only the FP8 copy of the weights. This eliminates the need to cast from higher precision to BF16, saving time on this casting process. Additionally, it helps reduce memory consumption, which can be used to increase the batch size, resulting in even greater speedup.\n", + "The typical approach is to store weights in higher precision and then cast them to fp8 before operations. This may prevent accuraccy drops in training. However, for inference, this level of precision is not necessary.\n", "\n", + "The TransformerEngine includes a wrapper `fp8_model_​init`, which allows for the creation of models that store only the FP8 copy of the weights. This eliminates the need to cast from higher precision to BF16, saving time in this casting process. \n", "\n", "
\n", "\"\"
\n", - "Fig. 6. Saving memory with fp8_model_init().

\n", + "Fig. 6. Model under fp8_autocast() stores weights in high precision by default, and casts them if needed. It can leads to slowdown and increased memory usage. Using fp8_model_init() results in storing weight in FP8.

\n", "
\n", "\n", "#### Benchmarking\n", "\n", - "We'll evaluate the generation time across one benchmark: generation with context phase max sequence length = 128, batch size = 64 and number of generated tokens = 1024 - 128.\n", + "We'll evaluate the generation time across one benchmark: generation with context phase max sequence length = 128, batch size = 64 and number of generated tokens = 896 on random texts with random lengths.\n", "\n", "
\n", "Note\n", " \n", - "This tutorial focuses on showcasing the mentioned features of Transformer Engine in the context of generation. It's important to note, however, that NVIDIA provides another library, [TensorRT](https://developer.nvidia.com/tensorrt), which is optimized for inference tasks and should be considered for such use cases.\n", + "This tutorial focuses on showcasing the mentioned features of Transformer Engine in the context of generation. It's important to note, however, that NVIDIA provides [TensorRT](https://developer.nvidia.com/tensorrt), which is optimized for inference tasks and should be considered for such use cases.\n", "
" ] }, @@ -101,10 +99,22 @@ " - This file contains logic of mapping the parameters from `GemmaDecoderLayer` into the `TransformerLayer`.\n", "3. `utils.py`\n", " - This file contains the code related to dataloading, hyperparameters, setting up model/optimizers/accelerator, model training and other miscellaneous tasks like restarting the jupyter notebook from within the cell. \n", + "4. `requirements.txt`\n", + " - Contains necessary Python packages for this tutorial\n", "4. `media/`\n", " - This directory contains the images used in the following tutorial." ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "31390c76", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -r requirements.tst" + ] + }, { "cell_type": "markdown", "id": "e8dfabbf", @@ -118,7 +128,8 @@ "id": "59560bff", "metadata": {}, "source": [ - "HuggingFace Transformers library offers generation API. We will use HuggingFace generation for the Gemma model as our baseline." + "HuggingFace Transformers library offers generation API. \n", + "We will use HuggingFace generation for the Gemma model as our baseline." ] }, { @@ -180,14 +191,76 @@ }, { "cell_type": "markdown", - "id": "2bbf3d47", + "id": "8bb40f45", + "metadata": {}, + "source": [ + "## [Iprovement 1] Using TransformerLayer from Transformer Engine instead of GemmaDecoderLayer." + ] + }, + { + "cell_type": "markdown", + "id": "fecde0c0", + "metadata": {}, + "source": [ + "
\n", + "\"\"\n", + "Fig. Each GemmaDecoderLayer is substituted by a tuned TransformerLayer from the Transformer Engine.\n", + "
" + ] + }, + { + "cell_type": "markdown", + "id": "263b40f2", + "metadata": {}, + "source": [ + "As in the [Gemma](./tutorial_accelerate_hf_gemma_with_te.ipynb) finetuning tutorial, we substitute GemmaDecoderLayer by a tuned TransformerLayer from the Transformer Engine." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9dceef93", "metadata": {}, + "outputs": [], "source": [ - "## [Improvement 1] Speeding up generation by using Transformer Engine with THD attention\n", + "from utils import *\n", "\n", - "Similarly to the Gemma tutorial, we substitute `GemmaDecoderLayer` with `TransformerLayer` from Transformer Engine. \n", + "hyperparams.model_name = \"\"\n", "\n", - "Input sequences can have various lengths. The most common approach is to use the padding and attention masks in such situation. We will use more straightforward method - using the THD attention layout with offests. \n", + "model = init_te_gemma_model(hyperparams)\n", + "\n", + "print_sample_of_generated_texts(model)\n", + "benchmark_generation(model)" + ] + }, + { + "cell_type": "markdown", + "id": "b5d40836", + "metadata": {}, + "source": [ + "We have obtained speedup of **x%**." + ] + }, + { + "cell_type": "markdown", + "id": "006d18e8", + "metadata": {}, + "source": [ + "\n", + "| Models | Time (s) | Speedup | \n", + "|-------------------------------------------------------------|---------------------------------------|--------------------------------------|\n", + "| HF (baseline) | 82.04 | 1 |\n", + "| TE | | | " + ] + }, + { + "cell_type": "markdown", + "id": "2bbf3d47", + "metadata": {}, + "source": [ + "## [Improvement 2] Use of THD attention layout.\n", + "\n", + "Input sequences can have various lengths. Hugging Face generation - as can be seen in Animation 1 - pads the sequences and then uses attention mask. The THD attention layout is faster, but less flexible. Instead of attention mask, cumulative sequence lengths and offsets need to be provided.\n", "\n", "
\n", "\n", @@ -205,10 +278,12 @@ "seq_offsets_q = [0, 5, 10, 15, 20, 25] * h * d
\n", "seq_offsets_k = [0, 7, 14, 21, 28, 35] * h * d
\n", "seq_offsets_v = [0, 7, 14, 21, 28, 35] * h * d
\n", + "

\n", + "Fig. Example of arguments related to THD attention layout that need to be passed to transformer_engine.pytorch.DotProductAttention().\n", "
\n", "\n", - "The class `transformer_engine.DotProductAttention` supports this format. One need to pass the following things as the arguments to the forward:\n", - "- `seq_offsets_q`, `seq_offsets_k`, `seq_offsets_v` – which represents the offsets of the beginnings of the next sequences,\n", + "The class `transformer_engine.pytorch.DotProductAttention` supports this format. One need to pass the following things as the arguments to the forward:\n", + "- `seq_offsets_q`, `seq_offsets_k`, `seq_offsets_v` – which represent the offsets of the beginnings of the next sequences,\n", "- `cu_seqlens_q`, `cu_seqlens_kv` – cumulative sum of the lengths of the sequences of query and values,\n", "- `max_seqlen_q` – maximum sequence length in query layer,\n", "- `max_seqlen_kv` – maximum sequence length in key-value layer.\n", @@ -216,10 +291,10 @@ "
\n", "\n", "Note\n", - "Currently, the THD attention for `TransformerLayer` is supported only for inference.\n", + "Currently, the THD attention for `TransformerLayer` is supported only for token generation.\n", "
\n", "\n", - "Let's look how using TransformerEngine with THD attention impacts the speed of generation:" + "Let's look how using TransformerEngine with THD attention impacts the speed of token generation:" ] }, { @@ -265,14 +340,12 @@ "from utils import restart_jupyter_notebook\n", "restart_jupyter_notebook()\n", "\n", - "# Import necessary packages and methods\n", "from utils import *\n", "\n", - "hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", + "hyperparams.model_name = \"\" # <== Add model weight location here.\n", "hyperparams.qkv_format = \"thd\"\n", "\n", - "# Init the model and accelerator wrapper\n", - "model = init_te_gemma_model(hyperparams).cuda()\n", + "model = init_te_gemma_model(hyperparams)\n", "\n", "print_sample_of_generated_texts(model)\n", "benchmark_generation(model)" @@ -296,7 +369,7 @@ "id": "21a89d9c", "metadata": {}, "source": [ - "## [Improvement 2] Speeding up generation with CUDA Graphs" + "## [Improvement 3] Speeding up generation with CUDA Graphs" ] }, { @@ -337,7 +410,7 @@ " return graphed_function\n", "```\n", "\n", - "We strongly recommend reviewing the entire code of the class `TEGemmaForCausalLMCudaGraphs`. Let us now proceed to evaluate the performance improvement offered by CUDA Graphs." + "It is strongly reccomended to review the entire code of the class `TEGemmaForCausalLMCudaGraphs`. Let's now proceed to evaluate the performance improvement offered by CUDA Graphs." ] }, { @@ -390,7 +463,7 @@ "hyperparams.cuda_graphs_static_batch_size = 64\n", "hyperparams.cuda_graphs_static_max_seq_len = 1024\n", "hyperparams.cuda_graphs_static_max_context_len = 128\n", - "model = init_te_gemma_model(hyperparams).cuda()\n", + "model = init_te_gemma_model(hyperparams)\n", "\n", "print_sample_of_generated_texts(model)\n", "benchmark_generation(model)" @@ -421,10 +494,10 @@ "
\n", " \n", "\"\"
\n", - " Fig. 7. Without CUDA Graphs. We can see that GPU(blue) is idle for most of the time.\n", + " Fig. 7. Without CUDA Graphs. We can see that GPU (blue) is idle for most of the time.\n", "


\n", "\"\"
\n", - " Fig. 8. With CUDA Graphs. We can see that GPU(orange) is utilized.\n", + " Fig. 8. With CUDA Graphs. We can see that GPU (orange) is utilized.\n", "
\n", "
" ] @@ -434,7 +507,7 @@ "id": "e6b171a0", "metadata": {}, "source": [ - "## [Improvement 3] Running generation in FP8 of the model trained in higher precision " + "## [Improvement 4] Running generation in FP8 of the model trained in higher precision " ] }, { @@ -442,15 +515,15 @@ "id": "1a80288b", "metadata": {}, "source": [ - "We are now preparing to execute FP8 generation using the Gemma model. However, this process is not straightforward. Since the model was originally trained with BF16 precision, the FP8 scaling factors have not been computed. Operating the model at such low precision without the correct scaling could result in significant numerical errors, which in turn would produce incorrect results.\n", + "Implementing FP8 generation with the Gemma model is not straightforward, because it was initially trained using BF16 precision, and the necessary FP8 scaling factors are missing. Running the model at this lower precision without proper scaling could lead to significant errors and incorrect results.\n", "\n", - "We highly recommend familiarizing yourself with the [tutorial](../../examples/fp8_primer.ipynb) on FP8 precision to understand the necessity of scaling.\n", + "It is highly recommended to familiarize oneself with the [tutorial](../../examples/fp8_primer.ipynb) on FP8 precision to understand the necessity of scaling.\n", "\n", "##### Weight Calibration\n", "\n", - "To address the issue outlined above, we will implement weight calibration. This involves running several forward iterations at BF16 precision within the context `te.fp8_autocast(enabled=False, calibration=True)`. This setup allows the forward pass to operate at higher precision, while we simultaneously collect `amax_history` and other parameters related to the FP8 precision, which is essential for calculating the FP8 scaling factors.\n", + "To address the issue outlined above, weight calibration will be used. This involves running several forward iterations at BF16 precision within the context `te.fp8_autocast(enabled=False, calibration=True)`. This setup allows the forward pass to operate at higher precision, while simultaneously collecting `amax_history` and other parameters related to the FP8 precision, which are essential for calculating the FP8 scaling well.\n", "\n", - "The code below outlines the steps to initialize the BF16 model and conduct several forward iterations within the specified context. After these iterations, we save the model, and these weights will be utilized in subsequent chapters." + "The code below outlines the steps to initialize the BF16 model and conduct several forward iterations within the specified context. After these iterations, the model is saved, and these weights will be utilized in subsequent chapters." ] }, { @@ -466,7 +539,7 @@ "hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", "hyperparams.fuse_qkv_params = True # This is needed by the last improvement.\n", "\n", - "model = init_te_gemma_model(hyperparams).cuda()\n", + "model = init_te_gemma_model(hyperparams)\n", "\n", "# Calibration\n", "with te.fp8_autocast(enabled=False, calibrating=True), \\\n", @@ -543,10 +616,10 @@ "hyperparams.cuda_graphs_static_batch_size = 64\n", "hyperparams.cuda_graphs_static_max_seq_len = 1024\n", "hyperparams.cuda_graphs_static_max_context_len = 128\n", - "model = init_te_gemma_model(hyperparams).cuda()\n", + "model = init_te_gemma_model(hyperparams)\n", "\n", "print_sample_of_generated_texts(model)\n", - "benchmark_generation(model, measure_memory=True)" + "benchmark_generation(model)" ] }, { @@ -554,11 +627,11 @@ "id": "8cdbb56c", "metadata": {}, "source": [ - "We can observe that the outputs are coherent; however, the generation time has increased. Why is this the case? \n", + "One can observe that the outputs are coherent; however, the generation time has increased. Why is this the case?\n", "\n", "Running the model in FP8 does not imply that all weights are stored in FP8. By default, they are stored in higher precision and are cast to FP8, using saved scaling factors, before operations such as GEMMs.\n", "\n", - "This approach is beneficial during training: we can perform one cast for both backward and forward passes, leading to speedups. However, performing a single cast for each forward pass introduces too much overhead to achieve a speedup. We will address this issue in the next section of the tutorial.\n" + "This approach is beneficial during training: one can perform one cast for both backward and forward passes, leading to speedups. However, performing a single cast for each forward pass introduces too much overhead to achieve a speedup. This issue will be addressed in the next section of the tutorial." ] }, { @@ -566,7 +639,7 @@ "id": "8d3945e3", "metadata": {}, "source": [ - "## [Improvement 4] Reducing memory usage with the fp8_model_init()" + "### Use of only FP8 model weights" ] }, { @@ -574,15 +647,15 @@ "id": "2dd0cba9", "metadata": {}, "source": [ - "TransformerEngine stores parameters in higher precision and only casts them to FP8. It is also true with the optimizer state. It is needed to maintain accucacy during training. However, we can get rid of high precision weights when doing inference. \n", + "TransformerEngine stores parameters in higher precision and only casts them to FP8. It may be necessary to maintain accucacy during training. However, we can get rid of high precision weights when doing inference. \n", "\n", - "Transformer Engine supports maintaining only FP8 copy of weights with `fp8_model_init` decorator. Let's see an example\n", + "Transformer Engine supports maintaining only FP8 weights with `fp8_model_init` decorator. Let's see an example\n", "```\n", "with te.fp8_model_init(enabled=True):\n", " linear = te.Linear((1024, 1024)) # this module is initialized only with fp8 weights\n", "```\n", "\n", - "Now we can try to use `fp8_model_init` in out code and look at the memory usage." + "Let's run the code with `fp8_model_init`:" ] }, { @@ -634,10 +707,10 @@ "hyperparams.cuda_graphs_static_batch_size = 64\n", "hyperparams.cuda_graphs_static_max_seq_len = 1024\n", "hyperparams.cuda_graphs_static_max_context_len = 128\n", - "model = init_te_gemma_model(hyperparams).cuda()\n", + "model = init_te_gemma_model(hyperparams)\n", "\n", "print_sample_of_generated_texts(model)\n", - "benchmark_generation(model, measure_memory=True)" + "benchmark_generation(model)" ] }, { @@ -645,8 +718,6 @@ "id": "3e30ca5a", "metadata": {}, "source": [ - "We finally obtained the **6.74x** speedup.\n", - "\n", "| Models | Time | Speedup | \n", "|-------------------------------------------------------------|---------------------------------------|--------------------------------------|\n", "| HF (baseline) | 82.04 | 1 |\n", @@ -654,7 +725,7 @@ "| THD attention + Cuda Graphs with TE | 16.81 | 4.88 | \n", "| THD attention + FP8 with TE + fp8_model_init() | 12.18 | 6.74 | \n", "\n", - "Moreover the memory usage dropped from *63.82 GB* to the *56.60 GB*. We can potentially use that to increase batch size to obtain even larger speedup." + "We finally obtained the **6.74x** speedup." ] }, { @@ -665,17 +736,6 @@ "## Conclusions" ] }, - { - "cell_type": "markdown", - "id": "824129be", - "metadata": {}, - "source": [ - "
\n", - "\n", - "\"\"\n", - "
" - ] - }, { "cell_type": "markdown", "id": "7bb2452d", @@ -687,7 +747,7 @@ "3. FP8 weights calibration,\n", "4. Models containing only FP8 version of their parameters.\n", "\n", - "Each of these features can be applied in various contexts, and here we demonstrated their use for achieving fast inference. It's important to note that the fastest possible inference speeds can be achieved using NVIDIA's inference-optimized [TensorRT](https://developer.nvidia.com/tensorrt) library." + "Each of these features can be applied in various contexts, and here we demonstrated their use for achieving fast token generation. It's important to note that the fastest possible inference speeds can be achieved using NVIDIA's inference-optimized [TensorRT](https://developer.nvidia.com/tensorrt) library." ] } ], From 3222fde4141eb17d1521aacda4eca5c0b647f126 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 29 May 2024 10:31:29 -0700 Subject: [PATCH 161/244] Tutorial fixes Signed-off-by: Pawel Gadzinski --- .../te_gemma/media/generation_animation.gif | Bin 0 -> 140610 bytes docs/examples/te_gemma/media/substitution.png | Bin 0 -> 78210 bytes docs/examples/te_gemma/media/thd_bshd.png | Bin 0 -> 122620 bytes 3 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 docs/examples/te_gemma/media/generation_animation.gif create mode 100644 docs/examples/te_gemma/media/substitution.png create mode 100644 docs/examples/te_gemma/media/thd_bshd.png diff --git a/docs/examples/te_gemma/media/generation_animation.gif b/docs/examples/te_gemma/media/generation_animation.gif new file mode 100644 index 0000000000000000000000000000000000000000..d6bf22b8e59ade65dfae46d55c80e1c9f47ecffc GIT binary patch literal 140610 zcmcGVRZ|^IvxPVAuyHmT+}$lW1b26LcLZ+cL zRn@DjpOKQ5|D}a| zeSJYwu+PlQq@<*rpP&C1_VMv)Y-~hFM-K}N>+kO`E-u#8)RdBv+T7eECnx9N;IOu~ zzP-KmXaxiZ2aAb`wYIj-&CSKc#E_Dbva+(Os;XvXW$oclb|Vy?q4+1^@aT@+b6fSa`(0$f)R;*tqzF#H8eu)U@=B%&hF3+`RmP!lL4m z(z5c3%Bt#`+PeCN#-`?$*0%PJ&aUpB-oE~U!J*-i(XsJ~$*Jj?*}3_J#ieDDmDRQN zjm@p?o!!0tgTtfalhZTs`Niec_08?w{lnwa^ULeo`^V=O01A=tM@>%eFL+Eky^)&S z{!k#fOvaDeJg&hAJa(Ink=p#>=x<`-gwk~dqwzGVWqPA^h2zOArehh>^+l8EJg#RO zqxHqp*}}gNiDVi|m}m2U#M9}IHI&X5E9J{%$~2ZOmTT49Y>qXSFIOA&hZD&*Rjk%q z&6nwqH&w1TQ#H<)Z|1TWW^j%OLmw>Mp^x4440Cfl2@wtIdd zktlSu-0Y7m!fH)*wB8-hXcrKLC9PuJHLQua7MqeG^}<*@c4>% z{;0BsriUg5TBZLbOd7!oYske53$LrS+KUixg-`z%IdhU0$#OJ;8C7wU@;AB+qI5ro z10gIeHsqUiTAa8n*C1dlP);szcT00c@9{-4s|+_TfNgA$t%0L*|5fD&gQyoV8CAT5xh%461+ zxEALd0E2Ah3;zR^$423AO-`0h;(PKnkW7m0+^^KXs!ttFRD&p%M=8%b}W5 z+QH%D@;tINIE2c!6-!gPfKu6S!j72GUcmG2FVr2_CYXVC`}Qp=2J{x?6M~ASA3Zuf zZIUTSH~1K;a=)SGrqZgaI>_V@`rQTWehD(IcYBx4-812 zf5MOgP&1A-ruiZ@^V3*@_Zwp)oxhh+PkiZDWi>M@2|!5fobBqNG;9o1n6a)9dYCtL z7^wY`1AFi}vJB(4&kT{9Zk`Ir$3K8TdT7Ak^*VGV^4SK0QaE-ViX16)Uji+(T~26g zD3-fZiIs8j+5q}p&Q_a_&X?UR>aXc@2>^2q)My^+b@x;n|HnXyFU#F>2U>0`6s+A< z(4puGe`oz6cLelv1g=tV$q#7y-+;AfTj9%_VLqX{d|SD%2{d6{j2@{PNP$3zO#(;= zaiuSaP$DH4QLjy44sHhpX|&T-Sxzf)ppbK8rERSZ;eqG~G~(7cJT#s#lr?NOgins& zRsZ%=A8y4BtqAlxdk&mG719yyRcQ0|E)d7Eb+8b{{1__mp{SS@@&IUrwzIn)$SwGn z6y4|qDsZ6g2eh+3bI|SGUboZoZx#bIKa&ic&!oSD8>8nwm$K>5jVCqG3a7y}{p9pE za01ZRPX4`pBmlz4;YgT>pE4zZWatQhJCWKL>vzqkB6ulyn_55niV$GX$f%ZGT(A*= za0PTzEi^@sEjEieWCbNa6=TApr~|K;_&>R735PVvtkb_G1E6@Ef{Vlp>2y6f$nCr$ zoymYb0GE|cd@A}BP6ku>?9@QWy<6Lv^L1d!2?Wv#8wBbd9=e4_klB`tg?)Su!iCpx zU!+`gISXhEk@olIsI8vd%1$| zFH96dzsW%1L%0&SrpV_f&!y_0yF+W~Eijx+N!wwdTaceZ~N<}6k) zMtQO`1}{Gzbq44&IM&8xEPEv{_)YGW;zWXJxk|aSZlYsP`-$i0e15kP8(RpFYaRak zR#no7TEt8#O7t|=LRK;hACZ3yT=~;yg&!0n9;2n+3ERvInNy5Upr=bv4&Or6+h87; ztRjmy@<)RO0a}+1%MM#b8{4>6lBG;KQbNmKB_;xskcd%IptpF!<6{t&JGF_W?cA+j9s1WQ4S`mo{+3Tz0!^mP%m< zu_lNe>;cSnq7VdS3xHZ15*oYlKYd-mZP=xrJdG1h*Xh`x^K8A!C96lP*Bo#8TD^eG9i@~LTFCde}9cRM27jwyBxT~rbycIJ72l^86TQnN)XR?%Z92<;lVzJtr{VPIbV5cZThHD2}77Y+0v)^!=)3n!~juIYjN zqoJ8kxlg9Hb+57zH6E(gSIl48bj=1~9{p(za~u*_?x}w=GsV;IP5At6$?r5q3f2XE zR`uA2XV)8t4UL|pEhGH{UN_@%1$#;Yum7>pyqa&%TJMHzgnxW}s~+>qXKlCC`eIDT zr-L-ED#Sh-2O=IPxhGI_k>cd_$e)=zILzpkny}+U(ZlwqeVB{d_4nHqq4>+IZwu&I zHG}z@WV4h@#$*8I742;o@`z6J+h0!|Q~qAhVKF|W0wnti`8n}}Rkg2B-_C#KTe~*< zB}fM5IC%mBHQ?RF`WDTVge6QR&BS(Cp7Y%=J)v_3O#`reG$dWViMv-j`K@vXyRqst1nkeT^@FD~H(rA1?ABN6z^PSX9U*xmQZ+lqbOtL$4aMamyX z7XFY%=)1ZAYA;?xSH&C)|4(%*H+Vhg7Ms6=o*q0-{4G9To^EqVUQWNXlZ)l|Tnw)Z zecU7^IG*5MEO_Ks5aHond;2U+Lwpqi*(~F^P@Uv-51poA#QbPfJ*@(8hl54GF}X;o zAUq+)m;Cu^&_b@q^!+j8Cnl^=%|IZqVc)tp0a=Ib_mA24W~&^jkQTlL^d% z+;1(HATGD?;&70qAi|D$%Wr^viSWCc?3SCH)S)u)FT9&(n6I0&fti%zaj0?+>%zQi zxP`XE5YOiU=aI^<;RV;C2Eh)@KWE%JFh+n+3auv#KUl2@WKb{?2<}eB|3Fed<%gg| zjHEKAbQ7EODQ-kmf`kq?G`{z5BaX;D?f^k6=Wl-_ox}i@TzsfKCXQ}`@ICxl?s66@ zlxPU>R#rxAe?8s3w0p^76>GR$HN1Z0#KHra&z8f#hGLg_ereyu?Z?F+mPVnZCPsw@ z2lfahCu!12`0 zKo{@LG7MP96k~i2>1p=me}NO!hG{7gzNnY%y-idCFs?&Ez;jp|<3v5rC2VmfU+!D| zNeM+)6_Waql9rI7lbV)alcF6PMG~A=21?KBNz>X*YZghbOHEe^OxIwIu=q zz?*>$n{H^5;mhLP9F{S#ni1WTu5O%8-j^|@omq$t&;t?7X%oy?XD+Q~5{YJd0kTv? zvbLl%ju5cdQ_1$cvo1&DFTy~5r@xz6v#%7qvlE`vdc4h6O0%C|vXf)76l=5UNpk>e zF=M=$Z@hs>YvecHIY2}>*flKZH8M2(T)Z||Odm2_J~)ClERwX`72G^&}#QAU?HnYfzn#`_-LU~T9LuGA~DuN zZqXu}*FvGOEUWMocOQH^{Nk+mVv*gV%HCohKJLI*O#ibSPyCX8@7+b}cnDN54HfNf=b zvt^~Ed_p>^$*gCp50=5VYYF$be{`wfr zdTQtT+}?V=a)SRnC+_%f5PglX$atPuIwh|x48vQ|kMstzwg&RudfmAO6?$96w_>gK zbbWqYwe-e0nVN3tCKG~UV*(6~H;}b&lag4c*o^ z!>7IZskQMBezP}yi_Kay@lA7PZp+{C>L8?+KeENqNJUZO=n~^@vFo7J@4Bte-mL^d z%?aREMZ`Az-WGF{=7RQO3!CZ}>xy#Uf}(Wvl6Fu%L3>kqo1=2ObzXZ5Jx8-IdM1BO zGi!%fH#no$7Cn!@EZw)GHnju0w>=f2bLuUx(zm4=+_u=Bx4@4+_eRy#-dR4^iGrFl zzR}bHqul{9nGQEOWFQ@3jp})(dWV3JK;I5{P}|OebECpYFIedvTbFR?Nq2~KW6+{} zs5N!y!)ah?e8`DV(`5oprNYlc0OTwFa~Lma_|0jM3~8i|cH~uf*w1C?j~=0^U6mh0 z?!OJ}zZ)Y~0;77lBRt`wNyvnW|IpF|XfifNyKY9Gu7^F1#vTL5gbFsK%k&$Wt1%6Oj;;XB7qKgqeH)I&TE> z?&PqS&U0S`u-^oJPFFxq<2Ft@BG0^1&0OqEov%;8X3`v0)I12Zp;WrTf6$_TjAP2r ze&}W4W=??wZLvF1zdEQ1h_J~;K4t?k=FWiQoB!r?bmz`iXSL4fm?tOo#%Jtu=1z0w zIqWN01asK*M+AvRgf6fUJMjbr7Y=k5;%#i_WxGtgQ&Fg+@-XqpYtmEWOdMr@E|3^sY7ep?6iH{L$~} z_*f4>-&keYNLSey+Qc1+L@B*kX|uiC5X;k=-%+wj-XugNllaMa+$ywT??%hNDnI;J-suMFYA-{_0%EAjYHS z|GtOUxp{7~!e;_cD~<&c#uJN5F>ZsuCg~- zvuE0KFtKwWATzbg{kaDtc(`hFz#n{Y{O=I^XD8>Uz53;$806u&@b)&{z8UeZ<|Vv^ z0k=-qQ3ByHAL-#{*^$5Tv6zd)me=LRn*Ne+7P(^*r^9}yDO0omx^R6yVbTnc^Z7qdTM5oM-j9u=&Jb13vilE# zm1k95BP`z!%r?&&6-a8nBWGOVc6NbBFHSjc!2Sy0zNrbPba1`%nZRh&z{pg8#n#^R zR?R?Hb(-pJv zm9ElFXU>f%4$0Ql_rRU4Lz3!i(tu1Pv?e^m+d+ft_be$|W?aJQ+dh`Nc9xrR;yVsU zT&`*)o@`vf?z>Lleap^`w(xzk(Y>f>z&UlcFS!s{mbrocs|KZepO56h6MWlfzYib% zpj!>5_IYUMe=u;YGjn`2nTG!1h-7Q{=;;1PmwV5g{%BA7#Oq*)YcHht_W9u1y`+D2 zkIwNoD~R`azISo@&!i2s zo&&?@Eko$=tk+Y+Pu9&(^P4+X-_OhL-ZHc`IK$Tmq0W1uRbe0iIZR?7Jyne-x8m~?fYai2-w5cY)m8~z1bY@pV7^J z#E$Ysad=bVl*dtP`%AB4SXj%YvRgB?YhhBSB%VVOj=!iL%jbxN;c(ix9V=%@qp+&; z&7Z24s+MXnE>GDkmHcY5Kv`Y1)v7feO@q|s+FSi@Gs|bTj^m)$>I&{C1Wz+b<$9VP z9lqf@8TBeYqT;-$s2lfxLud9dddG7yqwuFRRPR8rHIGT)^n84OA-9}a6rZOR)LFA$ zDA#T_NM7X>TXeP<#?e^!aM*1BwR@B~?wP*RZI43jPw4I9ot&7@UUlSsH&9og+3c^c z>2{upZ)(=1uj_feZ!ntNT(J56a<$}B#2vcj_w(g)ce+`6li&Ys;q2<^>x*Xv@UN&l z=bm6D6%t)9Mr0bxPN^D7&>PstPg4QxM?P&+5?fnp|J{$e4X-?aA)gRFfH@Cz{Dz=j z5Hg6XXP!~ph-J;8RQTw?aM@4f_a4hk5VRmVLKjYWi-sGk7dKi5ZC5u)nE2bJmA144 zLpCpJLM4Nkel>1mjr1w zY5hcXS~ZE=dCk!Ol%->@KEI%AJeo*N@2+k-ZP*n;goB|I@K0UJWS2!_)#Ag7q&vmZU~VB-d-qWRLhnZ<7HE+CudfQX*mDq%6^OpvGzfn z^!L?+1l}e*1NU`5kcH-GO6pOXeHZ9B!+k6DILq%7bdnQ7oOY5I`5phXAhAk&tT59b z|E#1iEA6bTvg`FYMWAuZhoGYK6aTzskU0IkI-QB&qG8b>{h}%P2>-I>Ff0AC?V^j| zs^eiR{i^HZli<1sh9u*<4@HRZW`GjgFym&3D1h*Ggfcthc8sx`@NR;0JL7Ij@Qd(% zMuH^seokJ9jObxO-7xcENk4$-am744^Ks3-o9Jo7eLM4M%kPWmc_)N~B+Q1hi}?L|kR<#4c2bDs<9^XF`{QvlfaLS} zFgu&%^Yx;e%mJNxVN_F3Mi69vO ze}NWuV0l+W(ENeFNtv;B;ia`iF|&Xnv}(JE+AE^CUBExAZW-%VT4DrSz)<7(UDU*6 zF_KT@&Y=klN zK3=z$Bx_piKT9=J{MEmb+*LP`_P_TD54DDQwqm2a8cc~_T80Iaz+piY8dz^7Qerq5 zvHk)DWSFf#q_{BR{v{lcQ-h>sh*iP^u@0ylu?Oi@&?9o*&2S|UWK^=)5=sco>9kvA zG}G!5Ykn5eT58MajA0}Ll_dEbRt{&%0TnuTUoRCvBvJGdUcv>gz-b3)?nLIoyT2; zJ`q@FGEJeOO})}A1gyD^ui4bKSY=&EZ_r(~+`iUW<(COId_vUfs%lKfNU>1|5frEN zB_ESD6uO%CagD{%h2=7O!_P(<#gIvdP=WOCvny1N?cgTm#>W!c{CLJG1?lB zYpw)o~+C~R2 zJsQ_HPb{9=H&CiQ%Wd_%d>?U4d(@eCiDj3^!!_!!7R`TZB5oUqYX0{+2EQBK?Uyaw z<`?XmfCpQ{m0Ru}Xkw=TP{t0drblnrK$*#Wy8+S{V-Z^-t(Bwg&Icr3A0p1pFLXi^ zj9)DSw30XBXcK#A&3Hr1NwuNY@7s9GEj?eId%tPVzv_^_R-Cf#se{I44ai7R+kRM7 z#bvi2PzbG#Dm&aJ6w@ElscJVW4b{2T+Z{4FYDW-I1v&=l9T5m>^b=0q#96klFAZsQ z7ues~Z0fCZuX9Z`C$gu%bXW?3K{MVs92p3a$D-SyS^Wi;)X5DiZ1a_Z{X=S(dAv;* z6`ig`hs(mSbZf;o&3T)Kdu~>gGsR^5#Z0b;GF3rq^=6-JrNX=v(1n$*Cr!Wg<0%9E zg^mGQd|jS@ezhy;+yIGSby(xkDt_|Z+_hkU#>vpCC{8dMi;4yEP_U9ctZ05=LXy%y!TDJl*jnit% z{dp;Q_o6G8mi|ez5c91E3lYBNII35{Y0ObAKDD?c-UEuCH}EJ&O`w36 zRSe&~^y;6AypAu2CZ9>#!xlSm+dg37wcWa`pP{XGu3YWhqk)2Tm#)!o>-FwWUCH&= z=CVI$2)dqot~zf69N%`M*AI)#I>Dvc8F~?`FE5#I|KJ#R`01bDo>97|Y>_^1Bfj1h z4z?>^mZLu(%LAIvqr2ZE%6)#NyS^?i3prLfI#uFw?ID7X^y~q=m4KTI037%Z*{Kn} z1_I;+fqo4^$pVS3^x-yr!vcJO9^WXV`TBZ9gEl2!&Wk<++I-Tm9DiehQC2!cE{B1b z+19(h)^dWyc7n0LhDoG{HICV?y4(OpC!aHCerJJ>0CM6x^cj^N4IvG4C=O-J^ykbB zAu|sZD9y+qPt_=`3o{_KIw8$p_;l*bJR;4A)qd zEtc1@n$xkXkAZRA!As;Y<;j1nnCcP8J-=)55}olfYe7;Bpu$+#U$gC5=+9X)1R|wJ z6w0{W=q!>k1OdvE4HbDFN@p(DaMUx9>@%Gk=+8Shu#pbLmWo6~HISSvkYH9K^;$w{ zrKOH9bV4{6roKwtEBnL}W$acvQui2~Nibx=(umH7>tri{g+pGimXep;7L3uD#{uv^V^}E>Y*lkYSq}qE--aDJq=1E zm1`Vzu~N-F3MOXxu-Y6gO5vB?4Pmt6@Ga}?PBhWi1l>giB&HHQ^EDmOG(8jwLn{V- z^X<$`T*B{bisftCZ54?a7KAG_WY87%q6!}CH0UAjqY0UKFp7C|nwk3=AF7@Kbj$qY zG6l_kkQBGRfx$$Mx%_*4*G?I>xNa~I_<+w0#p}w>d3Ptk&LJ}23sfBhu-`$ea>$Bf zb{NR9im7A{*bxvphbyVDIo8uuDLxETkg+Rn{$<34qS(b>0B!$)gbPMiBTJ!IdxlO!$O+Kdi<~AgyKTNya8y< zA*tUW>3lEwD>~7R1J(Ur${Py_eBD1|jqt$E|DRzU7vs)8i}$({{HQ?URbddiVE~`W z@W(LYR?D=y$<{W=_TI}js$mW`$c=LF4l}*rqs!j73MFGHOdkkwS5QV)^Jv}!7cb?E zrSD!N5I-6eG0zp%VK|j8*zN%jN(h|FD4cS-_b&`ivhvlE} zS#By4h|`(zH{nSo9`NF;)v}q@mEF}S9r2|b)ddyh6|dBv95nQ?Bitr`(0^g8`l^ey zu~%d)yko`QeN#t^NKiC)mQ3W7+IEyuyw(n}*3RVok=v+K*r;9Ec&DwaO!R_Uu$slVDOLtIXk8G=6onRE$3;6 zz!{-Aifs%xAkL^tV_c^{8qE?ErcCZ=LYZU&qoEAKPRw92WnMDjY?5MaGI@*A3aF*K zo7|a*l9tiE@*eBq=$t-9A>wGwl-GOV>i16u{% za%*BKjwYz8CaKJF3UA+=mp>8(s%xz;Tn$A__}x-v$2_)O4LqvZX)E0wDa23bvDocn z+EY&1hS6U2VZH&}%+Z!?uA20|<|&AnY+;f_l9eTjW7vmPfzhse@C{C>n1Y4>oMe+m z@s};Q6KnNxH5~JYN39bk|&V z+fQ2gkDh~;zf+E|lh17S&~@NU z=5Ei@^pTGDL2ur{i@h3<&+zedEp;?yf;AZP{5grM8ou*$S<_%G`RDc^!UnG2wx<6y zuGi|0YiIluU`eBuYe}4I1@=C#ECGuVm&lnLLMS{?do~cJ#kO_Y1p&p^FizFCnf59f zqwJ4-36*~_?62HkY_hjrS9jyrEPSkP_m{V-5w+~JTJEZDA-IY9x;&WHcz@6vL=mDc z6PiTs)LrB6f6YtSaIO3b4dby1#zVQq1)Mb~$gTvXdjxUbq~PEdw7Jt{wh*;~VRM^p zRT=efV(A2xHQb+!B~)+%TFhfxG^qk3*H z%!!ubjZQUDAAgB@f=-xT6?J|jLTdOjeF&?{J1cG zUy{Pw>xyh_mxy=i^o3sC%hP0NW z|KcAq!D@IKK#xOC%^S+E#P%p4KZVZh94o&ZPRik7I~Nh}?nJ4Um&>Z0LvZd$@}hy# z8Z*mGrS9sTtW$L$t@Wmz>%689-&RCzl+(Ue#Gg`bzLw1$mJ1PQrT7>Tbi+iz^(KsQ zn}Tl%=N=HjTNb%kMXQ*%`xp5*Mx!z*zZ|~xwdWpuH4+^W?@$0ol zjI@Z9plCR)q8lnlpRW$;G!4HkEY)4c8aG>3Gk@~_n^?Pk{F1OE?rYk*ENdI)s#+dw z)D&a>A=0*8&27EeRkybDl^>s=*;=>5SpxpVCGn`v2UEV~KBwcrg_$o@-A?W~BzY*sy*=Yx+DxlaE0DtQ>6dHQoZN_zi8 zPV@M>Ot4;`cY9yR`*3`LCtNFqIx|)Eg#-kTEH@uX|J0k!N*?EBm9I)lX>>=)scJ7L zPJQcKf$w9so4ZXt->5lNKj%gV(fY`GM_RyJn1J5n4By$eN7IYz77{EjXCCH57IsFu zF~j%3?t9|Y>oJtOx2}xetx4HODoLjTM}IqVy78y{2IuNi=ORb+vOS7r=ZtON$2&JH zGzqg$`R8=vJ=WLEnB&_{b$a5x^Q>n=utBYz2;zcDlV8gjf2RrNdtYlAJC=Eyr9%n6 zU9jgDrkmlSN3g#&xNhW0{bYCM{9~9HA{FOtnZ&D1fYpi{472dbr@y=4+nL_85vD&M z-7&J--pRQbEa8{*q-IH}?w!j2W0X+sMOPg~kW*2ggZa{}{eD0-Wz_B75I1uySw-B> z4}~jpep_#4UtpfdA2CYMjp**FBeN%eAq+dky0FtLpl#9qr~W*aIc;2!1oo=keVB3V zzkH`Gi~3Z?EwO{2Qw* z?oZ3Yi@fG7nf1$}(x{kz8%NBA$9~^C2BDAhz5Bfi>CVah&%23LOt=sTd?rt{&(zP^ zEtSrqACz0O24e__NZDq#@W)X+!Q4|r%lAPaK0}jXjGb8*QWO{D{v#wp$IG4P%BdUP zG$*TZZQ8o)#TZL*1nluT_=YZnwM2h{l}_;%sL-=|05XQ_EO_*?Xc@UaSpgO+OqZviJQS z*AlgtY(Aez3~nb4ju-VSyoKJ^zdwf<&Rw7_7rTZQziQ2LxL^|RpSLEpJ{a@YB})*) z!$Su;SutQlN*BhY=rKhPTh@jp@KvhObGMnB6w2ilFd*bWoN=~a`1kMM@^(6O@87)& zc3*u#!6Kq!;NTOJQPDB6aUJEt$o)`I(at_h4Ts;61e~?DwSNr0x_f&2fn%}%WiLz>W0@2A_2__Y_k5b$ z6OO}dGEw48KJ1U=9>hQPkB);1ou_GtH;ytH7ac+#jMAt!nJjasU#`l-lgUU?LM^T_fm)!o#ZgtScG&KV zr1xaf(EH|kQ#N{AG;4o3An-9>mC-@x-%sPd=0-f0wx&TGC53vzH?jkJH}T)xu#<7! z13FSSd^^>jq%HC*uS_;gnY1ls2X@u$rem+H;+aE!mCW8|5~1M7v$dj~0Ri?F7MJKK}{)m@OWPZK4Df z#4Cz|D232v{$~7KYb@3a+o^3KG~V#j){dC;9@Y+LX~LcMmJw{i^_B~xPDsZ#sNYmH zf;=pe(LVW`ln?-kUy`H4IBFK_fQ7FcAC|UwkP|J6Us~X#1d7A-3LDXnoTyzkH&%v# z%{Di|NW@c18HC4k&(8`2DIUd~QI=4+;E|Nil-92p7d-HmR(C@ZRFMAT^4YZ_82Hn8 z64TC0S5ZGLO0QQ)s!!VM;yY%VpXPIfprb$%?wwh&Wg=Upn3jbLj;vuTvUj`k54HTa zOlodoxt^q!mOXpBUfmr{Mz0(P!Ckh^9)(l63s}!EMxh8kb|v;RV(G-1`~zPHm*f;eEE(8?JqCag7zZ z;^2&iIoaXcSm%@@o|_1;d02U?qBFlsyLKHMT@*pv51U1?Fsm)EjV8POQKn*jn^Bf- zFe%=Wr2OKv++fw9^u!skC`q^UCFTf#mM-KyVfX6{lYFh+>JJ zD46vhN_{Sh`47pj#p{1BnB)ZOCEtko97Juw6DK*AbdAon_J|Z@#H?CUsF1g!dVBjP zQ-`IR*lfYSk#)xmmIMk*m^~dbC@ATzM`caw44o=OU53OmL8+5J&EwapN#W|oQzl~E z?sJ~{?!_Bw9c%&+K=(b(Y`Tdv-iPCQ6iCUVhF_5EDSd=gtl)kz-e8(oGuthihrdPi zX%~M(?Tk)=+Eq8DsP?9i${gvXoIlNH|IWXyv=~v3n&h!(p6OBZudi}&2%s>3@0eRJ zd*O|wtntvSeW)ua2q~Zum;gvmY`>?jrZs;60C@-exw!gd-oH zQ3tRGs(n6S7s!1z<;wgqd0>B+1XGC(|I11drOS;ba$$u;pS#+rEmXfq$H4GkPx&0Q zQ@$t>opOZD0%6|`&ppE)nX@!$v5M+1?1ZzYorakhzp|gCDl{GyGBUN^oXL|ShtfiK zBQ$jkSzvxT-SxAj=Dr6_X~7&-ZS5cQd{kzMXa$5Y^=HDTscaR6Y9QGUZB;&m8)qI7 ztLD#F+G=B4Du5E0ci(1p+=Qz>t#Mip&5E)<+@;n4@{|Xc;5&R5f~;Btx#()bjHFML zubg|4y!k7umo}8^b~*KA%XLGe16Rve9}$yt1bxg@I9K$pv!L4gN7sOnJxou{!kqs$>^`HrBuJ`}+|X>X5y*gI zZ!g@;1|4DLYGu{ZbE1$b5em7$xM%i>tC^QX*GEOtA4_9&hCNf%D9CkFU5%XX$jbbYmzi`}Ta zT`%Z$Rz^sjw01nR#mJXz{>XrvZ@XP>DVV9T&gAx=eE4QUpI0;g6sy&R?$I%8dc`lS z+OQkWUKa9}Fl!C=c7iDmH&4+`)NG#LyQyYvT)w{y3<_Lkv2AQraXY*v8HYQ{Eax-P z!V$W03{*ocTQ=QuijPz$DK?G5_X)LtmgAGIm##8x{DWxn_ncBE7lDohM^PGN-;rdr z$}k#6M>${1YbS2qf>3q>{mT{tjJL9X>Bk4Es!jJhcogWo?>Vd4l}HKd|DNZ&9LF)l zPTV?BQ444vxF^0BY)9q$mxw3fQ_ZzZG|3q)^L`$F6?^5c`8X)n<@ZD5EzzTO`Ocj+ zyfo`k-*mOtfs$Y>^3T!88|u@~@suVb)Q{xqE|+Y!mgU&5?9Ow+0~F z*=^AkLfHCDTqtr&9Td+IFV(okUi$|>H+l#%S=4^`*m~^9u$x>HSzT5Amp%z&1X=dY z)@|GV=Qd*Bw+3i^H`2Ji3Yir2FX@LicS zjD*#LMp&PcNUy_4GP~F^V;PK~K6AvaTB4H+?IyuAt%Q)6^~D?*^Nq5Uf2ECEaItFc zFj*kUZ|anA0Fzy!JXq{kImA6<5(Sy!#Grdxy?OKGGmfP$KhPKFPUIxIt-bo=dHiRP zUh^PM9VvP$^e#{oKYJjH?Gnb=1E|+t`3GhEhrOxQ z15OHMvPq)pe}E7rTl9;&`mr@=_;5`-^H(TQ4nLn6V}g-sjG)ggQue3C)i8WvGDb!!^Cig{itDX9th^F2XFYi>!T#-K2^Uu3H6=wCyX z1RFZ)-&jVqZigg;SdBxEBcIsgbBPl8HPXNu5za zp=6orWCxoVk#Q}_vaQ8m=uAqeD0y4Cxtv5Jo(jqpTb3TOClyp)EYxckACFp$juLF6 z&?s0dvh5yvvy})-mo#AOaMlugm>kA@iBc=PmZYSW(qonGe^cy}E-GG4H*ltf$=i
w%c?3E>$M7qO30kl zx)&^6IyB`@ByCpCywzLg3dbfJ%BGXw2{*6;{W!>@GR>=I2Hb+5`0`#a*!&0pk)YO{c!$<>!6(Cj$z$rW88LDF9WvbD$Ccfz7aW9y;Tp| z!-pib4F+Wh6UXUqZL(SCo6_Qml`gQSB0zDSmP{O!DU?%7f;-meNk(j5QZ;f?3xg`P zI#y|a>!=e3V>ZK5YH*hLj-|cJqQtgw%>ex=8Xv%B3%iC}Sb1B^4r+N}g6mM3=FK`Ka~dT580CvihVPqaWy&^!aZ#|h!Gmdxd~KU~1(gN(>zvWz@~T>JB9ejH9pR*AfnX(%b%<=W0%5f7%@Q zsxI+4uXc(~B`!-j=#mA)ZP+XA)%if=f51kmU^SKe%~l-Urpf}|JT;5Dh*&ng`?jmM zvZ6h*;ch0|#VizA5cQ@H2qntE>DAJDlflWr!M=+0$OOBTl9_RD?KX?){fqA!3hwCM z=i=g);4BQQ%wnqqu}t=Jz15`jXD0LOt3%IcHHw0bSs$%Z&$90rJ#_muQh(0yW${lh z#eaCh5Lm)6JqUom090@d&W;lh=?3`a7YRG)Zh)78=u&9JTf4c-{3nMcEO%;QUFlJ03zPb5EIT=$j-!Y z&Mh84nJCmFai_j}u zPhNQhh%$q!m(~?mS<3~2!8n_d9g6#!TFjNBJZDb<=SD`1@R-FcvyKJ27A9sXiZX~{ zp9PXa827oBPWQNJk+!R{@SrH`zb{$+Hf5%f6I_mZ$rS9ubJo05s&6uPLQwbTqs*Hf zh1^E$B;66LvLf7;l_gtjknqA5$ihH0VbV~?xJx!~zr``RaD~NX|G@2BLv4^i#P1&%STeu4 zP(5vqN|Lu2|B|P-;5n#g8n#!cs1a&+)39x94Pmp^tK^<&Sfg>nX5k<-fPma_f~1J$ z!3AJ_1q1MU7>}KJ1E6T<-SyyDp`IIn6wU^^!gv(*+(?0p(e-dzK~OuKsJ?-^p+Qg- z!fbEXn8nNxE1VGDn(+2(3^aR9-x`a18}&4WA%$zxC_Es%0=c6baI^r%xBy5VXKACR zs5BB!-aAmijmEx*18y>br|C0sgS+%Co*Vm@3xqB^7c@v15-$*t0Zj|8g|~yCWp36( zx-tAh0=TPk;6Xo%gITejYKX5IvHt>ez&B6_4Ohl@Zs1yXQ*Kl=02j{_ki5q?6has6 z2C#CEA^Zpw>+$C9Rck)hk2^s4UhhEg9E6*StXmGxB&HQTrbQs4Ew+ytb4x)^!)H|g zO*aq*8KAU(k8N~Eg4+W{(d_U8n;?eR=*5OhB{0 zDmMXUKR^QXIA$Y&0|bLf8`h}(=Y_T~g0p#>8&$rlsd{b!He0w(Ylyed|0C}hN6F9&vzxNYBHUj*$6MT1`3jtd*cNK`ZVJm?H zU^acv$f1k6We>4h9|5EX!=-P!r3W!%4?!{uwH4Smjt>EM8@5zvc_m-+Q19Rw_H)>} z`LGi^al!eT(q7U*~sEOC=3!!67E7%wsz&Z0A~d9nX8rAC4szBb?y+vH_GqvNf>@MGX-Y7aHhVh~*mD$=f2@CaH;Deg<2=?_fCyxJvm5<-h<_K)JQke) z`fGvc2li6{fu&G@5C8;1Vz^{B4M$^Th(tn-*DBST)l#un2_Y;Ymn8u-K)q(S-EVkY zKBw31cYI#I=lA`8z(B!4!a~DC#6-nK#>PA}$Vkaa%E|yb%S_G5x=qi_&QDQJ$5PW% z)Kt|~)>hY7*x1kjvs&9*TdyOiEiM3qp$euZnT zU8Sb)U@YK*=_@E>Z+czW)f*UPNWfiG#wCD) zUryWvFDwy39=Qcc)=88pRjy>&(&bB-F%3bQP}Am3oH=#Y zG&zUnPM|?~ron_p=+T@#eHumDbm%3SQKe3$TGi@Rtbicl;>wj{#IIbuRiL)fSdov* zK6VlLZ^#3%VMT;90&wl!Dt6nQPw7pqPi3xe7pFfVBYJ=!gUNqHBwSemO|Kg?0&SwPGf+ zuz&&L5%G$Q3TUf6($4!LERsr^PCFQ5lyUz?8f(l7y6ALrt_>V>#A!z!m-F!^Alc|~ z$fs;na!Dqebdr^=zWXk$*xG}_nEouK>8nT*m?d zun286zI@@#904D^YzqM9L}H&1SFCfRF9^hpugxao#TyUY4DcZ;+zFH*F8dslPyif& zEWRq-Ars7csw^`++G+~`fhSW{byZedo$*K{i&O*FA!VIaNF*n)6#!a&+|^ZJgB5nz z+oF`xJK0uTF{1LU)WSc}9^$j3X)S}VOe-W~bE733hzEc*A>f6-|0cqsLIO$@tRua? zAeR7#3UELq2ib~HAP=cXZ%tJDa-jd)UB>YMQv$R_<_Lb<($+92JRMb8(yEG`R(WNXTh_zaW4$7HB+SN&7~j9T@KVcBc`31g zIGqt7-g2xh06#*tZS=|uI~w%aaCi~Wmw6&?7l1RpCG?6($F(Pjm zuwqSl-HVqjFG^(a8T#g?NNe^S?HMyMWUJU^zylY2a2QE0`9#tXbmVE7(pm3$`JW=bi`AMdIm4 z-+A!%`s#Vtoubh@^O(5IdDs8Gd^<#5iv$sf6ANvD=E5Dgp0S&mD&2bJqzoU|# z6!kFlRLTfDo7lJFF*GJX00I@L!10KHH|V|YF%~dDywoPAdpK<}5kpj`AXN`Tp{;2V z9D)ZRGXVO8PCkS|h`zWqx|~^oJ_YbV0xXaKcYugAY(oGx46qOSz|9Lo^neH~Ac7&z z4GP{ffC0v*z7!3nT74@O`wnvZ>dg#F&WTcqo7-Gu11IV=sa6%3B2p?oXl6e7 zP&GdUbB*~dc%s&*Lt+mjtKPOLhF}VFsB@)jT~&h2;qVl%$N}nJ|AW#@X#&-*gC(rS zuxZq&y7hC3W$Rc+Br11W(4AHF>iJGr&cva{#+joYq#j!&TK`vzY~ys_3ha&FanemejMh zma(FBENHQ6+Fn={v(6e%Ysx7`h%I%kUh6JqJKIZcf)PcD-D*NqN?iEHSH8EEPH^2o z-<|ySj`qE8Z+$CR`36|Px-qT-jW8SV`Vupw?W{?+%TzGtjk}I=aB3|~S@V7gy0Fn3 zhnw2rU2Y#zO{N+u~Cj{T+J{| zAgBLX&6kTY** zp-tN4Lz8upC?fQFMa(nwUYE3+w(O;y?CHglBw1sS>z+-0>cGVray{nNj_b!;8vm5l zt&X*Qirtx&H@eSGN(?b*3}dss8m_?>c5yU9!x(n&^F}7rL50GSVCV z>2D5Q(*2hEw2waXrehk#b_X@G7k=<+m(f&s27BQRf7h%h6Ww0NdaXsab;ciF@;(;( z*f|Sf4JSOhDNgXjLrd{|lUeONuQpgOe{60mx5#Z>^uhc7nl3~7!bd-QN--Yw$>;xG z_qI)ZURiDRo7#E5!XN(bjX&S+xO{=m&O5V3+xfM3di0E@xqi~l-qN={`T}>B-SuAc zos*Z$R@Qmu(faG|ABNOWpY`$G|Nckxp0B@`zfjaI(^&&Q zul%b&>0rVBdPE*uL`+;G8mvJ;!;4^t$xFkjm_=fgXqyF)0XyXGsz zAf!EP3p~CXz}DM>*Sf+o$ihtg#c?}1`~x{jtgku{MmRFLV$82Z1V(4%n3m&2GDsF- z!^O>;wz(6LeLsB6ld7MW|6iAfBHGDL|F>E{+Y{8dINtwhLf<%^)lmJwuvYZ4# zRs>0%ODm0xy-{pKg`+lu)4?Z{Nbdu`FN3>mM6Hd?$iBcy7AeVIq{o@e%J^DImyF39 zYsom^O067AVW~-(X)ckhy(Rp+Aw0b@gUaiJHaUd5DO@@tq{z*RxdgPpZ5%{eG|J6; zD=gSWUc5?xBum63F0p(;uw=~QV@Z8XOvz*wv-FApgexhlMUAY(rR1^&1hY1@b3mRH z$D-58(QCAIbVrB8IX3*kRcy<>D@xSl2oao2+uX`!l(CasgD%`n-poxo^G6uGP2yCN z%Ct96{Kh@}#)4zWO!G|f)2>oX#k-uvdt*MgR7ktLxuN8|aGc9-guu+~0>ZRP!z@np z9InNL$;M>Q$Lz}XOi%fwj^li}{6HzQdcD)c&Y4TdzNF3~Bg(nMJfADXgabnBk~nQt zvO$baSB%2}<-;WWws}0v`Tx97{{zn8%*Xc(OAS3Y-@Fxm>_-eO(ch@gT=2%J&G`+P{P--;K^E6QVjnYpzMk*bJCC$>?h(>9g#@DJ&pc^xzQ%b0M$GAH{T69oB#4>|Z%1WcarX0^B zCE5%a$3DixP1wsV{06J7e{ZmGj2^3wvbHvVxWVpM- zz9P-jOQkqasnIqSProcX3=_DZTgcVgOQG9C62r^axVXVg)~yS%r^+dNRo2VP&2!~wA7~rO4EzeC=1QLG}dWt(0^Od zk3`ev+*3a7Rd5YgaUIti*+glil1>D?51Ujabj@S!$cjr%1sqcvHN=P$RlOt70Bu(5 zTr#E%$r-+)ScFYjg`Ej=Jy*#}*Gj#?EsH`<>q&NO$h%`xz`VtPo7hdA zz!BrjVVlwY1kw*o)-|Qg+FbuwmTg&=ecAYRYLm-{9`~=XSV^u7~N|>Ems;yc^{Z;UwjTluh@|(Vt^*Nn|R-^^L zdvi@GoGxd(IWetBq=nabEkM9@P%6YS>~pVDvs%2(TfKEvthK>Z<=Qhvwj!i1kwm%; z!`SRn+X#$P{iN15G}Bjv&6HKk0mZHgax7UJbxyfO!sR7Xy2L{vwSweAQZk4w;yqvVP2VjvUK_m0 zwDryzjosIjO()D(!9+@AJz0;%L+5o{o$b^EB{sRtO9F=5osGwpRsUZFUSJ01!iGIp zgS1)GvfrUyLIFOlBRkd(ltsK`+G^B3!If9Hlg+oS*LJ1fY$Z+5nvKjY%rlVM299AF zp5Y&3-)S7lpF>d6L)C|4uIJ0X`IXKaRAJg}+Y|%U$bCa34At+YJ_C+E9X-{zwA*i` zVJp63EY>p{PS;6OLy`T?&s5#>Yg~{G*d67_B?i!^iozV6%`%QK4+K-X#aW3xt@2FZ zEY4#+-ecL?;$4JT6yDd-WV;OnTrfss(qv-5BVi2mSE3}@{d~>}>O-M4;hdFDDz0Ka z-egYhz>}_H4HD7s-Xb;s@Qjlnit_nYPXm7*a$t7G0jzbkbT;&~A z5Pn>ABw}i=TyV7KYP43mrRLIfR==ylGRWi#y=a`4BZ}VDobG9u$Y>$HWmOGm7=@th zbz*ugWgU)KP<-S+)JBBV*IY(OIW}cX&F2GCF+BEZVqQ*5lY|7Fs2a zQ2A9&DOO@kU23&H;PCCi1Z7BwjO7U3UFqA}bOm0pmT0g3Yr$3pppM7>{5zsO99?`)Gc9s-rvAO>O_ua(QLcfj%;fK1Dn2S)Bhf3(mrnGW>`=T<>M8v z<-O)WHe-dZ%ib35NApwZih2nj;2%&rrtmv*`;FPUk2vo-sI$7Z}yVG`@Q)@Nqmu6(y3*jW5u>PfQ&P-FsqEO?0Z$55s2aj+(5Nj^3 zW>PfU!qr{=#mLHT-2g7>WNuMiezIjg=#)0;%`W0z{neRQAU5R%&eAY}-Czv?}ffw{a_ma4Ns@-)wIA z7G(g)+}a~(0=L|{W=`6kTerkP_;u+z)Xo<@Th`X_{X9nH#umB}(df{|avHXBJg0IT zhuX(J>=7pO%kD{Q{M0XZ^R)c%ndW7>j_K&8>z{RUY$aelC0}yp+&%x-+uqD{PM4!C zN6zEDPUX~QABJp?=F3T))A?m%1MY3CUUOUg@$G)?B=%1Ob>!9^-|z6nPN!Nu?{x?7 zb2ap084XHN&0TWD(QkJ2iiOP_weYD1;%84#)Gfzou3%CGa?sRYmECG!2i;%)_VpI_ z8@56u-_DrU;J35(ZFk&C6=}D%NICpSSmx*1ZE>wzdWH#2wY;8s|$Nc=Q!`kr!=pPp+mt z^}0-K4yWbh)mWx8QdNv}06t;Olxi|PcFjxht7gC*uJiOZ$&xcC-Wj*Gp||mqr^;%- zS>8rjAkSTHyXJSDYd2Y$_oM&!^$$HaOfUPMPWpS# z_6=|OoHtNe5AJqv?4t$gM(%m222Y`u>yFLe2-JBOzjL+sT#_$*pH6zJTy~33cvT1Q zi!E_N&e#-=cAiIf z_uuH{&e-hmXI(`$5Au!O*bOf7&$s&UR&BypeY`z<Ms0(aV6^>nI38+d&U|(KUxB{g ztCzjUU&Cs)P9Ojf2twlUSS%WcNhI^QR2qd!Dzxg%K9`YO3BfA~i^r0H8Guf!*X*|Y z4Ufy`^t%0y&+GU6zW)yxC^$%1Xn2U2sJM7%7BB%BDLF}53F*j~skzD7>G=s7DmqG9 zYI=%_U`EE;>iYWHwB=+q6{WRhC3}loyZd`ZJ1dOKZT)eVlbH1RI8ls=m(N?*0xR|1UpJzvgI(pC6fC-|znqFrdJJ1Pc~)bS|Mn zGF-geN|x(cH);~CO^k@~;YEv5%59X`aoRSG!A4?(A}u18k=Zh4)CMvbGL0dX*-TmU zOF}aU4Ep>DG^o&_M2o6=q%SEYenp!)eF`Es zNv>Jgk{idRoOThKO12bNs;wxOt3`==j=<@qRf}L(}H#V8#hD&UBJnFv?a@TR754A6&?Rr;~w|T}Wh+ zMjjcDgAd|!Ad*f#31yUvWvJnI?~zxZFe0uPTz^{;t+`Zb)4XWBC7H8oip zCOBvU=OURsmc`|Ifbj;VfAZ-E7FV~47u$zQjG>i~0Vqmqt+s9_=%AAndTXz~{@UwN zjc!P(mR)H|=Xv*`1(9B2!q#ey^NAW_wb8OE6M)fv)akRD&L&@pDB8HhUbgsBz4j)^=&)G<0EV@4nhIPgGM z!g0>|rg~(JR&J~g3E9)VB9Bb+DM!W2jCIypzmY4-6I{J@*kX?j*(xzR zn#^cW3$1LIIj?H7#VsOD_tQE4oaMozPV8xJY?c^qZAJSWrxC}g`ElZ|fK7Jejx}|w4#%-St6W^zT zxIOBR(p4$vg?}1m`Qj*k1jj)xu5$DC-(Ml}`tM&n{{H_@x#vL-R&1l*w30@!hKcDm z1>B2UOqU*M)v0|z%U%S<6RA)k>n0u&9n&r*J}a23Sg`ON08fa*m+^0f`b(h;UpO)W z4iHwO0g~$2bEOXwu6E=*nBF`xJAYwNd_qf)m-a@Z)K!ptL<|_{ibxE#X(fK~D`EX$ zh{Y^waf@8+VmyR3L##LhRCg<%q|zj@969i4rz>LA6!yRrMNoS>!V84=agc;8q#+N9NIA$bj2(+xYPM&)owbielfqVgcy=sD#%?%0ykM%Xl{D66 za8!}ZRNqE-tbR?=WAqDIB3}v1Sjuvi&0FLQwKBER=@E%=TH`L|)FmO)@XgVUyPF>?Oi0UY|7vra-0Y^GY&k&DP!UaVYofT6_P}v5t&)$# zXQ&dQqvEvP{cickrGGn`P&p(I_U$pD#B zm_=h9K11h2Ieu?0;6o$ck_gekee#L0>ri?ADA0tebfqk9sY|1>P|b;xV-8&=+!DB- z39gNBtm4Qzho(=uU9_Vfg(yz3^-*`C3Z3bpDL|_j$br6es#LA2Rj*o#n4WEB(*z|rV)k_1Q(2iVLH>2Lq%Ey!3kt?C z!h)4Gg{0t6X**0x^r5Y**kO0*G#+M>iA{^+NabceUMe%AP&k}9QYXTwo_4wa%x$i7 z4H?)*4K|OxU8+z+SVv}3kai3+(=V4APfHe7hvZbxv~bETH|=u@;OUq2Y6#Uo!vJZgiu8<6b=JYDfQ0I6Pn=?0d&) z&1;JG!$dB!k&ld|5R15>{>qugZZ=UTHZy@pbuF`ON@6qKhk?O-B*xyB;u^QMS)mCM zN|lS`G^=^dY}SsF_4{K|-j<|(9Iv*%rKXydII}GVZ#)0?uNCi-v84(`u#W9mO7_Bv zUxZy{kgYc4HcNWal&&;U-mD?6=J_py_HJEW$I&R$Y?O0!?28pEXr{vW%8iw=m+fOc zCIg8KQ_HKR|8I@!Tf8I8EsOpa_+L*8uOZnmeT{uR?6#p^#;k6Ms|i&gPYW2 zG(p(ujsd%okA}2VcMa}vi+kLfNUg73!AeI{x3<={Zr+sEo7W|Kz%%Z%-FmuH$c~A? zgZ?caPkgl88BTY}MKie%j_`yl9I(AUcW8iZ+>^RmJMq~ned!x@gW%&(61tbds1a1JW!;At(r?c2}({dt178}=kRv6RG> zI4DC7L2|TG5Zm>vqE1H4a?pnM+gk6t-~T>euIqP-H(ocD<^1Ag#u(42^yK0B&Fl^y zeYpNMU|fahbuMe9(KM{J!2$md`p}DhuY&gulb?LXc%w&qat*FBX^S3*!Q07rzPR_G z?8&n$u-Ou9B6uLq?ZfbW^u#Z|@e5k|ny>Plxh*ip8yIHHJM;BAt$Xipk8)z0r1#CI z`_aknXSABMy`HDv@Q;uF^s67b4R`p%Pe1Q-t{3Nh_9>x34P$V5aO62&UH(sLL0$&Ji%=zv&cZtlkrHQhjZ7Sq}zOvIR}bR(9BhXwUR%`)6u% zW`RS*ZXfq`ZZ~z^R(0q1fH5e8GZ;YaCxJf3WVIJm7GzNX19Tj>X|EQ2WrSnJMs2vq zU3JDidjo2~Wo(ZWP&Pl5cP?0Uu0(@Zh=p0`GRcP(K=y{Z)Qkw za#vnvhbl?;M#xrZ6{U2z$6sdzc6(+RxCD#^XMUqdjKye-7E+1{F^EPdebl#e)YF7P zCw)GMWof5iKxK#~XM#k*)YcR~h($B2&UsE(SEjLLY5pVWa8g+sb1 zULyj4X(VN%7GnaqVzAX#?UpfcREK$_cqWEdu_73EXpZa1kPYdODii=WxH*5+4JMUm zIMZriIExB~iZ#}KUnY4?hI9B=i7>{aLYO_aIx#)GeXI9Ubi^&x~35i#>Mw9XdQv$cVP6I?4c0;~0D(gOzuQmw71)TA3Bg zS6p{TaSO(DBv*lJ6^c)NH%zr z@^cgICX0(DlwY}$Vn%lHc$9nOVnK9Ppwo*4sFP3Um5wNto~Ju`NOkC7b-*c}4^(To&!ptn`xft z=3Z-+og?^mvS*D2sc%RHgr}K-Amp5eg^H0RncR1hhbS$z)}1XVD+T{bq9t0D2Rf13 zHkvyIaNTG@ad@5U)uL*}Z&ikGiwBR(0%)eGSk;x7j|F7Emjy4WnI>waN1Bc{Ih+H@ zcBqMhDoKmKp zp_>|%+Sqf3Ni;tvRzkOt-4|+3mxKUVnH+Me-&qFX>8$_l>aODktwP$A>Lpy=a&Mbf zfFC-WO{Zmqd7sv}h`Y+6acZNenO2>%3!4a_<>+ni>aY)+aPcaI+{Ul6*{hldi!vIp z83bX^Nv{6)w>2C=7A0R>P13GlKt%QeEeqzHJfP8v$a zma)3$HW(<0PS>EZhjcpTd===L5n7lhr&Pi4F*jGJ0ot%OTVydivsSCMX`rOT*<(^5 zMNQ+DUDaP<39>ooY53Vj`l>!_IjSGmUIDi*-dWj}zvuae+2#ZM7^>~*XYiUD8Q_HZu699-Rxb@Yy z^#Hh|`-g%Pym zzG#58761SaU;r8*vjyM*0?-B(j0OVW0TuiJ13(09fVFO*0PPC^6^z0fEC&|A00>aQ z6|evy+y(?N09ac94`2We00$1>0WMqs2w(t2kN`64w`pL)YcRxF+Xf=Q02Pb?9st1^ zEVCQj!!irRGQ7k#{KQf01~hELIIP2N-~b8m#TUH5^a8)HI!Yb~x!-DmDLbS8ZgpeP zMy|L^l23V&Ws7!*Rj%V&Q|3E`3wXxP)xZFNwF1BaDO|x}tOgov0E+y>jGP7|+y)k4 z!jP=OZ$JPLfXN6T$!Q=0n5@KX5Xz7|#zedZN2~@(T*PT001tr441mRG;K-1y0kgcp zKCH+UfXe_-%8=a3X|MpIYyg7{IXLUJ@%y-LJd_`xWjI|wH01%J>319#|EV>o^$WNTk>ukldOv!8@04AKyN&Eo$9L#KB z&L6A*0w4ldivY8H&ob))A|L=3AkU)=${w7~7=X1+?9eB?yl%kCYd`@19{{)tEdU7s z%NL!>2raV=kk008!S?LKIGoNAu+sKi!3n+44IRu55YrZ50Rqs_10c?Th|I~X#&RpW zl?u&abv-<2gz7|F)y%Kqt6NvfmO2V|b6dZ!(z!!QwV=DKJ-t`t?8F$H1|8tTU~R+| z3;~kt06rYd^}Gg_EYgkq)^R`pGW!6tOwl~81}^N$YajqVEXi|?&};C^6|Kr?kkM*j z&LW`EoLm5k4c9Vz)@tC^1;7SA>;Y?F(=zMWTpP)J?bmAX$OVAcX>b9V%+xC2O*6BBB0}(jQ|Tx=@y^>IUUlEt>cGX z){Kq@oc`n^p3|02>HzJ?a+%{2YOxU3w`OYJjD-<(1GsfiZbBQ!+T&zM!@RE zCy9>^8c}lEiY5!BI8(W&R-I=k=hf!63b?Szx~{~1>n(NMKMnwHJ>-N=2B0qFs@}>rtnLMn+}!@@A#U$)Fz@T0LeeBtqytHRgR+{+BnwsUkZ|i~cT!Fw-iM|3qRe&DjgFfqR0P5ck z0E&*@b>IN;7!WuNzHZ}xiM^?Wb)de8RzWAs>c?UgL@(M|FpZt~&2 z=y>1&HJrlEE%R}m-7_ELH6P+Wdm zt3Ld3!2Ej<`KDj=yI&-^@BQBo{^2kF<4^wOZ~nT^{U0CX4~+P~U;SwS; zkl9vV*7zUe3!UoC4+H=qpiqDek$^-ofn*dd$)Q7FAs~wdsc>0n4zrH!ID6d+s$=Q1 zU{Jr|arvBHx8L!3{hr_V{{aI92MG%e4-pd;7a1EJA0Z%_g15T9ga(Q{h%?^?eFpP_4oPv^+}fr874)XIaSWK zh$!+P37Y{@2s{`k4O|=p*S3%n6K>%a2OWnrGblx&F^al|l=FgwK$tBO3v3jkEtyFG zB^gTjXaq^YoJn^6Y{`>g5eEK{CRMtWX;Y_9p+=QDm1Q;pa}RO+z_(v1q#g{gP^AlEAu6L+830qR~oYG%H7O?bX*` zfelvJVTmo)Sgdr#iq~VAZPwXmp^aABX{oI>*{hbVR@-g4?bh3G!3|g3Nv;gFD%)|< zO;_D@*=^TL-FL$*x2kjDt=Har@y%D?efiZyUaIKrSKxsOF4*9M5l&d)g&A(x;fEoP zSmKE(uGr#>G0s@yjXCbv%yPF4^RhQBGMm<&{}(+2xmEj#=iJX|CDkn{m!r z=bd@(+2@~u4qE7;i7wjcqmfQp>7|)&+UcjEj#}!esjk{E>Z`HNTI;R3?%M0G!46yO zvB@sm?6c8MTkW;kZrkm*;f`DGx#_Ok?z{2MTkpO3D(>6wzX1#@%wUi{Wo!PG|A$q|lq zw4;US$hrgGk&k_pPyx$0KtB$Wkgmz&>JC{*MJ{qGiAo>7F_}qv zT(WeW)Z`~YnGR2mZjYffe%m92DTBUl-_SiX{$wG_lGKZi?OckYsxchF_$ ze%VW54s#2^WDNmCumKGWCjs;Mm@$bNO<_VaH5f3`cbv(wQc9DX-Hd}OGem%8W&r{W z*pG~gkWD?1!V>JP6*avXPgBB^G$ZnWGp&?`x9q{4l<21#1u0K~%9Ea1(GjAipce-H z;h&OdXsil4_)v*1GNM+wKo&-$h7R&6hHxFmEgz$N;6PQ%$;(jh9{%CCHSKsbrxglaiFEZyc&rzJXAPWFe@AVW~?i0@K3;H5Mg7 z5u&=#)MG^Tnn7*Rlma!OZ5EZRtpsO<6yOqExTjM&O;SZf5FREpM<%3Dj5+06gbS>6 zpOArTUYpa*A$+3(A{dZSgc{JYE_R-Zt=fo0x-2tkt`i>($>C)$g-jhEdrcLfN$(|Ngn!aZE@-xkvNA!t=3)b zarI~1ro46zDru`$DFWP#hz1(g@$6frYZk-;r2>UTtz(fJUP3CDDa@6_T4Y;Uu8d)( z3E3_ee2d-n>f}4n-4c1lo8R~0S1IO=LrF65T=mw*wrDDVeF0F8$);(m+12i07%%{J z;*o&*wJ;ygN+HP3!M|OYt*$_d+4$aOA|4)aXyNjtg($VS;JxsR&&c1TsAR%iaYJtp zd}0H$m%;d1aE^aFUo|+u0Di(thGYz5C4(=?NXY_=g_9-|2bsV;b}NY=GiCW!Sq?tU z0&&Ia|KcT&IfhO4hhq?s(pBJCsUW`WK%^iNvB0-1R7P1c z1p_MuvzrSnVj!=9x;Q?xl0bT`e?92WeYW)TD!q?`qDrBR5lLy2;27Jsmn#ow5Lwdl z=`D!5w)@n}rrh-1G+&z5`pa;IMs(9EWdXOdU@Cx~Bn8I^I<$*1?UAqZ>y(UBNyCo$9Zfx`yB^^$J&=>*BjLxT3DEt9PC3 zofPx8vM%-^?HU$@V%ZSPp{w*_$o)Xsi9~aSyiJ(?0bvx%a)@dvEro|DO1=3_j_0Uwq|DE%{_S{PLl1xz|7P z_|dmM;+j9U>|dYzL)(7XzB1o^@%JqJUn~Fk(SNh^cWwRYcYn; zhA#s>(D|xQbSm%zNw8l=@M%tWumn}`UQ)1WT5ttnuw7noX=X46ZE#$U>V;Zp_0&KI zbx;meCHy--@L zFaRcC49Tzz&F~D-a11M{hSIPN-H;3~i4EUy4$ZI&Ps0wGzc3G(B@F4X58*KN5T>je+>a1n=v4-Ih+*RT*D@eK)43@5P-6_GSDkr6quR~+#Y)36AKAP_u%jaTs6a6Mb

)8@r?$b@3XT(HUH9)?F}+gBUuq03sN0R5@9;> zBVjT%1acxT5(!T75_xeXOYtS0f+vMz@+Y%mCU=q|>G2?S5+!R=V0;oNnUX4m5-H>0 zCMQuF4KgWBFeIBYE7=h%d+;csQY2SW9j#Ibq|y+Vaw!;6E7ek8ATkf^u_~J|EC+Hd z0iY}e@hp`BFV}J}`Qa&XG9@w6EB~@3yHYNPk}wTMFZpsY@4+vpk|^&{9tHCe71JLo zc#|iRTDH* zQwxidF7+iXXY)6!r7iK`Ev+&2h*KF zJL~g5pXEL&F+Y*>Gu<;cmlH7obU}kfIPc&%=Q1uUk}CDnKN<8wg+)N!P&#pwH5Jr4 zi}FJU<~%X9MDsvHF_AzcR6>9b zC$u%gGe>#!NQ-eklOPMhatoW(ilR_Nk#tJ6#6j&~39l$Fozx_u&`7EDOT8ovuSiPC zGK2py{#!);Ol{+DoWF|FIo3m0y22(9{Hmz_6S%~zib5liiDh8DfuoOzw zMpH*MGsQGZQ#DanHCESQM$u4JiN#h(WpyuU)f8bhX>c`H)pA!?@l6S^SAq3dGeKB? z(O8e>Pm%Rm;h|LL07&tZSgqz)q4g(~^$*X~{FpUc!4;Nzm0AbXS{3kHM#D8-5n)@+ z5me1pUEy^S*wrH4HD2x29p)8VfyG(xwO=DaTE%ln-xUJU)n5(Pm&Ub2sdHeNuV4@M zVYP5w74|zFR%rA&6=GpAUqe-FE*4`uQDYg_0xQ;INfs4DRwV@%RJ*lgSvC*=mPK3C z_C~g4F&1R&m1c4F5N!5UPdGMbeb!_p_CP~(W!Z0Lf3{zDHf=igXm2oSWtM1}HfmX+ zS0U6!c~)w%77v{EY?d}_OK@v>uV*y9mSi;+YUR&p$2MERwrsjKZ8LCfLC|d7_Gn)g z4ytx2#nuJUc5aUqZnbZ3`8HSmmSX`hH*jOMY;V?Z`F3!@uW%JtR9h8Qm3Ctvmuc(v z4NkR+61Q^4HgYxhYB@J;+tvXcEjM%zb#y5PbxrsER99kJcXiS4bpw)eVYg*37jtB{ zcDXNi9fo&xH~M^+VS+06ca?8=y_9%=)_7eMd5>0kU6OX0_hE1Ma-27MWv_Y@vvjSO zO0lRC~eqNRP@>Unnlo_k_fAeK)9meUN>NMSRP*@;ujx&Qy!K^$MS~ zisUw=ptpYGb$au$OS#m557<@v_j7NRd9Sm!Pp@8*Nu6#jk%a&m)JhXSdJytjHTFx2bqwq z*pLrdjDZ-C*K~`|7>#kXjw2a{`S!R~c{N_>XN}nNtNhmeKT)5gC-Lc#<;|mvh-PVL6f2Sd`1P zmVJ1b7rB^ul#UO0Nbgvgi`asWTbbFDC3&8e7^8vpn+L0-+ZCNBnxX@imxUml5!8f9 zT2={~l?{3#tJ$IfdZA|;C~Ml5>3N@jS*R~)?44Cq96;Ef7k3s&f?FWLH6*x2(LfZJ zkPzHLLU4C?cbCOwmo*l5C%EeZ3oPzZ+W!4}+MD*er*kuRb1@e)XTEv9=Y19EUbhQ3V*`g(A1ok-AExv2!^_iPrgIgxnP3DNP;N~LtaayJ2)10q(`aOjD z<^HhJ_=^hrw>5NE9eb0Z^Ds^nhqQs%=rBj8zb zSp9K5WT_!tS@y+M7>;|x&-mJ-K!}qJ!XuqvSMU-ND%!Z)(HNRu6T`$9ECV(E1%-?H zWWwL*U;WB)hpBACepW+k+^HMV4VwhvZ~{KWxDccRo;%?FtNRK8w?ry^hE2oeL9@T+ zWST-tn>q+kyYS>=rhC}T9c5ekdDS^UoUU0dA&5G~8amr-yI*U4^|ddpMYOKrYf9lu z?beIuWslPx#+cd;^xABtZKUL6+p?jpmd{(6z->y`EOJx&Up+cnYC4b`t6v^nyyM*>fe4(R2n&DNp6uDK_%z$txgLsP&?DxKQf4lD z*~<0H*1GFYxdv6p>pt(_Jy5PLp&oGZfu0jvIRl^2^y-8w$Q)sTo3y{hmM5 zxzY2ynfxU@H8egh5_$t*$p|Htorb4%KQL&|G}M1P8zwYwY2h)$@T-X_c%&h-A?U_F z_GT6{Gz;pO6YR!z^AnF!VY{%CsRDS+I*| zB1?P(tOvxmIyARSljh5JY0E$6mZWZG9<$<|7$ToGt)L}WFFfX6uwF7774ay@|B%Y+ zQSzdH-IS<#`$Mg1lGke;Z-37uKr}jR0 zggSU|cfi1Q$f9tlSrU|Fv>zS~Mp_(6p^oJ4j$SGp+cO<;e>i@Bcl=S|TH;^MZpw&%~EyScf|&(A9Mml@AKnii zg+?e0jzBh{np;}i+B-VCx_f&2`UeJwhDS!n{*6yePEF6u&do0@E-kODuB~sNH@CJi zJG*<>{e#1!EgnQz7K9?Ev)E4$7i~Q)#P^l{#NSFM*J=arLJd~~IPs*ZNUow)f z9M5OiTVFa>%>5x>HB+^rY#gLhXEop30Gh1&))vO11}>kjvluHg>;qTKLL63mpJu8- z?#;d~N89s#kg7$L4;~q-dSmr+TgU@`qyEO4)vg#$wJZ*GXzhAmikS66KeP@#l=C{A z^*yY9Ypmom$Y=o8fSIiR!P=Mg9uD4}ZTyW{7=S~ti_QLIY#NBhgVpYMe&azF1oQ|! zoUfLxfrOo4rs}L02a)hI>{44en`RS&<>F{_3}igiguFW2U+v4*L^a)9ogZNqhft`y zJ7xeO&yo*5t;N0Tm#bvSk6e6n$)D;S&vL*+Een?AKzg&1B0pjGHUrBAq;|5zi+qRwE@hN>-y}Pd8Vi6$yEn)?(h!TCT+^bCs^esf%x| z#eaCmyPohx%W^%DLC36gJ;~tj)_U@{Al{7>(BNV3{4Eh#G7R>G;@-Ud~7@?pwB6M7rvjT>RsWU$j2{;DQIHGZ;I+N-jP z?pNkJwK#x2o9aP%jCK9!M%qq8Ka{i?Jogu~3t6(m@HLE6!e2qhXu>NSrhLKf8Q|qk z{;I|b>Dji%ZE<&dZeI)$Mze$B=!w&}_rHERD$Q@?Cp9(MX_0a8# z&2w96EVIyvtjK)*FxS61?1=prR8sxuqV!GO(DT=}{u5Hlh9^^LhLKK?Z@}FbB;Psv zordkz5aQE++|5p>k2_gT=AIQ)fESpKSYFRM(NRv%yIkEgECo1INY45tGeegX)Pl~Z z!zN>(8xfkfwdlm)!Krl}dsatuv~z^pNS-moi(yMOec^1o9Bv55G%68Z4b}oRuEt8k zhtEmc&~_e1@TxiZT1zxJ>^?Wi`}#N^5+*T(vZp}y574o>9Q_MMX&*9}vw<(r&UYTi zD-dD&3j*Ogzw4uXVTtoSBeXExwZJT@@$wwR2Cy$i;jRRFALT8cGNbVbGzNUA>dA0x zBow#gV%*8D_31QW3x1fFGM;0KG zuzd4p>J_8GC6%CfKO@;!^eb{B5l;nm4qZQs>c?4?Au}?0go(->RU*{38W}7PV|hAQ zv<|}cp&SqgomVuDBN0aU`HU4Dorr^&5c_~1ba#+W=}(1`0Cvu~mGd|9mqI~*b;#E= z_e+#ketr!1GRv{N!^;7Kcf*()TySr5xe;UG63ji;mJu0UoaXHMST+{2|kF8+_ zMQ;rGkaWFTS*D3KuL^m8O7(fP(8N^s|D?As?bRQpiHFgW7l{O`^nYhOO6gWFdY0VN zZzg?|I{YZ2=y|cqfR)}++Ejm$IHYI5&h04OVCf-|u=|FZeo|dRHF>f0uIQjq3O zf0!u-lHYJ|4@X8(=iUd!zSUE@jEH~@!Ov^Ki5lM1BYS=dP<(tZ-nW&B_3^YDc9d8dxEdBBJhpEI>SPv<%CcFXk=HODd!twKIea3$I z5ns6OhRQkCX*oj@A}aq%h|VwHj>8i?cmAq&$%O80|1G*NX}I+C87ZYepc5T}8m_Yo zm#RZtAAe?a%p~O<`674t3^ATLitp~Adn>E`$XmIkY4Kj^1bQQ#Mq znR0JUfbKIxLE<@y12#79@R@%?VDMf4?uFUpRM$I5SQpaP5;!sLB_L6S8rYNTV_FzO zLK@J9SV4@;oD_je^Z4$nm)h0EBL3owrlVcE6-LAmaj`S(sr-R6`f~ZyuCa~4%P!1$ zZsn$4=hU@W^#x_H*Aowb%^&%7?t<*GHQ_Wnq6)&tEqB-Lo* zwVO3`fv)g^c1^xN)@HgHRMxg1_LV;=2)^D@nv@)u8aytf%fJ|S!7}n$cR=oP8{d@R zQ=XMhAPKnz+hkZjx=H!;QM}A*0mr(vLO=fur04Gezx`+~6W?yy4h@E@DObvz-mgu%Am0k1d1=>o%< z2}1IDf(Jt)<_aR%3d7U8!z+j)QQbjp%G&;tUtd1;9}14VeGyF)HreUnM=$< zLQFq>M8SD9r7BdN$};M0QWS-5?ChV|c9*DwsSxJr_pCjU^Nq2Vl5t1G5!jGeOm6?rW1`W?%l;n-=&hMOm&mY?y=+4Br9l=&2*9-%|(&}L$ae( z@*mygzplwHsmX5C&}5J4WUq^4ABGe^sgwZS6sjQCl#tYvFlb7|bV}4kN(?naYMfMR zf^KS(YiderY8o_^Dq}h|>moIWAuUfTtw1-eh}t!+BsHxJnpQrYR(X+D&5#bPl}fMI zO$WQCH>Re;py^bI>GY!9EB4e1EA#+qJ^Pg_! zglpzhYUT_y6F4`Wxp0xW#E`Wjm9?gu1>A7W+Dy&bhGy+dXYE~N0rwfQ52dn?b+b=h zv(HnrftS$i>*?&k zq)*Fbgyk~N4@3umj%c6})pRP);jQ?Zs7{DbScH5Zo{L#8=2N;`?fb=5t!% z2NiVR7M9Ha|>T*j;MdX)qfxmp!j zZK$sxMud5LOVZ>Q?vykTFx5 zby=FjSe8NhtdgUw$gM1MMyv!@mSYtY%}1Q9SISETQd2Fd*DFZvbpv&jA=1h`K9?bT zL1m2PpzX3MQesG0X;NA_E_-o*8mK<`YkO~5!3^J+UU|vyiejtsu5F@zs{-(4`LY{m zSk+@Ozhd?hG!a&rCtbdCSvogU;jv%Q2rKW|EYBcuo?#a zT0BqYpmQa?d#wbh7SLBKs81x!UnfRZm!MkL$z37EBum9qvj?k< zo~(1Q^Hz>h!|vF=^mI$9f}o zlJC~IzxW}haO$gVu$_N{{#B)6z!lhD1o8yX=&KK5`~|T|hq#0{s?OGXSrhs3<3^J~ z1JbF3;Eh&lHNl`pb!%{BI8iJU^np+HS2E9c>CmWr2%R-7^NKi|ANGo=G~e32&>DK- z2eSpiK(oXZGSDAa6*b|}7HaS`I26K3#*1vMM+H*Kmm zQ40`0PzRl>wj^93I>?&uK&||UX79)9D7=1zs6ZQmN6Sun>mn11slR2FtVKA2kdwLY zkqt_`9Ho-phSf*PUbo0!x0zA3zqD@l4r}I}ZWY>T2F^8U@3aR3I$nd2L>VZUTf2Tn zyFhu1d4I=(blXz^%*KE^F}&SjtXVpv(@M6=UKZC5fuN=C`YGEbBMbL-@#yki>e%M* z_^sN(p`!bTyjxYEnjR-wX;Aa0ngMkT;fvJpv8N|RG-Q2*!^}r(>*pdzRitON@PlFo*+?yGLUma=( zZKwx^)r;U6{W~^I`3Qx>{{8_0zcGP<2)&^nOM})pLl*|Vcj~$~9q-ErUHcO8! z1))6)?XS${--44)EF;dyG4S<>;Z1}Z(Zr| z&3-DTfctzBcG5!k?=Woa+_*cTbvi4orbEUk$KKKQ^4P@|^jr#WMf( z!>qxtS!2U_@7>Y4`T4PV+9l-lirjRG;ez_oJUpPqSE5Ej9ci<>pafXdi(U98F^ANf zJFn8QM>2sAot4VSTUOL)j- zjNtM=)n!(VX4(i@4xzlWfjOaQO+9wd!nApQkRqoF$OtW6^;b3tO z7q*Xub{^RupclIDDEBWY4^BOIs1^EnDKRy6y+p!?i5eJ^#XYC~2^Q4u)!_bP`+ZJf z%nSRYr0NaU{@YxsC!Ed%h%j`ni zyAJE2Q_P1`Glk>d!1D(}=kmtKi+iWi-p9(oU3JtE{rJ&$g;QqX3z55{Pu?eI>ih3e z=PKC;DjH|O#wX62=dYs9qleBjfETz6n+xAAi=!`6g)ieZFY}`>v#QQ3ybp7kw#5cN zCkY>mDx9|fPlJrF?iv@*qcjh*v(G)Mj@7&=tE0=*$c{TS=S+mJd+ZOnQI`>mSCxwt zW7(Bm(btmiPbZDfXNPVwtFDK#FJY*wT{g;9`?7K43kQX(`l=fY_LiMJgi|+dEuqZX z4JIs5`5a-Q4xs-)(F z<=U@*R#=@1%+y;=*4kpkwZ{?^R%+LXPAq#MO`UP`D*pD2={_SN1p3wG3+-WaVs<;{ zqNZVg4n8%cCqHZMt|tsX#0>143{k!0KfZJVX||_n*V*k}zSxeFdq@{J4tC;=Xwbb!u+i3WB7L)U~-TYe24RiojJROeoX)YXLyk(8wIjOfr9@|~%CbNaVjZn@D*(z1aabdiSj9|mKp1UA;+6pcRnRmMqY+MeL z`78uIYSa;eZI!`=;P8xDA;dwij4<*KLSMM)VjC`uJUhG+ZpIx3MzHigrBW|5V*Us&rl5$>fi&NlviK5bb&Dls-q)$<2vho5lK?xu#nLBr7)X zf-}}}K#>oz!EAmi@H3itM)vv0(^d7l!Qumf;30LqoBZLoYI07^N<$g->NHj54K(k& zmh&A9KFf(uU+Gmmn~)W}IqCWc#Q7WYJy}j-emBe0amHKpqHfOByKy7IiA>&B65*V_ zb2ee!wtTXp8T?*yEjHfTr6(ks3!>#=o-Y+0t#&K5IrB-QcD>ug$j!y}hgbFLui#th zozbW)(d|_Gy+4a~{8>&*HCMM+HFf$1G7cI1kupm!t0-lU2Kx9sI@+<3*MtA~=io=v z*7lnG#}C!KVS5m^A@TF~Nnp9lE%?k$`yzIaaMiwf*YtdQ)*F?(!?e?M9*J(c&eKy; z0IXE`5EjJW5_Y1tW2?LmAAGYQc4;R0w699?a0<5ztw7WY^h1N#iC@zflMxE|>xwOq zR5&y4gjNR%S?|8X@Rsp6@ zxWuNZl}5?p{EnX_Tso?}kE>0lKu&Ru;;8%Nnlq~pD3;)X~0HQ#7#TzvtT9Ixux;yZyc)mD#xf74Klt7*Nic9hhPZP zbcMY?|C}3OlHQR)$T41xt$4`~+Ut2I^ylEW!RozA1fK_hp9eDh_9&Tl(DM>II%KyXgv4l4HC@ra1dQDa}L zf9uPdFtoX_J=U!WztR7}GV9YJz91UtYh&{GH9g%^BS2hiIe;jE9>`);#qvsag{}S! z=R@V&kgA9kK{4eHxhBylYuQ!UaAT_irC6+g#uAY~w7<{dD4P0u{W;w~D<)YArEg9qeA!_`ctpi0H~&Ukdz8qiE`!_CQdL{JCBaTC z$ zmA_|K?R#Ub#C_FKu-BgXI|gu&1)=o*+Ymy;vhaBbHoK!zUt?=%Vie3YdgrhoWp1$J z7@j(xI&lJ@6x?&cXfJFmIKd(jwrsNB7WsJ;@h@472vW|7s6u>PfXx70^`)NY>;ODz za-Tyj{TT|Xb39U^b#_Su0xEhJQo!I^=#VoZQ?QE%&XNz|k}K~0kwhWQx5zkcPLkQFV*6{+Zp^S2JsaOAF@y zfZc)IzVR^I9B17;%9Z`$(qXx`WUJ*ZAPeLJ2)B0}`>MFfjPFfQ%T|TgD2Kxf%G!uB zepx0q=Q8)(>r9MCZh)ovs(}o$do+0Lvv}*df9CeYCt0*;Yu*k*CcnRCes?8kx+VR< z=z@<*GKa+H7onmeF3o23=>TAj!XWNs7N;O13CRq2?t`zWh{N98{nMyBbQ150t{cHT zW?iwXu(^*Vr}u?V+yTSgw%|7x79WBz?H*SH!ZO9Kmx{zf-{KgINY(XSM8Dm=%<&jd z43-1jsWgw5X3P}&U|FVdmzDsy&BTmG2Zf3uQUKgzV;tt(q(Gg$pUiO;3#Tk5_)m2I z@h(!HE&-(#sWh4ahwKmcmgpU`?On5x9gB$N%sPs!)C0p5x%A{M!>ioIIUJqKf!y>C zUBNfAZ6q9c2}|^pOIxMiT;vw1-RQ}PnXylsyAPOgA`Cb*a$Hot^y6|9 zG_T+nWZ>Hwxf~c(F_5l4Edp?Ce0Py+rr*dCEi< zTojKrYV|yEF@|7;?>@Z0{K%`c2Pmc3BDF>I;3O>&dOrhfAGDZrXAADBd_#MB+ z*8p&K@`$(Dc{xm9)-JuQ&Eq6(x!>YFLtRwpV+vveswn2HEe@K6@5|KL!j;Vc@Ejm( zreAImKPQi;(x(}^hEJUF4%b8}D0v2Hpk4`}IDI01S7}nYHN&!!%VftX!HTo^=;h_3 zpAx5!VRBC;-~FUXSrpv-fgPAja7s~$Dj;22If%^wETphx0=oheBnc{-*-+nbp^2iX*&$1mv7T)VExxZ`>F>`rj z7V^jp`#xV&&um;UkKUGqiS>K-Nz^9UOycUpZ+KVP`m@U<#DUW znHtd=E3V`NQ_MZN&~8P5&cb<%SIOdFuF{K4%d8E*GGjm1Nq;N7MZwvCSDUP#N??!I zXA6s%pW&iwjC3^=aU0p=-ukpxG<(vZOFb=uq|okB-;DCs`HC#`Yfb-XC>|^qhH*Wn zUR!uB4)A@&9d7pB9M?y0Wi$M6g?w!@CXDU_ArIL@wglXk15b+$QM3zhGp>!X@8N`R znXxrTiB;#aRZyuvUwYuPSzlU)8{bl!v{JuXw&0{vj%o|nhPkl{)`1E|MXh3|R}2*5 zrP_pLK=myyBW_I^bRt8UyIW!3TeLG@DV2cF?p+|A=Vf7zhVHvdz-y!8aA_+=@vZoG z<_GIW5raO_l%@EueA+)Q;RaYFk4_{{FvbL?#l&*XC$4_@J5$NZ6sXEmy^wxyDXlKr zk{Gq*`1X`aCo5h^5y)3WdRz(&Fv_FEsW&rwGr<$)HqKedWMreCd_MD#DFBBvLA zrKK72RF9^#(g?7-`IS=mK6z)xNNe1{x}Dg=K`rpu%W>|FUHKoO?MN$uf(=jsV@p#d zYeL}E3)*etcb0J**7=Ubg|s;N9Hset0$`#}8|oBL{Oz5!9l=f;Zk|{K@0Y>p3f-S$ zls4@fk8hA$ty#u0&nhyfZ7A_g@!Qe*l;}prVhvS&JTJavKXhsDs?li4A2(~GQTC}O zQ`0>jmAGQ6wFL!2Q}+#BJgtnOzgCDQ?(bYh`29~yo6ue~z6%=woZ|{A{4x}ctu;eM z6{C%05q{c4QTyp~v%Kxo{!(8l8<}So0pL=!kzKIks}9l@RksRwR>k{D#l6zP=AM$+ zjb=bmg^>|M&rD?xZu5ep;Gk+|o{$;Uv5*tr_qp3yG=-4dUn^u2=7}^5>^;iusFbaRiulNQ*dZ_Sk_zXp|tt!3X6 zSI}#zD<-~yp;j!jN`1y9;t^QAbNiG}+J`2cL)NjtLuf(x`_Ci6<&N8`E;)x-BP^J* zMnvyGU-cIQh=q@Eb)T@Bq>qslZ}TqpZ&e!L?Xbhmeo=7f8uPK^ad87!zlWM64@=ij z8L{HBGjRBow^JM7D?9Fz=Kp)7h_|wY+ggIz$S>r3#&>jxn|4Sozep#BdQWJ~|T+DHy7LMXfy9wvf*&qmQ*=R*BQNUtvFY4K6H;BT4d zb?*OafKQsCpE719*7<=fx3~KE5p#INnJrF#NSgR&!pb?L#cu7G-!eJ2i0JcOhNhNF^fYQ?3>5bl;F58}dT8_=Q$#gKSQN9_r*vjgxHS3xmmq z&pj_>hZ_vpuuh+rPh>V<7!!fN0mV(|z$R?s-?_oxg~fkJfPW}3ikm8dO*O@TYJ-0o zi<_B)&Fsa^ox$eb;uiG5V2fyR%Ve-+wzyR>*s4n08UnUPiQCY3fo+DwZ70FDi{f@@ zupL(1{tRq?C+hByjK{FZ?HR*?8ZuLSv{DdD6IaWahvDl2Pc!D6C}kS!48_ zWDF5B1}GK#fDRhVCKbmGjT4rNmw?7ANF^vef+lE6C2B(xjir+A0TFwtWM^oy_d}_a zU}#FTRBAFbHCrmJ7@Afkl}-bJrlX`Xx}X_DQkj#`%tfiJ2WV&(Rx0}pntdmgLj=nK zO6Sr&gyphH=W)aGgr)N(VEGEt1xm028cpd!ZCIhPbdfo%$X@!M$0(-pmM#f~l|)OI zCc{dzrOS$8Wi(aNAP5YEk}mIpl@Cc*Ou{N?7NslEuu80S)fuenPP&>1UQGj(siA|{ zu*uYN!)t|Q>LlQGRQJhrO7MD3nFeimgRuJ~L!SKdt8E7&!9GWcy zD~7|W?mx5OaFh&!stb-7l0irj+Icy znhGb{s1sv_Qwnp`slCFPGwRG+;XIfUbsnv7k&L>?R=6xiT~bsjTtQG*D23}T)b$Xh z!p$V=W>Mi5jk?7u+y$R05UZ~!CvAilZTEur$hr>ICr=~MmxE`wKyuRH_Pm1-b-v{O zq9c+>)#86;Ct(9tkh#h;GRp9-wc3k_|C8LfdusEX_`k`GEnP!X|2MgDS{bbOUT#dv z#4n?!sBi5ao(_Wd&0?;|m<43j4QxCjGRhJCb351MEP`_H4Q)LmGb@k-^Sd_`tU~e{ zMs{9NS(X1`H@=q}zc9Air)k=E8$RaMphMNW|9^!-r~OX~y{sJkKPdG2iAPqLipGsHOeNxU8&lOx z$iGwFPHRoOQ`5}_+Ntdq$L!P%zvJJn|EFcW+wgG847B?nAbJr0UgKJl^&a%$W)Wx) zwgblO!S_4)v54adYb^5N`34BvbbX4!q5wn!`_1@tHv26!#N6fkt>h9r`)yQ80tf96 zwQU|8bkLiZA9ON1?;Lco2MZi_KS_RIbJ)XETz=Rq0NFY06X_B->VHl>X>&9ni7r1H zls(%y8d4+@q&XgbLuY$DqRd@!JgP3Sdp!1mT1oKa-xqD$lW`sMijxTg=UtkU$#21e zr&FfMwx`n;#TBPBHq?;a(^-cu!LzwPleTB`F6avCvjvZ{-Lpj>BBAr806M$#WvUSF z%JY>7iM{jH7$u>LwFD|{yNmS{^U8~j4ClQIbWSjp(B)=9vfbrYNpa=nb~$7Zc!{a* z61v){pR~K$ZA4d6U+p2z_O7taM8enm?R56n2f%Lbs_Vmk3GDUJu#)i2G4P+Z{msdg zdDYG7oHO?347e05e0#o@Y=3*PSzL8{3EY8TZ?E>dgzv78C++WUD9_PVcemGP*n0?) zJO@BUq}+_lxbH)(o`Xx*-ApJA-1nu5$ie4UZXwm%_oMFz<`7DBw@|w6`?Hhh5-TZF zx6-8T2k@xpl4^Ii(!%xwMSu~x4VS@z=*tu#mXIgdIw=T{du$yV0VX*+d=p@@_hO(<<4hmlm`(O>iLY5 z-JRmFgGdL;hJ}dC7yG#D^Aliq#fSpLC`xPbQVN8g60S8@A_dDss zSc;g4f+yT6J*s+#aVh=m1zZw6JsNI@@s%a_hQR3>GQc+g-X?(e#b3Y|*dZ=Ry?%;- z(4ki^?^)8r&&|AsTdMb0LoyASA|7*fpJAS3LUTY0E}8~l$0g^Q&0E>HX9{LTs^&bh~JIz>XQv!DzpGg0A z*irK37kWHKp8=a^zX4<9n=R$KLqS$zxuk;hcp5z#M9+VxQ_IWI6LNeTGNT>KV0b}K zz%e}dOa1|%DZ*TtsA$mPReaKg0dLE>%1Ck+P0q7NDI$hpy@s!z7vGLP6*FcW$d<%S zG$yYue*H0YB>xtZPJ!eW>nr_Kd}v*8MsBXTZ28qY@?T+Rgr)XHho+5NeX%49sEM%b zYaNH!mkQ;uPAB|{yD$_FQ%ny3M*ae9jeRwJ;_0mnx#iXG~&eQy>eqUwP z;lrW?vBF&#kK*v`=raiH{iLGxwpKpjnSsjqwn3!(oFmD2!OU~p-(w9Qp{EHAV{hr? z4pcR!Ma?7(1zEU%_+A@RYI(N?a5MDIM#I?p^4^km?F--1_fxeCut_+5M zHC#8+Zsr6}{AHOj#4o1TUzW45;l5t6$kZ;|s;B|aU32AnZuG?&t)lD8H{Ua$WkiSJnXDP)%#|6-U* z*m+oIXCO%^t*D1qrzd&-?ELUF@yf}4_TW%j4~ z;jsRcm-uMh?Rc!GK3W|J>Y?`9%90Z-C*Yri7YHR~&xVb@6)9%aHWBC1Gd2_e8%Yk- zbJ6Nc^?y-EUuv;){VDcF@{xHt12ICgX})aJqJS)D9rX(@rn;c%ri2r&eZ`g8y^fgYjeVc-g>=75UngRt+Tx ze8@9l?1;P<;Oc=1na&=wRL8p<(+}N_~XSzv#il0;QWB3*$@Tk}SVa^QB#fX!smj z;|x+|xXa_=e<{kTRsxL z3xb+1J&?@93@jY&JYpN&bi5p)8Mf*y@(WCKi<^tRUIZHZ4Rm>7^A&t|?Tyc;@BBue zor=RdKVuWdF*zrxld_w-!oMc@*-n$Ya1Jevuh!AuE?d4S>TgB3t&nw`PuhDOf`B*V zB6qlh+zN9;*46_LdkB|%*ig7t#vns9Cp{YN|1y2PPy9AbRhVwisE06QY4B#mW(P=4)eB%J_w0kvY(}~GMX8Z6@AB>$cOeo9w|X!AX`-D(g}|)2(<%l4OI6coT*c zS(T)v4c!E0uCNbxb#j~2H+L3 zA#wtD5Nua`k+~A{XQjtU`XbTKGAlPU>q-hpR_#-A9K-bihv(7H_CpfjacRWqbexEG zHkouZQ)A95i5g|E)4wT7c3$mwO1L=9I#*6P+c7ybZrRkAIYzp9r@S#dX-;%*>DW#y z4hvkWB|zmrb@ofjnB;tOB_ICYJnnm*T{TLM@sHpoWn+(Zpcd{UDicPB{N0@V4~(t? zRE1x*W4|&M3OW=@Zv($ff7VjLrTy?3kdvZuNudWT{Onc~lU!t#_E&eN$byKKXVY`UGy7*Ubkr!W~qu!_AA?jUP`E1H*#U6esdp~g=_~0ND--j{sr0RZH zdx%3yrQ0r9`fK#lq_Ubaq1J=uQeVe{-<^hleLw*p5BWaesZ(UidM0R_t8HDIavgAx zA_cO-#d#otb1z7HL5aP7IU9qlOH5W$381p{hXT{SF zld`qSF+kPowjZu_)tOZ_O&V3g1OT`XxL5hou;Wly10I3^q|E?w_VP#jAQr`3cK{Pb z6uKCSPVtrgd!$p#MV5fjcILdQ>7%Z7pQ)|Vt$oIy#iL)h0<9ImxC>#JMaYUJK@k#t zbskdnDcTv*OwKQ7>)NTaUa3XB)yuMRO)V&`9V1CnpJmo?&;H01V>I0$)swTv01irR zXsK!Vt`GLM4EqL3&`QrTHM3^6kjb?LB^s+Wd}2yNyVGMRcH z-J%d&UC3vdM(8%U;{p;w#vDf05I_d~d)1gqmN6m)_0*4xyQ)u{ZA@o^d8IH^Owb$MR9FHAQYTYatyUX#m43;PUJrtSnKEkn;BD5hAK_3GKQkyC0kN)2 z)`z*xB6>hfePqbitJEPi7;+XmHp?^)LXM_K=)jTn(@4|-vV*BfP^xLf9jUjWhHSyW zmoQB}>F_Cj`0-T}#J35ihPbgtoa;86f>1>*s51jb8)7N^~czgejtX+{rIMlh<*!Co75$OwdfxEDZ( zcF@{%(5SZ;_cdGUWwak)+KoM0CDq&3+#7$8Ga0D2ipZk0uiM|ubF#c%JLiN+Bi$g~os(|nv=j3QYUj#Z>*PS0yV4?|mX9{uqLJG1+1qE(Q)2SkeI?s(jvCA-7^aoMd?-#E$q3f zgq7SiM9I`FuR~GQtyGeT^2KPDk}2EDoR~Iu3CcjQ1IiFl2K9knR`uKp4U9hQx;!u; z8q6uYFi}8E_JZKO$W(n-rTS2G`%sgu`750%3NO_NPSW79%3mkSRXH(7*X3IA57*Q zq;?p*t}x6HJNPO9S<+PXLfmq+s&M~tzWLU$DV4Gn*T}HZh-LnWUBJkB$B4_;i1W2k zcdpTgcSpTUN2xxIzJm`tZH-z|joIDn#|h4QT7&jHVAv3Df6cE5qYBxWO8IdJ*Z76} zxVhu;NsSq<-_no^gTQsBV#wl+?Si=2!vnjCbpGd-OWt#;_6DqBn_8-e3= zd#YM2Wlc(Di)}%`TLI*L#|`a%Y|MT9N3GSYPE513x46(63H6@IhxWdmwg_ zxJaOp2⪖CUy?ddkFE%T0fb$@y#VvU<>GO$7B8fS2Yhzt5qa^X&eMQe8NN$&c_s zgSWZSg<*4|&Wn9>not}ma z?!gvJ5Ts_ct^Ck=3)tuP(@Qqeq;`mT#lRM)MBwP($`dy;Qum$~01&*cU2@QC_2XWB zSup?3dFgrl>PNjb@BY=H_O%4;6_M!$1=yMw()t5POQ3S=r-CKt!1a$Q>zn7RY0_(j zdMlaS8>P-`ke-!T0RK?0Ye0KeYt5E2YB#`vn?-@EpY=$~xSKr?%LajKN{CHIye(Ae z&1TqoSHZ@cz%8egEwP@hTr*g?^j0M|Y{7YRNN;PMVzZ-m)3*m^e-oyGScys8oTFYV z)->C8aNCXxgbfsI*D0@o-jX2mY_8YB5Qyy^1PslE1Vd&UPh|^ddlPGBi-2YaR(x%3+;dUsKHc3<%EkO=Usacu7hz1fx+~wJ)rs(*@c^F~Rk?G&A{?;n);^4(xd$|AMqwNpgo0iFx z?Ydoux-^IGl846nhtGJ9(k?8tgN~l>b~fvDHJWtRo0KuWe@kMW1axVR9g~ooCXb!< zPuO>kKL(xn(VTc*o&?jJhRU2`hU=e3x|~KQoyOLks>aWpnn@1E0==$@v9z~?iTcDb zTVqMOXBjF*`7USAL(cLF2VW-?<@8z=y(N~J$XL#Ip|73+pPwH&&74VHKrV@@3o~kx z+!|d<1Tik&n@#oQxLjz3T)e5exVLfPR&_ChHT^~fp5xUjyS*nbb9^i5Kokd|vD;E_ zhT!g6UnxM~u^`r55KEVEVVMJ<4Jby(-JfOx)MDIlW@Eg=b;9)#&GclDl}C+T(iDYwvc zQRx6O`iry$hL`D!jITZ|F*U!=SZ02cx$-82GBq@bc2+H0Zhau9_gcyv*6N=&`cAq5 ztr~|M|3Jz8(*FHf#;2mEz>CvMykA}P1s?aI`9nL~k128nY@dk8G}cNP*#*`rfNdFz^eEY1(H}3 zSw`{!cg@^TF5@3EI+10>kbJ6sbq5u)wVrC8>G9sy#xc)q{Rh6GU37WT}BtZ=ZR)!OXHr^-`Zq0K!&p5E4B`12$EVJrT69u^``kK z&@o8+^86Ic7^V~`ewyxZV?cHyn2-2$E}UX7gCzXBjjqhR6mt4*HP!14iqB#s!!6dY zHaZ}qQ-^4V2ETQ-`4p>JVwh(7c=)@FZXia|mBJG8ir?3Tq!hqhVyUaDUn zT zleB1+3U=IUyO|8NXeX<_;P2uq{mkD@H=fCT(!<($Ax-NF9wC*GU~uIZ$6poy_^kNS`jnxk#U_< z`d*Qj^aUNRJ8Lr?ezEsGNqYRStY#+uVjl`+M|#3E*3WFJfRYZSCoawU%w7m8`uO%7C2eVp)N{R!_04IHcwN-3@tT@mb%0RuJ)he}D9OQh#KzoI)?IwOn zuqT-$BRxTO+g<*WkU&R9Mh3QaIn|QTNGKySi9mL{l0`{a;sqn?ZMKewfR&yh$;7Tk zlHKvRq$Hxik%{vKTc_qgNn{0-iQ6)pv{QEv@X{}sc-`5$p5d29w@5Pchh=xYBIYlR z>2YKhOk?XdR0X_rD6?>BcDE_9MQPmh1+yq1roRgn(yDqCOA)vJ`A5);kRB{XVK?sl+v` z_XG(3#Z4#rMNa?QO(*ylIi29Eo&JZMj{grk9ZvXfb~@f)diqa09si$tI{sHb{dYYb z|93zAtDcVc4?i96yP*D;pKh6pH!%eW>a}n#b3jm^!>2JVUVuZ(gLxAdCNkjIFN2LXMy+y&K0piNRM?bwm{;3-ip*dq}LfCkbDK3YxO2U zRiCG{U@AfW>RoIB;DV*IU^)ZWnw(nI09ItJV1_{cnv!MJU?KurFzYtg`oplQp$uuE z95u}R^~a@EUkaRs@?LOZZ)gry4Oi3(6v6^sE=%LuGE}P`AFdUyz$(q(GPSIpm_`VLo4BxG@4~7lm!(Cj2J>N7n5ETI zFlQ0y0@t?PAXfGCQLRYre*U)8e)S9%0wGd=g?q=101AZmu|yjQ3U)jMpmW$RqRki# z+`B$%(0SrI(H2aBg53a1=mPbnXxnW}?!AyO=pwU>Sch5x)?P#@bcx$Vtm_5$ehlUy zbXmAgtjDroKVcsScvQX=>vQKmNFi|Q{YwNO|6K(5it$&IfFj)$8ZHp>6?1d*2M-?9 z)z#g&af6nYHYg}4DJf}tdm9%QcV=eB#l_|6)2F??y@iE^yu7?u>3OeEv5lH_Arn#| zP#e_`f2+|0yEC-#NM*{j0M-a&%=SMjO#(%Q?mcT$_LgzU_E3Sh1|=`FsF#DU zk#gqVyOwD)oyJz$k&Vr&D^Vu9?X1_6`N8xyoR%hxQIz|rLhg>GC}MA*0=;Z?lUuz4 zBneu*M3?@8&A@!IXy7IW*3Rd9Pq5+!uioA&8M>{OfINI@)Lq_gt{$~TK6=doZZ)JT z(p}N_JV>c*?7gL@b?v*({3A4*g*clDhY>DO11AB>ljgT|n@7X;rJTGAt`gEGi-|-9 zyK@XI1$*-^!UEncgnQ{hmW5v^e^|nc<3?VeM_^U1r?w(UH#lY{#8xVsumu+j6->F$ zw?q1|OJR*JeiExmmQo@+T{!BFZOvf?0-L=yTM{z^Ecp)mlh@U2R=8}W4-aGZr>nO* zfKw-Dj!rVb-*I6s_DAELGIghMxxEO44P^W5bgnZ2Zho&Wb1ldPHQetq)vN9T%iL!E zJn4!$f$YsXPJsRsKHZzYE(22Eakqs$y)Ff=K#9b5pV6j9ePTe2XU1#6!}ttMsI8H_ z<+BIUo#_Sowi^KY_cCdi!*&wAO-~%+enauO)<@erjMxvoRDkF(ex?R7vv{5>ieoZS z@ta#oRKdPa{X2;9M(*+$;rVH9&=EGZ-OKkcc9fKSLAD{%h*q&1sA;zVpC`TjIlJG- z7IX2@fo#1Pe}KIK<^;_r*0y`MJc4r-UXCvUDnOz}Ur6+ObVdR2J#KY=FaNG|)}RtO z*+|t;=ks~iCvpVVYwW{Qur_gGi$C+zrVO0LND z+CLoMPpl7S&damwZe0sY`qG?vVOo(_lDuMcZB&Kj%zx{_I_2dA-UH1ub!UDkLsR8 z1eyorYENj5whYV{c|kW+!g4fPwT_5_KfQWRRy~xZpm$W7-L`oT>HNvSJ>s&gR~`mL zY?`IZM1|P_rY!0ba)XJ(6=@OqX7_JuL+Mo0@`y|fZ4`7=4um7}4>@0nX-_U|&4WD^ zVA`GH3c4#=QTTl|@$We}hT8RxtNq%{)+;)7vKxt^t4tqlwT~yVtt+b50=DUfwX3_a zhiknf3Z1cT&dh}z*A6YgoUc!MA;=4B?49y$6wtb!N|2J;Us9R5c23Nq56(3VH<_C4 zI%s1!j5cK5DKKZc&^xgnDW_3>N6w2``&soQ4@14!zq5Y+_#L>30AP`x()u&q#ii^xa4Q&dqOJs0j7vGDWc06^ph{#v$_qC3R(ZDW^ZdF|7 zk0UtkGEgp!9qC#YuKwAhqg)aHqeqt*5Lt)0%;qrG=W!GJed6_i)jORQsR4RfsU!i5 zTydQP!@m4cECFw57*3#UkC-yMH`h>CL|PfC591Z`=VmL$E7c!c#cDde=pG*ba=Qfh zQ|@8HIsu2sh{<%hpOwx$$rbJHM(&DY1~eTF8iy}*fiQ2Ce3R$r#K*SnrW?;G$S2y?@+IzSFO1LNzCA7!Qr>lExR^sC zIcEoW+Rx#!SCs{-k`2HlT3zNo*yd;*rzoA9)e zhOR!V74!2#J49E2ui_%?26MUJ=DitRQxUOgLz;eDK%u*Kj&eTmP|v!!a%xkuR(yOc z?tM(tRHKE>+vW`sr(?nu-DUA@iI7WdqrwS|op&M^b7R==E9;zA@$XBzp}D-PR&cHh zC$X62;dyZ{3U?qVgFl6dsoSpZiTy!34m?Gs?rdNIeh{K8MRK#y{(d#$2;5r(v*|sb zKF0&Xhd8x6UMd&!6qhG$Dz$qIQN0(-5|^iafF_WcWZXM4ee^ONzNxJOZ`5)*pX5n8 zDy5M?z_t(17kOOhT6vrgx`}IJd1HO{e(&Y(f$BrdfW77@k z5x<>qc_=j!2qb757S|togDjlOhZ+S0x;`70M;9)?7P{yL0_}vOx`qR-|C?geXglGc zr(s_UbzxfH z;}rC@PiIGUeJA|MX}CKYh;S#Wz%A+x85MzR1iotomTSbr>}XfR$o_udQ{i~YNE+El zcTdA3Ps#mHDO1)!+@~>E$siR7NGlBz{Ru*eH4DmTLr%&@3OWsUK8=m?iD}0}QjtSa zw2LFJj|KI`K+)oA?I>F`V-ne7>5HOFV&aSN;y&{sVe>{Z+=wTy2LaOE5E;dA4CN$Q zVmV)8%fm!&AV}wCM+pS!${L~r*eJKrl6HL(>mGjX(RHFxe$CO5MIq1^??|jQA*ys zuQ57gdvc6q&@3bYn>=lNER{SbHMuwjM2jSd4q})M2PK0<&nWgZ$?L@uD&&$q?85IF zrPIceg`I{I6s6zDL9$!}NuQ#L^UW$SL)V$Xe{elk8&m?J^$Z zq(#%EgJ=^#ym61vGh0G3Lf=J8qJunQqV4S|9nVOu<)gI4vmCUt0z%S_4KsY@v;1>N zuiS`x<`XLtQbidSOA7Ic2#(EeFV6OQ7JvIJBiWZSRi4y(Jw9`sGMk41;==wiI^ZbTs{P8pd(qG;t&5sJ1lYFyDmSrI0N^5t34_<9lkgc2czcJbu< zV$xhC^6KKzp8>rH~!+feV3x97lDjLiMd|&PHCUpefdeXLUY@j9ofQY|sD?V_GWTFq^-Ljzdl=vY?Yb4e+TGmRjq~zzAWJt<6(3sX zHc=~cHU8>gH5001X2~jAH;Xm_;YE*R~1QiWvF%`J9f3j*VbWnwQzQ^ zck4iVF}uw`le*J<$<7gZ6rM_6%?IAQai)bdr4;7=2S5eJE6H*J}D;1AW*^U8qX6=v4gy z$o<4l9Hi6<{dpz*l=AXP`CJK%Y1fL)56lWuC@#9`npZh)7Il}C4Q zS*N=Xvq!X>rDSq2Zg22L0^2P@Km%|Zl1&&|z1t>lDi||4#J|wQ_aJ~*MDk1St^I%lb)^`MsUn|V=I0gYj95F3*d~nF z*NixFeTfM{GFbzCkQ)7{JBn^T>YXs^8!$@PHyQ{Z4c;0>;~fjv9gCD2yB<6iTQe4) zKZdt6mP|FCiZzb%U_8TVJWF>RuzpbL#|x&$i{Rr0HJ~!xi3+C)jCv5@w?L;REZd_j zsq&npCYyC9Tbw4_;6qVJqmE;fJ@83D0vM2*8q%E_cA6U1ood)ZzFQO04xgH(nx2=M zE`*Yi22ZUfOt05WZ%$3aYNqDEhOBba2f8x=ObvmOb!|;wOwAzRGa%R`+H5R8P!RE% zK~0=Rr=CtIoCIPrn|*;C zoo2m*u*xqz-6JrT0g~D+{ei9NX(kkESR=xJae95(83tf|tj%c{3<2Ar-rkemKET%7 zK62hZN!&iG-M*OKra&YTE+DqqFp_L4o$AbRf0J!brU3Vs!OLH&Jt zi+xcXdhxn_>VbXfUb-7J2ju(*6nBE??ztQghaSkQ&?(Lw;O`wgz@bx>ImA{yq)_jr z)kr!-FFDi+qSd=RM8Q9Lu0s1#{|MRQh`<1c_IlkB$H0+U61_Ri@q7JaYiyTe+oWUr zx?{(gW9Q3bOjnwd4>Bho^-nxqPB6WbPJHW5{DFqt)?|!@aaqVEDZv4 zUIwwKk63m=tR^8afhVxd83gPSu|so-xhHdZpbz{%aCwq+`Sh&r@?z$aHu z58!=Qyvt`%vd3>(x-am)>_E_}{7^{r7e7GM4s!ZStnNb6sRsU=AArwkutVGU5uJqJ zWR-u1uGt%(<@Z=!TQL3oP)@@=#*T3ICnY9pdrZFw0^*!Ddz)+fEPcP!>Y{!a0)QZp zZ*<5$oTpW3v9*54F6_Rvsbs+ z_Lm|MD`E-TJQY~&OQpFV#%+2gxc<`-0D^UI1`ABjg<+FbHv6zGnxB$@*@fudVn@`E zV%7Xp?HZ(sCOHmpIv(Zw=#li;z6~E(v@q)~!$u7QVfdMqHd7@Bgp_L}}fw zk=xRK0Frf-Uz>F&v>$(H0=9IXWJX{*FDm;JI&V4(Ur~CWAD#f1-jA#5gx;U8LzE#v zXxx^9A@I8>aCX8F{Igoe7%ImOVGJWz5qvoc9eWjPT7dW&YkG*hINPTP zt#9EvlK)VyBl<_W?rXV@=wH)yME}~ZBm5Sxi}Z_c;w)1MB{++HepE+T<|vIDSXY>eN_g?jG=%sKma%cZFGK;{OwF}j{F^L zxlsO2e6DZOK4{RtIqYa0hU{%L9M8 z12})k1K8i~fuHgK?w|Gm&JTUyr#*oCr#^u5-5>Z%AHe<7AHex05d7^A;QS#F-~a}} zuL1$~?*;)*)9(fW_OA*-aUj-j3IWdF4#78t0RMM~0H5zKtVs7aR`dWs;P-VffD8eA z2rxwe76Jqj0EYlR1ZX0_5&uGr07?WXBLE!%nh0=2q-!Y1xWoWWL?))aLdI}~f*C-J z-`OP6b#x>&EF=tEWNZT9OZ>0%Se#y7^zOW&QVKS;GIkmi0e5mc{eS zwXFZiu`CwgT2{jU;8+$@;(tGu)tN~D|8^|v`>ibM|8Okp$E_^te?OM>^Hvt^za7i^ z^;Q<;KaOSnax07Kk7HTC-^wEY?O4`tx3VaHJ(l&ytt_&i$Flx$D~sgEv8+FDWs&|k zmi3R7G6`@jt8Dn6D`jF}ru;7}Wg=jv{Lhs#F%0rQUMUm)bEUldneZPgWniR?|L011 z#Dw6FmGak-GQsaFWpykA$!-$u&#zpa$Njg;|zT`7MXDdYdTQvNKZslz$y5;{hw>|6e0z zoZm;v$2h-_l*zGw8z~24|2k4mXU6$4QclMF#Cl(@LlrxAFFZ@}6Z@0)4)vV$IoXW9s7>&7&((UxSJVD6eeYj)Cyyx_ zcKFBkmo0Y|$uTRLR{6(`9b|OWk14&&@Q+T4IWf-?qatk@;Yg zRB}PH+i6)XX$g}>`ZA{28eK5qDwlzfN3ilo=AjqZPOB0mWbFM`fOtkg1iGh0+#hi3 zAe9=fasPHzaR^!HqZD!}wl>8aRlnl0PlOI^icHLdF}LzEScI)el`XZ3JXPOkN-jm+ zBoXgRJ~~KdTVjz7kS}t2J(#2@1%wLGsJ>*$A10FL=R6MT=t(ytdZr6sZ){{&_ed$t zd%;A;0z^^2l#~0%nZ4RjUFdjR#g7^Dxoo3x4?Ab1_Uy3G6)W%Qb%oJJO!A{5Z>7BF zB%Pxz;#fz;!2x+MSro=Pw2n%!qPz26i*%0l*dLW9V`1eR+*cSMh&?LHQpz{Ne9}2S zTy<1l6p(N722){TeEg^a+?{V`(}@q0hAnawM!6;8M3^9>)f|X_QYQ z&<&I$yx~8vb+tj15?>NvnUcFv7)ObpT?D>c; zj~0z#uATx`=)P3_6q}{3iKvJ7_U*i<7-6Ih%=d7813=G7{J_nimUsSWdGoA8w#_1* z?*cubFOa*z4Lr_+@SwUa2KM|{ctRid?l*0Rp$tC(TSA|6+-0~&- zh_vaUaA`msbcPQjtBYJ(Ai55^%8pdP)+l+&?RBXa#B2hAh7?c#&YFcdR8{(Q}aaSZP5FXJSaXwLqoB9c$~& zlZo-20O4F0tfxkGRfA)QgUa#%6SL!yi8;^mqT(q%3x$bEu6s7sV~$(TH^!$nVk*F` zrrQp!F=HpMgKF2hw_h*BPAw^6*+L?=7PPNGdt0W<>xj$}-Eym^@sq%+49dIv8kCd! zwYCjdJ-dNXD+{P9q78#wd%K(65I*r*~o=@Iao3^QLjqj<4HG1F#ozyHq zcbwlq8wh$mJ_B+bCW+)?fcqYBPJFN0qz}qL_FI>T-IYS;v%)$AXPL}}z}tsT7eV_@e?T2xAUgAe4SPD}S~6xqBfIHKKZS3bnpnDez!2#mTBw7C;`>x#8Gf;3)m zuG$G|yze0SSW@gRtl^?T>x7l{vAW^o=CKQdtR2gVo2sbCZEx56vX0Wx7OYBO5B^aP z5nB&sQio@4mb|i_MK?W;-gw+?^ZdBtsbJu#q2a-F;_-IG(|VNL{N$rWo0o0WN0%&e zH`{B+BOfQOfAq4Y@L_%B4)&Baa8?L9@wAQh3YEbn zL)qqyKI-G7;g#j>Evn%YdE#4yYU@ermqur6ALWtB>RFfNm)GFuGU8XOp^MuX?cV|R z?;7>*Iq_!`1zo)h3eZ;l3d{y&1McvDxMB`q!2lFj&^RljXas2HB;a~7Xn~Y*;O?T`j{-u5pw@Jo{QqLSkP!fIa$rkoPGT_9#m2t|8H#eF~1EW3RCP$sc3LCs*C zouIL-(1%aF&ENRT$c8}Z!-T}b?}+KM_yn%ENs>dt+->}n#zM>SL3aUW*{&;s7KCFm zE2%(c%h4X@+$OZkboD*}Eq5ff#okGy`3HIn_b^?3Vi%~H9ii*vCuJ8|PAYuJdKCo} zgaVQR_~m=FXt%X+UA0GVPXp!+A}xV~UG1TRDP7o{ zTdPk_cmOR3?Sa%>Osw8%=meYa6dLHdYg~V6VvShb(pZds8fcI#X!u3a3-2W0)>fvW z@U9(bCOffPENVe67Rg@t0xf9oBnhJ<94jYD*D^>}%T5a98{fK?LUR^pNRYB3nnGxw z(4(1>b}WQf6nrWxNmUgo!{-$%$k0O-*Z0<&Kr&lTSmoPr8eq`fNWC zKS%1eQ5v;(`l49!Q?;0xRS?f3ElDME-2m1!ft*i4DjlDM);}5C{`AuB zlSWRefmp_ZhcK~SAZdHLnO2m|S%g+v26k|`^?DjHdS)9_rs$Z|2l1>|^6B@DGM=Yp zL@BpRdBy7Ztfz;`XS@=~oQV#59-CQAmidMn6eE~JtSu&KMObzmo8?{ePf@Kfv%_R4uYcvukIHfJkE~v%|Z4PIu|cQJI}dB z{z-K``2o66u2$wUgYTmozVB>3Wse14#DuOa6vqntGS#o{z!6jT8zS^-l>h7f*rKO9lKC%==SB-^kQAF_? zmr0413uTpIQe-|6h6<(Ln&t{7CAmlj|Ad7kh79 zu0Shv154?#>`Mm$W)A3+h`XxEGt;D^XQDQT;)?O5wby%b{qE zO{UMGY+$1_I5+c_WpeY=IO(O%>Int#%tVP|X`Xm&O%;acGDqF{MvcB-R%3PWD22?n zLv3Maogqd=qC?TKUtJW3+hlb$s$>y*m*BZk$zpEJ6i2lxMML~~a(QR{ky`z3b;F=U zol&kJmv88)E=D#vX9EJGj!&@x1Qk3!D>%H}$m|GhF)pUn5#;a>uU5P%Q%peyJuWIZCTSZYkSkxtpjbSA+KNdmXcbo| zquLgD!H|Mp#Tc^5kjtJQ;m=UvR~_Te&_bS{kjKzzpP%B$kaQ6ctZ zk&nqRZ6BCU*~wp0RiVgGdJ(V?*ICI~P|L{xJ8NcXz+_k#um9Xd--HRmitl2lFLc`M zVzFqNX)7McqsOD@9#N!6@hha9fYOg$bhE3r&im6VYxl6+>RHvH2fB=m^ktqDn?1Ax zJ$sb&CfdCW^c5$P^gyYRcAyP$LFY_G(MOXCj+E=8`Ot~lO&164qe|_=PN2i0>L=%i z(h(-mbt|-!1kjCC_mhP7gZetDmFPAl28f>5G3(M@_wAx%+oENi8o;+`7{?q$-5U^u z)86Dv(sUf$z1iK>H7H7zuELUTU%1dL1DAM;iQ-3DGG95EU)lmd%<7e)2T?CNUkr8U zNw9G`JgPBK$3j)85qUuMh4xz8+{hOS60~OtBI}zH$Q8q@e#7QcsFtQ8ju)JgZ*@na zSw`+~UALDK3HBdh@KCvE^EBKFUAG>6HvHzjritYWz%BKLBS(vls_dvafkGc&}396yp5ahFlVWx>Tr}*}%$;N=;T>kdTi44o7M@>!-A4f!u!PTZrJ@m?p z^ti0>TStR+zc5x91KiT~{`B%x`9#2jNv@fx?&(!5+oJ@|D3ht^fJZp$Gk1@tk@X&J z+0NW2ncNDP^-Y*DmwK>&XIA?0FbI1F?{F4dc}C&Y92sl|J&=>&5-SZm;whEK93jf% zE~#lM=V=V-SqADkc>a`-)!dyg^8(IOjEQA})YC+;`MiMn&HTqs-RzlNIW47MxLC)z zVS1kf#yQMp70hOlrxzhd=@0Z;SOSa1%@#DJow1CEI~zOhVoxfOEph5KxtrQ)PcKLp zKveXWB8c145?+hIs&$nYULxEJPdIO5ch)m}0gi16Vb{xIPutYay6deHz)U+59(e_> zPUqJ@ksjf4nEZ^q{7lav7<<(}P|baMC=jt~fxQkK=NeL9Wv{Icsa4AkTsyuy{&}yX zZE*Gm_j+2vdUe2ZQsVOc#D!C;@^@yyxM$KDwyaRGk=J9Zca55}x-0Lponb@EY@Hl2 z59wJnVsyR6hD&NrkMY$*({E6C~K^@n4!2Hq|MoSd+%?=`C-=COfd+v!G5%Q|`;XEj}`9uHzb0 z#o}zu3tKEur|W5ck~B^)yd{sbg0DC0uQ`#Jb@+y8H?9VjzjbuA(TVe}PGNWjLSGE= z$my->^PmlEp7pq`h}%h&|kF3(88(WFHll@yIg2HFE6ejj}P$BUeSeLdEPhc4Zr#UMmf1nLxX@{ z&cWyFXy)`07ti`vw+*g$wr?cSP+7qd=STxrc`lD14jd)X2q#<~^VHwK#mDx!yMPz- zh&x7Ug@8OlIs5SjQQQ|Zi0(ECC2>&ut9`Nr^};dJV+!i1-ZiUps-%(aQ+Uoi=mu_b zeeVt~-@|ty#r?Z@{HhkAB?Ei-0&13FrGvcBaA$8clnxybnhQO)iYWhbNF=Oo9a%A4 zLC9OS6InTOOd_gb8x0;kAr;ev2+<$hcq*o4A6qqkMlPZ45Lc~jOKWaE5(k~Upp?>e zN~oEFQx)6okkw8jsAcqBlImtI^>IEt_y{5vw8edP%L9W<`2A}84c6;@)an{#wsoE|$CjQ?6OM!;~m%x<5`-3{0L=-`yV;6ci5A8sSI zHl*xSg&1M0GeV)^}|PU6Bf3D7HrmIS`v_$v?WT`S_gJOb24Q zHvDRy_R*DUhts3Oa}sl$gPaNi84O85uHmABmRMH)ve$rY^jR}tPgI5 z;5OqcN#cDDz(u_ujHTv?*FySg@_zG8*?7EGN~_8K*anb!NGq)eWvHz7;Vco)0jQCv-sBOb~ z#J8IHl+t07VZG8k{+7XFa1qMzzFUV^6gY1a>4rbt1`##a++#`zADAV)Q8u6`G!-s% z3Gxmfd~lb_im!E#p6JV?$5WkZ51xoQ5e;jU^AquFn${4Fe4=gU9?^4x3o#>UT&3lI z8X!eH=5I}Pwf<#-6S0=-M@o_QPE5ZTV4@ zvDnSBFT;G(E5ptdFp_R&PFU3zqKIp&`f%FemP-FJ#ZLWh>06NoEE%0zbv*s}-EGwC zPoXVLNjl<}<7t%pT^T?E}?2~hOZ?Y4BVFBWgL9059c>U+KD+FeKoT&NBpLa z>bQ2tN$Pmo$;J3C7KaN>;z_Z!SIx;{1kU>MQi4k&t6zv$;@P^A(e&A7xr*W9!-`Al zi$LmYuAHko-7-lju*~~O@S_bf{RhWONi<0zRLL}#7ZYy!mk2lui{T}~EGU~#+Kd+u z*9{G`<1I0})Q3{>Gr@Ox@cW;mozIiL)sV0Y!3*HWWO+e8)^;l*p*Y~()f9L$ z>wO9Jod8*2(TGP{D!zvoe5F~B7EP=^k)8-Y1YxN~Pv%nx62=eR)PBjxW>}m3swgz( zB5_!7Ol~89r7tYy%QPeWjMr7Fk6gu zcG>ksHH1$t(6cDM;P0`xCArMM)glJWO3foJ=>@YP!b5f}d9h=U;x>jemELl9`-`40 zixyf-S|sipHBfvM?H^N4iy2JV6&$A(TGH%Ikn2r6?IM#p98>YV2D-Y-@Zs!2v(M++ zQ(SDq#*>aW54`X(uRbPYjilD<|8Vzx!tL2RNq5E8dHf4ZP>(iNWls+u+L?x@vpJ1i z%Od)4KZh`5`FXbCtyzXwErgk0V%ff2HJ9_?cLfc0tvv`TS9`AX0I6EMp5va+_>d0W z5L1*tCX%NWC{vF};~rI8CN3}kO;*C(XAwk?D&_O8Vsboq{&+*9vm>#D+BLyAVk2vO zd>94Qs=#jQk?!^9AUD2sqN0>}W9=$3jV5BFA{(jA=VG6qAQMIwhfIZC?-&QQ$rG1a zrRE9D$7=Rscci0%={DcUIWYU5m6v4;J-S23tM#R}LV*MfGBMQFo|q=CgnS{|GP9e| zoq z6&lh8H1(YiWRp*?%ixnX;zjO!WYc+up--V4U}(Rs;0L)I>*z84xZ z`I1G3tONI9ZC!S8ek+~%6bHgP*NJi5&xJ^Ue@1uT@N;ZXSa9@UeASCi2 zoeg*@)+ZmtjTt`35SuboaUmZ{AwJBK!_hTTOCtZ0t$vuJIc4;u4v&1eIPx&hP}f*{ zhI|Avd{}^EH)X7UNj};@d{pG4Yw`kzhGMK;{iq~n%H)j<#dtqXfc!Z0|4BCoiIT5fyM8y@coLjY}F)!w?i(sMs%aQ`tm93|K827Oma$UTKcDq z%&hF3+`RmP!lL4m(z5c3%Kt{ye?WD zOSI*E4OFr-7#pCq2S$xwlEI;9w>xrf2A+& zKVo&V)xli--hadDR6t!Bg>}~TDzU-8V0ChXU%C2kSe;_tm$Tv*uKp`lC%b{d_B&Vq ziPcFSsT(du{KnP4VRh2w7=w)L<=8*Cy6rMa_DZ7T(b_+_I`O}Ab)xlu|goszqva8-}&#K^>2b-^WQ(XI>9gb@1ON=ykGL)-|OG_ zzvaKb)xYt6&3}KZf8+m_|NakP9q%7t9rw5VHvragf6sq=v*7-o|9)bI^LzfA8}E<& z_t^%{Z~5}rvnxZmvc22Pem`y&6Z^+H_{TUH^VfAS zRlR5zpQrdi3MOtq)*Tl4eqXJF zR8o@eta~CI{ekueX|%5lK=XEG! z5$-kl5pTkg|6P_<&u^~4{{se-c4lfWZ+0fm!XQ&AcFBIL|kVTq! z(ttmF?*mr-Gvjcw(a!@}#eorcFBve>+Es_|0!K&v_;WRh-jP+5vYv4oC@}&@N3+T@ zi13WmFjBvi4BsxmapLlN&Y_LfU-2ETV|M-m*YPLk$&bNpN-&zYxjM*vl@MfmGVb?C zj8ki8nN`C}I6%}TY0H8{W*T1L)JVpmcNSa;-4o-npG%uQtjer`y~1%^$W1&Y8z7N7 zswuSF%hli8sjQnn<#BCbm_>5PtS3cT#s-=$3D*QjC9s`~tmrwPtv>)a4nW9wJyX)= zup}~@sQ(}K-YTjMe&5!{-8B%j6xSLKrKN&H4TnOZcyV`k*Whl!H3<^5xVyW%JM`r3 zyKL{Z);{~{8{>Rq-Q_kH8TpU-d**x$2m&RGQ;1^#LM-AHRv)Q~zpa!*f3`Y33H4%H zrMk`Q;NQRxw<%7hB{;q6lpv^#43JxUM5Q$E^QIhB zr@rJ)R(jSvno}E6s3^Acx!zzLSAwmkXdPu&72-SrF04V(4>N^qIZr|kYK-0Gk+g9& zxZ8Q`%Okjx^nt9&(+1MDmB{$}BT~21;}jv)`6XLs0`HTeted>=)Z|6AS~yeW}92@yG`sMtl_S0RVI_l^V#BN z{KM;Wuqm6}uGMzkM2kXnsmnZA|I%()mhYn4r+v>b0X#cdX-``1c0l9hGMCC4>`*;6 zfAHZ@MCA$aZW^Gc$@I`&B&(hxr%Y)4vX%~$ccJ)aP_atMHRl=xnFG;kS6e9Cd%Ccp$C04 zp}Tyv}vr;4J^;#2)?k+o{V3k=nug_~0hCHA048*ORXv zPRi`=#f_V8mOa8R>K7l*%I-5xc@O7LUp{#NUU~$0d7ub7BCC1^*n46D0BBIpf_y{= zml7Wc@Vdb8B7;(rI(mJI@}ib>p7_2ABz*^T^c9Hmm60NqoAP~z?5i~8 zh&ZzAe(+PiB3A^0{J1OpG`mQ&ul(4F{l9hDsjB*`zVbIxB{2a){T~hcn*&HJ0RePR z0&J&lfUD)=STqkM`y4=5Nid>LE| zBCP?1oRx<(00`k%AxA`^ty2W;K-JJa%g`48e_$z|pQ%d0*s`y(#Q21DR_#OZsQ5_y4ia43VLsWo#+_N)*ZwFi%y@4=3x)!15(B03dTGS3g&l;xfzNOX2%tsjww5e zk=Vw04HSwkQi^?7h=cea$S;qTGsaQ4j?E^D`^b(X1XPP_bBt5nCQygPbx*}$8U$0G;B;LGdO*VP@0uvsdw!+hNv13CpSp_R`^wz{G@2Xo72Dxcha&J^>fAQ^!gZ8 z%NQNe9R4(vcEFB)M3t_{mwvX5b^&xsmyJxnDMY)QPJe%tj%0#XEtHYQkbxnLh7L^1 zNcPOYyFn$m$q*;XB&kCsQ_mE(%mg^2QUmKU`G+#;gi#r&v!3#0J-$Iw~Bk?M(l;wA-gCWSuWf@tAFhq^)w=faNig0%OASoQQ60EV2b@n0 zE?3X1q{gdOFDi2`g3c6zZgK_Ag23>kyk>Q9iA-d>dNJ$j`QL#r{M(<8h zKTdKtLiIz9IKJZ!ayU;q7cH73&%Wp>X{0VeRVsxNm9ET`^f?!|E|m^}OFM*1ijxZO zg-XkNN^2#`&P~c5WN?vyi1W?D+YE0V!fU&!2xLt|P_8d63$TZz$=E9{ah5K_fc zUnNnWGq6+j5?GBG4R2M?(N-@ZmodeEHe0QTTf;hA?Mzf7V_Ne%xQ0!7ruuF6 zSKzmIrqu+})rW!z%7s{$v(~7-%A~i3ALXhtj=Ji`u+}iS^1iMz2@_%!0--8`&>BM= zI4h1#s-ij}Zn98(Qz#`n)MvM17YCXh0F?~Fh1uS^#NK*X?-nG4#DsH3B)b%(auP~| zRY3-IkErSbIbaVNIZ_6uh%&LeY-RB*Vd-t%Rc$mFw;&sr5E4>n`bI>r{69Kp@+-zDe-Ht@$cW~(>c9@T${Z=k@2S8%{@ zB;cRz8($s;z`7lmMkHWP8ViZXXMd&@MAyR@LD)RFeV$ju0IXWfQ@tyvsr{+k(~-Z_ z@dc+dkD*gkwlllAO_46N(?6s0Ltkg`Ql|`FS0Y1~FAa{rYgc+?SMVHmDBWCF@=;f$ zX4mV!1`pA0Z-#D=8Fo@bckEDi8W(m315HmPUr){+R-S85Xk-tVrYlagvE;71thh0O zrq{8#hpu+6vx)}plhPaD*$WrNYNqM);_GX_!|ZVFbBpW)_B3Gj&Gk7Q^$odVj)?Z# zDfLf?Vos&>Tb1|I%-vxu-1Yw;8dzz-u*b_?PwC#A%h@g-*yXC;$EF!XbcfG02QP44 z2d`2F5&hx2xxq&db6AIc1HbSvk^6H{<%iH4J2COm6*PNsxzX|GdkMKQh-t-!&6I{I zN-zMa!$#%9G(Rv96=Hp&kw@~w^?ki`%=ja$sc@3i5jBYso_%ya+EGou(dQ-T{BEPF zk)w}L+5{&h-s&BqQ!o*!RDC z#9}i!*jm4f#Hll2+{T&W`I*xDnR41$>`JlOKdZ!e(A3#KyTtghe@);1d8-c*y&+mO zh{6n_QS)cJ1~Jr!Xw)EDG>C!?qFsaN)gbydh=L8GQG*!pLo{a)Z5c#o2GOGVvp<9A z)gYQS|CrzX`T0LveQck=n*Z7A`~S;U-+!ci|B?3nKO^lU_*-iCJG3J}6k`6N4?}<` z#QZ}a1|Qn@4}BQ?Jib3tyT62X_c7F}+@cu~c{unEc)xC@JA0$n>91TJ!u-RS-9>%uopMdlxS6Szi`o=>hrLwtI+uaPZuXqIJ)~hh z%7PJhZfVB|8q8g4W#y~c%HaUxH!1y^&W@`^EiF)1TXGOy%BiwrBJghR-qgOScW!#5 zwrXwvgX84R{8_D0WA-^mu}g&M!fCAooWt8f{V?P!*|tE<$-af{uB^rb4)o10@**!n z(JZtk{$@J7_$Z%E;RUQF$|8V0*pE47nSj-qg?^pIfnq z8Q*In_LnDYt%dM+Ze0=c%M>*>(qSeyDNKm@W#w~QHT>N-z0~#JOYKrE=C#y!XT-`L z&Rx8@_b!lx+DLyAoL7`3?cH&S6!nMKR_Dv_o8W5P&P)x|)d(9ZdI z!)M8Jt4}UJf9HFRWGnf8{$m_6GL8CI${ocvFG~1#o!Y*!N z-ssJ)$ie~6KLX9q(TSGUNdr75en_672uf_?4LMG9(gCx|<~K;+Bh4C0+L8)NtiBRX zzct;__C|Z7vqC&1Vs@pC62#lE>TQ9)EPHer#;Se%G2)N_ugL;k)cr*WFEYw=cc*dv zG&G&Bj!hyW9!Y~S4Ev;mBHQovZit)RmdU@08LPiU8ROMFul$9#|27h-h*xtDM%aJw zRst!EUhFEB3XVnHbl5D;wD_!N(zBV;{Cd-d<%XKAxt&tye)qk5)8~L!9jpgO2Zk88 z)Sk^3{T}X)rU0zW#mmuYMXJjtdaT;T%Nee1daB!Jc35Q)i{?&`87qgv_ntG$4(tJ* zGq0SGJRR{IY5Opd<|w@;mp#Wh486uP-AMDjCWx&`1)S*wX?Gkvnb>HLF_9`9yoaMa zkUU+ux@eHGkWgPB#4);_xDKxTP#PSlC3?BfpqH1#7Ucuvv)qzPzksjX2Va?=XsZy* z^?YBa7GL@fG*VSBiI0JlTM(3If*xv8lp0VJ#O~DXrJohCKL?-B3nPESD>tMIG#*ud zGgWi*0-fba)MsoCHdFq{mfB<;wo*#2EaK1rb1#2QL9c=Lfwoft?pOK_Td1QD3!;KR z6Wu^AJ})GRAT!+{&LiKr{Gd41078CrW~gNbh$01w!p`O`lrKe52u0z+@+r9@F9Z2g zL%gd16g34Xf>4V}P>4oF2>Xx^3RY+x8$!8<9(Co}Ar;b9ps(T?DBQnA-dcb{0Pq_H zk&hRk&~#Z&140M40zI$L7_j^n3&NJAzN;t&ld%PCVv)}lps;WG9RkRY3s5SAR3koH zMYK7F+z682flwN;kc6)y9U zO$mac-d9AiTKX{*Mr{Xvk0?MpRgIQnjArHY$J-3&+9v(%X!JJ2|Ct)ui^OPIq5vTw zG7+J0<;Z}?FP*5p9bMk4Q8!{bN~=*52nNYxQ&R-_eZZ!sP4xP5O*Id-pi`rwe%2MN z)lDT%6`$%Aqd!fh=M^?o~q{s)X|1dxze5vnAe5L*$K7?_Y?73b$!R?bE0R08yoQS0uJe z1JY@oQhZ)T_fP>ir&BzUWA?F9J-m{~x+#sFQk`4kAJ4!jBes*~*(qZZeU}R<71h$5 zb)&q4((J90kUZn|q$x*WX>`XzlwZ}-?W0n!uqh{C>6V%457QK@ut;PM%HFPIAUc?$ zFEPXXNeW&ah4XcW`EmkD4~3UnrrB`{4Iqi429{}10isKysHe&@%#4ha&eB`XV56q! zpoL}WN2TxxQ{Yl(>ttp=zabw9vJ=3exQAtHAZO5s2b0fHi(Rxt~+6@^Gd)H95mf?&p?VogFeU2H`yUV-S`}p3D=k=E6M1 ziqNC&JlQ8n);DB)>iKdlNpy}gWc+pcQc)hlJNeQd^1Ow~P}yVrZb+LF3#6>l=t51% zlx_-MJ;{yUAthEXlw3};3N92Q1_kyM3J@m)GpNbF34?`R734XSG9`gUb+d|RNMmon z&xf;UDw9YZ)Qbd&3!%cKZgoY^qO#L>ia5Q%turJNK=oqoCpld>r0?pApG4&haFEJV zm+*N7jp2}1&^ni}=;qGMkf_a+aDD(U$&h>%E@fvd-Jm8xt@8(N%aCBsl(IcZI8rB> z6f9$Tm46{a0!b=k$jk@caga3Ll+k#VqPP&RtC!POfU$asOY6#Mqe^H9M2M@YE2t{U z$V`droGVC)%c(hun`bI8imdi%aVm+I3mu&x zOC$tLu0gTN)D$6Fys1HYRsPL|5L2V}zPZw9vDG+k4Q(aE8wy=|CjhK*?8kE%D$4%JVE;MW1`8z+`( z7DDiKf$;7Shcy@cP8WFhQuR)6gBfrJJ}yyzw2KE3X&N4Cyy(RfO>P?GgWuKTy{5fw z>WlPd;%x5BXu!b3D@tx|Z-(Qg;0@h0cQZ86k#OOSX|%xkV1NePnfjKNrL>ma7B~?| zk+T(QSxdvRhZ~UGS`!K9;=&EPZLKdac#0SVgcQDT#XScmx79U6MEh`sZrYyF*t{;r z_0(u*<*Jkl#SNgXYp+FuDTv|{fbu@{;ZSP+EX-(8HN&B8_*tA$2-K#*MW*Q}K7#3+ z;c&WkWO%lkxZ*sS>&Qd;Nn`1Xqb}N+q64!R#Ys%=%&`3F+JJ3v*O?NT{hEorE0NEK zQl~3t$(uZ*E6cMhoC_x+2sx^u1JLXgcZVGX>LN|N!w!4joqC6jFM;Go(}PIs3UE_; zzzsdcb3LVZJ=o>l>o!M?`jGn*G=luKkG2Zl<9hdv2~Dk=ZTL3}CN_4j?xFHe3gAwo(Rm z8V2^}1`e?A29Ev=?ywQb-Jii7_V48G;cgHKn|26AYzR$j2m{+~2rG36r*Q}mdwz)E zeu#*6m;_sFm`n>X{x=N3P96S>1Q7dg7C`L3X#lYsN4Vxkc3)6>c0A#hxD(zaN#L9ea%}HugsA_Xr?<>X=O9m>eR%S*&@tuN{f0Ap+=Iqc z7)~?%?qU~#uoVL(>QLH+y`fvmw>;(h1ocN{*sL0d?WC*6rd9?wD8Y&A(=ji&BY+2nbEh}F=!zWpS(rd;`3&Ao> z&Ui7wYt(gpzQO{rS;vOyX+6t4+f%g-z?=M{J0q{W#Sc`Sxe`I50A{2z;5`P zBJS^|UEqyPa8}%g>rPHyk@W%2`7324$Y$Yy$37?WxrvV0jgS)0%&OLdf6Yr&mI#yW z-E^esPSMp@=03eeNq6F1&BWqIzE(j?n=A!IINn5_v~TFKCodX)kstY0YUtcyc(&5z z4Y^wU_bUEFdUMgSFhmWxC7$athb)m@C3=lYT##g zweV#zI6twaC+r8kAl0rv4%6>mnXC;g;u#_GLBq?SYy!1V$q;xMe_p~0EAZxMJ&x3j z(i_}z

hKn_5XP>V{2-OWe5B!f^WOk zQsZT!x^IVj0gN|PjYiI_XZWXYmn@jgyQI6F{oZS|X5a02p21=xx0>=*z6A~Pr#bZ& zFOt2U$FFgHZcj2SdyT5EpntZkcJLc&=U>|_2*L4apV?z(R}w%+(+l0{HArT3O!YPr zj>}ifpP~TW(IA&-30IH#jGTwNma5?oKUD&ZTU$rb zsn3u`QZVu1mEM4iWYk+-xaVPhox?w1)WGqZrMF-kkI;I0KydZ((}b6QDr67g1!?=N z9zrYZrSnmFNA4$S{kUik3*!0#YVB>{DMSPBbbZV~oof5FMrh*;c}cG8kwdqqTy^&w zFkB(ESaWR>h@TFihuM#KVI@08JXC;KAk+b6HlTQ7fWtMIO*batIU~?zeT3f(Vi#vc zhC_BoCPL157TEP>>tGpy$_l!1?96ihG%jxh{6M&xmz7+OYJo)!rgCI#B_4xMIL9*P zJ!qkSoMsi8P`{4v;}pWLwSE<=P-O5ve1)ke$!IQL_)MR4AF8(@hdr{;F1c;318$9h zfSaL|gpGYM zKefcFF2%K~m6WudSuF%oI04q_uj{e;OOO>B%}B-3$x38Otw2td5?mP9+jFZR<9L(mprzf>RmmALA6=UA4jPt&Y}=xtXL z$}}tc`~{9jQm-UGaK3c9DuVf$$*>SE&iAw1+q$fcN?a$Y(f2(OC04s$Er~}k05G!`;P0#G2EWr8?6mt z11W{4*o2zeT@lBnGTPukDCCj-pxoayQ`z_UP9R_%fOh#BbuVK^inbudpqpVFhF)b zRkguMrY{WmKT<&yMH)2h-uSAPE}xM?=3ste@WX$M?k8aUPXy$?NVw0!ga+s$xIFKI z^^Oo`Em#@%m#FJTOL95svJBcY$6oXO9+)N1gtzAzyr@e#7R7R7%(?=T7rz@EYffiD z*Op_?LA~P0OIGfs&zeQ8p#mSnzwOSc!$>E+C)zwiD-xxx*%1|Ca} zfLYe4B$@A2u3*y*@t0g_I82J`Wi|S+rTK|be_DXz09M8Io zaKfMEAaMT?{R@`~Ya0RP?m@W-s@(Y{ zI8-MFN)2CR;ZYusLF-4c$1L_3`Fj?<$i7%H`gDRX@lI&#u=<$9!vxP;`LV}Dy>{=5 zs9I4d38SQ#L3fUlD9Y37!F#>?9|Cm+U}qQhBs-yHzD*hI%zcdcPohs~0TleL zWD7;gsF!~3DDYL%HT^^O*V2KE0Cqq&bJyWa^uYiXE3#1(XA|urp^kwC4(S7wM>Bg^ z$wS;tjzi@!p?g!2b8T0 zp-J*Q&N0p30cuKLcVGyAa3m}L+F=-cx5Hf-hb!$K=#(*$IDEp|fs~{aIa$U9SX$E{ zOs`P-Q%bd;;pr15mBkv3n|GVKoWk|^l+Feh7BpZ7Umoh-s-k}|!0=T2T8(9D(xa)N z%ru>%^sTC`ZF!9Bk+5C~y=N$Y_Zl^~{85*T4w*9(35vJ=236qh$6b}(LE^gW{L+c1Q5bIVLT4@V=!kMaV~ zDp7L^fo_-0?MoXC>3PV$Lw{g2pD9X&eGXVlnF#vx<-H=>DFK*Ap2m{-LNq#13xE`675qM84q9#p7ty-d`jzHtNoR8Y?KH zq0r3G4@aplDP5Dhx1*g7oP+%?MBB~mHp(~IeW|u=Nt*Tp%QN?CS=B8fhz@-oSTCIp zopZ23{XW{;uT4}ig+N?v%WSxK%*UchBsRU*Sc>xZ^CjN`uX0?AUj-|!W*uv1< z^Ven2Zp98q!=!qAm>%PjG`vR;6*^f#tQJr13&*(xX7_>c!Ail+ysC^5ofKIR|1A2R z*)8t*y}oAYtCsMg>&ifmyIRIxmT5kAW~QKwspAnx^OciL&0ZESti8UE>o^X;3q@#U z`izFI5z7~Pj5hj?zh>MGU8m3|Fz}X+8BL_mz8`?)Mdub$zI;_-r_^G0Tt^m;kv^pt ztf)4rNY_TpM*8G(Mw=J@+O-_^mZ)WM$o;ju^=C9H&^hd?F7>Z zLNkANR=_l3daMM}ga9lHTHs;pQg#d$>?Pe#<&NzBaR)LmGP5uVah}i8=YuJ1Q_TTbQEGAamOBO&4$%X zN;kz@j2V{y9HSVya4=>FCJ@|!1HHB1 zKsbSO>kW-ZH0ttl3bc5Yt~uid?h4po=>s;|lS~b^Pc9bW&pi(t_MW;gl1cBPEp<1E z|8PMm?siJXhoQk*ZlY3RPKl+{8mgjD93k^Y!gkbm$~*7-+}gUOO#5#SwtiXUpOEuJ z@8pfBU7r!k3umrdeSv1^EWKS+S?(`7BbE&l$RtvY|JZ5l$kM^K)@J&?`fAmUIsqxe z=C2zNSnC<^gCe>QoIR#RYPeUP=GnBHhB$ZAmJnQo}eNdez-Ve*p&L(pp67R@M zZcc1Q95Y&2V6#ncdg{{eJvvKuHY0vU=s4Wj2n1`|jjkKc4)<64ExuAgLQx}YzUKGy zlK_H3wyx&J^qx7+EnwI_A*&DzcBE=slvjY9FkI;(Fv2RSElltm>?#GKGI(J@p%aal00T~%Q@B@R zW^D7(B>IET?UKvi1$s}vX?F4EK5KC$TFNs2m#YJe;8%S@&f5kWm5p0Ui~4YeVv20> z#*p!-s%^t9*tdBUduEMMIc+d2l(0U8n93gD3cj4z^_6P1pQoP zSTU!~1U+^3?>10e&3ccPel9l?Gxe8V$a_qy1ppg*mubF$%`7Ji-9HXIj`j<}LR~&= zpGZb8@6^Z(GTFXXUsAl8LN@EAe_bSN9@6l}&OBHt_RERF7Z|*YrNa}hyg;iSqu3N_ zb4)d<__St=Pk%k};<{=(`-PUI^>N9QO%&wZ5uM$xdrNj~XD!BJs{5J)zy0C(64)?L zkRMSJwDGefY*j0Xcl(ezV8m}EZ*XVOX!;x73M^S^YlP!^em8|@{O7f-%us-0IlCy> z23-{VQt_7x2UoRa@cNE=MY64rtjYaw9`z>-xMQSjCFP~Oqx{S{FCNZd9r?1q1&-mw&?v!mh08&z5?o3Ih{`=Kt)exZA?3puexNdWwO zUVagW2*PScO+4T}**J8=7z`lJ@jovKf0`Ke;8SXf4V&e~I`DSjU#ngP>53Qv=fu{X zfNVTB$*+Yl>eZOk z)@R*x%5yj7lic>82Ji9bTLU%lGj3;N5lIBI%{jlGemfXw2RiVnu^BPnDI+ec~MZvNcdqFGGI+r!xRhRIn9i;e53dvJC&m1OIam;xnLj06u2%^Ev5h znlqc=1{ow!`gfF?ck=$`cLE(!IOo`b3m)q_1S(Wkxn=S*h) z=0D2?iCiyb!!jp9qU8ZlInYb{z@s+$H$A2-0Q130ws6H_lAX^oGXQ@^9Psvq|9LZv zS+f=3m}_11y$sYqDW)ALFO`N_*hYe+(KpX^#EU!ppIHudBQV1bCF{Q2u`ku^{8!Bn zHr)?oD+M4Vzpm{C0`81RV0@6lA}=}eb&O3GP*HgIzVUv35a?;?3pN$la2JNNN8c_$ zVG@JZ1g|tN2bn^UK%6pjhP&O}0M4YB@C-P;E3uk01hZF{w-$GRC^#lB!~vQ6-&Qwh zy29<}z+k1BI||+%1X}u&n6QPBxe(&T=IU!PNsw@;EM5f!$KGAg~9i0dH4xTC?NbH39t?NP&J5 zCl`$1h`6+0=UTYi{8}DN-J;`ua*r$*x+drBat$`9gL}JEJy<7(D4*J_Xb2`l`44B8 zGI|CmwgJ2iKM+^5gtl_Pg96k`4pns^p_mB#FrYmB=IP2td2RidfWO1wJs=Hz3mo(J zm^(GVD>VZv`;v$m7_?ADR3gGai9!O#w~cp!dQ_aFihk{SmUyZkJ1{Qy*9u=^4pn06 zqh(J1efOWPlmA&O#1+%eDgSLhQ)(u}zwoj*4&LJb7^oWzJjex5+T7uV__JEN6SJ}v zSAhJVa$c@syPfV}9Oz$SVhww%-}rC(fsjCjaM%kNe&JK>dPQQ!#sRhAz<3GpIV@0x z6Vj-K@*VVpsX?p1!dk#SeoFl+9BN}O?T;07X=nLL;4^RBgY3vqF#vH;@HZP;wSHEn z7?^;s>WyTsPV)pV-4EKE;?;xu^ZnQZdn7@QV!}OY1H40qKJ3o*YlkhEW&=h>#*KnNI+hD} zfkP{>f{B>{A)3RJ(F>WK&iPOlO&JSqE8Vy&^&P}Q$Ce|3p38#;`%@%VsEN@@xiyQt zg#-!pV`{hdv^k*$1Zw!9#^yCBumMdL&YFM(mx7G==Iy#neltf^K%C;|X{R>1cIV3aKyJko=7D$MdVlO!~0xne>k z4j;Q#?NK+rU{nD-x}S(Pjl&I0nhl`gq+qh4T^w%QI$X_!`Dr#-GX49%pH}?`!=Y{%;UfwSwNF7!Y&V?^s)1e6Yk z3_6n}Db9znrtrVPga=iBw?ylhW%;>(D-wNHj1A3E=G)%slqU zEv$?H*Yp}^6CwGH{ERW^V+EtpgIOWTf}d>`sSpi0isC;q zBIvCDqfDy(*#v&QHV9lOM1-!ZhZ+PC&ZV=_ej8zydUfZFj03M_mG%KRjOL@@Tme)A zWRkA%|GpoZM1rm=b=^Q%u8m#FYf+wYM;<+=3D%{8BP;5b5m*pjsv+)|_mP&$O( zA)u5f9Vvm(6r_eCLI@@h-dwo%z31$6?|b)+@!lABym!_g8OloLnrqH)e*OD>-}{tK z3uwtKzPrip30z_3Q11Prtj)}U)1EihL1D=|PjG19AXh$`#|d!09Dp|hM2Z(Z1_>(2 z%t7L*IWx`M)3VDYn?s+vuJOwhSj=8|WBOWsHtC{sPehibzT@Z?!4K%k>RGOiX)V9< z&9^~rWs|kQS;Y?E>k0#~If{Z%-f0R{qyVL8oqV4AR&G618?3=}(m24;fg%qM66OMN zjGvQ;N8bSq9GQg$J;>u?FA8~U`j7ul9Tb>fr>r&Q;|?k`dUpD!E|1x4N8G} z&lT?#orSa7>icR{P*kNpCA|l{Q&O~E?M*J@>md@M{#@5e+*k1LXCy#w&?S0F-`3?| zK3iEDKfYCVhKusEvASU!U_W(~xP;rrT&56of`Bz79@lTUyd z2=raG1OQ?2olqX*%@V%&E759!HNb(FJpS$KVm(N5Qbn>lRMA0|zR;yHOY&PI9VxeW z0T8~X0wK!E>-Mch9!(8vG$IQ$9e_m5CdUy>x8 z@SJgA7(*J_;zrK6OH+TtNx6FpltzQ_WSgO==CAr*=%K$Zc~F7!*S}$XRuF2u$asvEA|Afg--J&4FB(Hcuvp z)wEI@;6nurGvIs#i}J#c0~LHn)=uHqFm0uT(*iv^U3Xsi&lrnJLD*-_6-*AnZMJ~1 zCKx~q&foa;%LEU$i5jvm8ejqI0DuIxaTjpmfGm$mMg^2Q?mTJ-LOF#6Tx(A{5a6}H zYrpZI-iLT=eg}QO{4&_d=Fek(xCveR3Rd?qQHUmhq!9zV-z*2&mllg{u!+XCYHfU7 z_DOarfj_?>_tzH!qEfxr=d{>v)SJ0|;107b@*6BZyU*f| z+s*d4k+tt5`lvU#Y9V#Y{$Ds={1DTtM1;r$*$?sc7z}cMODT5QmBWz4xqUf-@0E=( z7*vNH%`U?*1sAjHZ?$dLGS!&a6A?%IvQu$pf2hB5{u@-<)t99<=?*jP_0LP^rtBPc zx(wZX=3GoY{v7G5Rd;iy@bKNUK&YNSZ_Ufpe5g*siJs%G56l?>S3UWAEBiD=CV%)h zL~j%#irHjK#4bjBoG>3;-QJjZ3%E;i+6{{4Bc!Af(I6nCw@zJu$;%bG1KVHe_;iB~t}()o^jx4Q+R zp@F3~t^mHMPG2I;uJ9-RDzl>Ku$;72cY6{6nEYK~r8Y?rk{@sdq|Z-B6SrHCJRR#V zJtq{Wt)6fmk=I}2J92s20kOq6#ey)bExyxsLu=4^&_XLQS7YbyMYUvNC{jGH_ue3% zpTdz~{259Za#)_k(S!Rw@X0X=D^qDrto-~Y(`-+9*udM~p}*E?xM5dXe3i%?=H1rI;=op1@K0j^%58oKjzYi{;hh=T zBL~`dIno@Y5`@PMVh%jZ_=>^zJxza{TlDO;@I`+6mM2eKbXuMb-?{6GR!anvu^%~# zZ&*zVabL|$#c#(dLbMs46)sZz6%!$Do?JR7*PgVY`B^d&JY{w5FoL5h#rA1kND&hm_E!L8brEu>qU~q01$2(3vxrDZtm+m3vY)8Nm(>-XM-7kvN$AKD z)=olTHmVuk*hEh$#G9vTN4EdCCjBtL5+VPPIdS}b1YiS2Td%pLUtGCg-1p8&Qvsf= zoAcG{1fmz7Yf|i$tLvnATpF423BkC;X0nL_$NjQnT*0IHT&%9ylkN0|AKR``S)Ygp z?%{eIif>=}Snb~O9s-8Nu>w29D?&V`#Za_Q+^V~&)Z3rKcBk->>60rE^^E0$SDYFI zcZGW+#0Ql3%eIV;ANd}D7Qn+Mzn>MPWPBwfyRrQT7VMyBJ z0rrjT_s_$Aw}zy<7LRh+UgGh9Ih!qsEuGz+QP7+ihM5iW9$C%rF$cx0o1hY?O#9{h zMO`ELr4+Y-Y`QP~!0Lc70>1$Yep>W6<&)utvtA^M*?}5awRJYX$5Km?_5mVabX7_S zaeer4A#;L%dy;+TFE2$vR9k@VfZEQ1iK?&clMnnGu$Po}{xH8nI{kAyZF|yXItCpR zg2Xnf4!HzLRppZd$G4*2zb-rQ#tu7vvpUznsDA!bzG`azkq%e#&5!m47dP~BJR^sq z1y4o`Z?4F<@t37@zNOA!DXj}@O&i*m)ZPSR zZGaGg&M7SZ;}d!ggCX{x2$I28W9=c`nO5#VLjDDw(3@fQC*upxO_@uCyb&$&+#y8Y z&T=RK2~Q79W)rY7no2QNe1zTyFD9&!5=%Ncit$5SHg}J>jNQSfU7d#IrE?rs-9&MD2}lyD137Rb5AbRE zIQ;#prW9>`kBr9wS&v>l9Jf%X)01FrV7XRcFxqT6n)4+2RIKJBHiFn%(kZR@N<@r; z^xBJe(7^mh_NmOGuojXn&98v*(LB_l5TxK#Quw&JwCGu{3#vju)Tz%U9#w(D%5>LP z9qmo7BFder14%2e=7#IT;{sq#%nnCj!9ZxGTdk|P4X906%bkD!4i&K4K|RX$S{)~` zI#6z{QBr|OnNhHpMQd534hqS(3wB$jo=<%SVIT+0iaw%~CRrAGc}IwQU=7_U5DRhz z1?zy2U@8OHhEwjup&D)dXv#a+(Xt!m*PSeiKjpNKU5`c>)1dvKchSEf9-KlE9u|YQ zC)HaP)-vN1VZxBKe-%Ob9RYU~T)ggNK6B(ZtUSm^5f=eJRGm4Z1G!WJ+?gF#gG=Y2 zI0)?g8?XpBf`>FpV+Fch`}{}(TX=6xzUUDHbeT-0bgm&*ZaCCCyOicU^icj2R^8o9 z4raa(|Jj1X_G=aPdr8=5=U*reNUrK`&L|*GLtm`u+tdlrSn)DEF&uwnq9id+xFlwKPHFl2xxwc(EQ@tz8LfOex8MX4Du!f*nfNB32Avi zc@U{Fknm*eI70ht^_U!NhX3P<{Lf@M-{0+&8_v^))Fbty>rrrjyHeb+!;u$=7PZIN zFJ_}(CW23=*uWr6ZCHG4>D(-YuT2J^DzHl+PYR`-=t@Xn@~dctAh9*S)`zgf@pJ^WccBpBrPRboAnVL2eQfb0^5@hd>{FNhP+6a<0s6UxD>yRNg&)e+*q zXd9xPbjGAp2OBp-DW?!YHhsow9s=vthe7*IaHXwm6sPi{{4T51cuhKY z)%v1~Dd8)g%6k!85o#$f=Lyf2djri#feqUsat2~Y~TCd5)8 zHoU_k?f9V0aAO&8yMVSzn~fnIFODK3{~G{YEBi&1)|Hy<$mFWqR|1?`*((d)d|l9? z3v5mA=b2g3*%Gy$j^5ek@{83Ka@ulGraKaG21Oh7EbaP?pL^x@yR9W_{A%_r( z5!w^j<}1*iKXLp;i(s5TY4L=>5neo#7|UwTA|yF|eGgW3zmGAT%e9{fK4jxqp*{Ha zBt$)r>e-{NcN+;pqoiwy#M=~l0h_H8GPX8BwHBz18sK%qG)rf7Lg5=SX;x?}MoeK&N>mR)u z0(}8wPlXque;ZFNz)V8&6@S)=@ ziXJ~Ke!h--*!*x_ev2)(dI~|(Kf2PPoSLa%iAen*MnC}LnsFZab;D|mTN(cF%M)X; zT&4p+Qml<&9E5*c`0(GWlQ;aC9MV#N;Q`dQyUBkVASFB+?_q4P=HDN_BBt-ICRd?( z6_`);NgVPN1GQXGIleS_1=!topm2hH_)D9E!LG%1-=Kft11&A=9`KOBx}v~}{&oH7 znK{c#xh%E+H;CY16QGh57LXQp#m9`?iZn5*KbNOxpr647;kanx0WmPt|6>AT?L zCc-L&R1ECnDPWBSQ2z^{zQ1rV*R*M8>m~vjp){#ViXWV_601Z{=Xy}Y}Xo+hWlYKZ=1bCr+Kv+J>3Y0yw;Dc z8J^a`6}p^LCrAs-v;IfWLF8#Kk*Lm-1Fi8{L;{0{?q*3pBnV>UpPGey7#~H)_;N63?|%M zX7pDoO_au!0ASfdY=IVdCL}=0;$PSQ+;jPFhc{3@{<{93@AdNU<^A{a{zYB?H{uEZ zW3u^I2UO_LUvY;3%!t>@Z+Xgwe`6RM4ZF$O08;=cEkbEk-RpVEznXqPjO5q#zmp;U zU%0RTU!A<6-F%u0`O7-mD(V`fT!in(A{flW)qL4u;k>I~XI6i?#1Y7ia7A8)-L2x^ zdZ*USi;12C&rs~X0&D=E?&!b9oF7*1Yw(7?45#FQAu0r!uHM-gpDAY)q@xK!3gRZ6 zu$Z8y2k^avV(nm@{s;&qgo8kW+r3~3Wy7a~jf&)nefj^el6Ela2XH`dDSXrkC)Im8 zyPdY|UuoD1)8qVmSlero+MV+bv}&RC>rVDr_9IM7W0%va{2cHIa%LteRos zlByzwtvu3PNE-&YfAf#9tg2)h|EgMCH^_&zzQl8UGZFYAtAqCv6^pLiB7sfNJ6QL- zggMNHW73BP3Zqy^|JvvskNKP;hwxD)^X)3id1oogAA6*b70<~to3KufNxrnEnWp1m z2OB|Ps4J_kTtVJpq97bNcOAb=5vRLxo~>@{L};@&CvfM@k7je<`r_l*X!E4#&x(RWmC#R+#jsNG%e^o>20dz))%e#ZEN3&sDcb(P}xEr713 z@pQXDMHDbZAes@CkFnJ=ymTw8*;zTnE%vK8qbl5^<;y$>%IsKd^_em(pnX`&Qy@K; zi=w~TjZ66+TG94s2EDkQ9%xj#2xv-%N>qzIqB5am_;f$6N&B}JV{U2H{u{h`G@WlB z^L|c1mN<&M9pVvb`C3MU%bu9qG%6{)9E1Gz>j9KAh0%knjI?O^7`D%+2x3f8{M@Wr@wAG!f& z5RN{5De*oO+KL~O?5!1xLw`E60qYb{xJ%D|FAr?RW+>QMgK0%#0adA~$cUCp>;mO_ z!G)(>=z*^Y{j2DoHCqSL1w>+eQ;pDI@dKzA7!Ts-%p^eud3^Vi;YAnG#6z2LU<}EJ ziGBHdgU)2NT*$*;WZ{$Qpo9x37XKv4Fzu9d^>JMG6B^Xoznx(dr70p z&$1wZprGy3)H3FD8kG$8{P=g0^*k^HXcGhyjmoC3blOz!Mv8^k$p?szLW{sbbD;cM#Sze+>LuxZ}cfN+%$Pz?UiDMEd>vnwZ`Y z-%`acy59Ea)0v&1sx8SpOQb$2FU~P-Y_g^5xZ`qhRGIt9ma3DERnGiVvK{5ajdgJ| zDeXjsXxwDS26HyDsFe`yv(iT-MC<71+$S>QJn&1B964H+Nh@MvXmhp3V8~!Deo%}q zUc6J-AlN~+PEbkiHD;$+F}6>z1P8Orge9tBiRCK0@Uy>6(GPzMs;A7aO!=x`d;cRO zOC8&NgBV=So~XVOD%`?z+W#y2xznC_fYj5ZS>LS|I*P@(w?`kHjYBXWxwBd|R#euj zw#?T`Rb@3;7TscmxPfh%9JyxZWDB_xUY$VPNS-&9DT60fZ1AjKcpHBqYFBYSVPZgE zvI8&nbO>FGU!HkON7-F^q93z+nxbxJQGgqe&cjPd(c8Q zqLIeP<-z|m9#<}q~4!D4?puE zv#XqzC00POF{C_9Kj?gW;^o2;vbfp`A6_P&!v0$D_^nJ7nEAq)AY>L}d2R z)bgg<3~9}f2}?71x=r_qiPa?w6>DzRq!i#!TGP&O*{!>Sq>QuIzcTT#yiZm)+)M)1 zCZMBA4;O#O^)u&FF@kts(|YtN|Ja~x{R>Al+C!T^-N7QQmtFJ{H+`dJqV78y3=h9; zbGLk*+Q_G zRi;R%rNQD=>15Cc1yAbrY;-yBULx9me7EY?q+Yv4^Ak=C6!LY2l`M$=35nOGvYpc< zcI${^C`={LH7RiD77f$%JlCdh-iX1VtsPT*9*aV#Rj5qSqydfKKg33}IdjWCmnb0U z2xk@Rbt?Sj%su?qcMC{y4kOu`J^ku$dvFb*4l*BfXY?k{s)o*_zHxBo>9JidU;G(i z6-w>f^wq+nf^n+JNM3xkeX=q#Sd_oYo^8xNYsA@Um-D+s?^>b@e@vA}t`CaOAx5Hi zti~$;?71QKF@|E+n-31$11w=F?2@(RFWH1 zg$2@z>zlTzLzE=tW#FN6SeC5zV^~U_FEP2;`EG?Zm=rcHA@ol(I?FLZd+0>K!A?~W zgRU+6uDzqSm&tTQ4ljIX;YpSyGWcuMoz`Bg-x7^VJM*f;nRRlDg?Yu$Wnw6bwaQ~> zOza7h{wg0%I7Q}@Na3(5iwCbw*JF(ekBiPM_?oTE3GQQ|QD?9Q+#fGs``QUu1^kjm zjG0MBw}j({T;-@1UmJC4b0vY-%dDHMASswWc{xR)qlO@QD_nXLyVY77^=g3=IYJA# z_D}k+iR0h9YXUlj$0hE3w67d7H_1pFF<=Vzx_ZCb%&}pv?}^c@mrl-JY$&qAtHzf*s9I>E z!B)m#iw`3yl@%K%z}-< zIdR;(sEwU8X#!=uhG6>{w$WE}8nYh=5KE(uMA6Qrfeeuz@&!OEyZVyfBm@gb9(;e- zeNTJYmiYLU=FqXjmBb{=WSH1gaMC&6KNyk6bh0XIn z99>6IHw3w9#8dhoTN_@n5eG=To7rfE9zQ(2w zq*NtCXHD|1_>%WAYusI@F@wra-zyE)_?FV#x`!7Bl%(hF4U6xQ6wWNA6r^m3*&|)Y z?8g!8#ftge9jhwV$lr}|PveZH0M5Wm7YaHbcYO>7`dWsjeYtnG~v$OI0vNFv52lHNaTeq~n6!c*9(+0-<2HPw35s`x{0}n_s zBj!W$yn3uAPK~bU(G>Exl(m%a4_9}q^36(7Wf<`y2N^6skTxswpI+*=6MOUhY74>a zlhXJSJU(RNK#Uc5j@{TS?}N623!9pv{B_1gvT7lh^@3FNF~h>Gx`XcXa)r zjy{l$s5j6fc3ZDRm+KAdBPL&1-ai&@rVCOipoWD-ape!X-_qLZ9WHlmIm2)-WR0Dh zyWzeyQ(*4#ROQsprR%XQg&}L9YpmP{Z#aImUz6I(C)3T+L{Af}9UlxUs)l|U{yvop zYFH7-{({2jmn%xcid23RQ^V+32Hk{3beE8aLlXf3^gQ;sL4$lQ{XGSsn9wt+mlI0ZpRr-PgXNg$P(i;3{?is9{J?O+iimh~RP^Brr zf7WMb&v~?$v6NuYVUOMR9!12ksGnqb$ZG4^{jEaVx3{D>YGBx_JI|m8p&(@AwZ~uI zboU3F4Ei)W?FvGCmmxq91Q11k>x>i3FSAg05_ZVirj}j0+Z8-CC_Mg0LH_D~?u8?b zC85gWK|oFqb(#|VHRt{k&C_7Z+^aW}^S$|#%IMb{;65%<*r2v08jR|vl@rz(AO3~c?^k=DDs+_a2SqJo3CZ%N12MQ09ryO@p9u1oGeakm^i5fB{NY*b_^8uk># zDAS*nJ-HoYF$|@6q+@ID-HFT2h$*|FHAYERj@c4C@m*PdY}_2KNzqz(G~D2IJsM|b zFjyW3Et9aymvl0 z-GGIZ@y3M|!?DW%UUW_)&^C!Crep9}chT_BvU`U8N1mu0{@*L^cG&zDDe=gAmFtqE z7R!9&iULL_Cg8CVGKt27TOQ-GTeB!GI;hg;om<$;UuHC^2tb?%DC4-EZI^5Cff~x< zD=O^`0|?|!{7%9k?*_umqdWpZ!wRuTG`@vJz&1~Bl+xT9^xR8)I&`uSgh2wF+0?}Vey$VUqdM2=K<={pFdxCUBYui(^~XRr8r@@=2DlK zZQf&(Dcj4*q*rsCq|^+rhOIifM{_(U`l|Z5le=h5B^;(J4ptOI%4$2ptF3p7moLVN zgj*qNp^3?m@1$;suM+dgaa5pw-Oyz?6AIH9cv&)H7cOC9Gej3!{`qO*rX*{PVZp;V zR$CUN{pTwcYDU@2mdR{{8rTyPM+5ggM{l!<0YYb^e^mXcx2QpB^BJ?^d~?8xo`USA z^z1V0k_RF)`KFxYjyL;YZQ9&AJ{i|% z0N1oRdl9EUaJ+Bskt__Sy8VQ?dr3jk$%r`OALWwKo4u><4`Pi0e&LDu;P=^vT5@a* zoELt(95x?pz7~@(;hrKjY?oa2YOqXLc527&KZc-6|2YJye1C{^0g%Wg9xE{A4N^rK z+v)6L(N&{iyp`~I0ToOwO}gBHVw3q7Z{_sU6%%~fzB z@{1RGug%d!$kFb$+>4klLhkR!$7w6yN8Pwd&&{bN9^{fO6R0d8Q9lj< zV=B$0#^qTWLg}UpyH4Z5{l1l#bQ|7!UZ8N*!Z<`k@GdX$^~mqb!Iyl}uBN8lz3#DD zk3W&>Gim0qrw9sL6;Szb=QW?+QQMWuVY7LRTa!moZAZ7El>4iKo{P50@aLwD?eZ+U zE`=+7*Z>80X}?qOe0f}PQv7Qi`+)?geT#|XhYOmVbZ0d5aE@cNx;G8_*A64e-FDK9 z8Z0JtrrP$yACOME?c#4-1pUtE4%lpUn~uJ@ULXuH6Vt#@8dt6G*@dydYq~two^KJmRCH%~V$$60Py+&=V=Z@lvmKH;WJJxC)m0&OE=yGw&fhBib%_Y?(((#3^eRjW6 zigoNPpMZy--JFEq)_;z`Kh?iJkiXmb^eL(Kk#l98D>dy=v}QECHXHxqXK0okLD}x| za}HfEY_eH$>ZtGs9t5Y{>tqH;7LLXQ27boPk3mn!!&ZyNZp4AhxM4;`@@=aC>yK`P zZ+Jz#$gMl@lX-hO@Rx$t7v8deTdH%O7~~CVUox|eBAKC2Ys$M9jj1x%HK^vnDN!s0 z^5&Y(&uWHMz6nvQ@nXlvGigrOt6c;aH_YA?^hzL2{l|->BgTY-SbF3htT6PLHk9K{ zilT%}^S5C5bt5-fk~yXl|DuJV{K8y(Oys?e5zcc4+#iG(f~9G62fff{YukOAZ-cBj z<@AymiZxi~Zs{__huzSN$>{B6u4jAO!&bK$^w`IF$(X8>yt$}Ri;=D{!x52@j+^VH z=3E{{NNl~rE=F%^jFuFKcFU)D>hivmZljo{N=Kpy6@{)yT^yjU6Tr+T=lY6oP0deuDtYJT`r zrT^6T!B=Dc_J)&flmUSCRxBx4yys+DoO zpeTu0Bq^Dj_>Y5rJEu}h7VI^4_gf$Iv2NOR$I5W}Z-eT%i}9Y_ za%Z<^g}ad{;nxR$j_0r!ZKcJTJf{vKuvmXgEm*Czs-@N@D#^0e)keG3`k6>}+Q6Ga z33e==L;G^Q*7ZOJtpZy|M2XeyjDAxozEGpv9vszlPX!l-)mBZ%44;$f9p7s5CR4V( zwyyN%ZD|!GTbHA!S`if|8-QAP7l@sa+U)d?m@Yx%JmzzSZcKva0Vey{1_B)NL|nv;Hl&t7<;o`-mSA)%dAsx`jSf9JP5d^k>}J)_QI_t z`R`Pco=?4j@vY%PiwMiSXWz1oo+sPPCyzB699NCMRcduD&N8BTDd9lx=ZG@hYeW4h zlP0IS$J=X7zqsw1p$RSUVx@fv+oE$^=5vq0Ji1^Dt$$o!D@&^73+?QupzZ{%JWY0R zFds!bCXa0=p72Ersh~`rV@`mz$OCn=<@b-9D&MdY?CgJILw|NNz}san{=;*ZJ<0AK zNvw~MB#Y&Uet@_^u*kuJS_!^LT$aoyff@Nq)5K|V;LaV_-OJ>a1dYCZaSB_Bd#+pm zI@$7xyx6dTleE+O8T(nh7{2B6)a6>gGo8KBwKR=j1!BPpjM^o^pK3kpH@&Cc_d?{` zL2t+9M|o?nv(xQrh%0M_TOD)mYjN?h3VXrH`Hl9PmFJN54FZS1sn3SO2;=xWAN?ig zq=3U-%Ct2Q?-CaEKEk|oAV8k*(J5|y)pAD(2nhY?ab**PyJ)keBy zVN5($KR23Z?u#1B#j|V?^WUE9C>_+RKZhHv7IamAmd_h!Bil&oj@I*;h{em<{I+sd z>TRRlz&L37D42CgXaJNV-xc?ZU)`;9Bc(DJ`qNl|toG(YWmq1sFT>3++8e}Uy2>9f zr%66-Z_?@GqB~ZUb^Ulazd3KhX05gdCEAQe^yL;oLVZ^T!#Ep)Vd3bP=uF2V_3~e8 zx;2p04-99#FZOgwc8_d!)WYGSy?GjKl8r1hX73gxFPXxVmPDC)$My!w|q2?>1&lv@|5dp z-J@vUxFs*U8$zpuYWQBM*;j1VSoOMSwX$z2hv-2-L3e>=n2>51h2a&xzSIJnK; zYSvWq4zgH>;p|PglZpLmv3b0$7mE9{R^FNtNK+UA4z}4S2R48(u03OFAG)oH{EP6& zA=XY@<04yE`L9Qr2JqmE99h-2x!I#vxf5!GPX zoPcI^#+vOyZm7_3xn{gUAE**cvKi5^jHU&lQLZXlzjp#i$Ho)_jeLCZxw~Nsf86p* zewK0Wco=4&%Go2e%~_q{_0rjMn&5C%CjKTJc|YsNJBbPY-Z_`l?8M``Qle(gy51|E zckjHBaM&2p;hzxl<;T4 z?kC+By(ORgKJDCOXrDGOQ!SeC*r#5f(I%`!$JEW#XYgiz^2oFl#^AUgB5}g`nx^QU zN_~#QSS-HL{(NioB#s)L%#1#FP%E=r=SwMiExSq+>A8!E+O4BLS=h4QdCW4o3h=9C z0N>)h$f{U`pVCnCKoxB17x2xmI|S%Z~^#1lY0AGiwBn;Dc3e(TF4+)K{ty z(56vcI{1J=dGTxq9~{?WzDycsu#A^+M7uz`7FoY50B2iO!$Cy(II;73#@0LIE8u7} zmVP$;o)t=BOze%1^F*wbelkYdI(0KII&X|;QKm<8R14m5eXAx_bXFdo+;4G(fH2coAAVwxx7cwPv)jmhbA8(e81{c9_@?jO&Y7+`%yH+b}@?9 z=-VwGcF!H+(4;44U#X^uo6*?!&Z@ij=WD%LH^^MEy2qv5|2TLHN#HPums)!Tdx-MB zSZ;RfjYG&KJ;7Jf?Yky2d-vd!ISKOiNbw)_I4-(zTmAM@GSZ&zLcWGkc3re&Br;ke z&tXcWlb(1s+k4H)ZHlzgf;d&AxT=%qh`p#`iFT5u+^AGDT(a%9Ie*wjz|RGBi$gd3 ziyBH;Pfj7Ny2!Om&e108F;@^=UG%Q%JsZZ%$1W)4pY5ROED+DSQAUo>-Sf2ChixCz z>BiaS+U1OYUbt&@XT~(vvM}AQZ(m1{=46V!rLOeui>Cct2kq&V7qpWRthe}vI;&Kk z@ElaOI;)N^s@ay;giA~|?u+RN%?D>>L; zR9lInef(a~f@tlR`(3%`%dw71vAi7&9lbn5#`nZC&9XUk-7S`V zx6*!1TTT&IObHS3?{Zh4)gW;guy_}@yXo$#HYGhuaadZDY?u}Y+p#ZK45Mg?bjRef z(se^hECer|EAMLKcXngM39G{RTfOF(GVea?7eR~mFqE?$4_h@0+-;{js5aKVo3iSR zt`eHD9oY3+=|mZ>I{O!Ts64@XLv~ThM`8gcIM#N# ztdONVHPd%&X^m2-t@x4nj2Bz`*(>Lf?Y`qHp6i1oTuGi`Jaqo?5f9Ew6>Fj$ZQ?30 z*W0xGL*1rIvG>smD{0^GsEIK5NQ~?^Hq+Cyk8%7UTFJ?bN3tFwn9 z6-9lijrLnPfr9Ur8epChajUg8-rsHgyovSxq%|`n)St2u^hbp6&VdePpVqbaCYqYa zwc36*L?5)zr=?`fI*Wd>53$py(F`?%qKqjT@lQE4?P(YKEzA2Oh#r1hO|raG{yRs_ zq?ZV?lT$NIh;^*N-p&TfsBw09Z!@2Iad(k~`JPp~SO~iqMm+U() z3(Wn=exX#-LNPUbh@CQ^IFctvaZg!Vd$}gpL_WJmVc&kVn?+e0Lt&vcnI6mic+u0LAdx>L-Kqp0omv3FtoBGmTC(zsK8 zjoC9g$eqDMd&DI^ZR+qH$~XdhI_ihFV>Ds)&G@dzT$1yo_s@9uoi+KNtxIZ?KMer5 z{W6}NKpv3wTXv>w@VkuZu z`+-lOC?im%^b2NAyvdgjaQ7V^@7aChT(kwV#Gzu?s>4IO`LJ*jxQ58 zz&Y)l7z2(hLUi-UKXH+&RTk>1cCpze=)6P}FmICiN%GifCXr;Qhl)LCms0(I9MzZH z8KLO!5gYe%hI01iXNDN5>$dv-3!kTCeoC@pNi#0%+lti8sR0U+d{Ht^2t#GZ%%F?o zix+*PPS0F&p@`|jAYrWLb7%K7o-rW1*ttP2e<20nK7Lk!yY384& z@Kzy{ZJHYHV!4c)k~3p3cAdR5o0^IIOu!Q3%@i?HbG}N^8bpTK+bg2Ogi3jIz4>rw zwcp-Krmu!<6!kn&xkgZU*KH7yEQbm_X%YR_unOQ@3pUMQ@@)pBQOeB z4MOz#20M_87c*(+m7nm~yfB%Xc7WTshc|zltKAacF832aAkLcPCnM>?f}ci~&Z33j zT$Hz~3SE^2Rjh&rI0A9JT>{&Cu50`Vkar3|VJ`CdA&WL=ZCsBmV5H(_xc8`d7OE&V zidEKqd8Bk^^rC-v@?-_sp=q6^pqMh??@b&~+-ZTPwLy*ccZuC#pLW^V*@v?$`TbTR ztSi-e;WWTqOH6hqb2rs_e?5(HA^t?|d|dMXxjXfBq||DEY`Nr_y`5x?HPOR=F{WV6 zlWZFwE!Wt%zPOTPFi}SS#NOxuhF%g8KH1ravSIeNphbySTQPeB!*t@PvwwxV*G|Hx zvliw?!PJT@KTrhVeF@5mOdl2UUy3Q%aIVRd@g6o2Z15nqj{(ql*qoB#rWmkJp<$px zzGja4oy6k8RpbWOP=C3Z2R6Km|M;@FW&Z%=7ubi5yLP%^Pkw09{yCJk1cS*vT|DvY zf6IqJnHB8Im$rz^HO~PDgQm^Um-)qV1nc6utbN-V_JtnYLv;4OGoR_>ePDF)q z^nZ^(u?Ef;bQ#t~PwlmS0{D86w)^n5Si@=*;{zXDECH=G_J6wnapE5VLP{+D_uo>( zY8PygKKD<|=Qj>5b^P=Jz7uer>P`&z(Dx}AS`X3~V!I)$01A6yuuJ^Ja*#G~aQ{`I z5B`75h;DukZ0&#F`f~;Ty&C^ZXk+AGz1M3c^6s-q`+UzZY~NM2QLcYtc z|NeN#XP94t%C|tOvSBbL(R$$-$A2myfM;dZtSlcIVtapm^FNLSY?&2eXMbJ+hoK_m zzb-@J|4-f8YKH_M_n6u$Wqn9?auYw)EhLpcG3`^@HUk^c773nj_Y8RDceQKRG{KK8 zLclM0zh)0eBKM~|6_5OdSSOHZ;J!zvH09uBAO8wC>@qfu>;WG!e0Oj{Q4g6_|E6&% ze&}SeXZI2$V33DzyS}eCC$}mMnVHp9e+X|Pv=4b;cH+1x!ojp2Ly^1vYm=m}M+>+l zRBoVqfKR^LyV_(60AJkxkoD+vIx7&^ zZhNyPKsJIXFYiAof}p>{C2#6U$&<&!f@#mTN^{;gJ5Q7L8%b?qSXEANwPFT&c-tZ}a& zJ*IRklQ1o}ZZ?aAurY?#T1 z%jo1O*ZeH^nkiGF#`Ni}t_|NK#$OLvxfJ&s?0mg3!G03ZZP)cp{`)%zKIkd38^@<1 zt(eI+0ADotI|!G=XV{vz3GL0#1xYIS6Nf2Ig;8mhD`glU=8G_-Du;r%>&v>SZl0X{ zPuG&!^P-X_)-#gz_eCvj{FQ?)ket_3%Rzow#~JvPG(GHU+8O08-)#s*u}htc&j77} zHeqMO((SswS45`U(I-4x5l$9@DOH>dQ+*|OtDj)PV#5T^M?lM2(RKu=Jh*v3U^kX3 zmF|2`%$M3Pl)A$Zk*NCY3MfKw7d)&o7G_Sx^4}|!-)WoCHSra|us?g2QuXzK#YKi6 zp`)jQTwn|N=gw(@h8oPmHW#B5j~GvV?hVdDooJa?*tM&K6i0qIjSAQ1fJztUL^E}9 z1AE{KgH)v3!DUgfIXU0(crRXP+KHaMMTn=1EN&xl3DZXxojBS`dRF$dcPOaL zN@I#}1|#9a0)w~tXifCGty5zTW?oq@x0FMlV~R8?fwI?U!D;%TqY4_2>DN;k-Vz;dzDQ#FqdonKmQq-st6wN9)Ek`96JGopmucnc2_Pe9e zTg@=DL*Dwpf-*;PYg?Qfu?6;e52WM?2?|``r0e6pT*rIxAoBU)8?r8_Y)awf)t@z_ z1L@(LY-jzq2S~OppM7GpATVQ#Q79btXzdZh@!iQWbmO&i7qIIH*j6X7v5Z!R!tdFU z4?Sa=$5mHPddz&n&4^sl?J_=Ia`Inwj>eX)eb-h(u>^3T|90;14=2C(_bWY&16_gx z3#QhyapzmUtU4zt=hiscFC@TA?1P(?B!=WqYTP6L92oh4srlNaJH#e!D>!MiQlyV% z6a-r6%64IGDn)3@c;)@qki@sH&jM)AMyxpE{0s520=G`QJYzcMo?8Ami+PTd|M$6j zrzWs#{+yujLF5hcf#eppFrzx99!;#75>r~_*nQk?`HPub=Ap!sjn0jPW@qId{$$Z4 ze%LAL!!%4-LeT^$Kh2gM-NK9b`)t9*@Sx+b|)4L5kjjZGkfM z5`;)W4D(UTpX!~v!A$#D5HJh%TIRF8`{7`i#yQL01XjtzDwL2!rJ0`yewEY9T%CPF zuYf&MKmPfHCF%_<^2->Y+=Zo5^PQgsA4*){UxN)A=Dtli3LYNUL-Pzk+;8~6gk6*H z@XuiN+0kWP0%EcXqYDob>A|VwnJMLu3ES*yHU3VNSS0}$7h#3D3R&Drm^YKXyGur> zTj%p>qt_yaoNpyx4{{vVeNDjKM}Grxxs80+ZZBpsk&-zm)FQa#i;Ev&pXc$2Zml#7 zPD5J!69<{6HwyNn@GJEuxRPQqG`%Hs^09G#f}x`t-8$VmhJP7NbvHm7y;0yv^-k@z znZnH48@e$Z8G+rnLRo`B1V71~@#81npirHVtd*accabl+l5a-U`Sn*$0xv#Cg5W$0 zm*-GUsx|19AHKe458O`FAp|tx`N{1|z$OExf&M2ACAKz9#Wvo;ecy|zS z-GoIQOx=$z0sgub7qr-%+)mp%bu5N}Lw3}Eqe@9;QAK8b4n^ECtbpo{0)a!5;ruN9 zv2Z6;36E9npGMnIfyK#0>r?R{_qs!!6aq@?Ir+WuI*Jg`XUPKwDsQ6CM+-R3DbCmd z?!xn&o(#Rw&F%MG&^+3&KhVQlxLpM9XLq})@La{#$ z4bMg(@sGpYmyES*eH7u^xa)If{w$=YF>nxg#jzKhO2U_Y;viycc_THgdgM;%UN5&3 z1%`(bIMS@5;jSTF{B=mT1S2sUhJ7E;sJW0v*KKAr^5=$ewOR^a>(3=LUk~&A3CORL z+I26!e)3g^r2FW@;oeRZvaHICJm{7R@8LJj9}s@fSOw5_i(|;UbB}$IJt(p|(~PfD zMXkXHLAzz0_IUE9RF5jOpU(M#>2B?OtN$}3w~>1z{QHs`JHbn$y9Sufaj z9Q8Cf3&rp`;!A{@-l4A13|)aeIQ}7mxi9p3B%%oZ{4!$qa>tqWYrP4$;jEQgiwaDc z+J?3!@5@A1ccsUD2-!e;sB)BkkV<#VPm}rFN%)wPE!XKf;;9ovJk7{6ILC!vj^S<^BE-_TD@k>h}E~c6W0pt(LMzX+ievdlHI} zr4WO1+YO;XG)RQ3g|S9tNHjwjAw!!f`##KIimYQvjAh2~Tr+jwpU-#se1E_1^Bm7{ zJm25(y#H~SG^rfW*eWv>*87-yL{P%Z101yD+ zPVkBE;-aBU*w4HZt<-lv0pKUPQIA!9bHiipzdJW-mfnT-L&#H`y=pn$GN(va(@ybr z4!^+^*B_`0EN{RlL-ELPoLazmjm1&4FBIkRf~B{g&_6n|*oe7@?njVaj3!&-mJ&I6 zOwdPDKOZj%CPFvgHidA^7Y3YKrE=Y`={_NzWe$`0w?hjs=J8d|cySWrsv%d>!-<{Xb5ew%$>!DRZ< z_+67KA0>z#>t#aJ7ef0*yptclv|vxf1| zK}=FAwBx+ZoS=zIUSA_GffZYy>(yvNTze+JSW_HY1S(nzn1MSxTe#O_Fq57ET^I?o zoHZYXvYW(ZOsPF+1c@-HHQFMBtFwUDdZ2Udcf98^VhAGBMT(exv)mO>*CS9@*Pn?m zXIJwB=22(-kKdDUwU3W52CJHZIhmeQ>R`VX(qWQQ_52jX&C0v1tK6mT)df0IuImb< zRcX-c!AK+6vw{}fAXGC!IN`1!{LP}0MQDQg?pYX4)7sH_ZMev*LAm4GP)$M|UUpQ& z@=e#kFl8evlwLU&n)8=kf#8{Yx?I(a2RIDM?2d0L+X)}U9Dm1e^0pjvBaS`w!$bff zN+r9f+sJa}WUKud<6fS~OXKp(L%HlMi3KKbT#ZV)mK;s^h?076le`%a$kz;U0f)afiI#*Px6BB%SRyYnHd(1}Ayx4+{ zlrJ)=Gjf?LEPX;ihzWBqSe~&YhIsZC+e2xCLX7oJx96nSphp_Fr4Tk}>puqwH z9FAM6^?&r;R9jWIWHS&g;<}M8|KX0Z+x_`6V{!L#{62>L^Yj&NIq>j-JW(@Y07uqb zE|pEbw%aeMvTMK)-)(2-OlpI)sZKkMp>SOzug?M`#A2#CbWw@!S1S{~_8GW$0iim( z+<_|BP&O`XdAh8TZpl$xz2Eg^JT44R2L)YY>_?>RO`3>_y6BZzobG7SqH2fDquoQ$H8Pd2?i6@kc2Sc{`i- z5Clfm)Tp4Cf`dL>DGosUF4vrPW{V!ztRp8%2X&~)_t~eib1Fse%sFUS%LZ`pDX|G! zj<(DvknM7)MU*UV+>?#Mbz55l8nAJV5 zr06?~Doi#Z9PQcaHzP_aniNXObHR z6b*n|-s~D7wo9rq*M~J&l=U$fn0pzIJHD84EMGV$*O}AVi2aDSNGbj^;apd}6^3#a z>JX6Dz|-QXKz0B%@>GM~7>s%u){TdfDHboD1)F+4t4?plZVpjiCIaM2uGqjr@$JOr zObdQo<|Ej=_SnhJrs3jB5xY4u(rb(-S(fss_(x#r`(h`DGz|yY0Y^wyVY!KmC%1#t zuy>yo%KoT+&WG(0C_)}3Nh$ehf9HNjRvxeB*9n;BQ*Uthc|a_toeLjnv}p2}3t$O2 z1x`larO18m_dlMjt^3{tnQ~D=Tm@rstR%UPjURX5{3|b&XilN6>lrn^3#kK6stAO&FbXtl2~rWhTajRyOAzD*~dR^1S3p0iu0<@ z(;QGbK8B?-f%1$|+{bw%oe^D+5r=f&pECN2$mI$4c)aW6vMSr)N7yPqy;iD-qZ0d)!o-*c-?Iqb=D(S1QSNsjsKN0OqWGa5ka{ zqUBr4AAGifYnIGz70hPiYf5n&VY66?n!e&=pURd@;ms;AjGt?4!AaNgOHB# zL|e*wx4nq>^+*Ti@KmyDRiMGS$;a3)IFMSIdy}5_?xZuDRDdp@`&j^DugjSL7%TlG zwF&RfX;l-H{ZA%a(=}?;JUxD=OGj<|*%jxeZ@z00$ITXWx_*){AWPN8|g^tGLlP zvTP!;LEstaSY`=(*Xo|r>HGFPx~$XynGrnukSr0MaNtt4uEd9d+JzIxG87CNl1FiP z%AQY%);_UVrB$jw{q>So>SfMqQzN)*{(}EjyH?YZ>iUNgto_0hgrFHWSpMy2IHUj* z9kid{JJX2u`^|X%o}0wn{VLcaB**jOif#aZC&8lzY7Acj9e~Y9tP|?AU%zBoC{}x5 zoiqOe6$rJPuxAwsfQ`5UB`sccDT*X8#+^;acit8a-vx7X$cm$D{=7t{nqs@cg}%!v zF@=g{Gbx9xC61|7k4_KbQj44It?O1McZFW^Si==4oHA&baxD|$s!sAc%JstjMZtvI zvi3YDp_8r67{J%TQWMO)$}+anOWEIHE>q^aj1IC&RlhVWsxbUijLAHFaXOFn{s_Jgb8j zq&r{%`(L-qXATFQuCc%6by3bE!>H=vfm0u91mSMs-d+!F*{20q^6rCXEow`#eI$R| z6=T9^*GqP`jFu4zH;+Gt{*4^oSF8}-T{c(h;f<*ieH7+QupxH^isi#B!1m>B6s5(4NiBzOQQ#Gu-v)RdVD z1lha@I`{EM>=mn{Q6J|Ph~8Ny3MUthNG`oP!L%^?Td!Gl+>9}p;3ReJBK{jR2QL{G z&1<9SEhOxM=jA2T!bVoZ!DsomJ0n4+WYlYTh{OYOsfJ@J{H9#8#1f5bKp7o=+mbO(kRP|f;J3M4lg5x zTZuJRo~$Z!82HIA_1=}`emPmv&qO|8mcCBIIX={{&v;oYyyzUr1qe-|qO`uw8rLfE zb}U`qwW6bubRlPf+!Y0u*L+iEsWe;8O%LLbXo3ar#W=B+M6@UsOrWxnglOm8b#v{0nBh~O2E0b9dPH!91Mqz_&cI4ZydHHeU9)fU zylw=n?~~piE6ckNzT0ZH0AO1??yt$n*zc|VIs9`7iFW`gcC{MLVRAGQb+4?0_UFTn z##&8@OjSyY!jyU^1Z0@O;Vp95ti%Qw@?3+yA1m)hz=(O<{cjT!MCxY(klV$h5vPFi zo9m5y#Pb83r{kY~O3S(O{%ra?DM0bVIA-^@zlGe-0<3@+UgS!ZJbAnyiacLvDWlso z98^363|;!L_FgCG!!M@)1w}#q!#1C8k41jN zgx=;JF_yJQAg6T|8S#;k>GwQ{DCKdrohx@kdL4+&YO}gNnB$T|Q_aZnTM5X=Qh5*5 zJ#Cn5pGJC|wp#!R-e!#c#pgUHjlLE$vX!Y{{`8i2wj}^09bWlLUGlj4gaGl6lHSYD zVpQA2+uS_JZXM9A`34?1>D0{hmqRpXkIl`QfJJjB4+e_hjI^3o?XbS_Jw*MLJ&K>% zk`bC@Tj1K*2j2w_md`w4>UB-x*>9a!*9igksPF>|w`Zob0nio-8LIdjtW7k}{){mW zWvo3uzy{z@*N>2h&Ixu!^ki1Pgu3@jJ@dBYIBnXvoKOc)x3;H@70gX50rXcjaV7{k zIsL5bH>pI%fSBk$f@eXhc_QxT%ILm1<(_@Dy35-a>|cnp@l%<2dEB|iVbfaY#R+=} z)z@Q&H@99W48mrTC)`8w?=S(>U}89F7*(a}f4O*Z)l=TtE(4T@4)HhEd_Odv_atP3 zKj5fNL#K@TB-d&V-0k5)w@sL^>IIO}9h`LK-nYDWM~zL~$R%3Fh#pc}^$>vo>B6T?eMMkj@Nho9SDD_ zP|*_h=HMOD9|%zXV5q*>HqWRn&d1JU&#+qb^cWWG-+ye%pcHEIwyDKlh^$rB8wqMsT0-Nh+r6b&}s>n01+K#ZgtE9);5KwaC; zC*KgudDl^w{ov=f8PIQfVEd^V3eE!zKiE&yxV{Tovwko#W|37=b!q9p(<2NZ{mWHM zNPpd_85{uW`)PuS((%UusV88faGNxJNM>oa-}S8SY1-4>&f?LK?0*Fp--jkHhY;3Q!r_2Prh}DHS@uP#PLnN$TLVMJJt_@oDXi=QdbKsnuL2vw1>81jC}os zmm{!kACX!34gllx&u|?UI`L)fwh6MHENCN5K^EIOICFa+$ ze5G2f!W4=E8^Ks!=QNe+xD3he5pa?KM&L=D z)bWBED}x)4CMDTYLE__wV+Hh~(^PlZs6nYCZoN>bC+NgUf-C>(>trKJ1@LwpB%0L~ zHx73}`%s*(J7%d=^Mgpdx58fUQcaL3tL{tTKULyg!jrH_Tt|45TIysJuB_QcYJ<*j z;{{=eKG${-q=O?7p^ROf9dfUQ1;~=TH@P(PuAxkxX$8vPmtILDi&buIF$o5Jm0Ow- zU@|yY3c)R%6fejfdv}QQdF<13-Y`)qPEtHh|5Fi<_)KAG?(%iWJK@sY(uB%A!qSd1 znf)$xi!*M}$SggY3HhrM0419<#Mi7 zhDn!nwEKjRxd;>`AR$(Sz*Y!PxJ$)9ji_+ z0*i6zgvD^Js{V;|vRI9q(I{;yH!gTOqXDG_;%&Ezd5)KA>-PK!Ok9rbqdvZ?XUshN zI(~Qy<2|xV7a`do*CoW)Iohyi7QIT-u3{bE!oHjDqQuD zL*7J^pr{h+Rc$QssqHp3iQh6Y9~s)TI0AQL9TdUw(hC)7RgC^JaPH?bnrAYN6yQ+L zY<{V@W`iizn*fbW=*9c9CL?62SNh7+N}De5i9>S=o}OIH8i8BRc@eo>;zNs}$0Qm2 zIHG)+RbIU^AQxWiEjH~Xqt=lY;n(pl-j@m5Vl{oPPZZM`K<-PQc02*Gp=P#L5AoEo zL}!iHd6mbRaoC=K7gL$BIuX&+6{oK7d?+)#O+4J%(19-n@&K(lE9TKAAfBn=1f*leo)$~ zU{!4?@#fP8La|ey|ACZl)1~lvv0YW5#hO*^&?t?Tv8areM|%(HH`KhGVBbk|)iz$M z?C6E*f0Cd=`uv-Z3tOt#IlgMUW%i4RU=ZseT$pEk{7ZjO!$^2f41S3Ww?BWEPQO_W zUD+m;ELXK9wCK4qo>AXwSn8aA0@QtH7bEI$x+Z!--sPB7^o0`}Jt`lLbO~yl98`JJ z8YVypw}?1%As;vxH;-V7`Oo0ETps#^BczGcS0+6f)4@yncUlfyY@z6DVL1OAxNPT+ zGxet7^DO}S2E6EaiwNBU()g`<5nP@{!0#l7QpOAVC-XF;RmUrrLFsBYCaAS4rL4)r?~ev}5I zf9+oA+Za}UGuod_=mztgEZLaBA|wjBZDBaj^cnAcw_s5jnMI~%P z1u*U019IOvaYlyp9n3orhmq*2r%Z@Xz<0jznnmBJM=W*iq&J0GV0auBuIF{Ez~8If z)vH-gD3~T{j;(Dfi|o`mhEkzk?su@;WGsph+S0uOE8Xc2A5&;=%s1Qg=?vf)oLzpJ-`1v83t)hol_V9 z1V}2JYN?)tPoi#-!>4a|moJdToYTb8Pvte;3HQ;)dTl!V811UItcJT==D|wuE^moD z;##6LGP#wT!n43n$dkp3Ji}TNqm|K3sM|u*9qD2%v2x$GZma~)()}h5 zf#-AJd{l_hZ!PST7oLl;h4y0j$1aH;OWV6=!R|Yfr}Lp#GWik!!S+$~8~0^Z`bq&p zNc3$e8S3D<4l|i{72xcKa<9fKmy;VzprPVbYJru$6K*-70iU>5$*Z^c-21X{YMG={mhmBX7@rB?iV^Wv4*!+Rgw8{yi(gT5oflyfdD^ zU4=%-qR8-|(d~hy+B=Q6g}$Z-huJ%B38F=I4B$0)wx|61%|`EOlFHM}PlU3bvb4Funz6YxC!jDfS=9@o||9ye# ztRr9MX(=x91spK%>Z1yVJHys8bm~Ih)@ss4rb>3SLcZ<|J-X9_f8$RH-=H_SQC_Mv zdQ!MS-R83JA4FVPKCMT@j=rd~MMw?r^Si4f%$vOPerL^S2~Q3fvBGXpRhqU!vDdN1 zD{9)g6G0;|{H2}eUDIJRld21 z`9Z&(kFVv{z^I^;t?Pai{V?s(o~2BL7^6i7m+UfaCqg$<)9$%gdIi8<_02aX1f`2Fw|+WC}7>LcHUCuIzZOiVK> zMX6dr-R`wr(k;1?4VB#`#~GE?@EQ;_osM@Z;10QPBqRaedJ+g0Lv(xiV=G9Rx?kLh zuk-up9Rq7FTi*3Lxa_O9pKK2$f(Z8qc5Rdia!uC$agc%LGQ`q`kxvf}`)-AY8P@D> zSVyBmm*;d|legw_!f1}D?Ql?`b}nOSC#GzA8`<>bZ=iVm1ZkVBQyDU8jeSGKhOZ`H zlchqMyeLgu-^@_hB@(@1z%lqosd;C3EG{Gb8~SBZQ$e;T58ps7Jt^FB;n9k7e#mox zXd8p~0gm!dw!%acMS($)GTRwJqqir~J7se-)jHH^>@$L3*ACrx+q0(jC~1XCYM>7K z(u(KYg^QoB*f*N`(Pe4QBy9a+Lw+q)WalmVs-wq|w@-8^*1k6MM>?S)%5SRhIzbc) zX``l0b}QKth5b=Z#!hdFk(Vyb6eiFY8@qU(^Oc$sNZS1c#ZdqLSGwCt&Uw0Dh!KiFEFWJ#c z!e^Iy;&4-@8_A$L8xGMs*5=C5o>C<>}6NOf$ z^$s8^p15WfIc0?{3P9I-d$*-hNHtswCx{r`Hd4zHuMJH?IZt0 z3fL)iq>s#T=WlJy+2w6#O*RFG(Avy=W$zA1mxXSZHL0|29;EPPw(NJyd$IOOEnvA5 zh#072Z zzbh<}rc_O#OSeI`L{h!>Qj<}ZI6_{?%ckDqOn+1^T0NDQfSU-w>M24<#JtXeFzb3Z?wUr&oM#aNtE z1P(qamK$Z_2V|M=o_pgjhnQ8P5D^eC07z-gK zf?sH7)Y@UWl_GjFWcgV#QydulJ3g_lb% zg4Uk6u<15OTR0;u0npO*s`c&G&(}C&PlKRqeJl_PV`W(ZBCY`Xcby6S|0*k-A9aD* z-XP|PJk4Q$cs)!w0?y!Sut-_3>RGBBc>=8i+QuM1kk}&VfdBXNZ9tp&57dIob6;abw1*iL#`Awx()4Cjr(#^G zcPM3!^nnFKP(#{mTT5MCeNPf2;$$I$h>G#I`J|CToWU(CCuWN=jN(e*is;clq`0f# z)H1v(f${O?RR%2rTtgBCOn7qUUj>Mt(LRW2n=jPx7W7y*Q3uX5j!szQ&i+CwyG}4# zwkMI}=Usw--Gy(IQ3Bc%En?gU(9tZ-|9HJag>?7T{!@y~w!Qx~W^!?k8_|w@n(y6j zdH`AtsHgI)BzMH;3(W!T&ChKaCvq%L8=g&+T&0E1&MXEnc&3wET0f^>%+x>QbTd=g z?-+aukm5XlR8kelp#&~A-Szs3?)(UK6M4S&&#-Q9eO3+U?MQhy0XP^)Xfd<%?#!Ab zpxxK>Un}f-2AUi|N#7O@+xVK={k)ku@ZW47FjiHBK7g~$m$2XIyx|d&4@#!6upE-o zV-jjuPQ*J+4;lA~96g%l!5s8KMnxrXRRnM@ZbgGGo9jK>#T?J?`~DiwuTwd`+Ixbt zU6|i!0f(R93(fvp$eoqXN50DCbYPn7qBLNI1I7dt}>WP5s%hS=j zD;;G6!@gxnjQ^eP{Qdc*5PZY6rtf>3Dz4L)e!glW^Q#jFOxsTj9PbHgqKyt%Grw2& z*Z24&P6`k|GeaAeTP%OQ%6=XgML;Dx^YaFXiKeY8G2-%aWrr(q0y=+gnlYPFS2*Gr z!*}cXudiiB$B2XlQ|7?5{4<_j(EE60DQvx=49&Pz^9@GP8E`pGi*~C8UhIF&h^AQn z9_PT*=L-!D4a-F6Jt8l$XD>^s$)03R$&z1FvI6Gud9(6-AtILn3y#qI)iuoN=nK+n z%;!sF+5^TU`B0wk9DU(_3UlnM|9ZKo#le3Y0>)`>W5ebm^KoAPt5IhyrPM?e(wMVh z;qOIIsrT2}8H<^M$|e3SrgBd6&OT&%Crcx)c^fVH!;*j$hN} zvS{Q}m_kBv%JWlci`>%bIcgy_J1@ngwa9f^?w-Q<*#nB{2$q`i0F+Cx>>n)FzuO`-F#5Z=2lz& zrV-^y;P|p$8(67-SsY*&y7Vuu0H@5G!f>|cR+7VJfq~w*PHc9I(oo3{2UHQY%@9G( zeTle(5?~SZ0sh9%v5U|;qT9~yG4)|R`&~9Cd*Us(CmSJs`wINF*&j@$weJ3hp8|d* zS&Kn~br@r^4It8yPwSq#TW~}^Pz4JxTR`A+C z-eyS^JW;{Ha!%)$nFCu)GgFhzVR$%0yOuznCS>TK0h`u|_{-igew@K{dC~aBt`lOt z0C{P`z6vNZO@y-bdJPAwg-JNWuLB0mo9VeSJ2zj+9*TIE%Xu2g5rx=Q;tJg_jo$Z5a$41dpEr?-rTxQ?G#r*F&wB+$a=BQyw9dxsIR%xNqDm>z;* z%^6pwNL#lEJ0QR{-vtkROVH;! zt(Giq7l0ra1GN|<&zFFYnjTqr9E(@fbC(ep7$=)3M73-N z==NMW!{D7WuAe+(jn$uI8Tevf^0iRw_+#MX*=dna3rzT>F8y@HSXgL`5XLDcJ2LMp zgDvroiSYAI^RkQ$OCXWc=GRPQ31oQ6PRWfM4L(4s3CrLVuTks19!8h`?%Mud*ssK zB4HKkFL-ski_iXS&R^}tiMH4Ae@gi7n)B2>7Wcs!@5Y`ujq|g-q8NvRcFGZuj6+$* z)qMBpztoEXZv(QL>WWp3q3XYge!)j<9{xj%F6(D3z_>BXk$tz{g2{0|0jG_#w3r>`@cnD{oCe&Tl|kBum8W>`2U23$}p{8x^wG- z-GDgEQODO-3%^xQ{4XL*@Rk36X#btR{@)+k2Ffxm-$o07*U{QLGZ_qx3s}&bnQx>j z**l$c^Tzq1uMRec8Bc^g7o<4`Ppnwuv0HbL-p-Z^8>)BOtZ!m-Re=s#Jj$Si#( z%{(bfe^^-|9+736rpI|Nn|<)M(nqab0XJRE5zm)0Nq;8fpK6iG;jhTzR*267w2NO& z7<#4sHkbViOGP-7yyn7z(XHop;mg}7kP2^7^#-Me3(MGUbx zYVWCm;4>9up&@yAAUnU`Q`|=T&K~?DdslXb4(-nxyu&PA{9Ef`9OE$JQvTWBR{j1X z1>CvM*&4;%eSz=7Vo=d{m1(X%s~HjrUHrQI9R@U30pmqnU(;s(0LXp#dx9-LgN ztVC1*LRRIyJ;%d$1))TW2!IMHOug!wYuQG@>pdRyqI7c7J9&Q_P!4~3g{meY)Y=u_ zj{gstBLIyymB~CI&QW16--7NsAaJ7{^>w;O>u!w%&T8AJzC#PumVhL@TR0MlFc@Me>5}QKBdG%?i#-w|jhK?SU6@c6JF=Sxo+?MF+z^(ha zYKAe8X;|St@N~a@E>Fu3Uo{@E&R+(k<@vhkXDTVM$$JMVO<<`kl=* zxon}4E3gMu7y6>Ksd0gtd^NNL>eb$HiE?yGJ#4Ypf;J&Ex%lWc;-|ZPWI<`LOT{^F zZ%ZI?csJPl!STb9BU*=;^S(k8+whDkHLVZQbm6V5LyrEUUfNicb^-7ZtPEPBCWc(^aXULnOK*YG5e35b&bFP5&wZCXH=ZzFBC_>~C z`|3i3hm=QP?a9tjCEz{RQxTGlt{KN}aG?de66+H&**W5UDS^rAanpu7Sw?v6F`nxW zSITSkGgi>?>O^n+Q1UhxiR_OJ_D;J&2TZj)Q)8&oJVNZD$y|WeyJ z@R_&(WW<$Rg(TXXC8_{jhiXf>N9?ZINxQJaU)F(#Njnn{S*uU;!47pLuM`FBpOu}p z*Snyxud(>g7l8}@jx~E&mU?9XX=%Q|Innlb5LDjTUl4GF($2q!xaCRL1QE3bmnQ2% z76lyGW|$`!jFt>(K=gb#Z+(u|t8YJqzn0N#nacj=10{(w;$cAblP5laG!>8oN}cR@ z0P@~Jc8OgG+SM}9lc@>5Ikm5KQ@ zG9>ZVN7q9qu}5?fioe@P>E`5-tR7?@l@)su1j()}m-w@fLICuMv`9&^Lf%Iu!{O}K zfnrDH=vZa4Mmc`7-a_kk-PB_Y3^G@CDOVEZrEus7{8ee#GY95IXKss1##5G3TGPvG zgBGdA>pjKwq2kr>?P3}DpO6f7z%U#uOgfJ`U|gRjfiEEwYD$xg);$+tgt5KO!;>3o zlQiGO5ojwcmhNY2M8Elo-ZMp?WA3q`hNU|FTEc-}|IUE@C7s1gG@F6kHdBP+UEw@8 zsFz6$w7Ct5vb(vtU_T__^3=&@Ps#flH;g6<#!q;lXV9Kv%Zrg!6coEw)RP zGgiQu`0C^OM5UUCYg2CKC7J1y?1ST!XAXt?>gq3zSLWC)s);F#4QOg(l#Ps$CD?w0 zYTeQ>VpJT#Xz~oid9DdwtrqTHKUDqQ4iLkz3k>QTxr$$*-pOI%i1-zxS)6#?E1zF23`VKvJ_H6Zr+eT$rv%hY9UPX1BYq`an5fKd%`p|&=yqNc%kziei4UaHw;Fwcy#G|5&Yu9f+SIEh0@jaY zbX3e*%^}#~s?@15NBo|Ly8?dY)Bzp}IpY5fzJGt)x5yKEuC>P0sD$&SNiDg{w?)!` zt}rn}R8<4kV21RyDdj;3((alV!2w~1d;Csa;8Ct~`~1da?vde@V9Bm;X;0W+_ir?n zLgzmNod%?zr)tG~OVnJ2Jg}$FKm5DLL89LBin)py`!BcLGu>kA26|B7r;e(ZIY=cs zx+NUK3lY&g2=cSr0Dvr;<{B<^JqRg>PkBliI=+tNQxBV3e6L-wFv7LUq_i~lCVX)I zL-AQBDUgl}NnEMJ0zr(woUZ1vEa#z{NpG>;ZUBu}HQjMlX6rC0uV`HLV~?Ln$ zWVJ5HXo^Xep*(Ilc%mK8qdu0lS7>d&GC7%Oe*JOyr+qCs<1nLw{?p4N?Ad_}87$5W zd6X&UiSyr5JgEy)yg~y$Kt$MHQ#gWdrGmw|dhdQm?TUsu>r?N@3v&~d(_1kPGQeMW z%r4F##{5`=4v(lkI8T*H71qsw@2yKi!4fL0WIibUDM8iSb9k+>^CnAhQN3UQrq#*d zZ5Sa*rH=NUc`TfLST8?145_ntt=i59(KW;COm9D=+tUk*WL}?jdwqHxyZNXwRF-WG z9Nku2B{#Y208Q*`IVPfk9YczOk~+d_@ncZWMatkQ?*E33;oph zT`PyyZ;sLh1f_I2ZW!b{?9^45GCN_XDK=JH+>7xhPa3)W&;}C)1ZPPjXgrNcN>}KW`772d>iKNKAqxS^x&T1FYMy_c!LPPD>CW2 zmQ!yM!w2e$z+ZJN2h3Qk6;!jrwcHb79wO7+0?}@aj8&LY^($=6BVobyjOM*8mht&tE!eBj!K8=l zWtF6|2Wirc8K(Syf~=HNgXkNc8^!Da7Y>PI>YrY`D~++Yad*@oLY|yE^-dvD#x)@~ zRsQ!|bpfzFBEl~q3HhAd^FiP(7h~OP zIwj!jY97rFz3|Dy<7=0;D2pZ6*DUt>3-Au5d5bgnOkI@F@u$ncZ=>MkF{-?y^C4S8 zudG7-%Eeg^Pyi3FUN%})<@a28`keNdm@)lC$Jy&#)3THk5tSfqn;Lh^cIy>)Rbsb9`>4=g6+Pw~ShX5ffCFb0YAskcYP96BEC6+tFQbl&n%5GIM zhy~|l5esqrjiNLCr?2n;t?Jk!=_KRdQsDU0-(JbM@+5RWx9zogaf-AAZ-;KRKS4`; z(q_=6{GZb+`=<*{{oj5?`O^!@ULQhd8Btae>Ii-u&dn@!Qoj)$Otj4!=bc|}TZ{!F0)19=p3}9)ShB(_rK^sxAQOPI1V|IuSGrs5jNtrCsUgvmj}^THte9^kschpFlh2 zQ+(zr$dl-}ijy6gV77hf?xtkflBOi~h}~VCXhRhxwm%}=KG0J!uk)2_H$q0tOcaVv zxf=ccv}_{_wK!eJ)6!1A{8}a`309eVmlug#e&)Tr;#q3XP(Gj4VnkKIo5NZ>m3KW5 z?_?O2JtfRX02pp|nEpmtBKu0FOB&Qj;hJe(2fX+*G<#x$A8oBdbWJd<`%ba~^;y=Q zNMQ8llTCo&&;ZHkw}}}ycv>qW=OYhOz}VfW$=z%`V+N=-IIadAZi+Q4z!Y`e0cx6{ z!RpJD*ZHFJTn*-Bb#o{Z>+k?eUhSacszP%^0rTu=ze$L$ZZM8|)%fT(VxGq#xylD-COoV0M zLQo<7pZn+DeS)NoY0&-H*yAjL$N2dz^`SWl2#mDXQ_~& zOR0Qwdmo(N+GtplH5mqx_%y8aLH!zYyigXg@{y)#`r;*k#PoQzLMlL5C*a=O`ji_o zxxDwzz5Q}04s)ght2X#!<&sRA`t8~9D(>GK09rkv&_ivmGP00dzx6|SMr@H-?1Y(| z`EAB|K7T(xQGNy1Cvv&}@BH`>FLJoy=Ioz@c&BUqwjP)o_R;y9E!7r7lC`V1t&uD2 zIf=f4P@`e(g#5Z?_W)3CBHu$$vtOjP1Tgh1-ZfJ9G5c9jJULwg@IFNDxAVRJO@2G! zK&A?v1Wuk=6+lXa?|ua7bz$dC<+m12w@Z&_j`R+C=VOMF;Cwnn8U1`nWLOlrTLfo$ zJxn`uztB^Lh|kdPOF+^&NJ-b+<^i~4L3NZYm_~qjGkIPUim(zxD+-cOk(PBVWShy) zY?-Ub*|hVYFZ+tLbSWG6OuJrO_?v*#YvUCc!YowagtXQ7h0}mi9Gx2!9{WM4WRz3* zsX#{LYs(FMQCgRbSXZmDXxd{jl;-rzquyR|i0|x#o z%(2|jY!tRYtVz@jw|~&v^;&H^7?@}$HCCkual;$_Z%cd2Z&?CyDpi9!ftA(Tlgb09 z8x%lo0$$O}|F&aYZ1+pft{&VkrF4+t1<74x%$e6KrB(bBiMCd-n9R zv5k8z(}y~nz80tAZ~?UI2R*-a9)6p+;Q)x=rR9tr+ujfb+(9lE-%u>}?IWwj>H$iY z)e@usITp4S@QhY&Qv3?i%S+v{6jgr_fBgmi5c-LO4#rB>Iqilg5^(sc_NiQgADa`M zJ2}u(ObIcD-*^!czv6>&V~l61G;K_ltW2{U^@^#-=eN2(=}-5_ncM@N^Ju*6oG#gE zUJG5rSz)a(9Y2uh1>P%xk>w`d`X!XULUq>c~XW>QMpjZ_0wCM(_n3bIVx z!7WuVVh{Bra8=#zVrU<;G`9s-p3v4Wk}NU)IEr6+OLXWQ7MFyStlF zU-sf6?|8v#{JPphxpPq(MtCc(t_AL>(klM?_!WbSXYsF-tcqcrJeQ6mWX6{vzjMP1 z?a4468nB#f-pXC+b1&L9Eg%$V1LeDkV;zurDjo*Ib>0lyBAcqwb_ ztB;}q|2j6LZT81*W9HVemc$N_Za46APkOB zv#xpkVYfa32xeRDl+=S`4m+E-NT)gg+il)|bF+6dW|d;Up@ffNX#s{}^%^xA)TFYZ zex&6VMHoul>b}PDu6qCBZy)YhIGqH#QLRV1Ox^-ES?p*$GS3|s0Am239nP0hvhXuVbo ziXV6u6xX!yH3cv#X*j)t+R{>$W{T%N_ovB*;a z`L@+S=V{j&3B^ib&DeJV+I-)Xnp`RUMlI2fHXS zLg&SdxzH6VwbQ-!4z{~^Vf>Cs!6O)Q&2CBW(sk{+(xiqlLu7r}M*KwMXof?H$!N#y zm@(l&sq#`rH;8acY09w24xuacg<9E*!zj2qLqf9dovTL8(Tw`)t4#GL8A@8i(u7bB zkopQh$xc&0VHRhV1JycBY!B82)JotB78QhP0M8nFV(Ee_h(5-EnJN zTP46!N3VdH+I&Kg54iO5*vGy4FrRJmHML0M z=|FK3Axebshivv-L&u%R$}V0Uo45B=KoQb}Gv`%Jx{{~Ucj-ZXKaB)w**qbxhgqN= zO*ZvB@oY3G5#E=KfuUjhyOk4-q3>=(38m1wY(lBY9gw^AQ1FTA-7E|{zeufxnblU> zBA`J3jp{>mfEm%NN*#r7hUcbQ5ckh=7nGVjUR5`lu*X;Of?T>55CvdD2H4|YnHl|- zCLeh$hY-cmqKG^9mIsDSJO;6RF4ijC){WZjIrmNjTD_;&gmZ4zE*V*WiSqaY{KJnv zlN&(h>D`D|vfmD7kN?N=xEypJX0##A4-{md$BU1W8#v+-t<^F$(4qOJ<9G~&>}mJ) z4K#Z8@vpI7gNGz2tAn8~F2*VqD^rHo9PE7ix~Jgs`~=($@xSHU@Ac?^cV09yMwEOp z%5j{>Ig3_U)=#mw?mMh|zxaz@+Nsvj;%y=;?>KGwb*{4RLuJBXkbQWm<|TE<)Z+Sw zwu1yCDEG86bl|quabfE-hW*QYbT@A(N}8b~5_LB_@;bDQ2Z&4QxCi7%E5mvX`m}m^ z2==@;N+0)k<$R!IwDw=K)Tog*ySR{Q$uSZ*V1;^~kAZvPdtvjQxouiH9Zm#8k6c+u z4nuxr*&E){uWdw!L8?s>0R94O2yFna99Ll)6(%YUB7@H8U{RrT?i)zB1;^{6kV|$# zDr`4JGZlCDOmWV9jw8pCZGBgY1I@Xt@15IwTwhAE?G?Q& zSA?zqW5067uCacGxD{DcD96nK6}?Q_=nBxcLA5h z*j8J7tXmO^VM#DhkZxXfS5-itGRufzZjk9xa`P*2!mwM6_rR>+HRGFEbGb5ARwy#J zqIfis8YNDBFqysja>*jrZGHtNuXyP3Ud;npO1?gfT49l?=EcycyFfbyfhs$>ID8yP zL)rd#(C@TFmqg`Eu9d75VL)=8OK4IEBy5p4k3hValc3qTG3$tldHEO z)t9P0MD(Z(_-TWJ0sEg?^X44;I^|g{rx5}8tD)~ddfyA!XKs*Xq`)7@xgaZEAL@6| zKm*7a#%qhVr`MWIOjCb+fwlNJ(<>V)cdp#H?7T{W^~m0w-9%ze$(aH1fthSq%m~ur zlk-Iz=?mh=cs1wMjVG_vHD7Yn`(4{vK!KghXxb|~-T!i()0gg)Zu@qilLq&7&mxF- zd>IxSYLrnqph78uwh!U2qMo#fo0UGHu^vjr>XMLx>OiUimcL$TpYo;@REur6`XWXi z_ndEU4kELQpXSCW@rjXqC%{RV5txhnVhR~LvFj8~^O^mNZ zwYugDmq>CSZCnn(cZoYkJ3hwVNXX9J$tkHj3b~MxyW0}9*n5aD(lt3`ZADhx@l!jB^)@##4EJ^E+^p=tv zt_djM1Zs(_`p-=D+TC*`&3f0}j?NhUyFwh=ZPEJc@*fV$o&JaQzR&qvLN1b3A~T;o z?*b`aI#Pyn#uk)1sZ0kPEg-O!vG;B?lvQ6(=Fhaqiif;F1{jd&)KNG&QO=n7b@;zCfYn zAWHynrG?(G^Cw6IVeM6;T^>3aM-_<%bN%*gXf9F)WPgmleuVcvIYi2no z2MuHh;aUXdB{L5=8R?-*Z!0LhgeK$22I^V0LQm+x3J}3Sz4hDbPHK>?+%sL(uSI;o9 zd(Tf(zt>wdorx|4xj>0%)e~QI3^E%%rh0oVetu#mK-52pS6j-eI@!8`E$; z^Wq)^rF|YgJalgKj{qLs6upB~pi1lSgW>ESL7?!D_ZlsP=A1>TB?QRlI6qOow~RIJ zt`g+zm9FtZ$=<%op`8J87}*G9GFd!2i2$ULC#sglw?*nFJf#7lIH_n$8B zF?eGT#iCpb-ya|?5!K?sxgE=LH3uuAkohp3vN z;n-)f>MX0P43c|YrAUxmdCVjlT%HE>rONvC4od#IZKQAn zLlNNeyY@@gg3ei1Q$Iz7qL8AufB{K1a}~BbsDtzKUT%b&XWxVVlsU>Yr&NOijBrGu zEZzI;y5Bg+#!1-l-)pZVl&|IuPAe6nr~j|^t~08sZCgjwb5N|PAkBszK$^gzHwDqB z2&gD6K~SWIP?ajhic&3rN-qK;%@C3hDWNE#BTDax6oC)|B%!3dwS(u~JMMV*jPc%| z_kOtlp*yS0J=a`w%{Axu@nqVuPI0sNKO2LpCUECq#>Hf}QI~bthXuu2Vc-94hFrDZ zjX$ynyPP z6u%R0pC=EiKK}C!+fa%8NqK0M?u&D)O=qRuGKm)Y(io@IQPB!vz8p;Onzh15Dc%Z- zLX^?PR8?-1hmGu59Y>ZCE@9~mqzD9e+w7bK@1|AAnWv&wOWCE?S*ttjXq;>xvL-!l zZt(Er!=%4u``ri&U3qLj!L?s3gxW2L1vv=&d)0p>bXZy}o#c&Bd}z%&QcjVl8{Wn0 zq<>O~5?h!OPIG>aoKd;0%AF9A{N5>7CaBQeNmp|4ML%!X_dgMPC0JV8tyQpgmA1a&Gi#v92u~TC-3A zmIk;x5N|{@bmio|7xNgm{mCPU+{V&i11K3v!GYQbS^b?$W>#(*H6m0#KZ{~S)lnH` z-pd~xntbFA#;jM&*=zuLEr7ZBG9bCtS|^z8zG60YL~zIP z&2JgK6EZ=G=P?xke6^|NOlgbl&BAS_VOrine+qu!E89wIMd7~cV|I=q60@efS@;=% zxAgV6tnNWCRE7JJIMmmxAwKszZ1as73NM@iJZYsO2xoIDsfFxaWspQVw|6xeh)0a%7L^`nDL z1a?171Kh=&X2sQ2f@uDmzW$VSlt!J!4yfM}!MmS%AJx$T)56|en?r>`ueL_lBP=f8 zvhfzIg~1eFojZN%irnFFMB3@k8>c@7U9u8TE7?K?w0Y7k7x4whvH+#8u!ZbHO>aA& zTrH#Pqq(t-N57jqpOMHNj^>7xlhx&Pb>A2nQjl4?W>BBN0!jgGxc^Ch-VhV}_gS{cb+p-G7Zoi^s%MzGMp^Gq2rgcGez5NZhz_);xnA zFQ@sumPNxRD~6c7VNkBKK+l&*-?FW>aedBmA&)(r0C#TOT$V=A#y}9q$f!I&N=M1) zrUtri4H`N+GJTgl~I6hwpIIR zNqnQEOz|D-`{&f#%`@7zle@vivSOtme44o1&&4$5OT`r~s$-S#qfNwb{s2+J{M;t8 zxBb)}2=~Y@Is575T&?sL~ z4rgg+Y&RV3u2kxS9CV=g`+ zOPSZcF~Q+6CW1s-)JxPrK>#`ED8IqWHZCG_r~m}Qv|pn)Kc+GxIe{mDs(~^fWLz8b zzvPI~=byx9QJ^mH-+*4_Oz=vIUWUG&D|JAjR(xjs24Gqh zs#_e>O7C#Tsl(luMkTv@4!7U-E@OOorR-xNS%^l7o#`bBD~;U_i`KV`b8PmB_6hfj z>NU~jp&ng67LlMKs925 zMePAYspO$T!cxHX#nL&ZjrY>5cmlwdK?GJ7`TO*^2LOna!Iernl8DooW;6ZVeo&1_ zx5s}hFk8kxEMf})&P~S7qqhORNi>^3NIiuZT*>?LBjJ3TO~{?cEE;|A&Yd4>_Y^M0 z_k6=`|UdYA15&pUO;DNiS?xql)O0 zQ@U;bZ2ax%HX_?3_%>~l^G!+V{U9}!)9@mMNS;tlX7=L%C8{WmLE!udu?XTEEFXQh z%Ew^aa9SIPA!8CLTY^}5lf>ok!{}?ox#b+rgh^7Y#Gg^WbCy#%L^Ju|Ug_I>H#vaX;j08k!$(w9b4Rq;@1*G3!st}L_Ff>pWjkG;1qxsA_p`?A3c9&mDpQj_ zBUVFo)R}+wBFHx7fx$ne^8vm2nQ7Y^yK}y8w}8D&+Zut052)4;;oRgz;F(noL~aN7 z?MT(=vG;jnx{K|pKx?AiIrJAH6c)_}e0No5ZYL-SP_mooDe~nd*|6o6hKO-J$kRvyh7pA*V8@r(-H9@~NV_i=Cav58MVo6ZzK`$A`7l z4=mDZ(nU8e9s?{N;EHdZ1_n4WP?;g44=%uLV}6!EL3eeD*5-j`AU4fI{&f z9j+Pc(}ej$Id4_728aHr0IKfxEcfBj3{$o~g+O%8`X=VhtOqfl_64^{Y@_lzkURX4|eF6Ah>e)zZ1mV98opNCaWxmvcoZ~c9E32%@G>W zK*HI!p@XeE!qy~Gw>!duZniYjS{BV8L0Vr}jf|-jX-$(wcIC@78tv)>#uXu$E@mhh z!Go+%cx^Zb>SjHxzqc+mCMfrFtB(JVC;9(tRUS%Ly;dQj&8CEK9G7>%M$CZRU{w#7 zY4RZo!hJ#W^u$_UNMKd9x#Q-ZT5syx8r5z@2CA_m4UhLGuf2{8G8PRS1h3#S0cOXO zsLd*#S~s2c9PXo9&^?$!FU^WI*@H1L+9?eLeLbk5(6IS14!dY{)?>NzYwAwNehqe{ zE{n>5EE{4cFhwKz#`-1N=bw{OJw->jTjkT?F6oeziJoERoa#SbJ4N38F1b~d_(t|F z)Gn;2)q-6{YfQ~i#~FaIYw4NB3|N&VS?pc1_gd+$pORwfj#%^K_IR~!6U+nJCQAbeBDT3bu^RF_Ux>f6vT;+7j2d2oz1X6X+myr`6ml={y@fSTC zdcOHCaG|)W=o1s6hKq(1QvQnx`S{1k({&mQ%c6!d(nmGiai=;;u?pYtw;>*wKeWsQ z?ax`dH6053kcD|k`lysB^2*?3naX}m>`|@!C&P~9GRabq-~$WQ@h<*OC^IJeqk^x5 zKZPeIt^npjj%X_mV{rXe(b`7XfC^ATTVL;cuL<@9?V=WwBRzeq?3V4#?D1iaomzjF zz{7U!xt-8=J~2GZ9@E|FsQX&8A<($X5udga5dM@ZnprPI!pMiLO=ji{s7ov+HzvoB zE=4cn(TrJ(a($#v&$Lu{JT+-3H5a`BhBpM__d9{Dx?vou=-Wl_yF4O0Eu|;7J!?YX z%vgf*u{3Ersk0_0a1M`r+F0rbx_@YMRe2)M-iuPs81QO!db7~O-^VVjUZPA#LPx#5n9ci;1|x=GQPOrVUO9~sz%$t;6! zxoFa4T~DK*F~uYjHO&T&(K#@%9_xb7D}B1muTD=DB~IqH@D|0*yKecj?aJ)(fwI z#EjNs#`b&XM(i!8P7h9ytWrZ(+db&Te!bmNW#)3)Cmv3T_?e zY3;&HZO_sJ1bodxNU5MYPCWlaEN|VnH7mPUWgnadQoBkgHzgLF+`T?R#nHBLn-Li64wdPY!(BNC<& zJRRqgwLev(wE~ahF_~9Wt*-l?O9_}wX6T5vUh7(OotPNRaiP>FZI;2cey{Bt=7ZH6 zK;Px6$+C(-#?mH9?9&|OvNC1qm#HO<9}M+#cj=n*wM=LwDsM2KVNE65SCY0ZU+k~r zUY@-&D<)b6t6de%AI~_3&$DJ+_AU8vw4m9a+?Yf~xI}Df?(KAmvbv5{Ww_aYrIr;m zeB1!r3K`{7qJqD@bThEw^FQx&S$c>jz~gR79&+1#9(#|Nvd_LdoZ7uLuL9EyAEor! z?tv!fls8Ty^^RNkd%RP#XktbCstY#>OkQ|EyowEkH@Qr@$ntyaKzc8lsZ`v2DGLi# zfNXK0%LbC?Gw%koC2Y;4C(ME`EP(GW{@vB=!Q&qN$6*K;^41jWJYFSG4C%n%SLxm)UvMrG3c`4cUWi| z*dk`%`K~3Ss|sisQ}}0_M1TjRkH?a}>{?84JaI702RkBHS(zPc1>jZ35zxeJE6`fE zWWMH)r(y@>i@WTw3n^j+6a5;d1=j8jm?xs@isK~r3U;mL9#YS+%(vlhZ!65TThXB# zVJ9zw1JHqw&Q+emrTgqKo=A)<=^hd7@fKgz7=cTwMwE3ucicnKmwI#{%*U?#c=}+1 zZyFubU0wtA3qwSx@2N(rf48X9ieHNwZ^AANB^2BGU>yY*Ouu4IpPkVh{A1bfZjH&s zl9L-?9&*q}bF+fyL<;fg)agC%lj%o~6q${PNT+`Zc2Ky#`sW>Zz(*`@*Rl4}E`h_g zF&4RYahFYIacAdB&6G4B>*X%KjLe(dEJquZ`6ah>(he*XSj zF;6^>)JXJ{ab~~C=qk}bF(ui~N*D9eTjzxtx|m*eRtS%CnY{8A_mI~TnGQ#~ zUT@XTA57%sWL&hud{1db-$KBID(F*XquW-s1L$lIAGHD96<4*5Q{qSmIf9Z$1ih7% z-BCQ0>-t`0Zol^oQn?_1L6kAtLr86S9$7{oIH19B)g~RY%NtCTOu{v02XY2&f0V)$ zpKE@VUOzLq5y^BtpO@G>c#WJpkRjO?&9n@8M zK%M?vc52_Omy@iq$Ceii(vT4>RGI5VBy8=i>0z9fyeyq&y!oyvJkfh4qeoL^QR}qj zGJj2P)@EnuYo?jViYGPxjj!404R&mw|y$=|0XtiPh7t6-{&#nV&q)Tqpk;Y<_9xg1syY zO%Q=gxhftg@wTWX+uevfWMehQ++{^cT2O$!KL+gxM`)Y8ym9v}HAR^5C3CE-mW5#6 z0*3L*e(1;aONsf~j*Rr&tK#^CG`E@gm)JShvGGomXC%aFb!#!Y0tGSUo2#K#Jzi&nkIw`w|yS?1> zZ3yo*X()ACt|c1&usem3PMd@9YJ0+qC1D=u?0W?DerAWm%OWXJ2YPGZ!5xnBfy}NM z!95pbh0ZcBkO_?Pou-uJPIJhOwuNKGg5vWx^;vb;wjR%F?ru!(2JV-b|5#_)TyrU% z5k?46dfXP_S;T_tmj@Ow)V)@YX~M0~$M1w8euFj^Z}f+6uL#n2arX#>yjVpOWW2A1 z>3-Q!OZWNiFzZ^IyJqUSn9t9b1@ig=sQPr$?Q+cVB zuEMKL_3qX;D@=P`v>58WF6H9SNb!L~ny-m8M%d7-MxBuV;+0D!sti<|vpzGrqSw1& z6klJvAyf-2kdfaE(maGepgd>O1wPC6-mCN}(B>g+gUnV#jgz~Rz>%cm_AR_^-=?;+ zT*gQIUIrF7|x5Fx^uDHRm6J2uSW&sW~^LNdp{HNz#Aq3*mjG9F&A>M*j; zJ6zmZZz!%BC9c@-6&n&!u39OfI6Q{9Z9In%Ig7}=!~TJ7-2ZK+!k~F*Ky|LALr=)F z;CTy4!6&4#CqP#aRw@RqxNPJ5FYFBOVXRK#$FXk{6B-bD82IRcGXv;lfqMHwix9&Tp`_VbUg68FSVD7HKpow z4K;N6na~CjFY=M=>VAU;v}9TQ_N1EZYq^#peX?VqMAq;-Nf)3(SZ}aY3Vqs3&UD4@#4fOB0$+|}Hcv5gD zJ(HKUyM_|^Z*u!w(`@}#=Tka;wW)G)*7jD@5gXQO|Jv1Frv@|eHR}bLru`ER<8dm1% zP^RcdUT>9|<^6i&SUb8ce=z4*+s(Ps8fuS=-4lw`G4oJ3wJ($MNNI{O)w>2_y&jzJ zn9*%%G5N0v-&8ddDM-T{*Svkk5BV=X)Yf~quwqwj z@wjbn3m=K6UU~D7S(>`9D-3_$iNxQSQ6*gLF!5Xy_>l-ff>K4SfqX11=763in@gg#Y9 z;ngJn>WxIGq62q{uBHgiS46*vdh+gu%GC$9zMfO+ofmmk+fy zH5P8E9uuWg;WAX>cgSgg5_fFU>6DGj_nBu>b9SI9ZY?;;JpS!FvtOkw#fzH!d;uZ?9A{L;OSe@ZPG2 z37(zzSEb1}p8Jr1D{LqL4#qEgw8-RssfZ`16ubWZQf?vHne&WvS4=H&aYjD$$9%;0PX$|}UsNJSMR}?VT>dlC<@$Lc7%WdgM`LGiYuSIME&PHscts|Ub|#zr=hKui zqm}lrX;NvXIc_FDB#5m$cc6=oJ_CPbkXh6PWL zU2?R%>^Hx_$HSv|SgnRizVJTMOMXdyy3u22bFqIaRM9|s9ZS4LqsQA70(dLUA-CXF zA-{vq(nBCFjOY5Rwo4`|Y0^@?*94muR>y14oVmcvsQjSduBLC7bI)@NZQJkGz1hc5m-Ezi>vlV3o_VW~o*dX_0&*4@Z+1$)4iJ)5M2^X)BRYT7GdLI}cr~Nbs~V z4EF4L|6lt&r?L)i^Jk)?{b;XTNtD{Ht+dCU@3DN~cyd~L{>DPj1?Se;Q0QkKC5Q^t zl_z%x12sMKR;xfiq4u(xkT*QH{*#mHB^jRs_tG;bXBL&r&Vx6lMgomhI}ozu^R*6s z6lF(I<~la(CfxpEZXJc8j$Au#rcUPM9KWqOV1Q<7A_I90&3TxFuC_Lvr1xIg@bdhp zSLd3k*H~@GsvL_<+rlggefbUR=&nebur2hYvFM|G>W6=rN?S+15zHvtZRomm5LfC) z!j_wCu{G8Z6alSgU~y-MFJzR>s_qXR! zj4BkJnx`y8ba?l^&$M+hpA59IUGT5bJJ#R_Kh)@&Y$D)U1s*Vn+CBp-n2=U*fmxEy z#|JF3wV8!9U*=ecHs0cfsOJX!g6>QD%M-|e=s4J!p9_bbySmoHDBmX@jbK;sVI^6MT;;%51@^TT*$aI zoaMaGEuuG`40rx8yd=9`J|NgL>x#zL$lJx?Wt2+WlfmW5l>+iu}jR$JL)W%B|22~2Zw|pO~(7Pu0nPQAr zcfZfH&;)PD9)Jb9N&TGtukm_Mw_fL-ncev$d$P7LCYy%7Fp{+;3-1u>K`i*)+Tjj% z=gOe(KCaSVElXL%(?u0VCjW!t-mp$B`Tg|dK-W55CTTJ%23zZ9A1`Ea&5gwsIJ3@H zn=BY*^v!zy2hao>EZj-#0aM=2YqM9jS#=uB=Qggd6IG7yhpiW&SE?zN;FSF7v5{I> zB?uJmo8O*f;p5(Jy=L*DBcm|rKy|)OzMwMWfSg^a(EH8dg%=QW?!tF#mGgbcl$1kL+pO%k-?c|NM^rBAK(5`6S)a;!DSaGpVy z4My@UoV5HLZ5-y?bR+00Z=&6i9{R^1B? zn#>kQW|=E_2X(4XyhIYZ5tk+C=fI$WLse%tr3)XQDT()9U3p?3md!8`?gM6t=g*34Fa`5z+gTMDFE}*y8R1SD!#4c=d0;dnD>2ywIyQZ1!jxhR zg8rE)V3*_sflT>cfr6Ob7@Z{WN4vc)K^|l`Av92g7_FE}f0?gWDSDith&KI;QYikP z^p^*A-*lgMmrJyn>C<10=yNellMJlj?hg0yF-&?d8@|1+t*Kg5vx^Xw+Jsu!dLyK~M&T|G+-pLzTn(V?vW4wf^z8@J>eiosf&ZYm+v?nQB^ zB-ZbB_7mXbNP{TLLkPZLY831uJ58g7=AG9W_JrE6`W_J$P$h*i6ewijTtl4Mvox=D z^x8zkY^dVHMQ={}X}rrRnBSHh5fF_pFVw>ti?BQHw-_8 zlLw%V_Lfbq=?nJOZ?eBR7z5{sc^;xgYL{o1DIZrkFDL_pzuX0}{DRERb<7r}>t_*b zYRqmQ7UBJ>pRKT09#`>SRAnG};__w-_#{|a)qve`o znvSVq)NDxRuss8MF0#kPalSL5X}Hufsop7bCou1M7ay80Yi_r>yU*6 zuxgh_GDw0QW1=H#+n1$6-XpHgnC`uI8QXEQtmUVIsz1eh)b$pWtR3Dvb}|<4$L@6_~V;!0SAGephHM72x_H&_*kO125&6=YO6r(%B)nM zi?vh1cOxIdoP!HaF1m!aqinIa=jl<15RQpBY>P#J!g;$pYz&l~rR40kTWF0G*?Q!= zFE%xk7VUNsl-7s;r`4J-X4!-ri2R=^+sc9&E``o#6 z&H7eX?+4N<_{1{icU+w8OTNWDIxo4b!0d_MSynafUXHjNwQ;KpW1-%?y5W8cN#7{Y zEu{b^$O>fHqE$A(#hO0T+3;CJhhdt}o>l2U3TMvO*(Cc0tBM-Avp3Hs7xc2u>oD&` zD>|#)8Etzkk#JkgXA`+0D`(=}lPc9U*luz8Grwr!S3^PXF`jLA5zO*S>5qd1ZZ6|pzfY;J#ZPc|6@Zz`7)O%EX0gv2>{6F9v?AXU#jA0MYy zokOo0Zf_%Sp{>VIh#luPpf9zd&+2G@8$9@#f@Nb3wuEIV)tXRRPeyNMZ5kFN?lio!|Y5m zy94qHFkz9qYbK>F!mRxYfnO~-E7X}4O3Ewz{S>rLHMrpoOGO$YH_CD8Q?ni*s3#M0 zhTc25YLb5NA)^BG^8^SObS$AVxmQ+)ib1f=T1KIUs+|`C?`CklB|Cu%O zF+XB*WX0XyTI#Fbhg+B~!`q0g3Ot5o7^{bsSH5?j1P@1~5TTtR;&m!g*ed)~(9*ec*<>$V%Fk+CD%w0at(}l#403a4*jUKGODGpX8hBAnq-(DRbU2%!xvDatA-oOVBl=RNJ zqc`?$a_;bkE6AbjG{O7!BLGMPYpjVP&mg}py}wdA?Z^7=3FhU9I5&*n^>A-%iYt=y zH(w|zig~f&xV37F`w4-vW z#K|CzXSZH=m0na|hUf=SRK5qK zX#b3eX_)W#NrsEb%ja-%8`gW24|A(|gjjs{U33Z&g)l{V9J++xsKIu7NkFB|TZCL| zqb9^Tv}37uv;x6kx1bww_Y7yt{LEDIYhOfM`_={sOU*WY^71{!xY5cB%u*g-S6t%I zQ#yC6gF$vd#b?wowsKe)FkZbKr?2;c7gNkPw@jJz85UIF#PrRy@41gF^L<)wgF7Ht092Vv4 zyEOdWdLB1a$llzl&@&U!YZKkMdVGXG4jcuJguz07W<|o>b%0jp@@-^T>Mr6vS#@2GS;}s zb&wPxwwH=xDHPonnxY1cl<3k zI1{bl?7dp^a@|@inTLlbq`t<^PWUq(_^544qeA^!n!fq1ZY5Gea`xEXRf7wsLsrB>YDTv{SKhD=o zRceh-ds3ro>+`Urbs#!6LFlcBzm7umtOY#F*$yCU&f$To`>Pu@tUaBLyJpV`bX2mw zz@n2N&wTJ|(arp=&+2_Eg;Vb9>$(`x3=sUjtr7{0FZW2&9ACDOsv3}l%gOu24e-(2 zGowam*00?7YJ%hDU3E!tRZP$~qV}-7JSN(5M7mCJQ{%JUsQEybDERa!sc#XZBUcrA zvN(vT=@R=Xk54$hhb9E4bt!)RsxK=sXsT;HX)-Jml#$r$1myddc82(sG;znLyNR=8 zFZwWS)SC-CpFCoA?%dlKWkF0B;oE&nzwYSb?h|OCspePyw{RO>B@Dgv5N;};jBtKS zN!N`^>$6S5h$hHlBELOFbi`z`O;C>Ux|}RbN*g0Nr&`DE!(N+2jX>Yssn;fsqQjKI6Faty#-|(epgg0 z^Nx!g{f6sZ!Ee4;a z;jd6eHBTcEmuv~5oLxl?dS=PPr$97CD(bz$eZm-P#65c-OS1v2s9^dZ=2MOzTEq3U zAqD{?KgtL={Fs}5+8S!okOxkNFib1wY7@XG3`mzyJHIj8RPaW{NUz5OTM|7iEcY|Z zKQtjBY9GO4@!!?X2Z12)!Q@sQ1#7NQil;Q|R|chRZ?Uls;f0EpzD!nOJ{>BFclJrw zeS(rs>=|s1BsM{ja8=FC+C^_?1;Fr__rA0CNFlyxX+$t2svUB4oeSK zktSpeFMVL$H<9F?3n5r$=Egs4z{<~IV#>OT!Jp-6V zy9ScZQhjH^&(};4^0fey(kY*1mL*y2o2>U1f~ zc$74gh{=9lvi-<#rh=v(t5FjqNPHbIn~qwzsbF5;WiQ&>&hv`EyCsR)C+5ePf3ZV9 zUXDD4VjG^d_7B95*J5m4qA)hePOh^zqbnv9IsD)3JTXwak$Iyx1Fr%2hjVX{e={cfBdb>j%C>@9Ch)Ob^CIHSeQ z5f?a`kTZ_?@}^J2@ASsW_;^Iy>D9>iy3W!nFl!24tiMGx5JQ`!Owgo_WKE4c1JIY560m++TiV3ikSeI zBr$&=b7f@fDB^H`M@&zFyK8ZeQJ4{;X)0<6i&<%jyx;-$h@wO$w#LO;(qyO6H#8h% zAo?<1d7;G}+x02phHPMz#dZ0fV{koDz0TY`%B0J#g$A{vjOPWc=tJIG4H$r!;T&}* zRfQJu!NBtJc)=du@k6u{)sywz6Wt>FJ2P#bt#9?Y`H`8#pS$LLBd!Iz@*YP z0X&68Zc?AV100EsnT=cJ24{VXg96>9_&>?TA3m4_S?dLv?qs-{61{lO7C>HI%*zJy z>1>z$mg9Cj?&0#%JpojLVE7q`IkugsLqZg7QyYt~Gj!6gwQ_)Jtknubq%h|Q z%z*{ptl{O%q|kUY7iV<-x19^pyT(;%*Uxb7TEt*UH%pVhCUXw4EvO_J9|1JfrFomY zS_Y7h7WP&Mj+ZDKReVGwUT=@_WR~t(e0J~#cxy}Mw^yj$>0lJqC29AyNdbFB_Jq1y z$6aNa2el&(*i0e<)HMt$Fu;L8Ni|sz!XZhD!z4HM|xd~G*|Z{wQ~Nw~hmku< zwAIM;<)CKz*rE}aS(f8n*{l+q0^jQM1sR)Ck!^`>W0hmZy;~yo&ITPjoja>sm=^=` zUlEE+$nq!arK8E=L~BHestiBUq*8J$%FQdPhB6U*i5d&K8vv&CnW{$)Y`&YtFW<_h zG-up!U1<|(zY3PA6r-?Yu1N_SQi0h%@VZ-s=UqcURX-)v+4u~L@~dd|~e4%f9r?Z~;WBUBzd3%wm_pfVPj`h0~LDhB|X%xv`kPXIlgN zR1q%;bAiIMi9_X$HGMk_oP>raYNx8Y_RBYsTp$kyM`yRkO<2$05L+Xm(sonzlOla- zF0`aX5hGt$223WxytAWy_zh~l!f6wt`|8ZWTTH>9)4`Q5c}S;8xLB{@DzT!T9eU?Z zPl_hwkT{4D3l={0kiOZ^h4cEOPoDt&E@rg2+8)DgTF5FX2Z&Rq*l37o^(v@!#aSQ8 zC6Edj+@~9i?Tg72VoO+*8&b44VXR!2W|&T|ThJcOA{qrO7nW;5`b( zQ2{<^-R<^+)upSXYastJco|bD=INr~hy&2sv{1WF^<7vw9&k;7I*#H|UUwr;Qaj*H zkT5L0|N6X$9xneNZVFuhu0_z zG7p!SE0ngdve|YEua;M-x)t;?h}qe0kDv)P>W@-qHUNOLuy?7xp1$cuN*qh0hgQvA z;X3{~6Wqy+NG#jqJ{K(+)VhLPC8MZ8zjtRFHwyd#LhBgB8$ZCyfn5F%;3NU{kB*bk4Vu~dD5&rQ67Z;M4*veC{hBLmZ@X+_2VB+G52s$U zkOAg^<=|)^-yCkcU<1{FkAfG`F1gF)+%#3RMwKGpag@Jn zx7Qu566<^}G2uP!!EG%{*;v@_8@X zcmvSJfMR2JR+(_adILBYCx?LmBpDt`U&!#} zH`~JYU2hg_UEBrl5&(b-d#p6bu2=$fW>Rju1wO#?-G3Feqp0UpSFoT#fqJjM1n!(t zQgD!0z_uN^FTP&I{*D#35-H|3Alpq>emM!uCQV|SMR0WBADP_u`S^}N4V$t#$SXCr z?@qqiG|vVOs;c3`^Mnm*Muc%@4@L{9g)Ns{8iHKVfZ?jtVlHE?HrW%m0Z|U#>YfjU(*;mz zDxQPr@*gd-5NfYH#qu8U8MdsARXeLCiR$h-=P}kb%re%y;&e+f;}9Xl%RWcMBc3|n z-Npou^X?JEp=CbMeKWqtr9ojg6&}*{pQkzNFQ0_8?wkFuu$Ugx;~A@U-^)ofNg!E&ch`4&~yvC zxVwWo9=9#m53q&JYm9vMlQ7VbNgNXqTk^HmCOl~pVq#IDyvmKwoch@;1 zKzy~F4O6BQM-N3gC`tZ0S0mOP%9h#W6}Bjj3V*nM#A~4BQ;oG!;AA@MIQ8ijNKop0 z`gPfS*;ssbO&|vF{z)iI*W_eg1#Gvtnfn-H``f8SPGo+ow|CU*7>ftqXs&Vttr46c zBYXibu48J2+yx18O6g1{YatK%B~%`$8Pd~*DZi-Pgr`a(o#yG?7SmS&G8WMTY^?_r zg&J#O9&T8!MKVSlxD8nyxnEd0uP7AuiF@q$}!LNk%01;7Eo39qLzim7ze|CWiL1C*h#UM?hc0-rmrMpKS= zIVWx{p`LhLc(BS+<8$Z%afNvt3`e_arJypcRt&N?VDf14wwP;NcV$UlG*F0h%1C)^ zejMI6;zWn4&0dxJg53PpI~5eiMss(=i-t_U=DAhgn>r*TkvbmX^zOqkdpZf}H!n?M#HeCDlkKMw%xW+5GO)aaO+ zSk5SeIp#NJ%m$%JIJTKic;D+}koF#Y|1+@XXs!8z=><)dD=~^We;RI>H5oXyG6EGu zCfjhGZ&YE$VM4Xf>~5U=M67LTXv6C0;tZz`)u=^?96t=Unp8cAw<`48P6~i``pi&** zM&c0aB+b6YWcl~^vd(@-mACc!->ZbLbyg?6v>K&$n7Bdyu#NtYi{2CNLT_wrqxF|# zBWFaqtYp$E(({ECjwCbC1AmHeZkd=(L}Q>_l*3(^EI=`VLkH>VfLHe6JdGlCn%us?w1I zhCL~_F(3oAy+JS^O}<5V>Qw3&u9J3p^0RfdcVSLq@&PC=aYojU<3LN0N}{?8VK`ACgOW`XH%|8pYEijsRH#ufor(?Sh0x^y%GcY z?wJBiO(rX=ii4E$ljp^5&BJ{dZ%(#%Uf9zsia|ZH4R}^fPJAwmai$J0ze_`uLw$LT zMC0zAa(A)Gh9zI-Jy{y|`Ir1k_v|&jd5cxM#Utx((VMQ|rFC!x|FJTlRz`5nu178ibE9bW;EdU(msu{&Pz7K_dk@!_Wx8(XuXoXBkv z#R_}m5&U?s%=h@n!p2MHV6C}iZi&z7AgU>|EcJ4l+CjGu%N+I*^GNvt1LxbF>zP=| zG|B~)YSB^eINyEY?)~$C%k_mic+(%Z7bP%{XueE{Y_E}|BbWhJr>Mhz*l(Ter#Lp^^4_7g*{y)-uhVQqps=sWVt#`gW13Uq~SOV z@x5BG(LV^J)FZfUTO_#6E_ClPem4Eajfcp#(t*OOuc;PXmQBz4I8Eq`+Y59!^V1{p zf6D+Gt&+ROK>6_7fzJ7I3T`^NLZd13Eyv--Tg+P?b{ zF`eFv+mSCgF2D3}V`R?ag)}Q8%g&Ce7Cf~&+c!?~O}kJk$n|L!7ihvcYCl=f3{I`ei=T`eafZVkIWOkKbc z^0?D!V77NliiLb(KXPFZgGUd{!oND*f0SC81krhblQu`QmMq~R86{E1C(pY+qIetC zJ~^gCD+^1NT}}hyM1s5&wcQ+?`?zH25Wqsd#q|N@)XO*SBi8gscuH2&7!i0FC3)V?Zu^=??j}y~k zdHIek(&vRs=-L&z&ve5hB&}k9bH>KB3ELIB5_KH2Ax%_i50Jkps#% z>=g`k)fPo}OlITT>A~BdBjyCu4_hjLxKqVzX#o&J$B9mo!G0FS(PV1!3LpP)XQTGX z%t*agHI4WB>?wA-hxAp{+sRump7G`uHrOO<_z~p;O0qsEvy1oF2Ad2@$p&5B`yx7l>lkyd6fI72|C94(Fs; zMigV4)!a`1Rfl9}{)Nem_0A_Rkb^t(b=-3cX6cX)YK=%U4b~r+$oT%*!?LHgrvyuw zd^)yRV8B{eGC9=$Ko^zjMJ%*~V(LB((5| zEH=+NvDvY<9j>oCowA)i_1UrWd#5*FR6E>R(TPvd`HsDhMXHt&+%zO?Amd!;%;$2! zgzq=Q3%49W&e2~9c3Th`1q|P1a)Je$?^h#60gKu}hsA+hEq@S_&IA3FIAFG$CP9*A zb%fqkb?@Gos6gmB>J$X`VZ25b+CiFd2B<`>w?`CUB6jJbT?>inqTxu^(dF_C(T?RO zTI2-zjyfaK)@S{Mdgo$5cAEho=(q;OOl0~8_)VyCX<)+^qevX!pGEJqT_H%V41N%BM8>@jj+#WSH zF?k0r!-82ZSg&Pcf#d={VUnx8u6joSVA%tv-Gqn_Fvux6g@SL|WzW9wOvD`e7NomtSy)nC>`U4tN&O@0SbP;eY_d@cc{B6}E zz=fLlx5{`B@FEwWTE@!0CS0pbOwRk+jC6RJVCK z4;q#M!R=4Op)7oYCV)<yPnZmM}sB>Xo`%{((TB4R;4_E2$7%>OUjW z@BoY57|7$h3{e5=;1}R1a`gu6VLjz#!=Ss(;Ah1S&n`>w)z4)HC;*=v{b;9#|LM6;<8Sz%Zn1NjV1`wy!=cLXKCCMe;R0;k4H zLJ%LQ`gXTR?azSbCF>tHfobma%Cb;FkcMc==O?CI{sh=EGNP;$oCY15l|SHKAT(W! z6Ny{D_=)h(VM{H0Ao(lt$W>U}0g=ofQa~Glzw=u>*+ynCk90kb-$?VnPd8j<6}8R} z&JoIl-~%jA4g_}wN#7%{UHX$trH4||H@JVF7O+u=&N7buaYs{7%=8B{>AVk+)lIdd zSiCo%Y^a=^1B(@FzY4txamVRDH#ES9@hLxaVSvDGh2|fSc-w5pZi5>0c^Ob<%2*)^ z&|ol1CP{Y(sC`7(^LQEb2C&$UR^STj96^1jxez?b9i{(wvzs4F=f}Ol7e9&<^Zoqf zSF_|#Tl{L4{Ar8-Bh8Zk8ZhkFRQ_p=UsL&~Eq+bqf7#;K#v(tB@&AL3MRt8x#-%zh zxlQnW^I;d~r=bcXD{H#*bmp1EB=Rl;-5>sCkyEU9NJ(YM79RzECT`7l;~L2y4y0j? zzgM`j#V5e($hTRCnkZ0x1uJh>#slW;a&SOwu~cwWu-d?v{ijC%xW?A;N4}MY{WCDC zKZsc9rv?9Sep0f%RQ{9+bmc+qLK4Vzv*mp}ik}&Ps1H`49+?kBZ{?tKP*d`xOb0zK z5A`Wy(AEk1c1HGqqA?iA|BSCdDxf`(VFZDnVh2V>bc;2+kIzF10ywl4Ng6=`fNGEA zD#*8)j`~?(jq5b9|NK^&_*)m>Q(gHWSyhbNw}<%+{MWoi$w!Y5$WJcxpZ|sZU{DJ3 zuS58+U!VLm$1iLAv;_?If5jTykc=$9uwgG}h{QN%Zo?RXKr}lY?0%pAk!C>;CEyKK z3oNe;ROgDzL6LV5l&tjt+s(=RYQ8;Q%4zn|a=7n-sSU`5@7bm9iV#6(J8)`6CPL&)Usd2Gc24qe? z;2N&@3c*W<&dga1Y=dG4*WjSS?#|j&bIT77gLu7QWf<6NBPN(=>{B|1rC6uh-8m$~+s!4FDFL9)8}ym}R7fV;pP(fnsbz!-DaAY;haILga?q-c%nUCHvK#S#&d!A6x~`z-UEAp2nfy(h2(OE>DiFf z`23;8Fxa#JYx)3RPc(Bpaa>Jfc+#wxmH+gE~ZsfhCQM}Fm*b8ht@^sk^kVdu!K&dyB zQvsB2PZ9MxEAI9Uev z?b6};n7Y>yxXWl;w5a^CTbLDKCGx?JD@El%X4`koU8Zyu>jN3fur*>^PR|=`B7*Y? z2AozI)~zF!7l$j;8?i|)?u$N3ik;XP1?Kc*j7C>O%H_j@?qvx?62=?s5`bHNJa`6e z1Fm^E+92fGC%ss^X)0-R8ZzD@)Th2aJHOWIJsptk-bZa$)2U z2y@DgJL9{KgqWSW;bF-;0{h)?+<0x}(B3$tCo2ZO));*zRlt#tZu?y!qD#hpK_SGr#-8B9>X|)~o#c5mbvTc6k5&9^X1nXk%r4O3(yvMsenCOi_ zbq~SPoPJAB^r7hx%wbOquO=EuquX;q(aft+xayb;rp@brwEyDg!*9Nx26U0O2YLRM z{HspVX24#mORm=jR0NVgC`cqd1@_h$I$uRaa}-*JAn^hU8v8R+Z87axOyuL4eNa!I z%`4gWYE5GAIyE9HSL$#fb8O{(8XhtnX^I902d% zeouDm20rTI_o^4|=6F;$!6nQ-TC~T8+PKRzW4vFvnGAoMI#B5t6Jx4!b056wmK!s2 zp0m5J5wDf!W*rMI-?AN8GMTGepE+yZwBpq*>Aq$>OVRn?l@7~0e=<%kmDpEZ00Ryl z3B}Ou;6^cq_}FxtdzJ@$2)LbbrPj-k_oO7Z*J}bOXsj8xq$Jj*qbMXqd#k2;1<5nd z)7-?BuN=lX%U^p=9b-vzP5F2re@~C?`s8ZRr$bhOBo?{&41K-Vm+z8qih&QTOa>H z{DL&tl4giYEVu$M)w7bt1KJ43st3XDM8Ww-d#`+H9!o-N5wvOA?HE+5dn!Lz>5SrX z3470ZD8}p3!|%+7i*fBPzv;|dMZ0&L3XO7ETEoFc-8F(ge+7d!yCQTjfVeZ-*qI zd)P8af8nyHy~5(dwZT);2)zQ_GyZ)U_;NA=nTm1A`KX0 zCp}XsZy4v8X8Mt4I)AD8Zef>C4wp&z-hFBTE}{#503M%s7{d?W?SeQSI|v$=g37UM zJQzI}!EiSBA<*kpGC7$jEc7=?*k zUk)hG4ZGb}Vda1KFruQ}YiY+28D6HQ;d$N}q3}hVI=~xv?8Y^Fv$~9S2d>RnxuFU! z#Q^pt7b9Kt-=i&}zgYbx|f_tS}af#k1fAber z;nI5_we|2o2HP@XEx+}p#5IKU>g)Wn#04#lKE*V!Nl)d_Rpp6N<(tJf?slwleHmX; zH!@&__6D;ml9=iZP&)IfvNm%(C#RMDqmBl`^(}Bq0kgR;i|YIMkv=hHYU35{p5A+l zSuKMe=I@TU$Vb?BlYFjT`m!2(Xj!LzI_Au)x6VqtKvNCofD>=wjeX_I6E!_755st4 zd^^I5)pQvC-v@YG# z7M371V72lu`^>N7^;uq%pTcBQ8?c+^apfHXeQLcnx7KCGumJxBPMjO49(nYQpKFRIlt5JW>n|qXWZIb9HdiCqvi?#Lm z2lP5#jvj_wEy>L>l(ZEnLFc?&i@3P+QJOFt1)lNAe0#|b_U83v{czQ}s!oB+NjGqM zTZW|+3RwYVIfN$=|kKQe4JY>H^5Z;O{TH}gsZqEAX z6Ah-_lQ#}+@K9V~$dwqz{jQw@E3v*O zxL25v^XX%$bym(HCcD5;L>VY1B`ee2a;(92S>;2zcpwmtnR*W0^W^w=pJ)jU%|kKZ zpLFwLaw#f4^v&jRnTw)9dEX!}f2ZjT%O#tcYzp)MIk>ZiORq`pMKKMXSA8FIKI>|6 zm1>?AveZd%8Q!;7wa^F^#h^o~yZh+R&5*v$%~4u_vY`~6$+{a0E?^fC=&gztQndXj zv4Jt%F@UlPp3Hzxt5}qg?IGXYCLk%jq&K%G1?}SoCjBG75~H`Jw`awcR5e18u^?43 zGteY|sQD^8_ho7G$DD?J0O&5+^jCHsiplxz^``}uqtPON#5jP;_Jjunf&nJR$Pu^$y^?hSM>Uc+h0A*nB$ zV37+cPZ3h5R_9`mq<~+t0eh<8WbDUl1}g7)w3^KCiO1A~p36bE&UslDYSL~ONMGOb zJ}PE;3e*9)dv4&Iz%FI~&fy*sQ#jR=ti@2Hs8S0hAC;|J%Vj}7y(02vsZ3D@rf;5{ z&e7$%*q1p7Cemn(ujO!0HM=C}t1X9iIf`(qe#nuLB(Rp@-pF$wwpX8WBHJ zSz|`xR7lBLLV-MxZ*YkK%FjXNTht^;TDyqBd1NZ+q`Wy8A1D{#``$VvubolnGWjvM zSLGJKd(#o2&=wOCt}EKhpnU=AcPdMj_8tMAEp6n*H~=B#$|({e5sq<(&;~yB;dcl8 zZfN*ny?4f4=cXN_Ll+M|{u!2~wVlv&vLgu57IZNXgQdWxA)Cu34ECQW)6UJ}5=g!{ zlyjz5b!!XAQq&BzDK%nz)UspdpbcYOAB+oLY7fj17+_hDZZ#~Y-7t5w-?!*GPJ&O59 z5ZW*1j4#a|Z*IWtbMQdY7iw1Sppx72e$Ea|MTlc87O5c(qmk?uSE;JDMG-pXzWoTd z&>$XOCq|}fVy0~PQX8ylCJ(u@Rf-(f>Q-_{yfrk8tWtBXt!C^MSu=lx)iF|U7g7{M zh|S=;n_&~yFk2Wyv2 z{tj9b-+%O1Dc7BxoH=8qkJKqgxl7HCzcrGJXr|1ZVexvLSNu7rz!_w<%lp5b+889i z$q}#jfoo_?g^M}vmG5ofV6{fr6D)ylhpPg*%gK?>48iu;Pk{ayI{&;&^=0@>pt18< zzv!xTPR=vsKVbDM7QHAJ^;g%REjpeegn>Ybe^4}Mgd;f9nv(xsqbJz@Cm&4W%X$n{ z3}3Jy1;+-p9d5f3*5?v5KN96c#5);(CqC3AATAj#csRoSB2V3%(Jh<8+Z$|G)}reR z4-;qmlJfAH8))s$2kZmDnH}YmVAcJbs-R>i&XKBGa{0O z5+s9y(hUeIIim!n$XPN<0}>0#6rn^y0Yy*5y=UzD(e|L* z;tkL9thMHvbFQEh=ZkLU^KAR3Jx_kma7Pl3jNC%RcTH}}2SaIeKA-O0Xo5N3gBm7# z&YA8-OmuP!4dRxX(9bj9$K~o5VlhVFKk(y$^Py#(8F5ni!KdKuKi)I)&o=iT*Q8DR zhnoBcZXPEjL}XS$g~Mb;)Caw}$&(xl(g^VB!^m+o^T;WG-DzFIWiA=P`U(N_*OJKl z0Lx)y8I0!?(`nw}pC7R2x}w*G*p$7`IM(uQ1-x!`AT~o*cGPNtec2UsEGAY-n3MJs z;XMF~R9t;Gr4Y8rjT#=8|18l9t0i}#Om83mPgWIjl}3k<70IOYNgc)L3{2oV$DL2e z4snq+Ii){~ZNp+t>2s@t#cz%s%_VHw;g>PvlD0$T+I{gBKARZ=JFt5>rhaO|HQ}IBt(UHglie5kWg@nAwMwtx6R7|G}E086C z-L22Ev9VNj)N0J(MT;Y+>33LBq*D)r-Xy5|VbM+OYvL_555X2Ub3Qui&K&>R#6Q7; zreiBtmQj3(;TCb-N;UacEAWY@bdt~9LxuIKa18if0E#M}KFh4Zl?~`pR+*Sr9RqSh zORz<2yKheVsjoCsry+VE)0$!$;SYaOB6yYzgCf#F0d#{6*0Td#UYu&iq|o%a2p zc0h!}U4L{kCWEj%x3pg8TJ=&S;CvSrDsN3#0!D7(fpIh6on=btcAN8FX@Lgk=COU%EuoI2r5xPU+J9STjj}ra0Zb4EJmTz3k4Hjk=t82#Dj|~O*M9Zog zh5mO3cUd@WDt}7G@*1z?RURLsTrnL5Z-&}#@~#Da40&8f*NMcUv8cv)rV4z<%u*Jb z)0?y9Z(Uvj{#aPT!I0@1ciUuA&4WUVeWS1(; zL!F=`m9QpSEW_$?Ss7=??am!Ta@;zc@4@Ypo#Vaeq-0DXY}yY9fQZ0uR4uL)+vSMb z(1)uKRQ)r*aJ-MtRdr+~sVB^YeI+MdX?9B%>L<1x<3dWzFrZAt0L$}cb?7gyR^Ww^-Ncd6IET5ObbVF1SU8+om;e6 z&G3T`^l57`R;@fpcHjtHGk{-3@xWwc`YrFwiAVB}edoJW!P43gR@%#6_NQrPbCsym zUFum%iRnS=AC)E?d#~xXmStIxlMZh*tD?T-D{Oa0sAs$PBCa$BhMdUJZgzs}&pwMy zW0_ZoY(K%2o9?z%))xu7yi_d3oRwKEZ8e6ko7Ed2P&x5GW9-%ChPX@@fayp&Wk1;C z$i5t7`Z4TP5)p}4QISSQp^oiUqa>sUp}(>K zyu6jn?bh2JyuePZpGb$hFW&V2er~XD0(ippx-L#@6RS3eRi6E~xtli#UK{3aSBJ}d zcbn@!LC4?!vKaEZQt@Q!7P%F_y@5_|quy&xj;cj{_Enpexk;PfoabAgrpbe~2|xe+=O5R0AS>wi&$X>w?Z1EW;|cut)A;}P zO30{i&D7dW6X{=ASZI>j?Be5}7{gEtesIpS3mv~IUq z;hZcwPja3o&9xUdr8jsqnV|nSyXdmK?$Ty<=E=>3&9YvTBQYEOUIm^jQX5#4boiF# zpWl-9>sye_f9n(8anfH3+*D8RiGD%u*(7g*>Q@#YbQ^z(*Q2aF8dv06>;L@u|1CeU$~L&ry^SW-#KtC5%c^j|^>S8WRfSqc#Uj>= z@j~a^Mt^N1v1B=z>|Z!OG1Jeyah70p;yNfDfr_gJxY6V6VURqLE;?r&<`aDgz{a2$LVc`&`FRkT1Y2rBm)iG zYnt!v3=H}Pt4=N!;;9Ym!Zg*Ef6&|?LhLy}GF6Nt;@v6AMDT03l?pe5Zzt6wlEK9U z2xttpdVI{gq1VJ)Bm|w?FLopM7~L1TG&4c0#-z(KSniUK&G5VY2R zO3FJA{DHcWLP0J9i;3mRH&yy*BH*{pTOJ|8q^i(7k)NVhqN6z|PvdMs_t)0SD`4#y z-C&_HD%?XOz>BnO2f^-Y0ooMbUb)VV)gy%qt3$I`lQNC@)on)J4RfN`nMX_{+hTx@(W`i2Py&BAUq5?)_+Y3R%mr6-Kop+gHO(H`+Ay=hWFy>FJmdc zukdX~_eszm&jg){y62nREa3hZ(t|!8JUtYwz)qkxi6=&u{_DL!LQ@&)rl>7)**c@x zD3tIEY#;rINzt6oHKbN!)R;ipxQkW0nNII19u z=Sk7pQmU=;hH##!gw*&$fs42kwr5OPCJ4WU&oD$C3&6u@?=b$B;GgQZtJ8k&b!`2Awczl=XrWB&1Ier(%oDYTCc|NQ}o5wo91N1I-+v?uY3Y@@0EcI$f@ z&qOY$ot>VIfy%xBYJ(mF)s>6ENU`eNO`fZ{2GVc4!78sR=VWb-F^QOSHbQ@$Ot3uu zuirdh4?b%Cnd1U}gr*a=@v<>B%4XI(d@?FjH(6*R^Zq$iyEhh9au;z@o2ZOCO;o$1 z7qSW8q$)3TM<DgWO7QTl}7X$YI7DAg39f58<+c%BDq z1tY`@2sj9cG3o}t_2CqT+m(U3W}!qY?zi*lmWXH9;!)pi@6BIHnQIJQku(<%447;> zxMOQmx`H@@H51Hp3y|c~X%9x>sS2ISf`)Rll>uap@4XIYG`a|(o?bTFq*snvTx@!~ zR(#OZawFVc?_%j1i|D4trH%rxvoSRkb9!OynU-NrfubkU>acHqd=&ISwx9f+45dpz zp({ft2-X~OcQ&1;oOzOjsXgks%{x`mEDzW)OuBZ-IAhOdu5H1d!|y($69BangAU>B zXBN7vS0?d0(jDTJ!-X!DnS)N+d^WJwL`kqF??t-oV`#NBel_ZdtZ6Z#7jA+}NJTE; zR|N4sPzf?DC&$;)b30QIS_d=$J2v=DF^ARw3m{3U)v|Z7O}T4N!rmb9s(;FBl(^OP z5n~1*N3Fm%n&?U7q|p?e{=BZd!FqZ6AdNnB16ZF=bti8<8-DS+BN-8GMoV7ksWl@F zwWyk$p|>r@KPWaoo`Lm4K)sweUV9=+n=B`A>sbm6HAzcrHs)+O5r1)npH5%(l)42TbrcoarO4ItY*F@ z5SonECES;D)@$p{Uh>wY9Iq2W>9sSt1gl5C(W_*S5i%(fGZpCZmwjd0!iSwZE)eX5 zL9ZL!S7wfx^wIWuBSMf!|5|;0%3J#?% zUC~?%#zPS+5{NCG(^zHenLPq5bw@(-K9l0U8ZAXCCl(cR+rpz$@)woRtrO0f_>XO5 zJu5G_cUo?wgiUUZ_hakD0bxDp*^_uitClZ=#mW+uoK~*=ll@;N`^Qp9k(|HRnm5A~ z_N?%$hau#j^!Pwk(kPT|B+Ncmmu&TelHglo7AZ3PClAV@+=u;!8v?UiCAr_twBB4d zh`57(S1Mb%98OuC&c`ZTml4HEY(lemgCtnNKvUwjWJ?6ADwe0hE*?Q-sZENME>dvW z_3&Yfnj$6?0Dof@m~nKF_{bdarfiGdx{ijR zD^J3ho#&KDzrZl{BiP*g<)N(T%13cyNx_nqhCVklhdTMPof;2v4M2?YXf^AJpC-4x zxlTe)H~V>-=1zSWce$9xPZBzr8s|{4jL>$xoUr|^j^|A;*1zQkw)gSd; zZ;Kw+SuAh>{ zh-Jm*B3ZGgH?LEx6qof^wCC2DYPC8msNK<@HCf+t4;N%HwE>cAqZLQ|X4u8c7*KkK z1*Nqj`|?{<4Ft2HbinMZkjWY8z>8rFUIdy*N&CnIk0YW1oiB_-<`vWm^bnUyTW6&} ziP{CQi$a205z;7JFSV^Hm&0XGIlBoKlJYu9n}Jc_zZP04SH(fP@*sm=tnO|E5KoS{ zEf4095P4ZU!7&_Q(xccD8F-P;?z|MplL`$ZoOs;ytH!{KDBGaYUAw$vI{5$sG`T9R zhqnrzCrD(t6sl!0!x3`91D}^dExk)9?Y4^Vd2NtyNf9H>1JM2?A!<;ujv#_I4J8^& z>b4sItgJYDX@TdHpiwW;G7R1wj7ws*G)&Ty+qLlJy<#81S}x9>Z@MHH*V@-pAQb_v zO0`wCxRgmJcmgxdO*0`8(uc7EJ=rs+7`K+#;uGde_U78Mb)EMve*2*F_u`BV!F^~9 zU6!;IJGDVI`Xj0)#Ec$WcLqBw%Zd*WDS9%wTtocDM7*YpNp;jbJsH-RpT;SIQFmU0 zHjNH6By=Mc&a5XZ5D`=~oiSu`;wOCgwNjw$~4!t&Nf#a`|`n>O{@>4XLbQx1(j0h^`f+7oE* zZ}b-k&d_LfVoWuz!8h>Fzq5 zJ*)Hhb3HZgMadO&xKf*mi2UQR1k#I z_kPQVbfsbXm>N&?K#0@k0b4U8$o2Oau>B zW&&DBik|-gTPSmC9AUPk;7@*Y*ZN^N(gF&{a(~|V+aYWl5`eu<*)m4e@SH&(GphV2 zOyl?O6$FNl*ZJ;a!hvB%w`9(Ze0J+Ru*UP+_lRHS{odP81?cp!GBRbL#lBQj7gFsJ z=KFL?-J7I&Uc|6_TC&@n&LxujjnC-%9LE>=3d5*^>x+EnMZ2ZDE)u#_nCnCK48Cw9 zMf;_`Xraxeuq8EG5ufi6U=wKNI&JzMPf+Q=xkN`hW#u1ND>09}!SUV2K#p@oR6IkB zBoT~Fkqs@Hz6u>$rx&QflD{AYjAAwQmwg1ehoh8G1fSdNw^|dc!pI+Ywej0stx@rt zm%1C3-{zgS8x2U%u6xzPYbl83rsL{5P?w?rnq4@KwG!I5Vn&@LX8ri_((s$P5?CG$iSk7Sy2<+ou> zmTiKe=Ss=~XxH)G$RIf3BcVc52b#zC1<*MyT=MxOYnDC9gP3v6X3UyZj$KVh1`yTi zl(^@9K$dk{V&rW1s}iaAZ%w%w5aT}*2PuBhYheoIma$kSbT5#;v?vh6{rMQ(^q2@b zkohc(AX_H&rjKi|H4P{`j`CA1-@(S9-@hoEbv^Cxkaee4fi4%OzB7R>1+$>q^|V@o zZc|2#AqwUT=w;Z1zPulI^7`kJ$4e^^9H<7u-+=6Sq1&@~4@u&Dx>)Mm*t(e1iC0!e zi#>gh>v|ITP$`oM`5g}k+Fq9~`qPdI>^dw1DJAv!Nhs}lEU2lvVU&ALf#;dDXZ~_c z|3)w<#8-48&-xxbYoGG%P&EuVJYqVgNn9`XGAykNPpNC~a#HS`AF;|yE!OV{&Tfbv za#EJ6)qH6@qF=NWsAw&I{U1Z&-x*FX&2QTScA5pM` z1g~0MB7Xqh;`|h(B}81dr6DV{B&Pd^AoDpF-)Fd}wz%5L?|--7v@LfuHhrGBSzuEP zyFN@ZJAc0=RgY!$MaR&(vO?8{O3V@kf)w5WA&YAwh(U0f=QwQDGBYfhM|4Ba4ukaQ zy9?i2o@k}$@MjC1d{iDchAXtHy&%-+7EzI!KMfFr^Y~2{d|uC^@L07fNnWdB*rO;< zgNxq;jXH@_n4mQ2o>#GISpg^qy%9aR(v%lX2L0M~d}OEi1wwnG+miP2&TLfEm8^pu z1v~^7ICCBAHO*&)q=WLQAM$y`pwwZ4Qb)#Z$@AtITSmeV`uL;U-ABg1uSL~#N~`={jb?8vs%=a#U>6{PjPJX@&@UDBD>3+aEZ)6P=Msb}gr+1wmO7tVK9J3AsAgup5c9Z#(YJJ~Ym2)HxyZ%tp@}6) z=6WBTWPM~~3rh8)Q#=0*6;;Lds6t_^UY(*eHWQd17?(E{FbI!weUT+iY$)47J?lC| z)eA*TiXrv)vo=d4Q4NATj8f_9sEHszOQ7|(Yr%y(Dpkj_RmvuGZ&6yP!GcTS&$xE{ z5o^Q3<(^t@(X44NW~FrGeDBoHEU~u8E0D`lY~s(tG_OjiSvY_u;%mDuodwx8xPduPnoZ?NeCRK5Oh+!^c zrQ29yxen?J``m8Bsu88-RRl>W;yi1blmHbbeR(R!sXp?DvX=BXFym`qs%Liam^i9V z{7hMWMvZ>oU@9<@J>|zzx*%Q`h|Q~PF`ZoErhC*g+2rec_)POcKK_`6gZOP6tSNbY zIfm#=N;zlDeR1q6`h5FI7v-+@dlmse;mF<>!ISjkEPXr4F9ZhAcdRQ`?6cQC$ZpsK zh$>N{YG_){0_OGOvzBTYmdi=|IGXdLzqFyZ;@GGs!@#j4#~ScU{n#xE$Cul6536X@ z)fO~}q;TH*;VmtmVQg;rY{Rz8Q@j4$^`b+5cTRrx%qeuXx^o|*p@14M6Fevi`z>Ql zBQ?bjlrAo^oUHiK=jXfPV$JE>PrfR?`r^yE{-wkb=J^^1))l3b^D7Jit_P0QB3mZG z!_fyTT3L#GRP;RMhFH z`VV4Yc~aVVxMJ#r8fM)F+T6XkbB6aXr00oGIHnEgzVp`uQtohr1x&}~PL9&VGDjpt5k_Ob6zGTe+<_y}N& zqK_>dR|Vh|`n^^7m2nrxn&-3N4v*TUAxENk(gVF+cv_ERTjo}fqDJo8Sz%%T3BE)( za|bK)VP+N-X7zzV&uf>U7U}w!#?iFrVz;uC$(2;w+`_wkm9R<5h824TpHEW6m==y- zZgg}+)psI1oi`oaaJ;Hf*NC4zOlFRzcTN(ddkXDmgb97Ed{9}-hHl7njOr0wQG;ri z-hu}?<|Tk7rM;MLH*)MtTq5zBjBOf?$ZfNpMEfk}fKQRpS)72A*pY`eQy&2Z24Nog zXKpyyy>#=XPuLe?Opgs0mFO3puH3D~oXsUgQ>{K0o|IqxL5EzmPHh(~;t>qbDzbW; zdU6Z>{jH*k0`$C8Y7-b@o`XwXX8QuafxSt}NLqe#4S3lp1gGS8LS6)xJWk<|v@v1z z?AnM>N2p2USLMOakU)c-vV#gHB^&DihvPTxe_~7(R76$43c{rW)1(Rq_O+942l7Wh z>}n7Vt0>A3G#OZ&=!t;D_(J@`%_@%+R<{Gkxk7X>b&Vp3d2ngAaYS=A(7%q5b!-ht zopx1-ZYCau1p_Qa{JVVh7d0;WisPa}yIJi9iJRxpQVILaur2)72|c;*A)m)Nq3YGE zC@^RMJ)7Zrvm)q`A~h8Gy{7S`6Zva;I`^P*8hmIML)5M7;kTG`8OI+StVjKd92pFz zmSPY&1am9g$ii?N3sYcLbbiArX9yiVUtxMUvZlkB{&pSz|DrPLPn_C4*u~(icECh0 z#5ve(9#KhM7)@R^z%Zq%wsHqonR;MSlh53G%-EQRkzw764W3J6ugx+oT?wSOi!}6m zkLz>uBQx;yb5k*I3FWvnyr*A< zuEFPLaS44avG2DWc>mrf2(d-WXL1GPQo0C!O9Wf#+OuE|l356tGO)7Tjh7M1qel7y zjIcXmMdX+-fIJeq;Q?$s2oC5LelYZxmO)}xzHp+G9fm}2nj`jXe5 z^G#DUzhp>8FzJ742g7O$&^*D0im;8(`gx@xC#S?fh6g`~aMHi|-0UnmGL_G7TlDRJ zvC3(dTekRJVPO8Iox7{!mC@Zw=g33fb5x_)86Cq>+6zNLDigwtJJO%7qENn@a}iZX ztOCxjd%{9T^bNWQJCTib5KubB^mCAyuaX}3^8UrE^U&@~{rLOUxTPzr55sUSu-9Rt z76CMJ?!#V@_F-7&)tD|qJHH4prE(W<)_ps-?7d?zA3>IS_4A`=K)@J0Q=v_c!1 z%!WTpXS;zq&J5LUyFPB!Rp~Y+f{_6vZ0KFjnIe7622?ghz1$HBMamA8@SU$# zUxXBpz$i%1`URKasW{nd;)by*vZF{RW!+rc?Z9h_C!(ZSsDl#T12`eTF;1GNCQgx9 zEn*O+QZEikB8fWu*j&-SL|Kx`v5Ca5h$nqKB_!`(`@J&}TZqB>@zW(ClqdKrtu9I< z=pxwhdS*HaicQ}@y*W}U1V@_+5fLMH&pp7o+^4^P$!WhpDj{v!z$@0O?uaRF@Z2(Pi4IwEo9qR(>i$cA+IXL9rooybz3L2Mn zHgJ@6L8wZaNS?ouii2Zb=*0CURl-$=X0vP(527!LFEJP@n_h`avDHORGj&p#W|gNX zP~T=V7MOBLh~d-f8V?*F$9u!mv{TIXgdSZr;!q2hbT`rnadD^D2{Urb(k`N zv{hvaep{qN-%o+PpJT32skC&O%dUqE$1qbp81FGC9Cfx}B_SxsG^wJf#&H`pM26CR zG!0z|Q&B8})440wB>;BjgCpNl>rJaD*N2Ej=GyWsj;1R@etZ+K$U}X8fonM`PYrx9G+fgQkO0!} z_mHjZjl3hzy?ZI-@-6KN<2$!T0nhjFVeDJV)0k(u_XU@Qubes|_woH<#Y*Gos_%~~AH?3#3-di75KO$)Ce`y?Uh^C)JL+$>yKB${Cry2GT%gWbkdNj7 z{4`nSNo_e=FW%LaKKh|uC~4QyxZ_3b=eeT(v5%_`8EyCSsdHI9E7kof_Z3Cu&<9MW zMp~Svqf+gg>wf-)4#@NzB#O_k9hQKHy!7|>6y?p>IA3k|OR^YY7~ECtU{RX4BJjYS z=k3qR<2(afv&T4l_z)Ej;-LpM=6F79g$fH+98^$g44*~B-maa&%CgNn7!%#*b&^L$e5yZdkv<~46y zD*rG$yL$)k<(S+!k1!95$4t;l*U0Ene+<@S_>c6fZ4! zPOSAAQW4{)31PDdZfl@1JA>v#zw1vtSrM6HCWRf)eWW>hSRzK!)L*`RDBql3Kpy8- zRPbNigQ^RAdFN?G{>wX!@cJN@iftQ_X1JsHp53tS#Q44_}h!d-XUqgM_8?`L}E$Ug|u(xqi=Za5$6pHDVH9IsDZY&ieNYvQ@tEtIr;; zHbl_5`ACKwS7dbz=Ji;GX~QPzI+eb$BgCVwO%bmV)NK_RAtAVxAN(+k3d5j=Ng@n6d0~Z!OnauDmgX}F- ziL~(9cze7}ti7$$A0cd-ft*X{N;D_mQ}GJ-`GQ-9?*kvgxKmU~(3EAlh1hVA<}#{e zY=Y-woNK@`k=MF&OU}yuF&iq1%7D9rawx~Fv*+&s?f{{i3=8Kw(xK`CdvcXnA#v>fs z>Wo<2gD*H+(6gst6FT`HWH=&l*8GBFIGj|BJE&!rL@t z%Akltx7ST3NjD+zd4=_rY^PGtgG^d+mdvyGL9F;&0_U!Zc8Xm}6UX}LL#}dX+PFft zY*%1M5de`UK#57oj_FiglBg?r9^tX|wChFRNh?BrnYa;KkVjOVxeJJ^o`h*mfqtJo z_<)zJYj@n%uDZ7wB1EPn>);BDmFtavKIBdPq_Tsn>(s=4p7C7K0c#*~1d92k-M8Ms z>gDY5i;H~@u@XpAClPOTg#tC0Z|0Xd*{oAFPq`zTKWkd zetP`Kk*wDqASrs+t8+D=(+SYdV;2sE_Q2xE?LB_`lTg$S<wN;6a6~@$p{tq83TnrSMY^*`bHiHk}W|oKNQWgPmF7iBlN7g!IGr^^t+icKxYwB zPXT~hsuI-4%sb#XiJfqd_5BeGp4Q%{-$6t9<=R8DKa`h)Gb83<(+7rYO;nE0LP`^7 zu^vFAo;-SqDzA01UefY`4Kt5UznOa=gSyfOn_l~_+d-!Z_)bd+9}JaVZN2!cO~gY@ z8Ewa!6|*^KCLeBBE#y{g9J>u35JBBOFcLV;q0j7ROD?MAa~4{K+P2E z36*vrBd!W*YrSnBexyHpnPNou(=-SCitcw5xq~?dUh_(fEBeRh7J>Fmcy3sGhbJKpb z7lZjYzOP|Kok5}ZoZ&5%?++r|j91SXmjW_TT;G-vE9F&LfjvRaVRS%y(vedVD3Udi za^jBT4&h>q5*=)DgkY-d!o;)}fIzY=F+g8#b$K^PEHcYVNMnobIiM@+{9_G2nrU&M zXV$gwkqCFI%rI|QYgx#pM7`#I9MU<8?Xzzt!h#kXPB~!}hfT6wIsI4b2P;<^cT{Ug zA783bF2IS^8UK~FVDl#!@Y_%S6e-!AM<3Ji-fM^eQ+3HsRNi>++_U)^qFs<3AUePwOw)<_ioZPWaZ$S+&6V&cth4oaWd%QL-jy)uM4grJVinZr|u6FY1d)JmB1SD{I- z%;QYI{z%4vPBSI?Okh*w(^_-39W+f7zwz6^59m+YdzbV})vXVNPdTB6@sansEGD?R znAFVpm@^o{PAByb%~ z4y1CK@|i9nQ=jdS;lLy*WA(ZtJd%`SG+Ot69b@oi7fx-90!?d?PZcZT7>p#51*K%< z{gxu+mBGh&8_oS$uN$k7?oE-A7(W}Z9|p5-vZj51Fu+G|x|b=zWj_s%IHD5z`SIq6 zFjCsiZzhe00-(s}Bj{?0;j~`JK=LimG4pW6w)+tDYMGSeK26cN)Gby@@$C%2O6;NukNdUZ zX=tinD_*;gNEo~CC4U+M%nSFCC9=r3lwk%ugP-u#-#Dh0v)pV-W7LNz;{NsFDe-lL zv{^X!&_k|=0+V<6=JO$7n;hV*`PVyC`jsVvNNT4w@BmqH6%d8R4@vl!&p@Ly6U3QC zz>3h8K!l>z$^Vr*kmg!3QgAx{;3C156<%=No_@N=7ZzVJQZsf#Fw}*SaaOa!EFAM? zumQUPx9b&iUw$Zl3t2_gkC59{+Q!*Q(oYEBO;Y=dA_a1F#@d#%!q2DW%^1%8HEq7XV$UWbnQ8X>?eTB@FaLir@%ni` z|Nb?9{!jZCmCGi=1afR-G2OH@7gVC)z7Qg`s4rF*Yw~1g1^1yf9oy(CkvZ6JhEBgIM-g% zy~Mkj{q2wExi-YjcCjU?4Rd1e<{~~oF8Vq?NbaZZVxX8u=s#O@|8m3%DRA7}1P{v1 zm8yV?H*+Ve*5!y1n{~O|^U4 zmT+M2J*9RWe&~F+0hLhY&uHngNirE*&|9m2T&shg=;xd0A1l1((f{`k7`0@}$?E6e zT(F$r`lb8$`3w{PVGjTIPiX$@p2q(PqxX$neWG1N&-MkI!$5QSCAD}2bWQu)d*H1F zAQp>T5{jOx)v5Hwgh7bJePT%OL@?H(5}BV={T{-bciQ1=Mz3k<>QAwkmV;)6fyO}R z_weVf?F@6eh)5oljg-NxzB?kHAx3wOAf-lvZ2##&Psm@oOMx?!E7*Gl&jMd<#PIKY zwY=h36L016Th>2`>NOd8i2oIm2_f_ZH&&z9i0%PtaKN`-K0OH6-1&;eHQc@1X?GjN z1)03o4gC{;IW1a5KZ7{-RiO^yVaADDLvZS7z1lHD^VU&0i@A|2tZ^mnQQ|NV!TR62 zsuCE+SH>`&$Q_AdpxuBuiU^L&^k`e16L|l150*U+$aI~0{^;GxP}^-uQz`0nKH^Kb zd*k!Xr-6{0Td@Ko&GBpXlW$+fLaq_q$Q3>GSf_hgX03wSt8b%EU`T~B ztsxtBchWlJPrhd^j1ve7w`qk38~wY>ZS`LHg^zh(TSOSnsc=-uS90a(Fsw$0TU0LP zDJ~0d>&85DpyLIsz3S@XO}foGWDO-`uArl0uev$CnK|kZ9CF#Zj(YA*=jWC2AtLaf zYR|T4a;=HWM{z|7(CGwsv_=W6dltYlV|uo?`lp^Pja_UP%491Y;09$6w1fUR$C4Bx z&u$B~cqWGU{33%gWKbV=jPrJ}sfCpXsXI3prj8)f^5S4VQ!nm8G`V80>Tt1PTThjg zu!1wf&`bF^2;ILwXq66ot|3;B#0q{+H8Uy3Cr6Q%X1|SE>R*iB zl%(^(lD>4no?t*g{2xIkGgj=MKS}ky1FoT7;4xW>S}lcv9|d4!reQ7Eff!!=XuOt5 z#0yOd{bh9xcafwgw1L)dEfledEea~IibouL3hLFfn-S7=Apv|xc(|eNFq)!06v;0A ziEqyj|i0a&VqE4glOKd*7FB^C6f>x+avo5&x4Cm1V1A9qwe* zDH@6LJfSA@Z_vHlwX5-PR&yjN^JUEJ3V@Aw18#P^CjVn8hr783!Vzg>(Cf%Z)HC0BU4Vd{T=3!-B-zv z!vo*~QT=xH`YhMa__0-m{UtVxS8@DA8wU1v?evECt}4F!q%YNX+wRXDi3vmguLl`6 z6c#$}`UH2l+f9UFt7#?6#?XS&e~Wb`Q$%A3i@S`~Tm(-y$cE8ab2n-{^NKn55kK_Z z?w!|^x~WPP*M3jv)_i+~U)a}`9UVwgwXtEuU9vv~N>H2^ z{aGnYJIjWO6z6hdom<`7sp~V>{VWI)!W5o2L|YiPdk&0djb{JhgwbTY^h8}gYn()@ z?pkxWf}z(1Fe;14^I@~9d?Kg68Smh)wp5U9^OXJ|7FM6GM5%bADQWJuqcrWF&O06J zZ1?A>bDr6<`>()==&G$+il|(jpTTY0@o0DW*(E6-PnwiQaLEfO#HXoHZ+g0ON?i&` zd^GC$DAet9VEWJl3n|7$Y6B@LBtO6L;^Ad-UTcO8ct>$cu2e+YXHHz5!7Uq)CQZtfCu1 zw0Qz2Z+5wz>@J{Wj(}PdX*C|zwNsiq{&JB`@E$7|;#g|N%zSn;tAr4JhT-2h`JabX z%Rl8pF{miqve2S?nw&TMK{~k{3p!G-eR&HZ-+Q1(rn3dd_J(_;yf#@Jyp7`=9#`2} zPn*jy$BhP6htu~gO~P7ZGHFU(M(GgcheQ=c%>vC!cE3uT=dEX-G2RBabL%g;pN#BM zIefpx1rG+I`Z?)Ji!0Db4P>N`t?y;i#Bq>9MpWDuLQGZ6n%4GQi>s6uA;PR1o}!^! zWyv0x{d?7rciu8L1AlxWHpxP(vF_OySDZz4Sd-`pTk6@;bhs?IwW$lkX z2?+UQH6_kTd~uhuoAMOC4G3a0rJ^Bib1*tIAoELdw%d+~mF8P4SM*o;ZSYrjmt8MV zK1j1U8|8}A%Tt%XHVr-?w1H!1w6 zt{YTTTLz1YJd5(%*OM7I3F>5z3NzXw#8-%@a7cuDl0zcuq5Y>9u#9=w(s&6m>*H(0 zXn1Hn@6l%&8kKU$h>Gd#SUYD@xtCN_1l^o z;{q7Zo>c9^5uWYUW(oN#U!F>%Ob$x!nO?BW`@z@9BV8t z9-xOHi$!;WOWOeK-HKN`3ZjnE>5qLd_qfQ+8F@7$Z>hB;)eE-Hg6m7G8Q3R&mSouz zns*oa1hSoC!}EwWIortmW(j!5Y*|VuW@d00U)nVjll*uC37@>(&m(eitQ#*Z49YaR zgz@g6$NoYdI!Z$mjW}Na{1`o(-zUu;`g-bp{@xjmBcTdONzW@w52)o{dD8)j(h=R7 zCae4^*(=#XydjvENzd)iGF~U{`YJX&XkYKdCDEsUAWta-SCUm=CKKw@ni*8^c*rXJ z^jC{r<&8t)UA&;>dUP*vv;rKYbR|@6fY4%%g=rWRj(#mK9Nf zGC91Fd3>MpLc-=sLM@J^r1Hh_YBFm=6idjSKPhGIGwjYo%Ab#e@i^iR_4AV>v+LQY z7)n(QWS0IGI|CV(8znizGWcG7- zyDDWXf|!E}UES?7TDVsuarep&^0#eQB^0sL zBmR_Tm*>%xxU!TNQ@f=7qK-}`$opwDu^$u7EYVpUGIXDn&DFSMD&i4cdDU@BJ+sZU z?ty1wV&j>h^dJn)v;@u`=SqDLxx)W>i)Kc8Aw6Oo&rdG9ApVZ*K+gG-pI#ulbZj?| z!)W{Cx!<|M{WPA-DEf99{TAXtSekB1{QYD z#!9xyGo2x?GA);@M>l8nbGuKhI_(EA?Bbm*mj$K&hlH8tS{N;w;d`8~FE_*V3e2w7 zMd+JsB;Jn=z5(gUI-`)NE8w%$JOsnero#}qf-BSZ0D(R+E~fR!Bgd+BpFN%G>ke{^ z@~mj%Fh~-^c6Bi*O=fQNyz(9ZGO2&2C7p%s79se(C7xnQDbSTs^?CI4PM;US!y$_< z`rzJ|;H&k)5ppG99L^8I&J4(O5M6@5TT~MgCy5m`FHgoc$vBA)fs<-~InLz-U0wyq z#cDN5^1#TPeNWuRq%bdoC!$m9es-!urGhbWE931iSKv0=K-5b`-A!u^`1&jOva{$M z?#prZcNrfw+p|h}y4TE_?f05%;36$v6lk*VVy0Fg4&{}`^|6n`^gI!aG3E(9#Fts7!bB8bx5z)-g}$rzk@& z)q?n>$I7!Z@}{;Py5HiydV0$w+J1&bqHr-35k+sF-p{^U@5CD7PF9Zu3vXa58sOdw zPY+zkzJ2a0Dg34B_H>`R8;9#qRYcAGR5CWT&{XLgzO7>0*TD3^Z25C#@#ARIHW#ew zxK5qU*ir9sB-SKRV}XO&_nmb^1L-anXErGvTxJip!N1U?KUM`*_fd~8m zZ?nd`7wt@wHJ;cKgqKyc>~Jyd_2P;7n!Oes>dDPo3KLFkky?M-qv|vI`50~}-^#T# zmY6JF)4Kk6S$G89`PH*yR40NsoOn`Tf~%*&J)Fm3x%m8?&#sl(zi2-{^8(uT@4cc6 z=bkjGzmV9H+}~qg&IdpB#;;~it?B4c^1b;%Agf`tV&OQ_&U(!s(-^@@SHH0DveUh) z0H*uVnlU}Eil&z;S+}$oMRylUw_SxT zXZpZk3Jx$v8Fb}(az&%es2TA0y7-7r_f|m?SC)l!@VGrTVfN~!#i!JDOyj}VC9#cN z77=D8oUJ;mr$6G=qbtED1ITL*HuAlO%SXXS1Tm-lS7>?55Q}go%KS~9GB`fC-Khe@!iHm<=Nz_dTZr%b!|7HR)SQAsrYU041Fa~| zzMF^wEmtGUkjnE~kBjxPMa+^r@p8NBqoQ@{zdA#Gw{oj1G3G5KKo z{Swa5onB~Vc`BjaW@;@50lN17(|Ge{j7Yq}9lHVIOw#3t-0^HMzJb+}8TIL2+OFPo z!^n$X#o)N*Hu+Atx!bPluJDbcKHQo@E#;lOz-^!Z4adr$sPj^VFjA6saP zTx+a0!Qv~Ruy3juPkYoa%~N9X>qG&h2*k+y*QS3m*Nd>k2M*$zQTJF5?5$hcmv-xh_GhMjN!`#v;;+bDa6b{D%!t_PM25lE0{UE z{?=w})0FY1qx`r5T5Z)a+m}N|igr6Iv?q)4vis)gtqg1=?eb?RvepDt{LztHZ9KzW z^0#QdL&QnXFez)MJVvhXhncZ3C$l@RTeiz?@5AMufLZ2tt&${-wUhas?3N+pyUX2j zznrj+V4|H8yKREU7U!$$+Uk% zTDSp|v+lDM@c^gTup!YfBSJDp#0%u#kb8#gQ4Z$rK}KRg;pq-c1QM|hNauo`?fsIX z*dapmg;V|g<8^ywL8{WZ&5sw5$eGyiKrx4Zj-6a z`P-+E;jgrhJ7#{k>^>;#Pf?86C26kV>b1uh>3HKN8dq(4PwHuDOcK;LN#QG-$Zjlr z2=PkxxAyu4Zq$OQNJg@ZwOs--$_1{>fkY;et2P$0YjJamWv$oA1c$0?p8D(K^B6#s0d^8@*NCE| ze_w&58ufc>tt4|J?yH?lHC;h!H2&PKoo%kyv)3^Mjt2q?)rCRQ!>vtTtd(mme0kVC)u*d*R zNTfl!^3-}=Mv+dwi|)1AW}49&@T`jv`U_B4sDp@qc;VgFO+a8guvg!J*)C98YFN}i z=HN>R?r8Ul;YhJ@*g--QT<*wJE7sFAOL$LjsJ-Ahnpmrmb3O003r!L46IRP?9Zrtn z-g(_(tRyuDb7EG&g{Op^k}P-R0WFJZR~{j@P`aS#o_q|kWL?tCmt1H zVJCgzr5*5?iMwX_$sXpK^oWG*M$VXSa{1XO%ZQm8k#dU)zW$Ea6A11Ut?;}J*a0dp za}x4gUI{h02JwDu@iY}+yt3s{d$Dq++H&b2H+yoF{g-UjM_F8Sic9IAbTzt)=tj8L zd$ScRa2zA-34d7l1P0Hf2wxF}`Z=tiA)SfK9E<^eE-T|9!tNR6a-8I7eILU-M$RUFrM}V05Bhsy!a7^1Z&gv+@j8;_-|g zwXlNWxMFXH@Sj2!h~0dKP6SOa`vv92sbVs~IZwUlD5%ND<%4xD3DVoT?s9alLiYm{ zW|Rj6`xIBJ?bwcoAn0wOu15*rCh&PhdblGF_}L9z`&R+Ecd9W#}#+)P0cf5npW$Es~KD+6a#xr8+a81O{8BRL$ZhP-#E{m8Qn@ji_ z`vxV23t4ChM-nN%vdjAmwVbRVS$$&$W}aF69QyLVtm|EF zVxV$5`$sCNe!s*D!bxu>w|Zy6dlKI89FB~`nhc!rl0ktzm%+fuo6=@ z*E_3v5aM!;5tLKgoHe52ltmC~*xoIQ6uyxOXSE6U`s8RC6TD3~n`id?)R+;-zYG-V zHIqX8u>-+QB3+fZ5XrgTaI5;;2LO@&Q;nGl9TXTif)tGFen@_9-RzHAe9Dk4m^d&D zg3)meKLT)o@3o6Jwv@-9)z@9Za*59;>bs)digqMrWHEZD7-< z*3LOMP`JC8BWc>iqspWsfiqC#GBv`AEAPDit9sm{X%vL8cEpj|&n<`6wgCcHEEv&m zVXoDZP@w6(Cobl9_>#vpTs;J0KS|-f8%;Uosa#~ zYL%sjI^e_u^>;%8{2vvG`WcmYoQg;*=n{ZS$>-t|ChjzaxHVag%!!{WlZ>I-lK|XA zA5VQIkNiMI(gnX0HrDAj11g_)E*Tn6cL6TED_K=*`exo2_jXn#A9W$uF^+_wDC%kV z4_bsT*+AJRQIbPAU~I#SqXq5_NN_%jZz~^D)?!%pwYO9gSs2w;+7&dbe~s5Ll=X|d z-k6I@o`DI{T5c^=sT35ut;B$nk%7Ng^X}6Ht88-w#*&AAh&igvqeP=`C_Y(SW0Blr z+EhYyo>N!7V-4F7O~H$`HS3u5Z6sU#}hC*I9twd*Yz{x z!8MvH;KUX>+?BGVKpOgr{D_nbhndev(36CSW}}t1yU83SiPs0*I=~r5gmAS1eZ39W5o1t1G@|1Bmbe7b0A8%cVZ1+XZu!r`zu__R%KlV~$2$>+oPWzc{T!9RxbtJ#a#V zIBY#vDKRTVRZi1?qThbqzvZ}FgvYIw{>Zi@uX-4A@OH|A<8#aw5uEOSIMmi{W*hMh z;sEm(Zh|wpa^mafqHS+x<6A0e?w*G!g}}k^u6zpq9E) z)#4+Iew!kz?pgZSd+)4SN!O-(>RiH`RG#`u%eK;vN>ntODbuIT(sM~Bl8i!|8yXcngr?a%R#W@3g!n zn-n_$82peMXF>ch{*Bo7XT<{2B?U((wv<=QZsy4ky*V^j$>Rt7sO|5TKJdvR4mu_< ztjgrgfVCe;Y~G`sZ$DT!h^Yf5sIwt{aBmsRKU4X6XGFS!;X1_aWuoN@Auy90R}P(99JA|$`92# zq2Q3Zs(^#k#KmfhSTzr<)wk;$Cyaa3<<|Gdxt6|r`vyQw0iE)fzadTI zY2h^ISlfskzEbbbtrmi3kseQ|pOnm=yVnR{eE#J3Hz?kRMLEwLJqjp3)W6=CZ_~R9 zEP;UWE(U^|JF&ALPF zGQVEnsSjLg_xOFi*wsWwmp?=9@wTZcHE+`ZWMSup6m z^4I1%mz|AF12*+2d1n@GiAYc{)jan1UjSkEk1HT76O5@4@ZfnSsEn8Uk`sI4ggT;D zpHJ{fU%k|xyYu(I5w{ZV)d$d}!v?>opmf!K0Zo70>YuIrQ>Y+({72D$3*#TV^II1G zZFc8hA{=o4f1Jc`A*cUe>O%Vvt2auiApu6bf4$z~rmIZ6Hj!2ec{c0oM4?Asrt@%l zu_*vWfhZ=X0<`S>XE92H!pK^6cHY_8ur_)#e*3-b zaX2vd*fu8W%4mQi@68Gf(7aKknhOLi+F)Zgc ziAAsUyyuC8dh5W*{WL_uu@hzQ>&NqB+e|TPjWA5M#6j5?adv%7Q2m=ChJs)Hezkw= zoBO@~|7gSCW$`~p8-9iBKVJWz0Q&bn_$Pq=Eeq(M0Q$Eu{tpA_UD^}KSQ2RVC1yoO zxi+C~?X@aL=k2X^*fB36*y zY<^B70Js-Mk4Qkm?Yi()6c1O$W??q00lM!$`u?@(7^pLwECbM|J|75HmtYjy-#$@v z?}t_T_G#Ud-?F+{1mzDdkMDKtzu)hF@M>_vV2m0p>=+`Fr@2fBU zZ~))#TqM7X_b>g_Nu;5GM4K=GHZTsbf<9mtiTMiPS;$YtZl?Nejg0da*7TtPt2=0Y zHOfo<#><=ji^;9wK4?t2uJaI@#?=%3Rx;y^j-MlIp95gw1w#N=z=8IgEqEU2;>aW& zsVty=G~$Qm=vqA8;~e@t@?JYj}(HDO4ci3V&hEG5Vy|m0wBoq zbseh%Xq<`4jS=v%o$*lqI&sgtGIO5bEe5|eQ^K$F{Xw2j$di(TmaF$CA#oMM9pMOa zA{`HOOy|A=I0avVkKzY}rMidg8&Z1h)~@c1)#+ZPPV83$k? z%67=MMrGI$^rSXsu;qS|L#m*h=6g;JEgpz`9Pg~YHK1dBqyW88I8&L!Wzpt00&vgX z4)yr5037vaBXK*JL!MdXY5ztWZ{Y@D@Rh8aZ7KLZwCW^yXd zNRl2ZQZ7yT5w#S_0VxE*w^Vc#KuGzV_;7y2RG<=S<$kIKFt6#c?lg1 zQ$79SCFqNyA-w|GCX@dOIJgKoQ!5;F1~tB57`#!Sb4u4GHY$&DvK*UMLPx1&qsf)H z8*ZFUO6cR+a>*z005V9v7$2Ly)2Hs(|P(Q$eo9&~l_ zK#%6jmnwkY%DA-3MHTXN(dFAuYOtiJhWR}rkp5!4-I>P?@6#Cy$rAu>tLwl6z=K9h z5heXPt6;f$uykU*TaM9{WcXx)$(f7qZ!6W~0Tg%OI1As4WzfYLt~F5EUcFL>_95bN zc_JZ7X>))RfpsG)E3(0Dw{}6ScpLPM$yCT$16=c}Ngrs9vijvN&T*5Nn4HO+a^O** ziCQ%A;$VtBY21W}ytvSeyvE7l2{61DJ)f2&U&acnDJKU zIkK^#=jm-QdRr?KJx3lB(Eg0)e!}M$e5b}e#=R_ix?EBn+9YHnf;&)w69$I?SozFn z(xNWy!$i2p18zsWNUXk%7nxgC*?NP@V7qT;{?QEw8p-wCPKS?L4n*a>c9%P*(2&=T zOMW(2p%weE)n7^Ky8d0KnYw!bDXu=sE}>2e6`%Zh_WoV|&yUJeB$S`M;0Wl_MVb-J zuBj8o$ksASBp1^bK#l0wM04!bupRs8y=55%t*$BG!vWf}0n`)CWz_(LL2h)+Vqd4r z#rsD=XBd`#^46fb&ZATGk{NyxbQ7I~4#DTGrxQ74)S!5kU36ua=!~QfSZA>7Um4ge zDM?i!Q)ZJ>U8GpdhBMAl?Vpfzy;tNAqFNB`Q>FPv)%Z0QK(>O+f?M*B5mfcBKf=v` zo&RS}cekWt%8*gLe=R&gN3+l~nac5)et2j>8H|(Zxc?KU@=XUsuPc~hFLp;!)2YXw ze9`BwH3FEBb3E2Qw_1liH)rQ`?E$1{`+!M(OSgyPRMZz-)!WeuXW0?Ze_}@BaV`Mi#$Gs%kX2L<6=*QaJ?bcd@ z2LM;iZBOJbfc$*FY}%l&?vGBArE=vSjX8On)Es-Ui&+|7aOXhusX)B@o-`P5tVt4? zOqMwJjGX(HYI0a4bmK&)Um{-6{_<>19q(#S!Eu@f-@BwC1^#;+lbDZ7sJYDzs*nSi z_(Fp`wBOs{>Xusm_6%OCBVscvvvm=e&r&-8U6xnMu)tzYm+XYnY#IPmID~Y*(!Fk~ z;bL9qR5u!OXjb(Xh$0{ohT#wNJ<*iD{ys@b1|a*@$*FFsmdM1jaVu8e@#uIt3bP@-Eb(^(u^AEF_`1NZVWTRG+{0F

=zj!@8NI#YiCPWu{aBZhdsGqOtFOr{>DVrk zRvGEDM$a{Q)vfML68LH@Z52)Bjan9QXo;&n%Xd)EREc%W6TTeo1VlgH+dOb_cb&w? zxLfMuX=p6o8d3^r{~_5kUfHF}rXs=#sgP|g;q<@_r1v?jaaHyqp08LrhA_$Ga! z!3>#+8b*@uWnu$X2vcu9ZN6QE^~Y6y&PGptEV0inP01FMBRC!;efx9JyRbS9vwxNJ zvP&j&;>^1&mP)LMbs$Cf1*-`8w>Lq8A9R(gQ+edUn5)>Ia@}JhwcWs~E+pe_=VR$K zsf91}oF%AA)J?G`0&A{ITLP$EYV!gwe!T9G=A6i;JSMojI`|zcAdZM(w>Qk7( zm;0IV5{%l~ST`=y=^5|jx(f)c*FI0-%%lf=EV@AchXkj~+%$KM?<4e5+n0~#8c76P zb5psL5nTWJNqm@)E6=z5Oj+&7&k#G;pQT@{>gOq)FH0cYCRzg2BPTNgP$=Fvz1+gt z^*V=Y$^KGT-g?tRdQy`PA~c_`z_vW}ms3SpI15N*RZvP0E&F^uB~8}!vjhNc}- z7t&$)@!#kiB7Sn)SFUE2PC^=h!=kf(yJm4=FOqPo*C znN4*;0P^;)R|g}XSZw2R<&3|SsSgznoD~J(?NCPCOf(4|o;V75Yr5C^DV?5|!2s&W zO$p892m4-D{bs6L)cy?=UgZJ6e(ha{DLbwU=Nir~*n|zdpFn^+P3x7;OY@nDlG^Gr z+F9EZmFmtuK2`y)%me2#@41&RAT@^p`|mK4=f~_Xb4F)0BWG-Y35!Q>Uro14uabyN8usf%flB)b;GjlZWb0zsGV1S0LI68-qCdsvP@sEaf>o4FMS^HA? zLi0b}Pa&oii}n#WwfzIu*gms|ksfGExV(-~`qUF0hnsZ=_v5rG$n1AQ=e75Tf!!ja zr#`IL6Yoo%Ly9O5LopjyqvUU9m|D-|1ny{d7Y0278qf6(UAOY-oZ+(%SANZ6V|Gna z3%9lh@r$CP-2@Ac*=`M8v4idmvAhl&AtTq!dX1CcIq!BPCzg6@>nY?MqJvK}=$RyOx7zk%~9VmNT}ayGy4h*FT>;gb0pftbNL6$nRd2s5cYT0|!Ii6vs;c&u*Bxb0 zOvV;k%#{l}z%f9|H;Bz4*+?zoN-Gr*T+SB!_Pvk2k#ByWZ=q2TMAbhrP?=-XX28yk znj)kfik2-El56DS2(l2ma=yY0-MS3=DH!(3GPF>@-P|FFY%1jPz2$YjL+h6Hm|*y8 zoiMw=uslbfzJkDkt+1gfZ+fZllP%PZz`(77&>T=`%4B%cmDAEzP?=zaK1w2LZmkSq zhe^tYNt%CT@2D{@wrL5AX*@qyje#kuQxv~aWTagrQDP*?b)>()abTYD9hM|PqG-Bu z)TwtAmk=uTw4rDKpZr7=eqkg7Rm|SgnCiUfvyo`=>llQj;x%IAg-k8>ol`7=HIsqG z%1y^2STiN6IF3Km6l!tGPI0PYd*jVU@Ev*0AdrdJJW?RmgFw$GxFE(R?5<7HY-jbjBQsR5=D zzU8}n?X#bl+>w}4CzK9UO`CK|D*>h7K+^Zv6RV6<7rvxC#3e2}r6a>qb_z2Hr_xc$ zGmohn)=ubMqixi0ZnO^#7dPzfyW&cOUn0OIYFF@UIhu73kP8zQ*xEiNYb>py$TX{6W!S+o0L#nwKW5AQSA`K=xWjDpQy4 z1S``7xr-Vjpq`0{tr;%FTCLW$J5v8PN?^LiraMX`5wcJR37v(8-$EmSuxJq&X{-h; z-US9qh9%X*QVTUdrTur5oaE=Qv-&!w)QH`YzW)I%0(AX2N_Tj!4p4J^?oiO?b%ukvib&*`c{p4h<8IsQNvf| zhyuq52T(*Dz_SQswY%dz(`zsn5>g}RQ6CrtV}pkT=C{gJwL8NN6+S^ zeXT}P9VSyhWd+;hRhtSOLiLM^yGv*~lkVD5rQ2R<^wQY(8qKv{i}W=# zyTf&)OA~$!9l>H5W=d)Ani{ls9X3`UIHnmA?jJ#plzqZINl;}LfQNPFD~w2V7-NJ4%% zR&ZE-Za~v@!bx*{Q*=`CXG0SwbVOsS{I=Fzwh=DUP#e1&beT5Ap3eBqlKl-OBl#Oj2EU(erk!gQntwA zADRTxWl5ZLBQLOgQDqwW;baj>kSFebpBG1B3H=}zzx)Y-1TyYHEE*CrFb0{BXjxup znZSJ+0FT9&Sf(aLrde2)zFr10Au}ngu)Q~4q4rtf&|YC6;Ny8%;S@(^O`9*gh3 zDr6ojf-q$qtco%rzj~@2`;uq%twQYk2i`YnYp_VXzQwRQiNMX~R@J@&^H*1>Gh;-Nu8;zG4$rtAe2x`1Pty(`3=e}A4j%^Gks=OJ z3B3c-w$|K_9wxVuVvunMkM=f?etlc!COPyZJLFAwKYj?<`BiyLGw4oNcHBmH64kkT zm$k!E=El~<3x}W7PM#ngoiOt$IzLrBec^Euv#`sYeyWPO&$D5TiC zyUN)Jd+AwZ=a~Y%(>%f2heL-?zj(K`&rKnRs=P<)zpOP6c}*V9O&kvHo?fu(d)OH? z@gAgI@Io$-hAzw(tt@#jZPG364$mEU&7GPq`BacE2M&4Me_2!NAmhBeqDwfwk4C2S zyrRv&x~SMP zzhAg`W7Tn1!Fy}baYn>sew&SPj<8ZlFWzP&tQ21j$K6-j7L8Dv4?5lN9PQ5^qR;93 zz8dOA95~b8|G>C3vbeXIxSUZmL?|h&FYcEJ4OgUJKA0;%Yz!J8loTfMM;L<-KkSjq zns|u_f0-*^Bk3Gp3>sVwUNbEHUu6_wPm>2&pL3*sESlitFPgWm=f*>0po^d-jN%U$g*+a$B-`<$j!%FZH)IZJW|K_3?_mzYE;i52m zYB*9X^b`fI$NyIsMGxCj`+EiqKmFH#by0kxBTId;Y=nzLg``@*@P~^6p`vh;A*V+x z=6D>Mm{HC3r-}kt{(4-&_=_q-ObgFavo|Iwk0>6mIIf9oIkQg--&)t#KU5U1SQSrx zvY=U_EKNK`s3;I7iZHZOuKve7MAkYtFa9u5B)prKt_ds&H%6E!5F(0#TI@m!KCeFC zrQafo20e7lpomJIvpa-{Vg>@K&v2nYJM`tT z$oX#%MH9v4@+P_GpB{=81}yhO_TL%`RR*7{=X}346w4NrQcDsT6skxqR#(VJ6ann& zV1$OkPFLrVxJS)z4Fy}QEgyANOPJUn77F_HTqZ_Q45vRV6y1ZUS8_;Rfdqm?UT7pl z5=a=7W<<@=zaY(^H3JQjQsi61= z0|lcCS7kRxJy%t~@V_)ra6@U9gGK(Ef#UB13fjLLDE@_jg63ZcDE{W3VEMa%;ilH~tyIFGGo6P4YDEFDAl>h5kW4fkfPVe>ui|EGD?@2Xu!n zlY;OCS6}vq39d!SiaV|0SWXLW@Z%~7ZKk?BIBDQ(rU`BHyd?NTJW+ayMhw*w-m6e# zTHvB0Vv`U#fDtmk3i6e+%sgnp9mGY)KMU7HEhB6aJuTK8QrQ<1cqw+CXf`muGLEoL zFq6~pb{zKDycF+18A8Z{N@Mt&;2kJXUPAPgsR8l(^?D0;(KkX!uJ`w!Tz}o)u{I+w zAbTP|Wq-Yk(F)IqwqlzE` z{VY<(Y&HotnTp`E>@4!CE-?&~iV%&pEXp=PF}~o6(AKsrD#lSPk>}uwuu#2h;OZ5> zU{6K(*X(TCQ>hoCI~5VKOxg5|sO&GqaVjJEQF9ncx1YV{sElHl%wb~ce)>+PGMWLg z=gIhNT3pJcGKRh_hgFj5iBfQ7EIv~%`==m21#o2?fn+Wxquw>Ia!+MEx=${*jWMt0 zPG!Q;N-nQYH;)cZ6=>s8-V??cb{>6>s>Hc(dC#(gxQ%40l18&~UsPf9n3z;0_q65- zFt#~yn+I2=v^>fe8ijG%f~!&?-||IPsW=^bs?tib^2HcWoj6=~s?xJp@?WCvuv_9( zXH=pTydqU+_vWb1EPP!c$;i~h7ARAl<)2mX_8A9Ts7ZCU>q^0U$zayV;OZQkZw1nf zpM*cef~#`_z7@*q)qMcO1F^8k3 z5bGQGQ}zvg65`M%<_-SBSVc__HmNC+TnDSPB{3EU*Ax#j7O9Qa(Nu$LN*Z;GG#FPo zfY6?r(wxj9ty3K8hMk%+A&tzhsI&BNoZ50v0(t71~j@k-1W3e97EqRwrZKYyn z@%LvyvObgADzW8aLqsKYFdcDLGJ+-wlxg|dsZFb3~ zc4{Fw%Ow^*oJ31F5a=C4sa1>!;Ti`7cFrylCBRS_MIXb33gCowTT4{N zWT&98$ijMzJuBm?P6WxUkiWXF*%j0cHBsdNMeJ{#03FnvJaL%-bn%nvz=t!(nr;+$H>s(&y zl5I*9`_Z%otGt(P#?m2mO*uR@?uubMH(+Y#@sl4lMf1QNRZ{ASYK6mMi+MwhUG)ao z0iY(;34iYZ$7O1QySAbzWKUnjVY<(}vr;*f(5SYl$#abDO%%eu39QY;dU# zENAWK*1}6*gZneKMItqO<`CmWe2ZawDwhS5Qb?m)=%MSb>cZVXanmpEX-{0yN;E27 z=r_XfqjnD0W$z!*HdNf>z)QPDGRo?o0A8MzP*a*=nnS2RD}gm6*mdm(H>|rqnCc`p z!ET+4aJTd8f=)uOX3>)}n3r+*DR_a(brX34HkgWYmceSb^+o_Z96>jh!#STXM`cl` z;6iYICGNKKBCc+HT;sgB-gZ}0v2YU5WIj?JB9^1iTmODA{2~Ksb% za$uWKI)AvZ*RZPvvpjTp^^3sK3x~@6=$@-#`8?>VeKzJW^l*QLo?fw27CskCRn}m} z8*y!C>V66iYuNmWeKQnbbCxriw+&EC?;k6IC#71r8LAPw7{c8z;^P|kV}ht|=WuN< z8wm3bJ<58Qa9<|0^TN3zBJNZG|XchOu;*TA5MGfwOKJ%OBas>QOCtKP50XJt()^AW$qcRYUC?=}4zCq@EdWB~CQ zkZYrm>2#>)(@}6}QTa1IiXrMcsc1lj45cp_=)$xZ20CLu7%;U{G2Judm@}|sh_Qk3 zI%KH~IQ$>5N*ShVGH~O`@H#T4`x)@hq1S&$YAp&6x!3vjyA-8`$CHZqC z#UHC&-$W6Ar5$HSt> zpQk6zqW>;WA76pxyK>%lZ59KAJOgtA7DM|yLw6P<|2(5`7GwPQJmWMLle|2WQWn!% zygXAli&;mW*#OIr34**I3oPavdFBT!78iLIcn>U=82Oe2tX35HR`_(R)-3tfJgheS z`8MK2thVp+Z53GUl=JPh^9c}sD8_3v^dCroe}Z`sB+sv{bi~TkpT;KytAnVq{Y@2x zAbAj+4+7mm06Yku2SMy0cpe1Cg8+CCBoBiBLC`z@Tx|Pmy^Vq8KQPaK>TUn8)!QEZ zOPI%m5joamr?4OoyaUc2t5x3wmv`WB;#N+WnBvxM2=C(5T?CupqM&i^mH&1 zmGxmrm{v8e;FOdOD!6dsb{Z7@sO_<@Pbx(9h{2#`_+f)oTfD= z9XNZnS1Y^w@O&?qgC?YG&e!#5qHrP?ap2+`Wf z*I)0?x~vUKZxlxurVJ2Vo*)4Lz^qfMp1Q#PD zU4=G$IN^eGRJaP>>*+xAs@2q0n(MVt(b2((C;yrqbJkTwHmH zHQBO-n`7U`10jsiXL9v@X#$50huOTmVl9Ju62hmIigFFZ9Zb9rUwbqf1PFc*M zKYbNx-h`yPoDVw;9$kX7Tf`q8yy6I5kG6W}+8ihAuNxx_c2{Zc$1s~mu_ZihN9ljD zbW#4okMu5W9SJ=rO0ac!qb?^2@4R>&hnujlHY2UiIi3V2=U22D6?BN6}}Z zDjjds@HH(7mmfFoSKNTDc*+?b|NU0o`!|BG>00@`kniTvfm2XSnJ3=LkFzLNzr9I# z^}$`iH4C5IvDxNlnWwJkBl1SbYw8W&5RFhKjXJECk1J(^@M!eNC8FUH&0F5D%A@c&n54wNjWf+A7oT)!XL^(a}-*Of<4aUHmL3F~MBqV*^v|vQMoe zcU$%7pIvuSnGXf9ndu#iV)v+9e~&V#*5g3FmSL%V8V1tOWh%px{%j%_zm%QJo!~^S z$sw1p^YJ@J1V<6)CQ#r^y;t1~92>*-smlQEhA40KY_em5Iot`L!v= z7HP%2z!=RnIYX*J#QGFUOw|2ct_XxmKp&n{k%pDpBi0$G=#g}qxMj=s#K&09IG8vM z|PNQQeSd0W5Nom-mn4dY}zp658~7nfjDRB`CQxb^udFX z>eV1&u{_xr>`#f*_?${Tw*38hV2>zR*)Dp_Kgl+SvoTX< zGVx=%LKXP9dP)@1Gm&UJp0C$v^d-M%(DU81Qi<>EjAb+3S)khUGIR3vk%qd_yp7Lg zeqUH0?{j=iIU&uK6l2q2HI0Gr)TH>k?P(rUe`)8B&ljuNt{cIrsQGwN%~DZp^fFm3 z@M(Na@cD|#Hdt*?{Y549+YjMto zd=kzg8HdHAl~VZKz6E$JxeNuk+lcib<%+7QjI^rKe41+k-rLpr)l=45lf`F`-?yyG zeAV`xSi55defdu2N(C>?GFZ&Cj6D0QrE}jB(OuI61zvTan<9rhH1JRyfjWgBj>Aml z^s}(2I$w!b9=R5;F+YQKGx1c$E~9+U6od6h*FTL1QtJm-PG}zNr>dRow@|6BZ#bMQZv5`D;MGe}dCi|t z9&y`m`T5Qz;+9JQN~$nbPn`nZP;i7%EF6VLtSm>HaW|;o&IX3mFD+9Dw7wtf-9x<&Obu?{Q33plRg^&6UpJ#zt(jVSHo=)vS| zg3!B~7VYa1rj?T*e9_4=1uJN&)`~o(8x+IhW@Z=poWkL5)GIRKf}&qno=-j;u}%G*i-Bod>**e=vSDJ7k) zz=2Kk5H?>aqkt8apsXvj1gVd~U2gX(-gF>e2f^S{C|YiykENr(+GL;`$R`8Ky9^rg zuGyGESpraN$d?z`qWX=+UN_X9yRE>asFu6oDpYbQ1Z@2+KTEe;P-l!ykgE3E^p;4E zV3_=5=)ec8u(d7J6n~UT3}XasybVH`RW$y5bW`btHyR6q#BSg;NxOJTZzI7DY21MRy&=Ko$Lm zvPmtPMZqbWEioFg9mO?`*p9-9Wij@XG|sOzO5ZTL{1!2D5?v37@qBI^r5xj9AHyHS zEewkh-45RD3O_Cg&v=KDwGm6;AKuLpp*U^IZ5NB_5+9pwWG2FDIzjOyiy~eIJzh5` z&V?CCUnoe^*h8B={)=?H(X?K*;9CRZA2#ftO%wfpNPAW|D(5?TZwvSg3k2@@1-M;% zdpKFRLsh@1CZqvexY&YR6G7_`kdvM5E4Bd2D|I3OpP6yuct9d1#3MN|*mpZob;~Bs zmOUxk$z9VY$xF)T4O_CZluvq)f8cf!x>S;LLeR!JC<)8s|FHK~QEfNszICwT?gV!W zT3ib)UfM!iibH8}DDLj=#oZ;i6Wrb1t+=~$^4ePOT6?d3_88~lyZLT2M(#3_{GTy@ zbEY9B$U(DByRl^jd^JYN5I*1Jf*@PN|x7l02 zd1C(|vXxp;Vn%e{x6jKhLp^1}Fm zdXbv=07AL6-P`oU&WyW+%&&-znQg^7Lwu84ctcnz6t$N{y1o}%g3A;Urc> z>}j2tk6xRP5)kCgoU(G872}j()Ne0)m;VtlE5SI;S1tF&d_gL6cHQT~PsUyk+l5Yh zS+8>nH*X7s`qOiWDxC^%9FuE|U1L#-7Iuq<1Ct_A5~Do?!SkQN&$5NRBw(-J zykTRP(f*>IyWHHmf@-JYnz&+i!JM%9VtEqq6mm%}L+ETk2>Vs&;a`SVg|ryKHc>5xDyivNaJjLXmVVk?4IDOe~e$!4)9qO8%ruK}e5}v>COh3&e`r?8gd0w+)Q!{uePL@* zL$i=U>q72$e{ls-9z$ZGevdpPdZ8gJDWzY&q#_XvWX@{>HYvq6W(74(^yeZ(C#~%^ z^$InG?dHro7kKqGdh8ZQ`8Hd;Xm(&|oH9v`Xn?57v|JW7qxU!059G%SHtQK@`rXw( z1?yY|JD8IcZpr1Aqc(j(UDQW!bX-7b8-lbK!g zuF(EY$wehSw4Sf@JqTzm0tb!Ai(m(|FHSBU5-zQ7jU6beq`mGuy-oc+MhD$o;NBLK zZd3nm)xqv+!@ib+4&1?xC^Y?$nl5OMS_h#1SFqggv)`2b!J%Dq_i~ACIVCk4}X$Rxb zYz{p1Yj9|ZRdQi*_>!n%RpZ;pq7GCU1;bLd`m=O@ zQSne?Wb10^l+=J7Qt=X4i6lG93mkpq9c9=LM{f#8IvgcM8L~jrtVC`qXTBf(!aGK~ zGzJw!QId^Qi;UCa{9|SIkC7SnKSpMMZOm|>hNwR`W`9l0u%L~mKPP67lh7!W61K?H zN6o3vu2Wx9rqr6IG?u1_ppGaVvgxlP(_N+&-wqSQ?WWBHe(qSpNEy2V&t=GA008sH z>8i1?bt)Ku2f$7f_yL(5S||5P0k|H{I1G{DQ_cd)0KR5G_hK^i^jWVZK^p}_nm zQoPprpyGM2w=ng^3+Y3o7!k98_wzxaFoyh4oGwj4`U5txki+_cC!S8!aU@| z%MIA;)S~MTy+C@mb;hS~;L!?e7%4s*kOf4((aO0|`eB3DZG&W(M6h{7_-U==bVHnc zleU@UwbrIo>UweHrd;zT+tCJx>eJ@1=SJq*=0|eU&tb%9WdQW$6@WObX4tA~Dyg~_ zaqrT+jwmcLv;#HzC5f;0N|5Y={ge^n@w`708^V=YIy?(){VkwMsqnm}`PZT?Zfi1`1LX1*$P(dgiPfG_~PBzSn7$d{|Q=)ImYwJ ze9Z|?__2i6>8{%;kZ9zHq~-J{;VF5k5hcZ0AnzG%xB<c)foq- z0hidh`R8-q5qf0nYQj+CZT+-16NE-BrqIo^ror`;(uvxD7ZN}lh^D1I$- z-3zMT*SOyg#op(5-HOfKtCiiDX5L@*-gmx3=uV^WWe@9bc^DkI^*(zTB6|EjF7}x9 zgK4_OcsA{EzVt3`{4own^-w(3wJ>e|Fx+x~+Iha;qj*RYdpN>;F`34&&#rf<{rt(| znbPAw<2(pKUNl?<8eHLBtsKA6JcWPaJj?9-tmM%HTik+gsmm73YXsTSCibs{n(s}| z?t2Zi(pM>UGuHmXc{bE(Ne4Hxe$naYZf&HdZR4!w?l2k^>@pb_?lGHWeajwD-)A)| z2IYNlLHHoqcCj=SzwKv2t| zdfWEycbdobvC;t;lr7@k(lR7y=eqdl!eV$r*$%~d_F7epI$8X17`#v!%aC6M{f_gT z>Vf>>T@Ma#c@?< zj&G;MzRgoAHohzbCDD~z9p0@oeS{HjiUOHjvCgwo*XEk3#V*# zIU7Qr$y;px4d-E1K9=kVVtUdfcHrFXjba#(RJc8B=_aJmJ36@K-T}) zQK>XPoFChnsR%V9=yl{hT5Pbt7YE1iS$uluhs$pMaC5jS3&nX>4)5*Hcj6+HOSLq9 z!+GSz+_Rj799~cNs9^qp(mar*C7HL!PruVV%4#FIB99MuztKF&#qKaD{FI)5&^&C+ zb(G%dg1_TD{8Ya9Y7W2SJXHQzV9tWB* zfxSAK@ZWKsS11Csk@}JL*_LmZ>uICkwtLgY^a%*i{rsfHq!II8T&tcgPSYtt*6cv$ zmM-CYqGK+h?q_cLM6-rl`eeTRdekJ#c}Ip+yWs-)6vw+;hV;)zM(A#yBu2 zL3I+A1M;P7iOy^dz)*Fw`j@*6_iT+<;04ph%Xw#Zj&rqO_Lj#xli=pVCn1hDge07@ zHe{B@GCCyg-&me=JV15lkt_>mH?fI}ZwK%n1e509Q67Pv3~M%bzT5gr4u(CC zpRUENupppyp+eI`lJKm?LdBKQnaEe^Aj_5T6Mdk8)Z#GYXtDeL%&K zAqC7z&U7-JF}QtP^jgMh^oGp#swy51nxugG?@om{cn!Bl3f-!?91dzA<@ql;PfTWseB zt*+FP4p|MOFXu+12)<-dfk4JTZ09E{rau*#flMVY=V!+Wlq-ut<`lLIi)YjCA@qkJ zi}Qk3 z+ch$FCs%bX!4$rxBm5($+sWnY@tL|)Iz06KBJE4B=y+@L1$K}6?9c1)b*&F%7H;U^ zRbHx_7F0J59~#y*WM%xic7kROKW^|Ex^-O#`4dNgBo!hqOO~_nG@P2*dRNv>k)n8jKK3|-K4^;@0Akeii-Z@f~@F<#}$*s`in`q zzW;MGS8OHt7s$GP09)&cE53tui!-`@5E=waoB-o(@l@6inY(c(Z?JAdxvycnX6{s+ zbMUsvdHsmn6L&fsXy+Af!zi|&=*vvp;+;3#4P#-@>^uz!XIEAksySw& zIeWr+nJ)?2`w$KFoQd)jDi`m4s>FtR&fIvwdZ2wMpEX_6%u|A6UA(V(-Z0br#8c)C zI?%<1%wlzm@>WC_AAIMA%niHoR%Ky>4vmx{^V7||HI>DOX4a4etmP-(x(?8hWi(`Q zOO&r+0<-wYwi2>*Zo@X{a@{$o9uIFsK<7mFwwZh+afTGX7b-L$KS;Vw9t z#dcYs+_d+(MQ|FkvgESZx@lkQS#Y+4?W!Cjy6He)OlW?hBz1{Xc_&yDcXAXcdkfAd;68SPo*#GCzL4Y%c{TMR?=EV}gJ7kA6;u)F9*76$v{u5t@>yDNHCS^9X0X5DhX z{49FY!Txj_jo$LGB_?(^0Tt9$wmhOAxr;q)us`2Uv^-t4po=}7mp(t9w>&>Vzq-+1 zqysQ0VF=P;$SGmzG16h#D4|dfoG2x{R60DmA|=A-bObHR7y9Wh%rPht?a~q5D3Scq zk-{jE(PPt*Qz=n$(@{z&QEM>LQJX2DSP$ATCHizaI{Gpt##TDU5hdnTI_47^CDw}! zEDS1af(&eO3@RM@3>-EpT;2>^Q7SxisSG?tDk#{4uSG?mkCs7TPDN;!LFh(BuHWtjPd2Q;~LOkfIG!kxgfiEmM(iWssvEQBhoFP&`plzR0A+ zK%=H2$fP2trl!xNW<#T<;mxEGrKXk2q*cVAru&=;g?s4rGwIE#8PM!98QiED{W2ND zsF~1WGnrDUnR7FlOQ>1UYBE`xsad--S%;}XXw#XXWoovqOtvFxcC@QZ_9tqN7g-z_ zG@NKqx`&*Gi$05sjfNYIH;Y@8=A~5DOGO$Ul+Rf_S~R@+P{;@E56b5c$Or8oln)g0 z!TLMpLlXh@#r%PtQ2i4-fubmXI%J?c36w2?0w++C1d5kHsS+q#0!2-rkjX#z6DVi$ zhfD^Fn!w_bL2WazSOiej%x~TpVj4IgJ(NHB16=t-wG7oR|3fD8pSYC&$Jbv(v_@FS zu#$xLW?S`{3Z{+b+sGBGmCl0g1`eHCR?8d|*pfzes+@+*AuA_!E{dfK$2W~`Bggm8 z9-dy_KE8hb0f9lmA)#U65s^{RF+XGD;u8{+l2cOC{*#O5|E8jO^=DP>Ke=fB?VWJ`b4=S2JO*G{Hy^7{9 z6Ak5m(^UJHq#EhJG}ZnssYdqqrrLi=s*(KFRQs<|9w(z zkLbTks{Ll7A!ObMSN&^J?f0S@!SDD=FMZKL#UStgLFJ!CHGGr9KV>v_`-ioE71aoi z>i?Nk`)5%N@2{lV--~K^eBx zj8QTFJ*b9#yOnHzbi1AI&-Q0f4J$vdeGk+RIS-a5l(yjQMarD$V|QF{$$7@p9In6#MaN(f;`H`kz5HOlVQ3<8VVg74SZ4mQcgK9F^{~A<-_S7){7F7GKr-u2@ zpxW;}HH?1-)qd-#Vg56y_Ipnav$sd|?ljy3?a!VX##uywLN+zwpFK5{H zduo_`eSL~fXVHK4)G+?0W`>%X(f_7qt{cSso0>V7Lx_mVL-&Xk(ECz<7`|2E`*KLa zsEz(m`kaDMe@Nn3)z(0nEd5)vhLEJWi`@R!xwn?NA<1hm4f;k^6>Y=NlBfdog_el9 z9Qv(O=`Rk5gf~=&T(^|dexWglJ|~php}b6i2O0ACMkxBwPwHW~?g$YQGmqISRb}Gh z1Ek&(jD<>EWRgFAk(O*Ik0g%zm_Z`K)C+xH2`;b7!E5lH=`1)M@P({Hc>cLO%jr1{(WC*qyJL_7g|DB)qeXKM&V5J-DRlr$ezqNr zRdrctv8e%?{KEYF)gLvr0rbDs)E3v#ey^#OSj7!2ZVA#{*N?iIspQEo z?I;B`%!(A*(*0aIsXZdjlcI-tnV{+!sm(^qy5E{RM7vK9yhjn0R>Bj1ZWgECt!?U+V2n(RoPMmZ&vg}vHxOdWiqBh|;#XDr24PydMFSCEp z?#j;Oj}yhFV&`&`Wxch&dz1QP6D$GR!;fs3WD+FiDSVf`_bD!as)QUHtVfoqI&mji z(`@T+MOnO$e%dfYe8*G9PHP%v`6Z%@fS4aU;lPP)A!0}iI)I6MyO9e+6ftJU>shXO zn5UT(J~zBmbxCBd=YAjdEB7pLX6VSy-2nif|Cy~GV#Z{@aKC&*{sTrr$HU>AaOtAt zV>6=Si5Isu0FlTG28;4oSEYLuEBvMbSKJ4VadibFFQk)i_!NdF9SMQozB?M)%ZOjx zTp)LhCbF=Kmh@=jVYXL-QzVBC78BZH;k(7s3e!!dD1gCl;eYS?l!XZJKzgF*91DfN zpqeG^l2?fu3up5|6lWLNlS{nqBlEaqeEYbGxg!{ZHfjYI@F=UzciV%Xa1_ksalpWL zPrYEyiBv@jFySsbE{*fO~V>1 zXS3U>mtripeIPEjzB%@gYliE`V+M*Uq^bviinhqPJFf?X#r>Ah{OLT4{qf)#4Q9GD z`xN_yeYB9rJ=;>+$#n<2+tLj90l?ov*Q3o5TT2QXmo3(|dJI6=TzTZub7znS#Amz1 z!ns@VT)Vg0bHf^1!GG}pAaT3dTY2PpO!#x7BSF93Wnb*D2T4Qax1Ns@0Q*``QX)@Q zZx?z(n3u`4UiT7SAOkOkIuEWaFF>moD8P$Z(t&}&o7=&YI}2d%?=2wX&C%g)ye-S- zz%DUsmBlb4D{sgy6Ypd&LpwF&{Ard}Ilx&po>nE^MYE2!8WHeB=qKZ4am{4lM{95Y zJ>Tz*7oZdoaI@_vMeV{c<7dN2OU&=@$WN>0@8mG+FIM5t?6u=RLKxuNOJjrxz>E(N zJM*vc4@k%JJ*)`OHIR$xWsB&w&5H@#`x*$!-wfRK@+hDU+O7yFSi>pR*!9h%FmFc+Dyp5E0y9=tITr2W_fB95{d4I~Ku*tZ?~h z$OLyIeXzXo8L11wS9bmgnw1;v1@St#P9pjA+_+4{hfWy=TQY`i?6}hC`9~tjuf&7a zh~&=$K*M}4J0SnSc_dClB)_ZKaLnlN{Mzt`_%K2^w=azH*Iz&n1`&v|VZf{iKwkuG z9|tCrtUNjs$R|INuz=%%A+V$>ax}^_lwY3qmNnZ@o@tJi-ccUZ$BG*e464J6?vL^m z-HAj(=4_Vn@+KK=Mpiu&|L_%U8 zb^2@qOGRLMJ$0HuB4b5Vt|N6mS$rbHc#vgJBE$O6GFfU7fh0QWq#9!C>Vza}uMmnx zRTImd=#~PiasK2wyU^}iszry`J^`xB*<{l8&}|x|6oRwl$posCo%k6?s_FU^0*Qzk z#K`3YN++hc4M(a~WLNTSRVuvt*nL?l;@ec5@zgD*G#s3?t6R#*gfw)VDDsDbH2b)O z=K{)nfppZfs2BZ|+yWuU{gnK3=_nOR6j(u&3I(3HL6l048Hh?5p9?acBQwb5C_fQr z!sEp9k;`OKqc}1$Q_>5@vdmNH-(~`E;;fmnZpwXm_b6UcCT2b9#tHXR{Oa=;OQi4< z$iAbBXJE*dGNyRbkbPaA22>QJNG8s?gvt4cLZR%GbLN?#CP$$R1^J$LE`z)IvtubycVy$L1#7Pg;`_~3Kcw;) z7f!w^ENmdVNGL2OA-e$*7fqiQRwa^EIu%Xm<}{#?J9kxiU3~GNXYpbXDa~DRpKhi)Q%P4_-j*OKh%B*W z6{X<7n3Nu+=(qvMFId`9o^>TiDwyUPg)~1k@HRuW8F6NCM&~l~;P^lFFBZdIBlbf#l>u6=fg5^!Fs@i52D8c>s<) z)_am2lFA~g;+GH-ZgS_!eBDaHU=pwS%A&S>G?XfEWd7?U5{dmxX?2nTV11PgBup`h zB#xvy7rXEyDoKh{b&hVe+5$0{Y`!|vvrtEfL|dpPO}9uNl_cp)q0v2&;X+MHWD(FJ zm>5T>Hu|*4E{V7|u{O@L!bOPK`Mx%$91Qf5uU$~9``K0*uunAAP#3*c9Cl9>3WTb_ z>wz%?L~Bm>ix<~(iez01sj5~OG$Hs>&+7y0zFH?!9@3W4Ss8-70yIR z>JYD2Wmdodh@T~-@t)B39^$xG*3LrIrru~*Ufx?oNCbp5x@9yDI}_3Yo9sX2PC^JL zoSSUcYUhOrS;!We?A9tcP@8Q|>o@laf}ESpb({Af1iK5(Ca0C8C#ZxH1K{%_0x9Pf z^T?_@AmQ7EmV4A{zyX0WP^k5YrS(N4fof8#aa%PSE5Y~sR(+}(oDc$RphlbCX$_GA zfeEDTdqx{MDS;KRUDL9W))e0X=-jS_UB}#rzqrt@($>xqg6|{LLFZD(3&uZ&%9%fK z)(gAf$KH1+uXRW$;J?!7d>>gaBaEL6gmivtYgSCg%LH~QeQ5ZIhM(`;_2EOQ+5uh( z*+SP_Y=}-HUX@U{!Wu;16tAHu+sFmaaIsrbx1G&m5VupPN1Cd~?g0k|02}uOel)4) zB{rZ`7Jdk|R{#bUWFO^j+RNqHD=P>;Cpyp@+89Bl(-akg8+MQy>C(qx38Tme&pz0f zvS`H`?-|M54_fQPV+5cI=AnrLCIb6eGy3ZolTgYW5l1t`D?)I*8f|$L0G9C0n1Zz!J%BU=e4?bZ{)h$y9iVacyu>IHVjr#4t8wZvb%RKpSSI z8iuXT$+y9+XLnrM_3~J01<#YBD7KJwcbtB zQEty(toNfRARIyq7)+>)8TAhv^Y1cdlz(C}s8CMkFH8o7`5&CjKV{6Qe}OVkP6p+l zpbYXKGG-_!gZhVz`FBzVMd=@;4Dw&7j27x&sLVIbX=HuZX~UFhsPhKJbZHuj$XJlg zSRsqd*l5n!xz0GCrp!1s&A2ShxINAwdyvg~iOl+F&Z79a&IY9XR#gi@{;R4Mh5V1I z8tOl)YL9cs>16Yn(5jl|JaVqf_M_yi@ii(P~wDi)_5)BQ_%a<>W zjEw#ms-2ykd3kwBN=nw(*N=~nM?^$WQBiSnaw;n;r>3T^uC9Lj_6>SWw6U@2>goc6 z!FKikkLSMs=It*c#6AB13m^)aTvh)6w--R&|8xPIdhuoWe{xaR=bY_K6lstAasRV^ zr|4evaDQ{YzcSL|{&fC?z_{uGPhgntboq(SVagMYcW2cbYY$P!2k-N5_dmhBRX=4; z|0MsP9)?K)q_MqtfmInqfk74JhU>wv>yeT|LPvKJLpip3al^o;@d4q5Fa;YCFOdTi zBYi225~E(r^Y-CHL!S`?VlqGMZvLcIMNW)$wb)IF`{LJ!9iJST;Gd9~!L*f#R}h$x zWZP<#kZd~Lhm{huM(m#&b!xPo27im3kZy*x7oXv=P2Z218N!+9mlgD?a3}jhCMZ7V zyN+>uZj4<&MqU*^2LJ}$lfFw}Dsce!;bJyPir{MEA^SlC zZ~5iR2uVMffTxGm0R{4u;y-ZVU9h0zYK+wvH3*#$GsrSMD|X$79U!rx~jP79eu-)22&fI|IL9uJOPd!Ev1L zzL_0(UGC1Z##3+H*mS#=%J4jL@CLwM2G^H~andrJDD>Ew@6K^P(}!}8LY4e%n}e5MoQ-y z#-HrHm4w3*RytuLEmgS1abbErxrdl#b+JjgG3lm6FFeZf01`ZrIGoXo78tj5*W-?r z>vlr$DU&Yz1NSPT8$Ie4)3JZ;$r=`cL48*a`uA-(kIURVPS{N>9n96W1rI)v04i*q z9XOB93K@7yO?J<-^Ys>s86OTUOxLf;1W&O6U5kAOPDk=d@YOqI@QS-$WgRpu#)!i3 zCl+pP?GymgtDaDVU>CVE4{B^=!1=542%`s;>s%yhB;#^dGq5j~XK&vZ(wA7%w$e9?PDeg;=xGmC z2ruFw?nhzrS?V*}E34YLVGfL)z? z?h9(ri{3IOMhLX$GQ(;6Dh)_QGLx2|#(V$eHKq7w5LR7SL#chu9Wg$!fYj@|Fl3&? zuZsSZa&ROt)xjq)V6V)?p^^L-X_bguE^x1Pk}*U z_S_(^nX9cPcTLtGn!$S;>1c8wG%;oU8<>g17d>*6TZ;MKLvUd**fz*1+)AIZ^pVAFeNuh|(^l@|*SXoIg0dL+4SPr#x! z&>Z;CW5(y@0)jO(ll}TbOWbS!4Xee%z%m?aM&323$VMesth;}@oQGG!+g`nlAwOCl z+m8L}y(&!*S9`mT4)IXp=!6G`r#p<|rhnyoh1A$dWTs(oFidL;YJ6QM>}uLKWzTPD zNpl5KPjHf5Bws=MJPpFo3FQjw8ZU&nZnUW~;)jNyS^r4)76^b-!K&w^rIP}W5*jP* z++PE;ir_?=UVkwr=n+{qevHQ{3KrK8rMM7hvrSh=GjOyv`61$p~HDGXxqSj;=%L#9ZWIAc~ zLv3k6XPt<1ELRESEt2~+CKB)+#@-%Tv$`EYbhV5d2WIX)Rt-Vzfw4y@G5%8Ob9Hme zT8DNAXZ_-@OOtSs6KHM^vYpD8f<64kp=fwSfDN|L^1+mhuyL6CvpNQa4mbP6kA7GV z0DUFuu5p;jxOe_VPJIMF`wLa1)(ytzS*BAS_E&(K*V&g)`skBWWgdFkS$dX;X_r(o zA^NPO8clO(uRTLj)ZbYYc{&ra)m(PiLu!tg-JZF@MD9l5$y&w3DbWn5V+ ztlwvqEV>QqkINTIlVJn?AKUYUifc}|V7`Yx86j-@dKB4f?u-uecQ}&XP)@tdPGuBKAOcJenZr8RsFqGaNS7bl)6`FU}3ou_t|NFt-$E8 zU(BhIDyqcSESvx)njo&#LuHC~=Wff5ozE*!=Z=B2ap#?#iHljF)yEUjvstm|qhrcR z?9{TZEEVc8x@lc%0Q;=!iWN<8)Auv)>&9U>hnv6%Cn{lz4aKbV=&q9TJ2=V%?<^zh zTGtpUk8}%rfX8PfTgmpR@9)Lm=t+H(T;;gQvOThgPw&!m$8-8DWkvbgw$#FPj?c@6 zLkJIQL;D2y(Cul0s5*Y_i6&)}&EVA9rP=!Z-B|Ufr)}HQI0yHbuXl(}7+IWBuk8>b zG)na(R8%~c{a&eXTfoM{`)s<)jf!5peq#amh4Gv58xL<&DckEjTjy&>L+%NOLP{^% zc)uSULcZ1T=N0Y@4g%WN@6z?%o<14cAZb|E83%OP+wofn)j2)e+l zHY;f*@Hz&Sj-zWK6nyGADx$>>wNAahL)58p%#*WVqQUv)~`QoM37pA zQf(FVXX!AJkqixWxB~2L9aP(>Bj1(rT7PzsNsmOrV&9tydhU*#;rCWOkJ^ZeKq&}A zPxu+;@8!`d78k3cF5Rny*7{DWRXl-6n_8Ot#fjl{PsqJwEN`8t(Htv?)}ffeS$@^M zRK-F+ft5Mhk)sE&fGY`tn$c^Wg>Df=R1xa(#Sbf`+RlEy2Y?pfpzx)UE42A#qY_YvPx+#M)O$6$wd79!Z}$gAM(D zmaZi=C<-Lwzev*ln(WNsSR0t!IhP#Xn)FFGiKabyST!XFAD}@D99IR78l_CnrI1Oa zH~>MIQzdx9|GYaHbtAI;O?kB(^$!n=MGYxlN0UNPSnH zT1k}-n2+ceNVymAK%J*Kbxg-XL3}Zf2R~1ZNs>X-@B-I~notlC*nmfom@!M3DXW_{ zGM5pfn@O$MkQo%18U4zBPJo83Kl9}t3pWWdeqjoqT9&XA0&;(fkQ(DFCq~KnEGZNO z+8{gu<}7rlY{kMCvQBvNcO)O~vaLSlyvfLXGo4+Sk@H#OF2}$lhf^v0{e0R$Uye>< zt`=pkBxeqvM6U5&j#z)HVNkrS6NxEFUTSpStF_$Xwmc_6K94(Gx4m?GlKcQ>A&WvH zt^Raxjt>z*DN!T^F@goLY6ZCX_@RmU&&h%a-g~$l{Wgbf@NsH|@y3O@^YHmjh2Vz5 z?D@j7hQjE1c%1h|HBLo!iA4Q^seNDxp-2r z6jGnBGg_j=STxsP__nO{E~xPFuJlQ+Uy~^jg8?^!l zwH(HyQUGWI9H`tXtQ21;BUdkyk{l?b-bZRztFZG@8-%fakII3+P;t;u^69=pp0ZL+ z9Vsc%nSmu+$hiuysGN_bW<#)A*Q92hxyl$)Q+`!$1}u?XK+=33s5;-P8o;i0QU@5C zWT6|=yAB}mbJePe*B!!Bat2j=2iBSbYjHhl!jj5S?<&pIt8lRDadqluZmaEz^75z5 zoPpJz5QHrG2GN0JZ=p(!^mMWQy@kutcAa4+olF{? zKlYp09y)@sI-fZY>PZ?qw&Xi3zIA+t{*Zt>av|yL2k?5N3@WTIXf#?L2P#YkyJeCg zuZ64NNV{-2Dn4mcaD%%hHo88x_PC;Tt3@=}HZ~|NHlU!US)x@fpte3ObgS(59-(wu zs&|(Q-bZfScOBGk-q;B<*dO}PUF6cV;0KOxHnuPv3^EH3gF;67 z8tci1Mp!`^j;N!nrlViPNB4s#`3@R=9wz)WDp8%sZq-J3$Og|{)L)`kTrbw_xwJ@< zO#+N3Buyu!-;R3@PEH99V+?k@*X(CbY11|9wsvJ!A1V)i7^F{{dYwL@xR@?-IAvKp z_;sluO?^`8?I1%6JN9BaDQF1kar)EA)N9R|JkvSt)tO1Op^xZuF6b?vho;^okBOS~ z_@?AjEDgxHf=!oZ>CGA#6y~HU7tF=MHji`EWb+PYkUH0il%e)0*4e1yk-X4((Bngo zYVxe{;mE+yqM_@kViQ8u(u?r7vr!H+L`gHu5B((yVVvWzORoVdBy!9B`{T8todTgtgh|VHOXXZkGmlG^!eooaEGu*}E5C|uwu^g; z4l7E@x-o~B_a7z?9+Hk0mqLmc=0c}hTvcI?$}}}sy+eDE%@?mid-aE)+28_};q_wB zpjYzvx@L?2Vnb8$9QM%CwCNa8*w~ucJY8z8f65}Dcou6IehR%E26e+6w1NL*Ch4+4 zp;*3&_HaoL1ysrO;?A`m)w|MgV7fqyvnSqt71 zw9d^|sAb+2ubKNAV;oeHufIHb*fe`AvR)~??5s71dAKDqFo!8JydlK49j>*L4O#>) z{d#r`)74Szco^FBeCkRITQp2Xc=5y-KfK7`HphRsME+H*=m2n{DY+1VVKCcyw=wB^5O2oA>-lb zy7|Ut>cUpo2It|<0LH=Ca9SYw@e&q~wAP^PK11+S8L_LCNW1)@vh&dP0t^#>mM(!x*;4b z557)2{PF3U^5FX7;O5@#+QI$I=|{FC#;scDZBW;Z>znFUw%gIdEk~cT@bapkBfwup zx9P>?U&O9i2ktmf??#yK9N*;xy5C9a-PIr7u@&7x!YPSB_hWKMvgA{jj`xyvr5p`u zxEW5leebdc!f}D&dA-jMOYPSMoR9KKkJBy8WA3=~+9WG0kAiY{uZ^G1L!Uz1A2;7I zZD`|mY4`255S<>cJfVv{N47ud9i?4iW{6{9!3Mn9z)twcmhg5Pm*l5n!KZ!vq$%@y zb=`Cz0d9-lIdRJ8(mAu*`iIz&UE6!Iw3*{u*Jlcj_ZAUlWpYSu?=2(C$2f)~QMjTi z#?QD+ajc^&C(dzSezb|Hnyh+v{?Yd5l*!J^>Q8pDpP6b1_*CrUYG-a5h za8)?~o^$oYcdCMkPV+`25H+Es+dyaF!@Xeg!~Lx+wWe!I)6%2J$2o`8=H(}HG3^#5 zx0PpYIc|qy`I5k7JkqC;d8YU zzRBle_2T#6CvBLN%kaCxzQ<17e529rvS3$eAno1BJ~h)PrB%JpI9^`uKZ0D&~371kjRE~5URUA+wQ z5Om|!hR_uJUIMYN%p&NF7OO+dn`05aO~`y>r1^ep;|J%uDK=~uURqRoSA*at?Q1x* zM(iF^TX0^t9$vNu1lbb}M-s`OR;dd=dQd_VV<;G&)+?6v+iP49II0;}ot0A_?OM{Z z`)%f8hwTjlNRUSPA{Q zDt6RGpdt=`_9@jDv1@cqJtbgv(Pts2boC8JN!&)5$df}_=(!>8ysg;)pJH?*rQ^q+ z;dO*lY>He`bn4P3cEuVv->a+Fvx$~(4T7CREw@BpiIVJQ|0Iz8Cj7;YSNDVb(Tuhb zBp@L;^P1ylOfG0>PEG^*`}aPj>jgwULzh-7hg7iqENH!d}5; zGjWwQ&_o@z__7SBEiSS-XPWDEb-PwcjCub?8=V5)vVmtM`FyxV(czOjbO;71yf-KQ z+B%jKd&0_QuV(9@;wpU72mL#+K5sN)y+fdVjcqgHOOblBzJ}Ep-E>Q@dEt7lXsqU` zzGCf^mGkX~y_H$`&x?bf7pY6jZQ$jt)le~t$BQX9S3(6_%Ir%v<5}7RedtVm z4<}d0Qc2EX!K6n{xe%C=E~`z{={urXam37UcdPfQNgDb-$QXT&HifI+54Bv=yYKW- zIyDgSvD-Wz6)fi^sbFkl+mA&Ceb>~Jj$iktBJwRmL*X$=$V(Ql(1zc+4Wd$C!1}d( z^c=w+SlTkJ8C7N*dxq=`hyUJ;Lpl5h3H7D*W{_D$O{t3eGh2T-f^c<7nX~eXZ@L9pz ztsRU(@O5&7|DEqo-PON~ zPQ}ZqQ2xwU>4+OE8lDdKsE^IQD=)|RfeW*#10N2lE2y8_SnOv$zHi}&>alYgy0^PZ z&k6IBPrG=!2tzq9`1(H-%pWTfg4$&QTR$N);~S{GnPp;{|IigP6y9gL%@x@|p4W_D zVp+XrQ9{ykH+NaG2*1r`CaB{7^*F_e)Z`lS|<@!`sdIFFm5M`d>x1<#pAB z5}wLS6y}Q1?hBQ4B&snwF7g>??%Kyuc+2ul_vNS-)MlD_Q_?@pijCzD#aeM!a@%HI z5U3-RTA_m9AsiTSI?Za}+g3$!o69VntFN^aG>rW`qVl@`)UWiZqNmB+>{PJ&pdB11 z{qTLJ&{=}0VnSo(&5>$R;S#=nq+e6Yfz7VE_I)f*Uj27iUY6jw0plN)vL}Vs3m(ME zwY*+=E*Ax=xXOi>yt5L>S`0482_NT{p9wpNV{>`L^0W{8-Zd@!D$%bj90f<>Y7^!v zIrH`Dbl`Ev^w6^DDm%k(!TPsiO`U`}qN!kHWyGR2(9!0}O{Ev!q|+};gy{_=o*iA4 zWqvOWLERZ94v$VCk~|HHT-0NCPw=EbUvi&{R~d*RPKA;BZa}!b$%cB~#!hecvsl4k zVqVp+$ZRU9ZP7r?`KZTW^@qEd>DJ?TY}A*f;waSGU)xW%c0618mN{s;zvP^X0?l$= z9`6j^xvTkmilzoI5r2J&iq{FTLHIE@C?{GQFOqlb_3e9T1lCMLQ(fw1uoCXiWG|B0 zHuLoS9;{6ivfD|b%d)6RQ$1X2Nq6Gs97p4HR-v8ANd?#2@^A~N_=vJ_+aCiC9f z7HR(<08T)$za1Y>r@Gb`jrDoI+UsE#`%k7m_9kD??2k_S+S&fuw)+$8B#ryqbxU`< z-<{-mhx6Q1tasq@-S2@Hd{<&lccmw<@NYl-;G7xXcr|dn>5z9t*&Uyu#8)2kc$fUr zH_xQYv%T|~7d_H3K6+I{{PSxMz3EwR^VI|W_0$FY7;2IeX-u*9xeu!DYma-|_kQ0Z zcgL?=Uwq@A{`fjiJsFp8`D)kuH201^!ts4c;Y;P{$=AO2WA9Vm&Lz)7Qg>}|2Mw-49@=UFX5n%G^8&y)KC8& zZi!gz>!K{qXoJe0sLm)boHR|#IL)*Ej{`+706VS#sqX*{Fab$0EfjG57?AiD2LfG8 z0{aMyDlo0sN{d#o18=bMO7EH8%xZW^(YmhgR*lgzNdsXG?N|`g+$#x{@X8Ep32$sS z*bnAr4+pET=01?wnhm405Wl!E>S*c$ix3K=5Xj)K45`q(nve-}Zn)U6xR%fq0|Ul@s&ZN}`(lw9!)_M0Fc({?2Yt{FuK;V7j2EZS z%|h|8BrL{Usv9HF80SeD_pBMu(HR}E&7!f(rcnq<3Gy#(=dlMmtka|jUWTsOzqmou=jW`d3#aNRn$#Xf=(;q9W z9-|XJJ0O(xf~I1OuZKr?MJbJ9Np0YKyV z6GB50JS8ER4$|eU^EwGLI~%eo*O8yx@~Pl5HB~b~*KkA?G&T(>F(Z^kqscZ6@-s~` zq1bT`2TUgwb3%O+HD6OMOVmVpltfPyu@*E6Z?x}P6iInZGaD*Hp|MFv@=3!oJTvkz zne;RTR7lbDKw+?_@Uct7FhxNmN0XFHjps!xG)8q)JGXO9tueeMX$gck#WI{ zt4@cCnMOnLA;_j5Nn5kqAZ+Te5pU2;eB(KR78Qe#Cu_mKww@jU}o zQz2_U(G(N&b0|I4rE1hjwbBe8)JhRkM1eF&S=9-z5+{qaL1*hvH}BO}(`Qg|vrP#v zQ5Cg5Ju@w<)H37kQNL6bUG-Hyc0*_M zTHloADeU2o)ctqHkLY-P(QU~Q58zw-@z~9dNdFlUIFS7g#g5aT$)H8b#&@;3b-FVG-$9C_OmwboKaTN}DLzQ@gcYW12eQz>+#kUKn zc7#9|3gpiOPcVPYVt>a1d<5h7+sbcOfFbC7`xh*@*DTU41O<@YQ1JQ!*O1JX`wDn~ zD_H0#n1QP=gGW#VIe3FhFbw0DbFHgeo6$Bc7==wja+B|bXEzpGIE7=_=`h&!XgF&F z5{6}1hw~wa1Fnam7Z!hbhlhAGY?$WgcMj1JiHS~_iEqw`Wlo4!`1pjliK}=Qs5soB z_=#H%i>=s;XF-aMPmAr>>$>=h&$xt{ag1?zi_zGPb-|3KlZ<)JgxmOy?--2j80D7O z2+xs^@vs<>vm%axO^YR%h4na)6WLMGagCj_{K|p+&<~O$`H``Kk|miN{x^KDSCKn8 zIl{R38ad@7DNI=@G}(AL*^^sYGX@#dMA_BaSG#DrmPwi39Qc9J)s=hsg6Eh? zDw!P4PnfM?CW_f6E*Y71S>zyCOo7=;6)c+9cACGDZIKmz8JDqqnVXAaP;EJvI~asN z*fW^9;VR*e>DJXmSDLHYPTg5gZ8ie~xSRCr*)qQQeQ#NVL2#ezub)E>fLpQw4VjME znVK&$U)|P47jbU`|#54vejH*QmR zggq9bR~joo`RkSqo|3JsVA_S}b)y}&qZ=2R5n6GR)o}IoZ5n>F^sl&P?1{JEU%Bjojto62e*RpyW zRZc&eq)qpy6&gPf)NqOVF^u|i#hR}}f~;lQti8#<(mE*-mU=nbXw!H&(Ke`wim?~^ zPLuFrr_5Y$))f)$xvw+(sP{~&OwDvW`?E7zFsB-DtNKsi+M?7pwN+c5Ds^TrNt<(8 zvu8UO{+h6xTCfZ4wximn7qzYJ^-5hfp&z@hxm2X(dZhXzxtV<8@+?UwsTtv1`VTQYPY4-c83&q>sPVy z=(`IPtKS(+&D*2X8^BXSq7m(fY)hBmo4wWA^MZA|H&b_2HNRIgyvLhj8JaS!I>0kL zxcQI3amm05t-U*($b2@t>AS%VA)U(!wfTFbvl^Z&{KYZW!WopFHTB%ag<-%+t-!m> zSwK6%Yx)nv)(Z4ms>zF>r*OZ?JH{zntGl`vW81WET*`U5e5PBwu6v@v!$MP5$bE~s zk@mg|VX`ePV|{eY5wXngTDkN3%i(Uy-yE2Zyu~`3$3Z-%Q%!nrn5_}pdZ~P@UYn;W z)wl!6wa2rv=h@HUT+yTaEv($VN9@ev8;FB^&vBa31y{JY8mJ+A$;rIQkK3yQ7Sk7< z)c<>>tUSw+q|WUe$la8|>y=v{9KW4b!iihFg?g*A*}rwV)O(%RFI&eweAVB3!5rPN zNn32&y#02%WFIiXeRR`{`_tn()ZOzFQry?-v;7R#`?ew-b$q5%M&7pR=Q+jo^o!-rzv$0&nh3ujsUEGVDN>999?-ju&+&J-+;2oP!cYT}d zo#C-X&R-hOz5T!?zQ;xTSG5(^xm)19+t8thgpL+f_Q#dETe7IofaCu9Ms9O&;t2 z6SM!fzOT7`;s^WNhn?8_y|@28;~o6m2UqR68q^DSxD#H~!P@Q{|L=%i(yu)7Q(e{b z{C04d913y*Yn($** zbf62(fFR|+xwgx#<>b#W%sVN`VZ8T4KhO_e7L)_?( z9tZ%!A#sQZ1d7TfbBIhfg_3Es%A889+AWuRgm%N?@%BkRqlwQ6yX}6%<7{(vEx#z^ z=SZzR?|;)p_5BeBMqv>qPGF@qIxc-ZN={N%T3%viYHo6NdVYfchKi1omYSZTrmC)n z8M2~170Axg*4p06x-!M$_WBkeA@Cvn7T@YFA;a!fF-|(a4ueHc22#t;lhRwBTlTin%04YT)WbW_ysa9ixVk(ZJbQ= zvVu!~q5*9W@IKO|@UFFr+NQMCEmgaIZKW@5zJEC{&aHbl@7}(D0}n3A^tISATN^Kl zJh@BQ%-hbcEnO#C+(U^A&#pZU_U+!kgAXr$Jo)lMy1%OV-1Fx3BVW32PBXsP>Gu6a zT8iF0f8hE10~nxymje=5pn=rvw+emlIp>~TcP+@C5)n#BgnbuQSK)O5Jb0joz$N&h zh$E6%qKPL879xr(mS&w$DzX^kXfDcFqm4J>n4^xBe6^#G|G*eckUtX1$&f`NnWU0S zGD+BxO+qOTjgCzjrInjJd8L+Ha@nPqJXZOom{Mx^Vwhv1S;m=Zve~AaZ+4j_oJ{_g zl$>)aStp)*ZSvWtpEln4=Z=Az(xae!5_+hji!$1%L5VuLBBY~C8mWnuTAHb*n{tXJ zhn=R0DJY_XN+7AFqME9zsh+y3d950`rmKU>I;*X>;%Y0bxq7#&Ccg5T9k9U?Tdc9P zmU*m=dLA{bvb!<+th8p+Qj4au)gpE+m>pspqqg0G8?LyEeLHSb9FvIvbd`83%*E2E27h|lj z#Tlm;a7P|GQju(HLuFOg8Z_~Gn1S$&7;OVBg;2un={Wp1I@3Y zLF;BS7e*6}vd>8~-88jGJI&b606`t~BUM|Swbo5yy)_wKZ?&}7W{Ewv*=I8fHrnt_ z4P@DC*E2HQTyxW{soHh7$JEex-_|zYe*euzgN z%How{o_XJzQ%$*oo^$>y=%JHdy0D|09v13~ub%qVt-Bt(?1NqYx7-{0J-aBk>rvjj z@50s2ckoaBUL5SjBcJ@e%Cj>3@O3}Wy!6xG#Vz$X=T3I^*GtX4_um^8{oKwU5B_=O zo1eb=j))I8`P;8gX8iNhk0AQs^Zw}l-{)VizyJS>4t@Xx%jB##K<@1CfC@BU`_}fp z1*%Sf5S$uAS2;xUhU+#?!3`*^`U z0y2<-93&wNX+}F4GLed0BqJLM7e7KWl9HSxB`aykOJXvUn%pEOJL$B?8aGM2KOB`s@d%Uj|ym%7{~FMH|BUjj3j!WuJw>;xnHi`rIc!`{~br0yLlk9VkHyYS4osG@%MzC_@|S(1$`aq8}2S zC`Bu3(TidD*DqSf{Tk6u6!ZfBboheOgYSWwI zG^aYHnE0PEMgby*v3MZv66l4WFt%2%Uag5AXF=7smE5#RyMSt9c^YsTUyeZ zHnpf-?P*)9TGz7nwW4k7W;<)!D?)JCBlT-~*k!JDv)f(nZWp@V4X<{`YhLi4SG?#Yu5N34+xPa)8o<@9 zTj|T&`p!4MceO8n^$XYj?$^Kg4KRWKE8zPg7{CTD@PHTW;08lD!3S1wgefdx20~x> z!5OabggcC33xn9h8vd_);cH^onSs39W0t*_^ zg!Z$d^L*$<8#>XD9<-w+oo7l%8q$xxbfYs}X-!+Y(DvgS~Eampk6!mN&WQZSQ*Ho89!* z_r3X@Z+M40cf!{$^|lKR)c`cv(Xr+=truS1h8vmT*j=~7kqv8zQ#|4ipSZLu?s1J< zT;j^cc*jBh@rZnnSO?Zj5Q+}k~OxZl0*bxk+{2(ISi;%UeG4jNg3aIZye|Pd@b2j{I8(?{~nXK4~?lvA-`oZo;w~z^`{W?9nZI zS=YX>vDfeHHJRSaS8w4ug1%BXe4W4x2 zpZ=L3$#k2V!JDB4o#RE=^)a7pcOWp?SYpTj^P(>VHjee zc9EeMs@fTLV7fKi?y;d9YL^sBA#G)w6Y8Pk)ZDbS-WN(3{RN`@4dVR0p&<4bBIcJO za$x%%A0uwscqJksPNE?`A|);&CN|=swO1bQ;U@+T6k?bUc3}&G9|!VW8lqwLjTm>? zUn-X3Dtccle&3_PqPfXpV!fa%q9VkpVjSw?6|&+k8W&vI;e0V6C@N#V)SS2392(-B z-ThhEX&N;u+}KqkHeTc7Z6o4!BeQ+uEKZ{~P9Zmzqc@%-IHsdGZejrTohL41JUS2( zTwOM5MxFXeo#uHQ<(V8mQk*|_96$=A_5I_s^`k)+WUU>fLM9|Z1|&j0B&;1{{{>hw z%Hu_f3-vjl%{5@|Wnwi7BS^}k!i6MBjwB0~Bubv-G^Qj=t|Sq1AxuJ}Ouiv0#-T~x zr2o02TUq}kM*3v6P#-6jVc|Jsr>SH|!eIfHVoi!+2sWh!(jil(p;RuV08XVj{5<{Vt!qg-Yl&@rA}*5A(ErC#PG zMBW@<9vfZ;CSVR5qe5+0WjBG&gsqhwMYWeQ*3%^(~ypJwJDXLj8> zULRVgkhCioR3Nvu+ zmgjP^<)r1K{Xw0So~Zvhos~{$bRyph8XuQB=;C=Lm?EW^b{d&NX#{qtF=8E(vMJ-x zW{D!;ReI(?KIke^A)Q93b^WQ}s_LjIsox>ch->Z^`Vqd^@+J3xkHvG(a==@#GRr)_>HtWvA{ z$SPWnV3h@{Rc@YBJ|4FY>hAHWM|!KClHs^kT)Bp;7@{jRf~!87X@9BUc256mwc6|0 zylGqUT_ncqdx~ls0;;%1B)bkQ_k}9xnH};qQ(_}-fPCj z$+gLjMIiN9$;G8+8 zb!zO-o{UB&+^n+YpQ_~q7UO`1X6vye*jXXdF0Il&t<-KI)K0B6J|@*>E!HM2*XreD z>fC@y>%|%=(3)+q2<^P~Uz09m*P3Y48m%oF>P`l#+{P{4ek27JW-r~}#DaJ1D>-xx9qGhXoEnnJa;KnRFvaLEYAUXD~IsPs>1~2cuCM5J7y7oZ_`Wa24k6N>B4*MR^(OB1@^7XHZGaLUsuHWjIVAT|EZugm-JYWY zColpZuuV2F1V8Zb#^mwNqy^XR%RZ?C6KUZUDTwxO2tQ5Z(&`FECo%dV3VWyK?jkVK zFXpx|470H3#_$Yt?hVf{=+-a}`=VDSaPKb@DayL+PW}GR_S13qON`# z3P*7Tm+2u^@g;h1C2|@TZ*d`Zu@+mY7k4M+;vyDTELeKzoP977v+=)Vui6ISBnqVc zs%5t(UFjm~ujcJso^BrBOK~4}9(_`w9+$4N--F3vKSK$I@@^ zwx>C2UiT8LCfn=_ZZc_hvfg2)!>%K_in3>ZvJQXgD6=x{e&p4;|0*QQ^0<(wzVh9z z2B|&L<=0uCDMF%S?JD{GDk6{U90GH%;$ty)C)K^HT@JFB-Li1OaxBksHDgN(PO1>c zsTZa%$=>Oiax%TCqrZ(aFVn0!OV$nYDrz2UIj?HOhF;vAtfTF#fx&Y1T5~>sM=eL8 zglaG`2l6}buqsRJ;Buzl>KL6mr#x#QL#MEeDKslqCjNHbII#s&Go@>DyLl$z~uIQy|RFXRGQj2y^sLFA6_r zYqrklL)WniH#8xNvJba25_V~fUZ^-bwSyLQE?4D3m+>BJJE)D?=+#*eP=mFq0H}1$ zYcyk^a|SXxn;m?v?gOJ2QkNH8&u`qW^x@I8QcLbPd$O0Q^f+JfnoXi3|A+Np&x$C9 zS|dv(McY&qpRD7ZGEOJDC4XpW1q^7AImOmMen6f7cb{|Hh0Fh6FV#}i*0ZUwrL|br*K`G z?xt+ha8%xPD4*u|vNUDCHmO22+sWvGCN(=hH6|yt@EWIU(=|m3nrioSa;x{Iben#% z=4nU0Yl}=7A>)QcVZhRZ*6bu?(cd#IEgGb zt3D?IhwuCrXS~fRjmB|ZZ?=E?b3_kWhr2aN9_28{w;2K~d|##bKdJVibu2YMc#QW+ zs!eKZy0Y^f;R3^$JEnC@T5b)4FI(TVeKxl<=d6)8FPf6BE*l+pTdr=Cazyhuk&YUS z|I7H5!-;$`H&J(UBC}yh?{$#lZHW%Kg1(|9Z+J8Nco>8DN4{{2mFv+~x2JL3kkTrn zel?cs`Iw}bjW5^By*H*Fxo%4$Fe~SOALq9=-t#{3{U-LJ@3s79_l0tEp$?^O$2FaQ zV0!O)r@P6Ua(Ohj@Q-$Pn5VhBwlsFzc4YQtoMZZ`$MZ_(C4OhO+jb*#M|vwCWgR+h zJ$rS7dwQ^g$N%aqMKbbB8zS81@aQI}OIIpN+pm#l>3Dni8cOtzTOecGx`$`*kB{kr zNiR>6wy=}CkX&o8OSrOcwWOPL^U7?F{&T+RpqQ#=R+_nWE4o47Y-4*WKdyL4ItTR2 z0{YI5d$}8Y;WTcsyLW5K`2_#FjIuMPm$baspk?AE6+5pFH)M2k_Qam+@)_*&PJ6Ng zwt*G@{K2C`QN5r!~-~?hhs33a})EoO^Y|P?;8f^V~Pi- zp!w|5v;B#zJkh2pCW~Ipm!GCXCbNDl88fr$hWXhCw5{UcGkg8{S*=l{H6F*MgGH;A zw|(P#3&+Cp)1$aFt{swhH_6j%Z^v{v*XqAEJ5^HZya)YAyE&@=c8G&K{-!D8JAUie z3&Kq^+|w(x(>IDzc+X$7A2P9DOYqW}`dL4`*#ms$8((HaFDN5tWI1$xE$VT*{`0#q zJ>*O3!i%&nPw{=CxIFJz1S5Lxvuu(hdtByX@dmGn4nN5eev$JxIcBu|6?xq*tITJ2$puG; zV(eAxm7be#HyfTMch%W4fnwZKC;>nK!9l`8!$ZVG#YM(O$4AIW$w|se%S+5m%}vfu z&ri@$(NWS;(^J$`)m7G3*H_qB*;(3J+giW?zg=EGzTc+1;IZMkI=r}Hwd60kWV^g) zGh*tpus7tnwLCn!;&17+DKal-FYTs@ds2)UnsenQ8tW$Un#GMAJ7QzTtqnalBA^`LdB9T2Z>D*gldluFzb| zPd)v5nh-AQ)`L$DvPw#iuG!oJOs^?Cqi;O*YU;|j)S>|DKLu;F@kShT)Nw~1d-U-T z1{>6B57Er3?L)aRJnfyqNK>*m-OxHuohJFK@GkRc+iok^tTSgy)kL)E!+3Vnus=Tx z`*BS++jR3yIOCLaPE>+?$O&FDRPj0v-`i_K)jH(P!m(Vk&#e_flhU>A0DB9+B~7a` z(6w@^tx*bT1J1ty==AhcP(u}UR8mWo%+5QPpn%CmQKYL(48uvZQ1`~H2&5BjnsqLk zR8r?kU|(&KN-`(CRaVmmy$_lhJ2iauS!kn`c3NtywKfw}A)3d+@|@%_Nme&AGB6%QpM$vUSeWi>!=DsGWVsA6x$tEU1D`#UB9ZmOR{#Lwe0q9&jXiK@z#$N zW_Ce~ecIe!_k9?OW{FpSefHaT|NU2tC$i!$r+#o`>0`WqaFq2{#NO4A%?5G)dtQlB zn6kJgE~c@rd1lg;?Ywlt%^B)zy$fO)&3HyMrqM?SOIT)*rw1+;5M>c7*_Z62 zJ(bmvOsm753Z1w;8b-;B3jAXXA^lT8IDY?b6?FMat-U?${} z)PZ9wQyD@V#x0bFTP7;^rzfb@%1NdyVH}TXDZ%jscA6X-uugfQTWuv_kPBux&3R6A zrjr!IoFD1HxWtuZEs8in-D+B{c*Sbol5;l+p!NFs%ykWGYl%bF0v)x*cy0**<0PXx z9r{p+Ml_=A%nBUAL(M~Bv1$GkpP`~g#fiC)dmt)gyq>o_2ZD2OP@)?v^_kITX32Uk z3(f>3djC_L=2WL@WG75TI@85D&2F?)S~piHJ7ij{Qs9&6OuP6o2ZC^;tOA*dI0jYM zA%USb-Dy|7`c<$FkD~N^DoPa^$Br`dpAU=PEaO?OpqkEI-}KX6sAodN;c$B4DWzSh z%7p2VGpvIpY+(%>vz|_Mq<$=ELeqwX0?AsMy#>h*Z>$Wu;=iJQ>UVH1HTj{p3ky>P|j>ah~q{E&m8P!YOv| zqw^A^E9HY3!M0Yq)1_{850qHk_N|jbYb7XYS3Ml2i>@!6=Dlo#F}pFUqzDwGdF{u@ zemoXeSW2iUO<%T4>*iO#`{i$CdP-Jnu2h8c%~vV2Xuit!7M9v2C0okwm$>BrbmREUUlRneCzS6QEmB(1c>SeT`z2jbQILLf@ zna)p>uUR$P7+qXvTie@S?Q7rc#>|F$dO~$-_E6f~jEYjOBSNsV<>*#fX3?ynRjNvr zS*vC9d6~54|8~Fq{co^Rbbhg}aJM)mlbYSkiq z9MU8wdeMzuMqLM(jwAE)ut`|)r_-ClqYNF#aeQD9RjJ^M?vqCmJ#VU)HhWsNZ~BgY zcC@EmgKqwI+t4j#nG?Fo1V!u51t{njyH%82G;nc=)?J~g&S~t6**Wq0Lr*_h|BmitK)xZZ?oQS(Z6Cf`i2VaVd#vggT)&>#PL*vI~6 z|Mf#QDBC;leU)2p7G5yE%ev`z+1u+c9<{okonC`$T8UUU_fqvh7$6T__W#qTe)XR< zOce9A@&-jc@GnbC_0~K2rM|70SD$m>j&|UpmM>zwt=Fw=l7u-QNz+(8}$kpIFX3fvq( z<2eKjJ{CkK1x&oDD?GA-I>L*)or|S|;xsuUEhFSEa%h<{qrV>PLNENnQ(?RPTPUZC zxVu?02)Z+R(=0D?vCJDVep16Ci64E6J)k4Og}dqJxIK*@QWff)H6MjIP@7K=_CI{QY=MN6o{mw z6bBqWOq;^;i$R||zGi|$988Q5} z{&*~B>%@V&p7WZsV{k#N60KD{xI9yx7F02qx~sPtpqD}_0t6fvw7{B5Ms-|Ac4P&U zQ$cKe!=h7OLeqL zB5^y$dbhDe!EFR88vI4(u@Y6HL?SxDy|YCnVmfQQJY&1D30xHP8apZgM7z99%oIet z)XRI)82+fANzp@CBTN*Wv1Fq~`(ewh`7Wi?sJa>-{!>k(BFz~}#4T*e%*;*Q)WFR= zi!K97D>E5u)Jge+&BB~A)>$iiq?WI7Gs$(>`G-J|>(y8T(F9+Z(~0Fuyzo+}uw6+)uUZ&gjfN&eO#q zRLh_#N%c#;_!Ko?jG=Y|!91d-zd|Gh@;mh4z6uni)-yZ)+)xfRFgl2gNH+q>K)c83 zOv2>?NW~&OVN|I-3r;D7Ih@otW2#Wf)5>{tzGG7lEeyK3?Eg?7{ZW7OPo!eFQ8KU+ zHLe(mOow@$_cP6{bkQiYIn-p%1=^Emq@k4ruS~g4T>?@s{nA_WPXbg*th__|T(!R} zJOgx0Q-i>tB(Q#DP7!p%(HKB@Q?gH^8mH`0Fzr)6MYG-{L%x(y2y#F0^QakgO#pqT zW}D1M1S#;NJ46JdMLoSvOwLSmPxCk)N8{5!{Zvrxu-@zs+iMaM1x-aVPdr^t2;|6F zdPJ@aB;L}bQdOqf%09?~P*BpIE~q^oUA#~YR$;9!yad8vSyf*Xwx6<7p`5nyL(c<^ zx?eoY0WC~Kl{qtwknj;rNFk>D99D1**RUGZ=ZiLYD<(lI490V-u^>FXEc>LkI;uc| z(_9V8(RUn!>h$9a)ksrw@%tI66cJTLJ?pfC8`p(-{Gmg;_9gS)09C zoTXW^1ONd@00?*hn|%NV7yt*Tx&>X=L?R>x1u&4D$Tr0qa`USsxPX~WRx*G9mmR!T z+|mv6)RO&Lub4_CK-%GuN-UU50Y$S9P+4}ztVG*{RK-HrVx!kx# z7WL&^mz7@4;49`W00Wko4{+HBh(+3r)YnBk`czl*eV}?Qz`$IZ6_&7>Xx}_=@maSUS@ZTr^;0GK61aR2|a5d{KfTAq}0Wg5oec9|Shs#Z4Gqza< znBnEcT9;*j3z&o6m0kf*-Ru1TCUAf@cG(XY<2i7E1nAwFC0gzyf;A=pKbGSHc;hH2 zWH+wkDJkPKwtxgE+9dho0{8&w#b7YF06i`M1Q^<$kYgl5wn*vP;e}#Xer2&4+c))s z{q>*1t%B&4sw@a&K^`3gNB{#6fMmGbw;cg7xZGeq0rY&TK};*)MK_!;S;Tmh8s1 zSwR+p-o0TaU}u*W0%_*szKCkXRa_iqx%Y_QE-h=_-fd_(YrPc7B)D7zumI2X+|M=c z5T^FUqy3!JcGVhay*G3UU71<|V@P5?OP5WiWSHG(y-@j{VfO74OI zciRzH@<&eb6gP4wAaW_dVmCnOk#1o%@NqBT@!b+|D>f5kmI#DK0ze0FVJLATkYgQ3 z@zLVsa<>D~FpgF4Sy zBgpj*zyI?X|7(&^-zN}u)V}C;Ky>{sfEh>f5rA&u2J)C--ZaMawY70*zi(hS21LJO zZAkJ~e{&hfao@#Amg|)t19MbQcXf9TRj)6*V(3N37eOZgg_cDn$mulh0zU?3ewXSq z;P4JVa%|7{CU|rq5Oi*zhDOH&X=iGqe%sBJ3qxN5K;QLDxAlxzV$PkYJ;Xfv-4suKc(ehq_HXt?x($aWh5u;}VmTLwrRh4*4>$J}8+_=}!~hL?hc z7lLX|XZa9xiAa2)zjlike7FyHYS8#%FMw{Q^Gv^7bq>OjCulpI&&1C3nCNRhdNBz4^tJjk;Pzt&{4nr+BFJgZp8VFn zgTs%5=e2{v-2wn#gM=US0{HwlcxU9df)l?0$_y?P=FMP24q6GEMAb7B~dsnRG|b2@?u>k zkcQ~6S$JL%NJp5M-45RXba{AJ4;(@dyAW^yUW|_`wJW_JpW8wY+LZGDZMt-a0N?fnh@tr~z@ZhnqBDrv?pG^X`=nMburPdB9LegQ3 zMbJFqXIEatOBTVzbLATfdGyds($?yPGO$yd9z>bT&;M zg?F!dmWKT%fndP`=p=yW`*P5R(~)A&`Ip1~8DwPIz?q!3Y%>lB9+ zT~0ZL+$!1)No0{m9*Jan(wr5cgDQ+>S#y&|x8F z7F6K??wR?}6?Xh&+7NO2W0_KKW<;70gPDoI2mm~0ddf!WbU zdYyL~g^C$)DYwPaLFPygd>}weE`}l4r4SepuS5VikN}+(UYS&;1T-bYQ=qnraKZ{N z%y7dFKMZli3nQQktW`EeL9b!WL2fW%5K&2-aFKMi%%Qcpe0s}{HQuS`3Sd~;H_ zvNs&cd3F4?%>jPhE5KZ<9m>a|Ve1*3z;LzP_Hg=|Ja*nZ>nt?eT>IJy!2=r{Kp<2f zj(FmVFV1-5jz3O~)mNLth2UCq#NOChr;YZebgQhk*r4kTcIKLQE(o}HdoJ_VLf5@9 z>}0!scF8H*{WqtDAq{!&zW)w<@WKyIypxhoehLu;&)jzDI=>C3^O)bAsd(k}O*ZGw z@6D^*#s*C5<(G5+y7=G^t^4+3ioF!ewp-47@$%14fBp8~kAMEd9FIJ~$>)rv&V?^) zc^lpMe3iQJd5(2wTU`ThCq19dif;!@pan_yH@KOuWGo9H=MvbxwH>V-2h*PkPydL* z6smBAEMyh`Ca1inOfYP@n;y=1CbtU~aBe!RnE`Q#H zzz0Godg3@g!q}&T)WR%kaf@8+q8Gntl^0&EV3s?d?Mw%_)iv>eL1dowhPJ>3+HH;^ zjGY5BNJHqY&w~%d*$z9{K26cFdAQq~DRvh>Fe-A9jBKPMAIV5D!Vrjp#9R7&s6nqe z&}SMhq4pp%`!=}0LJcmM8qB`RjJB14w8KN5Zqcy2+3UPa+kd9r7xdG zF>>^;Ed{$H2jf`5=235YJvycog^0#u5)y-clqM!=rYWsaGL!N`pCW5SpC0 zB^G~)&UC7Co$PEUb|g7T3aW0Cm2BW8#W_F$%FTe+oTfej_`34>GnAe5kpQ*nwl&i8 znc-C65b?kkRT$2l6s@R5FN)EOUIU&m^d8XynMSh-lzB{i${eR?oOKHqsQ1 zXRRvT%IeHLesiR8%p*?0I@pFv&#V;OX-H$4&F7VqaBGdLWiN}_%xV^KbM;PAmpDzp zE-{eNLzyOJnab39Riy^4-d1NRQ?FvxpqP}WUZ42V5Z1G30a5E_e+%5;3U|1p>a1sr zIWXM{%(b$`9Rg2>MTdw`NCRAYD(6xG(>G0m#fSkhIO%c-P{!tN!;|RcfIUw zFIkMs7LVN%w5F}8S<|P`hG2GclqsDq;H(*4ruvtm5{%{)x4GS$vQ)66{pidp zEp@da4zMQ0M?2Z}lVuQ&FbNAcCk>iWpfi3v7pZl&-X;-TCHNJ=(!OjIu+`M_tCUQqlAE zGFSbH=c~R=tO+e__Ug%8*s^-I0#0R_Asy*U>w4F`?)4Wl9jjtuOv>K7vtDcH%_h@1 zq$IVoh?Cv5*5+5ifd)(v(9@`{pP|G3Lv4)d5pso|6(1(In7W0Xtt zy3ai+#kc(L-3lGioqqU?$y4A0ab?I!F!)>u8v8x-! zQawNFK-~XN!nwO(zXxDed z^Qz;ev-;^~1bT>ro^iZ8HPpGS{CBZzc#;;H9wLq%VCl1fP!N4!h%` z7Aw(9{z`>*wZPWq_}IBQ@RnAafsFoS>q>{%0@Hihh8DZYd~5pTD}VXSw*~5x2QnFU zo|aJuXxw4%^tO9c*#IxJiES+W`nJBb_!js$Kt)^1qDP*Hp9gpI-w*%z_x<_*6O;6D z9_`vw|FenTcVB1!cKzo>W!HVEgnx&&Zq&wLy@zb{wNS-Za-Onx^8ZJH6=;DAM}Ob3 z2!gX@NaSY-s9YdteSY+TE=EA$)o!1bQwqph{#SPo=Usf3Wq|gAtJ8jG)^Zo9gFDEB zZI*%GQ7{j6c(~_#3wUjNW^|-ucgs{|cBObe$7Nxca`F~l&?RyL=z<_MfK(-cIM{<< z2!>&ZVLymc#Wr9N7;=?#dqTK3#3xXp1$EfQb6*E=Z0Bm$*FpWJg|3H8c_?NW10)kD zhJ#3mg{WCu$8`?pVH~)7eMUzxc6r>Tae=0UB{+HD1a+;WfadpaSU7h>IDPsBT07Q- zUTBD?h>EEgR6lrsfAwlkrgG&+Xe2m_ASH1%2z&2Fd3ZH(cu44RRESMA_<(!}ZN?X8 zX4Ye&r;5wSjLpbPtXOc*qgqsGQft?X{8fh=7geVug(;|GdADMHID8)Ei~ZJm4)|q; zRfc+Jiq8M&j_(MMUBrs=1%xpsUG^4*dK60H=YTI3aI}|=#CLx8_=-6fj;NzxY4~~! zmS|`w3H}3d?iY^}Ns$#fKhXFiGRAI!MQ?<+YM-S<#5Rj6mVF}`g7xNSk7bSQLyUPA zeomEkY&U8=){YiwlQ)Tzj)Re#1_ek4Xnb~C)p%u@=aA!tc&D~yTg6W`HDYZYl117qe+@hf{2Nzh(k$$fdzt7iE$w`M0d804~3Q@gMX39 zi)E!pUB-DkHIptgkU<%0k0z0&|7o1ZiJZ^@mL9{1zG;m4_GfJ9koGuT!iQ%{^q8TT zWJW2KmU(yoSbSH8h2dFzfQXD)iDt>Ep6kh;+fgy787ouwmP==s7gv@TB%4jwn;F!3 zvAC9DH%|mri=$YE-d2tumW$K52dKb~=$V7<384`xq2K_Ptq6>Fn3VNJU`MHb1o@KQ zwuubbfs$B-bqQv!L>$*ClGa6x9VmUES(6hgqccjQ(omsg)|}mGdys{Kq@;5sx_gMV zn*f@Hi`R&d*kDOoQuoAS9*JYSX>qMcbu~(*RcfWgP@$wKj<9$^i3yIg=7INld2Y9a z;|Yh)nO#bnnZ~ANIpYzgl~7EBbP+kRbpsaod`;f zB?*})8Cxv5pK%$1jwNOvI#6`CI}Pffc>jv2nR=&s%9ET4PHqXCAxD*|riLSkqGNfE zkJy-zYK0`onAI0%R+gY(nh+3qo|merxvHx&dYbTwZk7{?l-i%s32S4jkSch9ycvol z>8E09aYXr=l=z88s)CsKXr&mVyQ;0*TADX%p^$2u+&G2DYG8TjU3vMZ{Mn6#I$k}> zgqHVlx9NQgS%lO&J2#k&+^VnpDwy7CtCA|L%xY!VXsG$Aj?^YH9Os#}8LPi%hmU!J z2fB~;30ZuIJu4@O{Hn1VYmxq1n8_t*s!4wR_o)NBp=auSO4Vegnqd4Uq;Kko&UoIU@mwOjjzKN~|YieM%QwNK`s+jOoG*NGK3vt*mJW+#U`q^=idt+W-VaO;no z2XVKWtz65ueOrM)i)22@kmXvkF8e?(#hq#Ss$y%B#)e`s7`G&OvaaZEbp}(8X@FQ8 zntscZ@y^YPc(cBl?OUR=0Clq@u`uio2H5ijb+5Uj^!@3afcjG@LPpt$pcZzU#f; z8*`vrM&uc`DA{;3=!a7lvG|mxQOT6umXHjKnXT)#tV_37XfqrsbufBn$_T#y3%~~_ zzK$wZx?_oiPH3>RMq>&pu1hPgC`q(6ON+DUZ@6i46AO|+2?_+fz1;ht0sjobA$(~A zT&U3-sxN76#7lnS3b)+HrUbdFY&(SZwsU_bcNto{^xJg2rMDg|!aoeel{UaebH7h3 z!)E)oahbWS`l{P^Z~1k#b%=eHn~81Pk7HJ%^jb?8ySYK^#b5kmB%F_9yQHgjzQP)N zq`P48J9N^EdsQ~5oHn1ZdyZ$?b2bZ0B-_KO(7j;n$A3&*L;RLZDvbpypi~*4)ta!4 zrM8X-chWY+Nm!!hh;&g5TSCalzqi3$ytjbt$)Ai_Vyu)mTguv;xmrTq;i^pcV!jIZ>ItrW`N`wiEfCp=g z7_81Y6^?hg1V*!Z`rOC<4ALR}NZ||sz?yEy?8YX#zSQceTj-d?3wBD^iyA4SeL8~g zE405^&7k$7RKm?74b(v`R{$N*!PbUJES*niohbdhiCf0EOvqPUc=h~To2ZyK4X1+T znzdXsr66Z~nXy7G)??j5L#%ocU4-%*y{^l6w_L*WT+u8XugYw)Q1#RrjHVE*SbHcy zsprwWOxA%d*z_~Rhty{!N~Oe@jJNB0#-)qMcldKbNwPKVf7YbG2^_oTm6BVGYaNWM zgZ~ZMp)EfojLA;v(xXbj_=b-ayUMayf=(N(j2o?$UAWVn!K>-k0^`@aEZV*8+rJac zMLl?g$C5BSWpwP4^qI}bwvvXI#_!77#5|00Eqe`Y%$nWU6+*zL;Msf(*uM?l;q5fy zt2}x8eY(fQ&e_Om>(+D4woN*1=!~*4U6sw2$ROE(t{Qb=Fusczy5{7+;w|6Aui&oLf$5kqh*YJ zglk``3*XU=undgRGCju~YSR;a*TPu0Q*EjbC%T_GIC~7@BQE4a9wnkoN$hR91|yo< z?Fzw!72PqLZjo%m^qs|O*?URtu25`+8|~E`JVIejH~} zNw&Cqxm4?+{ORcUeA|$nf34=_mipO)4(g$P9ffV?^_Z*!tFrh+-v~a*vI>_psAx*} z*w{>B&lk3)>pnB7$UBV5Gy(>J8 zkL%)++@=ow&iV*^`K!5|EbQSf?q5;ZJRF%emE)BxoJg$I1wA{E%)&=qYrE*i3mv2v zyr#J5rw?7~{M*0d4)6i*73Kcb<>cB`+k#Cjc^s!3fHbw_g}l7cL#!>9g2R~8GM?L# zS)ZKz=>q@n@gE-!e{l3Rg@zG!J_kSM@ z1kc@;>sZv@vYMUe>CTXlOPx}G=`$;)7*1;}i(7HL<^E;ZGx_0t5BQz$`L|%!C0&22 zjLVh3<}|p(CR*{Td+=}mrA+_UM62I%U$iK#ssN++VBPVb&-=YU3&cM3eV+Ny35&Na z(Xk+(*K;hvPEM%gnW99Oy)ikd271vNiLS+z_1CKV-tGJU-S7RcQ06@jqDF|mweO$x zuIYE5fXbcHlt`Cw@1)sC>2$WXAz7HZyz^J-Qvv?{{SO2H;gDD~9+64qlG$`Vp;76S zTD4xWS?!kF^?s27F@;<E%{b$5AteSd+2 zg@=iYjgOI&m6w^Dou8qjrKhQ@t)*<&w70oq0bXAhGf@?1IVAzpMN3)R&rcsq-~S9Z zNmte3%j4baJNGx`LpRWzyg%vWy=$ckTp2mw>V%``RZGk@ix)9w)VPsjM~@#th7>uH zWJ!}JQD)56jb+PNjOZQuH;~`Hf8=ahN(E}1CWA8<0?mYQ-KTZ#Od(yAkX_BAHErIT zr?6iiG;PYvilfa|O0QqRh7~)OY+18s(WX_qmhH!tF5zx_=tf*nyGEnlZG|_fOo1~G zHrRi^pHxtRs+Savo zckkc-!G{+=o_u-p=S9DcyDXFZ|O7hez(WUlY z(Q2vpWV{b37sK4~MirM6az(@p%*V&#D0|MS$>N(+p}%x1a6h`f;gd70whI7(D+w*s z&_fYTRMABly&_93SpxJriRAMSx#1dgY)Bu?%&V}&?EDnSA(3mbur{m1lP?$p!f{iD zRfQ_$IwX%|B+w=s%~jW3dF|ELUx5YdQApu}={Zx~Jg~?yPyNy_Hi^uWxV$El>dj@T zL-y4(_bljB=IY$!jR+x(8E` zV-_lC!4xy)HCtY>ui&T^Rk48yi}l$%8@~DKvxl=UQ;U;S(pGexuG{Xr@y=WCjiHX3 z4KJ$|vutSNTwCz5`E$5zlF?kaZ3Ycb++>#>RFLXXSMIOPq4_isj|nf{TlCRMFWvOh zGZZlZzqh1|#%&FxO~GT=o_)bQXJ6PKHNQsCZ5nkCoaSZ=M|s)(7B|+pitPv;_2;3F zUi#^&FR5?Vy&>5)%)ySeRlbCL_*+yZuUQ&$n7wV=H*j;tRIi93PEpRPYPXl@M?m*% z(CYc`-~ays7(g&w4||(&or})bGkr}?YYp7lzT`)r`r*8IvnzZ z!@RASK|)I^ngCJ%m_ikeL5us6HQ> z%|5FVpG+zzL#ka6f|DxY+@NShJ?@c@ee|P;u*kX?dXW_eJ09lRxIG{$&|hiWAsRa< zMu0i&IE3R}xSaK$z|~NT<#OV=^7uzlj*^t6H04?XDM&3U%rLq8T?E0`J0x;YKav64 z`qK8cO2PwK0*ezE>BPk`xshdFLJ}?RXfe8-l9|nP<};ybjZ{W$kbZCvC0{8=52CM; z1>2J$!$(9`L}HLPT`Z@|6sN2V&M|(4WLz-q=O;=$Q4JM}<~{M5PkruF3#uz%q*i&r z%srBVUK`^Mf!R%OF3y56n$rY-*+JPc3~{*3?K_mb0CO7hdy%^(bO3SlX#W7Tg#aF*3J0lT#RBmp? zBW67tT;UFvxQPNSq%>+xgc59`3GJ!|-RQXEsLfTFjOjKhD^o$=((ZJ z6O_O$?s?IhUiDh&xXATnBsr<6M9TJiwhN=qCfnMv?v$7|(^>h<$HU8gv#MqN>qgtU zTj#C+m%$Bo@PpgZUM+p1OM3;Oro3uc$!>A5{BkX8XUo$9yL6YsL)>45$Eky56lS2a zEOyKbVHwYu#x)L!KmR$hi~5wJH3QyjFR5J-qf@r9-EKCuwa}EJGlFK^>WAs*rEzt1 zgrT(Ym9d;_JKK3|UL~l^W`x>U?QY-ov>U#P^`oZcV&8f>8!O^Drw98EJY`#&E@iQ4B68zSNXIA1h zxB1OaN%6SMr8-vKH-PEK*E#F3$Xo^bO>O$>aPqt4rahVaGW^hC(#CU?TKHvWj&rSV zo$GqUagIlp9jVf8>2j9gbR)Xjh}Kf&EN1SHM;moSXB)6YXS%ZR`{PezMdqS*c-H|R zc)?ds=Q}qSUz6xzaSzy*!(3p?UHNGk0Zt=TTiwIJ796Y~hm75A2d8V5JLbR5I_m$CpM2$i zO!f7eqY<@TUBt7w-ant#b7e#z1U4GuV?H~HJ?|_dJ`VGG0R7cN?RT0lpZ)E3f3KQf zFD;T>q0$EG(b3N2=O6sTKO6q{xe2VeEqHv*zaA)w-~(=o58Yzy{T%QCA+S*5FE*-4 zy2ek_qA%ZEuFkG4of=Q$wr}9hDY@V&*2wS4-VCF}?fev>sMIe4VK4?|@Ixk$0v`mD z6fc9kZOOFH&YK`$G%GxYS_)swV$SBI{ZUeoqlZs7$FtHYG@fJHm6Iai5CM|uoZN;oc z48O|Ka&QN|#)DWX2oXlH6fH-v3)axgk(#S{p0MxEVAOCi8?|v8l|dIbQ3%0}(;R32 zG7q}^5Khc&4F_)O#_Ey=FyktZ<*w@O7>ETwM-H#?@3?Uv`LQ2$!5fQ3>SAm6(9rJq zYXoa8wWd$yxD6DKNR0r*3l#`{%+MTF>R}d;w;IqNIkF=?k_v1gBt>#0NwOqO@+48x z7F0)2R5B!2awS>PC1o-uTQVkL@+N1}CTp@LVUj0zS5kFe@+Wh%kFqIqk|vLGDR+`8andP`vMP15DyQ-)d-5bdGAzY%EafF7&GIbKGA-3|E!nay z-SRErGA`wEF6pu^?Gh!)J2Efz(j)J(Fa7c_0W&ZKb1(_BFb8ul5i>Crb1@mSF&*1b22HjGA;8mF*7qYb2B-!Gd=S&C_ytcMRPPsvouZfG*L4(RdY32vo&4wHDNP0 zWpg%Zvo>w>HgPjIbtZE+d9yct^EZJrIE8aKiL*G3^Ei<+IhAucnX@^a^Esh2I;C?u zB&oAHt@ApuGds0&JGrwvz4JT4Gd#s}Jjt^>&GS6bGdBy?;WIwvb3W;_ zKJD{9@iRa5b3ggBKmGGR0W?4bA#^|qv_K8?KoK-S6?8!vv_T#8K_N6kC3He5v_dWP zLNOsULp5|mIkZDP^g}^3L`8H&Nwh>w^h8lKMOAb~S+qr69`r?FG)84~MrpK0ZS+QQ zG)HxGM|reIee_3xG)RRWbV!M`NR9MJku*t_bV-@CNuBgbp)^XRbV{kTO08`4O0hIc zwRB6lv`f7dWV%u)nNm!pl1#&NOwqJV)AUT$v`yLcP0ch;*Hlj3luqGvPVux(^Yl*j zv`_i;PwiACzjRQ$&M*!2P!Tmz6?IV=6)*|)QE{$OC3R9MwNfqhQsoj-HC5uUF&jHJ zbyGq0*go|ZM>SMUHPA}67E|?9SvAX6^%7sTRb_SA^wCpmbyjh;##E;tk1AJv70Y6^ z5`#ANSBZ7FhSd&_wOEyPvyzn#n{`>CRj!_u3a2$%t(ByzRSL89TDf(fwp9kdwOhp% zns${R!Js z@bv>Ub_hlQ0q|7-w%}urL1wL_WJmU9{l{j@!T|7f1d_l3P__U-Kn}8iW{bgSsYGZ0 zadv54XkEWU0PuAHmcRpCwg3`fWAD{tv-V1qmTA5AbjsB{KHy#-pa}?eYlq-$6#xnF zbqF}tVDD7~+%|3jzyS=PYBj(BRyJ+v_6XWm2;LTM<#q_>^=&nP0S+KTy7p@sS9BH^ zEUMOLnScQB^#X|CY%SMs(Kc`mHgf@hV-;Wo;Pzf6_j2#mbBo|_KX-IbcL*GyYT?!b zh!#D1wQ+4%j~>@66aZiMHVHV^08F<4U{-TGRtQEQ0TSQ=?iG2FS7U$oUMm)PmG@qU zmv4oaZPE64lXrQUS7WQz0!BapE?@wrmjNy$cW*a+mj-=NVgShY3C`96f){mIi9mJd zws?&oaH}^CI97h=7klydevjY+ptlGJV0!nrJlI!#8Q5bM*d*fj2}HL5%GP=lcnItl zaM891I2HkVmjJ331A_O0g&>1hH-N48YKy>qt(Lu<|F(f)n0g+#BqG=epmu{X7l8A3 zgNcA+rS=HMwuP;ieleJbPxuH}*nEq?0W1J~U3mRsIEtrdhDoA^n}C3En0j@XgN;~& zulEQhSc!$;Y)KdZOc-f}zya{p0*Dwxs91{OI9T3TB&?VTCRhXhmw0y=h3~hG0pN!H zcZ^$CkGYtSg@A?el>r8zch6&v?=_GCfQ+}ehr9TM%NUC>S(6Ew zlMPu2E?{~K_5n&4JtP^DUHPtFxHWFL02uiQem9Se;EX$&WBnMAaXEF37>$?M0D?AU z&o`D|IhmQpkwcmFu^fOF0QB`H9K-pY6t+KVp`D_-FAIoTV9| z^O=u@`GozL0q(V+-C2mKnW4MbgiE-S!$Y6}x}#}OoLPfs-x&!`_FfU7odF<;fmoNP znUo&@rR$b=i9n@|Ac!rxrG+49VcLufAf2%Uq&+&Q*@dG$qNMLtZ;fDrd%AaL8Fd8! zmH9Y`CHj5~00P?ifbZ3wiJ+;sT4Ie@s*OOTwR)xtn0A3$sL>i&Y8O3>nrti900I{Q z*4CES*Z>$nVi{n1lNw(Qpn2~XWf35HKcH$oc#reduJKy0i$JMH`+8r60E91j0K``S zY0L32vEBKJol+t z7Xk2DfzNv$w0ra(&iNFCqV0s_m1BO?a2_ONA*>mSv02ttV2Vkeu zIJwD|x#QUgp1YV6cm(o#jj!8{ce}UI`%r${LO8aJIcB}nJHAZ=zRzQxm4?3MJHJ5# zzssY&*Y~{jJHVY}tvklPlZL+qJi*BV!N$W*B|O7T0>hDw z!!`WFLBhjNO~gUG#5ID%OAW z$^|{W3VqO%{Ls%!(GlIp8J)Nuz0q&H&b2YpB|XM1J-GNX{L(pn#+`fyGriM6Jk&dj z)I~kRO`Tak{nT0gA6GpFQ@zy%oYHL-ie>%Qi6PeQYSXT9z1Mv~*Xs({eZ9AZeXfdq z*fZN^lM*VKec7GeCYj*bn|&*%9VnZC+NC|)sr}lqz1ypO*}47OyQ-0@p9+_>S<%QnlQ-0`;-sNB2D-K@f$C>7R{^xD| z%@_XZF}~-i-r|e?=#4(>wO;6ted(D#nVtUXv0lo{lI+XA+>1cx(ca%HFW z%Tn&=UhP8z?8Dxb$G+`7KIf<2@7*4mRY&j#KkB7k!&zYaa6_9_TmU??r#~C1CbP|K(5r zAN3Jl^*NvQUH|n<9{|{1@NFOTi{J6TKKFGW-PJwzWj~e+|Mxq;_k&+`XusrN)%vf0 zM5aL zmTRo+EN!jrEpD#vE^n{zFL1E%FmbUOY#XKLG6}MjGIOW16ar_obW-T*_ULpER4RA4 zkhgJJwN`n$`Z|06yZbwQy!)7Oe*PW-3}|g2q=HPqiQ}el3B!i0 zvgJcXu_DHd8aHz6=}x~s&t0#)V7-v|S12Ci6%&DLOB z7wHY}-0tWB($J??CL(gAF?PAcPT0SkZz{bmz}^EoM=6*?%Tm0Ef!rkQ3IDN0se zdEAq7E{dZzjy4)3rxZ@ADXXoz`YNn_t|`-@%OuKZs-fC?Dv7+>nd*?S3Og*Z#TvWf ztT%dglCq!XS}QiG0@blG>%I&%!Gfy}GR`^cymPD*XYn)8K?^#xf`JMFcH5xea$lAb&7z5D(< z7j*+K3+}@ke?0PY7oWTrlQVMvJoM4OzP$7pSARYB*;h`z_Um^4J^114Jn=c@8*+Xo z=$oHD`-RWWz4-CVKmX&qt9^LWhN-`Q`~0u(yvg84zX1}kfbZuds%Q=T+1EY@!axB~gL)=BAqtsfzn8(xgEaG?2;GJx3#PD! zHoPHHIOwwRys%WLs^AK7D8wNWF)lhB*f*^8D<57fP&6ze6r(7`_5?9MJB-4#mPoK4 zI(l6RYC;C&bY>dc?FGZ ztYjB4>B&{HvXy0wq!XHDmkjk#f28DzCTWvPSMsu#QAD7jGN!Rn)^d+rR3t2Ysmx`1 zl8|_`WigX@%bV;{HhyGqCN{IFO=r%Mj9iK4YG~%hB~??K<~-;3xH(K0@)41G{8Ks6 zDbIO6Po0}#=P{)TPFW?Bp8ou22Dv88ds;GH#B3)rRSD3AGW3MN^nw{JK}%CT@RXMX zqCy+0(T&p2o>Adc77s}ciyAbeF1jX1Y*VUI0(#Uq6`iI;CK{x1Dk+{Ty(v!L*3#h| zG@s!_C_MX_)1eYIbvva|K@rx|LKb31gf8G8r7|GwO$YHLPY zOMYeSS%rGq+0wSDpd9E&LpWD)FpH#0b?l1nT4mee5_fB|r4?>nn$%LZR*f`GQeBHH z-Ra(Dvi@2wE}L6Y!BV%o-lde!Vu)R8YWKI?{VsabdnSR3H#F@vCT3#`T=mkozVb=! z7Gjto5x(@jwzV&S1FWLx-p{{KINJyNcw7NKGq}N#Au#du``j5W*ufUQFk2uzJPl*G z!yX>pe5J_Y50kjW602|uM{MF1v-lSHjiZWNEaMs1F~m_6E{!*HtYeYHn0GziF_43N zTVDQH$VNUg8H!vxB_p}XPWBCoQOM*cQ@P51nJ^uvY~?LGSn~ z(fD&?pCgTDM;{K$cp$W)GmT+HC)m9rr1aoG9qLegn$)G91*SEvY6ia;%E{%a2dUb! zLbC9go^IgG{z`~cue#UuwQj&hDQhdzde+5;?`Q$5>p_6o*OktG_7#B5GGZTFnxDO6 zH#-K~Z!WvW(Ec{KgXsoYu2Z+e#^P!d@}qRCTifxCrEi(b=46~X-1bJYv_%=~YfI8x z7(2IUotx`zd;8u7KX{gI3vjM1oZ(2cHoxJ$TZhCft`VL$!ZY6FR{uNVS!#HdJpOT! zzpL9>GV36Ct?`xPHQ%h=_rtMG&{}7E;GqrpMI~31Z`qitNbFFXv;!yv3p{2eOm=nFJenK7>WHUHUF0!a0GtjW1r}8S!i2C1U+)>9_1&N2A)nf<9pEWn1(sl5 zUEqHFUiGOUWd)xGPFzpaUg)9P1e%}?auf<`MbI&d{>@$I`CIW(90axz_toGH9^pcT z9|O`I^{t>So!U=YTjoweT|ydM%~;SkN=0wUn%iQl{MU|q}+BKYxM z6CxH1w%`~l;1@0-+Yuc9VWRsfA|rO9{5hh~C7^OeVz9X&5UO8x4dLSr8Yg<9D>jl6 zdf|aZAqLtR6h5LPPT@pUp&|AG7QP}drjs{02`#3eC?*~$VnYuO;vx2*F79Fxl3*}S zqc8p80eWC1V&m&k;w-`-G(O{3v|=}PQX@GMl{F^gv~?jFW}`NS;Qw8sCZ?e@mLolK z5~tzXt>I%n;*Y5PjsJllR~T6G)gwXTpd<2^sWs5_of-ozT0%D3Lq1wG79>S76nimb zHwt6}YNW(*q~GYHTzTXORwPOK6Bo*tL-Lq7mLy9W&@XO|={B{bOg_>}VopuYBu;wH zO+%UR0JTBfC2c1u%2QCq$xT^h?=zD-qU)}>x%5Bbnp zpta8?{H0);(qKwjTuR7Y?xkX~%0+fsNMhthie#B2=3-W+rQ9P#Hl$-plVsA(R9Yrz zX3aUGVk&|oVVqZJrY80r<`2qY8qMHU#p4@&rfSya$&6+h!Y0n?W}oC=rEUJ^^njum zk|ue**3zLSa3&}4;AU)!Aj%*jInkzaPG`O_=Wga*64;R9Nsx-X%goHDj0`g+9;5+6^P1Wcv@V6?kI!?sgm-H zeXiq1=4X4--j5y-jV`H`*3Uc+-(6TJ5%TCAt|*pXshE1@kltq|j;M|%9deE-n=*`g z3Zj5UsqvN9gZ@{W-l@JUC_JL6oC=|jY2qoMerT4qs7&#xH-2N12CAaM<#C}OGRhT{ z9%?vF>3qoNqF$;{{2`aFVs0uGn6`+eVydXBb$O|jzG!Mx>4UNAsJ`k3X=oW} z>GRPkl4hrjdM2g9Dz8R}o}MXuK4&i~Nq=GXs>nK~(h8Y*;}=&l|sv{q`P5~`d@ z>S;pLup;ZTMk}{who-7(bI7W&I%A@CtGUK!vX9*d7tDdX8s>PcUVw_rQ zhQ_FolGeQTD}|LR9g--sW-F2uNUOgd?0p?1WCrBIek4d{RlFjs#12JcK5S$*tc>W_ znocaoJ_u%3T10~EK8kF)cC5)ZNT@Pt#GWk6ZbZFe<;u3K%m#}q88+<9-YjLLY-hzR z&i3pe@vQIYtj`Xu90+aB7%kBzEgK;1Z33;*KJAHpDi|>>)K;zCyz2$oGOX2Ztx-@d znPM&1hHX7~Ey<9r*q-fb5Si1WE!&P{BZ6()#%(#gtPPp1+@h`BMvT@j-YwtWL*7En z-}bH52Cl;puHZhc;kxPC9xme!+1oZQ8>t|!YxawF6+jwX-MrCp|0%qtL=J=<}ThY?+)bS_Ac-mL+&!F z>jtm!@&oL?Px2lw>kcosI4|?^>g1-_@tHYUUMl$7rT2z! z;}VWvrmtbHZ(+8t`noUszOVeoFa5%={m$?G)-V35@A#T8pz0R@2e1GSFaZ~^0UvMy z`>z5^9s)P813xeXN3aA>@X#!<1z#6124}DaZ!iaUum^uI2#2r;k1z?Bun8KUFbb!z a3a>B=x3CMpFbv1A49_qP*RYrc0RTJA?&&cA literal 0 HcmV?d00001 diff --git a/docs/examples/te_gemma/media/substitution.png b/docs/examples/te_gemma/media/substitution.png new file mode 100644 index 0000000000000000000000000000000000000000..2df4cf791343d87db84d94733cc7454ac8e846cb GIT binary patch literal 78210 zcmeFZcRbbc`#-EwG$bW^6oq7F9H>qMpxdzNBrxg_9podLlUF5h>nrfx~V;;G*Z(B^Hphq)`?b;n!jzXXQ{g{)q zx&PnKX`b?&``>T<@7?(C9r*w4*?4++(>Ijji^Z3_w5Sjk+w$@d1TyJhpEO_dZ>m^_ zfH~0xXBvO?R$Yl_P#oXDQQr({-gz@fKvvHB z6G_;*kiNj+QDdvE@2uIP4uuPpwCaYnbq&K29D!o+aTL!$m-FDhtEI2G@S$nRfm7gl zU7uStK_{s_OWNr?*y}XlAcatBtb}#cJ6zv0o%-UE7M6(=JAA*1{nn?{kyDW`ha72B z3Eng;*e~o%UgN$aK7IGqv(Addqb82$Py>bkU8wYBkrO4YO2vR^R~j7KB&=Q!lXn5NLU(P`1*_PVf?Ra;@ymYf3wk+iRGw+u%O7bn1kb16}m7f z{dewcXH)N7XSQZ~lXBc2Ptid~b~NNM{t{WkC7}~vv;O;!)qEl+a0a?g_IN2vu_;$a z z*AK~B|L29Xiwf2?bM1^mvQH?jQ|40L(&t#+WX;CHrFw4O2V1}CR1-0YS$UK}=Zyq^ zgAK)ukW*{U%P+(9bA@D!6~*FcR@q#RB^LGSDi+F9*L!=74wkBjY1m*=;6#}>g6Wbu zvtdh*y4C(_0^LE(2JZoxH!QSaEtT}SSlX*MWxQ<2IVUwzP-*fwc8Z5mxh_IOvz7;j z#`nxd+;q5F4mJIc#ZeRD&+hC0XId0(l)Ug=JcJuoE|g{O_#FGed{8fUDe4PrjP8725&vgV*F8t(hhcruTKsn zE^xG*M-O15tjKJ>QRE<{mlLe2tdsYjs>v)e3f-pZj=03qBoazW^2GzV{1J|1iFeHu zv97w?Mxw|F1z^t`GZzQ+=SLzc#xQ=^tCsgszT^J$S9`t4SXS1Uko8Y=QptyD8l~xx zhoyQi(aSR~95cyxSSH|;MMiipcGLTxK9U*x>d)MluDQ>!4_NzMtakB)72k%y(h%wnVYN+Mud}0>!QHQk zFixrW#UG}5>IT}M7Xp89JpG(?Av_0_l$#f?Y9-(umU;D7CW4Y+<0Rb?ljZkJ?d`e8 zi=G3W^!c+(6Sz;fnmm;S3WU8|4ceQph$wOjrU$M59{rJc5NB+el$Kf;p}b@RHOVmm zL>#`mrbIpPohQ&gdct(pLezoz&ETO#$Eq9cV&2LwSC;b@#M=};`E-zS+mseTzq!u2 z!xLYS{Ko5D{3NLjmnrV|?PMfWmRJ}O6`PjC_&t+`MmrW9#)@p^5eHp~kH177z|e66 z7Wn#@Oe$sAF}=(XA)g{d({Nw#%P=-VcJ^YyiC%)Q96)2>M@>A34cp%r z1MZPA6?LM@QdT_se)ToaT)Q>bcgw+@5`KY8dI3*h_8Yg zlbJ4zYIFA<^KZQ8c#~1&aLtIhZSBfpgncL}r997z#irXn^_wD*a+b-mUB8vK(bmM89v% z$XL-(=mC|^?o!}Vsd@k=jd$-i9rm~&>R#=Ps5?8YmT_Pc=eWdu?n5+lx}Xib0AtFe z#s?w%MsE3@HS3`PRLR`62+5nNPv3nMYIO3xNy%6S6R(&zeAyv?0aeZ|WO>Q3AxCVN z1KKlpp3mi>r| z6Q+CAiGyH|{cU27Vg%5IOlo4g zm^0S&wL<^ZKJ!a0<(QM(Kdc9+pr#0n2Y(n9PYxCpwOx$hdcR>t_OjpL9Z=RbJILF7 zv{v5gl<)Qp`i;N(`JT+45|;>#%)sVDx%^ ziIbF0Sf*t3Y9?#l7eCN(QEBtdOzK?sxxD-ChuJUK%?E%hi~Y})2@+K&`Oio!v!_(k zBWH1C7~He4(WE(tZr)j}ghXPm?$m8w7k+IMqnNKDqtbU}4Tv1CcePHH$v79!fokoz zhR;)s?eqU2ADwr7jL!-0j^ELHCxDK8=}gVhUQeOt*Y!lS=XElAd(;_G$9reCoflD0 zn7H0l8&cCAWqT#MgVA#-O6G`2y&Cxc10M0{bg?{rU07Ck)W^{6gizDkEv&dMd`xk7 z35|@^Q>l{Q`aQ8v1)VWye$iyjyQhIhUkat1hkGnr?EKpLdK|q?f84o+S8sd1wmPTN zhl#-NN$=Q&7u$A8JeQ6=Hf77p6!bcHc75Z1GOdv@Ptzp=(N5Lgn-^os<=DF!nT|}4 z`8{^2%y5_BEl{{7R9ZZ_!BpJ6zLQu!=cQBTQQPf?NQ;%llCc|Sh`v28EQY&p zpaUi9-imZsrcDdCFEjH|8f$@YwL?X+ZPtxnJG9SFCU8k|;=jK7qs#iNGUNWU40FRn zD{(gEHf&#(?bwVdm=Tfh>Qz)PW<-D5qzqGWuMcM-OED9o*I~}ptlBN*nod|=#;-b| zmXe7d+ph4k4e1>X*5FY4y8IB*YsN=q&VI$nsrj2JluVQg!U!qvrm@z| zsbL*ecZbUimu(mWhb{C>b=cCR> zX6%=6@6N@i89f$RP}!?!8|19lwZK2}L{|#1^oukVuCWxP`ZT0^~Py)UUT|r1O z6k5j3*Ll8WV1|^kCcmo;?oWdU-Z>}hL6vC-J_)YAy0@CpAaYuD?++M$G0K~4v}%WF ztWBW-@A&?$-Z|K_=K|iX$0tSXeqWHgH$X8}`!vy?>LI!_X?3t+#mJK)i>Mx$0uYr$ ziLeN(KHXO!U9kH_fjA)7^2gh@A>S(JmrJ|t><@dld5p*YJ{qsoE^JmYQLwc3X-8*0ZmZrLK}$-ZRcoUF5 zo=!{1CSGzh@&i|scuWd8`}*TtMg2x$nW5WaVTIV=c^NQRV;S)B+5~92$l{=xt5{RV zR6|t2LM}FRuO*b(YNsb5(miA?1#6#R5nu1()1gSJHOM0Yv?DH7`TIHHgTC&S4geg! zmF^Q*8}3>y>xQQ^P?y~q9oY82v(`&EwS8$5p#oCqPEI8S34$RupOwHuOb^}vDB1~% zgRnbs5JJwkq6!C>D*~p_iF6N|;{1CgZ}7g^Z45o2gys~BMI*)Q{km8?mDX@KIGeW~ zmf3N^%j~1x?C%UV=uR-uESsRWALZ7*9Vk$}5Ep6NdqX|vtVA*2(1Y_AUP&~b1oW<~ zzs|;~SUXl>1~T54Q3TClVw^R8}tO;u=fIKv7obIt_{FHDY?csoeXI3%G;8>x*uJ&K#n79K1Df)4=QnAv zVSj_ro4&vPo9R$33D6z87z9!jBjg6^j)gIxE9}_d!e{gucfj)G$v@7~6>}HwHu3ID z1kWV91@xw{D6Ul6c@zhZxk$th8k#6ssy>gqIj$akfX*=W$-tNrkn1&;50XxX$0=A+Zk<9_SG{`>TwrXwW`!=Eo~#%=9y4V%7} z=q*{PN*;h&UOlGgbwr)4z{JrO7pN?kRTy$kqOD4yAoFLEp>(w1cRxg% zzj$%4dvfctT=y}H3H6ly5Pd-*g~C9!-o*<=2csY7E)9Beg&Ge&>aUt31ONbU?f=8| zIn1F-d!Nsrx@eWI*>8L7!HsU0ooPW;Q>8`IPGPsek&)Bi-$4T!Mn9!?p+s1SgU@ff zo>Eurx7}uCIl7ic;aZv+vWwBNO2tK}?oEF4KEoRyb7sMgXDyLn2$;`>FY>6@QqO&+ zew_LOa4;t=GvGhOZzlF!?tN8DLq(N@$Q$&liFVQKv|RijDDM+$AN~m{qPK2lGYw~^ z;9ZIdGR6KnWUP3IhwDA*AClOm!tNXhSTgc3xw$pLCw2x1DwL|ukD~nMTU#xv^?u7P zR*HUg^pP*WVxnj~%5H><+o^hpVoxPX-dd>q==91|rD3PPm<7^Y;7Y#_F|w_{lNXeX zbuOYetY5Mh=cMK5zrFiXuGU84+!FNj#aOn9MN3Nc4SJsj>ZZ$j=oQ54DdqjGDLUdh zM6k%uS!=BykNsH6@OH3%!LQ}534CBGOAkYRW(`h3*7OX9A++-TW@Jy*k$EF_I@U!8`;)gPUG z#}6@ByHjk?Kf!wC%eav>dWBBF^Ki2Fj=g>rRbXf366s3!f!E>Z zFv>1cONwb^xnD1%X9f^UCW<%sZMy| zirtkOy_v9fBw-9lN7l7MhZ9Ml3!A}1s9x#2=H#A)`X^^jq1F$$G2;P?H-t*GMbbFX z-bKu9{;E^#oowc}pbn{Z>n+Uo0D$BJi7nOd^d<4zeK{(IN}HuIqMPG^q;G5sgY}!` z>v21LCj09-s=;S)s6K?7S4yjdjfgLg)^ss*#vqk45&Br=6NBxqJ&oIy(j6MKetIg~ z)9Xu4-2;+v1g#H$rQ)w^+e{MCYiDhyoFgx|(dP0m?bQ;M#dn(yd*7ztjsqj(!VgZO zGtYkK$*0&0A}8KvxXFJ)4r?yX36Cl;^wP7pYp#g+A;Jmxc+O1cPb*bNd}Hg4V!I4a zBMkL)8(n#dUz@facWw4}*`;1^{OHR;6waCW@!NzmSnhjvcbefs$UVE{$yTQbF-L5H zmx#rme(O46?Zd)|8M~L;-sU>S`v1fn4X@7u)^JD zztNCLF3C+~SceWAo`Y1o)}fVUD*<Qqce&uCM=*(q^>;Br(?4I+6xFV_-YB^iUA6pLcc9ox zb%-KeyNqz>#QmtYX?l5(a=gODV@yp*)*(7odg$c!1ArFv?V1HV_bcI_lc)>H152;b z*V?NqHq2m$O2Dz3c$MUnUp8!Hc&JS+sbAhvLWyE&=GY}F_{Y-mPu)D-RnKVE9i5bn zAE?cO-gsEB>g_;`H&+$4#_hs2imKzRvQTm_FDq;g_>5h?rRE;drefqYxVjWF9@&NH z$A#ab;|=;_Ezo4oa;TObcIUdd10O0rku2fBfmk^0TX{?P4Xc)a1igzcKXcHiM~j8s z%T|*Cm78(ypz#BWhcuNvXgDApXu&te=I+P|yLlktp9{$^5!)!DlH-}$WwWR&%~4sm zBx7AgIC-@&zLz)0W=g4`W$=W{*yV16d#XM^Jijy4qGD7##Z9^kr;$hoI)e;44S~)T zIFyUJXPq!BgL^BL?_lM)+{$pQebG}MHb^S_9+Odc&x3XHAtFvt<@#GMJ*WJ_o%Lrg zQ1=KDY$h8aZ@o^xT&Q(3CMX%zwP4sKcv@jaqZr)st_K&Q!m;C{f9k|V&kl7b#oCb< z3N}`tCTH9DXw?=?J_K}Vrd}|horoV!a!fN_Dj&2Adpjk<1?V+)(duBQ%ur86&FZp4 zQxrRY%_)~6mSR?|GWC1wdi&FQ$7YJuICZZ{dFaCLm`H!G6GT0}8IT<)?*B{cm=-SG zd%zS6jcvmBlj7P|^SV(nMxt#KKd$7Uu2l~36~J6dMU-hi0QvSyms(Fa}H&99Kj2y6SMb(9x)>nVC^m(pU91LPNyt$+yDh| z^Q(xszyWlcp=;GlytB*pUau;(*Sdgs)oNQaz(Lm52zOGf@&aO}1{~@X&8gr`bb4L0 zc6#sGw$G~T8I0VxOEA%^j*|sfb3|3Qf9ak6Ue~)>*`izoab@$pJ)){!=M(q$7_J$s z{TZuBbtM85R*lu-Lt2t3aMSW2;)DJoCEL$nHuI9N=34RO0O}r`^I5jA%V1IJ);mS5 zxqEIBUzXXYJ#Xc~@`@&$Q8jVd?7~I6X81fa-9Hle-1cg5+s6SPP?sBwdA0gWXKLMN zY9sXQjJi?Rt&D@~-?+$8>PHQKm$y$~J!7)kyD0bE?#dW<$bbVt6Z?~)EZMNkJk?YU zP)Ns|p4Hb9qd?C7fLW*khf_XLW=%52Uz-fNXPw`yR630QuFBjEM&T8oikd7^YXzuq zgVu_EBY0MiAIzp5nqOh2SJOAkau}kguITx0>EhCEq`e5e=5Lc%O zfTo#vB-D&v27egf=n#*7L*uN%?tubRDIn{bF&&Qa+WHfrT3fkM=$=~tDGfsWKqEDO zhY?uo@U2I0p;dR$WpUy=&F9kogcH1r11-q$6}4GgM4%TFmfb^ltyzA{)sx@+744TQ zCPJ_q#i7*obN+7L`S9~{Rljt4SZJaIL{4byz%$DI6V63Gl*OM2>tl$US*MCUe+08h zE^4yS(Ux)@c6j`Vg{8dEb5jOsdY<)$?nrsCg?LXvCOS={Vj)<7(sZQvlz*`Su}+9W zNyWrm7W|A?3zhx^UZ2K2v+yJ}FVUOaj`NTNCLA~7m>w)}qD1s7ePr|MxVPHJYd2j= z%K}9-_5oOOKa+>N+yw;LK^91fCZ|zq9ZMeiiMU${MsX0AAtk)WD2?7<*EYezhdtaN zbgPM%1)}3fx!OsFu&GxU#>p+bs`4-C?>3*^gmC3f?k_A`Ps+sJETKnBTvXO=)TDNI zjKmXl{Eo%=CZ>YH!Ib)UkS7^K{m$$cF0Q2? z1PRNqN=5Sc{eI?rRgd;M{}yL3|IV*NbX@nKNzaWl;Iu*}8S$Uu1e#n-vui7W^*MAC zv35pNtu!L>GAYIJttXw^caA&&rU&Vo|`jq%ADuqi@%gl|FSnn_vVstsD?i26qu~pkU z{nrp_HF?#8xxR5?mRj(*wkm<@N>+o^1WfU_A_MN0FKhQUc+C69bvlAOc0*FvQ0NBP z-CZGAaK!>OV`1-w*sg5~B?o2ayld7$4gT7(`{XCA0_hn;Q&g{tke3HEZmoLi!SM?* zr4CQAAJ*g^CC*IvNOD?rJXmKs6G=C?2{CKzX~m_tm{zqpNjdX=L8d=J%~EOaFsnbJ z5+vbvv5gPc=Mz2FzqW~Kbw;s%>rKp+Hl`Jckte(=iz6uvR(1R)f|Ekba$NnvX4F@j ztGqd6_jq}~o0ob&AlodQXRr{SfXTyFJ-Vi~m=VFFK#v_L$FjhgA*Sp?$=%wYL+W6+ zCbR%ewyhZR(dsl%T>Z&U7h?_v?PGGJaV;OBpOyD>%&@Q5Cs|4ZP;L?az&evs*0&Pz z4sE@X^BELwv6WQ~>4=2oKfN|LG2(5DkVUv`-c6YdC{(!Jc}FDAPg>N};0X zODmjv6sGrr27^l%P-WZoXYkeWQyTsITTtXYL*+vBrHv1%cNXiVZ2$@tfu)ZzotiqWrr{rQ$MCLFJr22ifTE*=zW!?0#O`3w$^I3~drJ6oM6r zJBX_vo8_*c@9RwV2TQhrfccE zw`=b89$LKbx#Ep5Z}Uh;2DJQk>DYC4>fz1bF^-*}0}wVb8t9p~L_R59t@2*Ni;R3^ zsIoiKCSSibaX_8GrXOslJJP8^9Gw-KP=co;Uz!{aNy#nPb!*?o8O8+P(-fz+c5p9xuwP5=Br*?VMlyE8O|+#M5A=5bFN z*q3^Dvx1+HqbDT;{(*&WyQS7>g8TN7eu$W-#H5a}Xuh6a`i*2ylvADwr)vO27X zrN3L~Oq(Yl`jv?RsYGiEAI7jlrn|ZE*CWG=b9lIW{othddq z_v-0oc?PajIST7oxpCLXw#ho(!X+lyk?JVASuX7FAXV0_b=9? z)?_p)B`G{f;7&4`wM!NqM4!|s>f4}RtfDbptCljV>KWZa z3In%M0uEWcq4ZBrak_IPQiz@ zww=tESQd-$8@9(Izp9TAvo(gYv^;`zR~^GO9}Zn23)?0kpoM6#ZM_@npX!b z*I)%w(>C9a2mb+GEzhJOBp9E!2`)dQa(|;Z9fOZWU!ow~14j&6W2Mf`L%g2Q?jGWW zf7GG_xRU{gxP!aWO663~z%)VLvk%dffx?X1#VAJ_*DV-2=X1jl>=z02MwqNN436P4 z1f|_X_Y|JYCOHE89@w$dXIYLa82cW}bN+Qa7>ww|gyGJ39aLN5q2i@?UT!SjKBe?c zvd1HJN0eS1A2~fyj_Ae~*P=LnpDT2a95$>xvAL9ji(SnpLlV&x@3fi9mhdY3D|2-csoFlxe%}&b-cG+I(-Htmq`-|Z-PV271 z4mn#Vg(KqaEW&?Wwkg~&bZ9VG6k+Jb^W77f_Xd-HE*oBEs+L1myxp-eHY+1{vMTse zj;ijh@Y_?j&cq&5%MbBIGD0E5#!UD$&=gZmpX@GI=w+=G!YhjWR=m>TN+kcztEUy_ zJ`(;mSZGX1=u#l(aPw1s>8;$s0Ww!^tns5y^r~-rUEW z1ti;N{BzhbOMJNe!dZR85(LxAMsKBDp=(awAM;vEcyN`b)y-sdNT})hqF)xhdvnBv z?Eb%V%SxTc|DkZ4RQ{Z;_zm8AAogWHsQO#p)#oQm+r%4`5ISzeRA1e!pwd%2fITdh z8Lj}ITPtgGM&tE#fxOWj^8mc=!xBvi4SO>N z!h=3=%O~Axf9(fMGdvsvT#3Dw8*`XekO9+Inw5;vfG_{rZV=S!I&v?fM5wC@#r_4# zEwE<%7`^cFag*oIdAp`YEjT(}ZOEyzQ`P}~(s0l{##=iWa?CUy6*w15QL6Un%kaDnTXC- z#r(IA+~DK$y1trOkC<>BWZ+6E#o@jGr&aV_6glxkOwd3gCj3Zzn}}mE@^_nr$+>5V zejkpy5?q>ET4{7>D_Q*6Nhs@W>2|M@nixbsb_0N$&Nctju#3~9iz+#9yd2J_X8Rv^ zHY(R*8%&tBG3qoOmC95m1Ya3XH>$6Q+sK$4j36Pd9?7&Y&xidDjc3~!G@{#E_{47= zKt@kop3YT~9iS;o#^L{p{SVUUo%TU0>n?nzI&oftD|o#d3v`CoGQ%P_6KO%Gt-uCK zAXaw7H5P-HBpqrIKtG))Qw+``(BozGF^nOOQM@ z6czAe@MTQVVd-0jHY8`UgHeaOh0r+d_d}m(D~F{J3bfHyApH@enMEelHV?(RzO zc7IuXG!qC{x>Hwywc7|BDGOEQvgcWVI)hUc+DEJvml?&kOCdv+?0B`ZJR5dv8sJLv8hV=WShC?711`@NKypE(m)AF1zKR(+qrLayeMTLT?bCo0Di{bV^r_esVo=QyvX<44+MI^)Igsw0&}H8 z1jyJUp9_PCC@)!B?OHWcH8pi&JGD&4vdp66fVAbaXfe)=rZ23&u|ILIBk&^^=mSZjH}922 zt62)vyBJ&g+Y84bO{H^Y05CEV)^TeStqiYeM;b^;=1N<# z7LW96=fNfiyEY~N8m%6{{Q*qP;rxU4)QpYrbGmo4JOO?SP1amiTr5BXs0N;c3aAr* zP$yZjYb~y0j~lPGf5exYzIcRByB`za`CS|F;y5ccpx?-U=gjZIYH=N@pR%O30^iF- z4yR9uYE)`%Td?@G1}Nt%f>ASht|Hf^Sv6`sy#Yi;gM|Sss}%~o>@~DhD5`1DN&!Ye z76hBfQI1NVhI~d4bm)?1?N?wA)O`hRTKRVvDfO z_spu6GV10?&I3hHE8hripSxK;)0?R%OjM%_Dn&RZaC9nC#%17h;U;3VQWB7ifAzl# z)YjDGmd5L>Qm!-M8u7sntqPOnsJ8|w)u%Y#IVRh*iLD&!{al!*v&+--qNu|w_V@0! zPT}Np3U2(Uv3`=4I=8oh@SBPPtwLfgMSLwH2wd*$7#~g~*@bCAUaLw9xep1xS=(