diff --git a/qa/L1_jax_distributed_unittest/test.sh b/qa/L1_jax_distributed_unittest/test.sh
index 31206898b6..eb09df1a84 100644
--- a/qa/L1_jax_distributed_unittest/test.sh
+++ b/qa/L1_jax_distributed_unittest/test.sh
@@ -5,4 +5,11 @@
 set -xe
 
 : ${TE_PATH:=/opt/transformerengine}
-pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_*
+
+# Skip ring attention tests since they need fixed environment vars
+pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_* -k 'not test_context_parallel_ring_attn'
+
+# Test ring attention with and without scan loop
+NVTE_FUSED_RING_ATTENTION_USE_SCAN=0 pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_fused_attn.py -k test_context_parallel_ring_attn
+NVTE_FUSED_RING_ATTENTION_USE_SCAN=1 XLA_FLAGS="--xla_experimental_ignore_channel_id" \
+    pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_fused_attn.py -k test_context_parallel_ring_attn
diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py
index 4ef56b1b10..7ef0d68474 100644
--- a/tests/jax/test_distributed_fused_attn.py
+++ b/tests/jax/test_distributed_fused_attn.py
@@ -35,7 +35,9 @@
     get_qkv_format,
     reorder_causal_load_balancing,
     inverse_reorder_causal_load_balancing,
+    CPStrategy,
 )
+from transformer_engine.jax.sharding import MeshResource
 
 # We will use the golden reference model from our non distributed attention test fixture.
 from test_fused_attn import general_dot_product_attention, make_mask
@@ -333,6 +335,36 @@ def ref_func(query, kv, mask):
             )
 
 
+@pytest.mark.parametrize(
+    "device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs()
+)
+@pytest.mark.parametrize(
+    "data_shape",
+    [
+        pytest.param([2, 512, 12, 128], id="2-512-12-128"),
+        pytest.param([4, 1024, 16, 64], id="4-1024-16-64"),
+    ],
+)
+@pytest.mark.parametrize("kv_groups", [1, 4, 8, 12, 16])
+@pytest.mark.parametrize(
+    "attn_mask_type",
+    [
+        pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"),
+        pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
+    ],
+)
+@pytest.mark.parametrize("dtype", [jnp.bfloat16])
+@pytest.mark.parametrize(
+    "qkv_layout",
+    [
+        pytest.param(QKVLayout.BSHD_BS2HD, id="COMBINED_KV"),
+        pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"),
+    ],
+)
+@pytest.mark.parametrize(
+    "load_balanced",
+    [pytest.param(False, id="UNBALANCED"), pytest.param(True, id="BALANCED")],
+)
 class TestDistributedContextParallelSelfAttn:
 
     def generate_inputs(self, shape, kv_groups: int, attn_mask_type: AttnMaskType, dtype):
@@ -370,37 +402,7 @@ def qkv_to_layout(self, q, k, v, qkv_layout):
                 raise ValueError(f"Unsupported {qkv_layout=}")
         return qkv_args
 
-    @pytest.mark.parametrize(
-        "device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs()
-    )
-    @pytest.mark.parametrize(
-        "data_shape",
-        [
-            pytest.param([2, 512, 12, 128], id="2-512-12-128"),
-            pytest.param([4, 1024, 16, 64], id="4-1024-16-64"),
-        ],
-    )
-    @pytest.mark.parametrize("kv_groups", [1, 4, 8, 12, 16])
-    @pytest.mark.parametrize(
-        "attn_mask_type",
-        [
-            pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"),
-            pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
-        ],
-    )
-    @pytest.mark.parametrize("dtype", [jnp.bfloat16])
-    @pytest.mark.parametrize(
-        "qkv_layout",
-        [
-            pytest.param(QKVLayout.BSHD_BS2HD, id="COMBINED_KV"),
-            pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"),
-        ],
-    )
-    @pytest.mark.parametrize(
-        "load_balanced",
-        [pytest.param(False, id="UNBALANCED"), pytest.param(True, id="BALANCED")],
-    )
-    def test_contex_parallel_self_attn(
+    def impl_test_contex_parallel_attn(
         self,
         device_count,
         mesh_shape,
@@ -412,6 +414,7 @@ def test_contex_parallel_self_attn(
         dtype,
         qkv_layout,
         load_balanced,
+        cp_strategy,
     ):
         attn_bias_type = AttnBiasType.NO_BIAS
         dropout_prob = 0.0
@@ -469,6 +472,7 @@ def target_func(q, k, v, mask):
                 scaling_factor=scaling_factor,
                 dropout_probability=dropout_prob,
                 is_training=is_training,
+                context_parallel_strategy=cp_strategy,
                 context_parallel_causal_load_balanced=load_balanced,
                 context_parallel_axis="cp",
             ).astype(dtype)
@@ -574,6 +578,60 @@ def grad_func(func, *args, **kwargs):
 
                 assert_allclose(target_grads[i], ref_grads[i], dtype=dtype)
 
+    def test_contex_parallel_allgather_attn(
+        self,
+        device_count,
+        mesh_shape,
+        mesh_axes,
+        mesh_resource,
+        data_shape,
+        kv_groups,
+        attn_mask_type,
+        dtype,
+        qkv_layout,
+        load_balanced,
+    ):
+        return self.impl_test_contex_parallel_attn(
+            device_count,
+            mesh_shape,
+            mesh_axes,
+            mesh_resource,
+            data_shape,
+            kv_groups,
+            attn_mask_type,
+            dtype,
+            qkv_layout,
+            load_balanced,
+            CPStrategy.ALL_GATHER,
+        )
+
+    def test_context_parallel_ring_attn(
+        self,
+        device_count,
+        mesh_shape,
+        mesh_axes,
+        mesh_resource,
+        data_shape,
+        kv_groups,
+        attn_mask_type,
+        dtype,
+        qkv_layout,
+        load_balanced,
+    ):
+        return self.impl_test_contex_parallel_attn(
+            device_count,
+            mesh_shape,
+            mesh_axes,
+            mesh_resource,
+            data_shape,
+            kv_groups,
+            attn_mask_type,
+            dtype,
+            qkv_layout,
+            load_balanced,
+            CPStrategy.RING,
+        )
+
 
 class TestReorderCausalLoadBalancing:
     @pytest.mark.parametrize("cp_size", [2, 4, 8])
diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py
index a48a38e6a9..af05538ef5 100644
--- a/tests/jax/test_fused_attn.py
+++ b/tests/jax/test_fused_attn.py
@@ -7,6 +7,7 @@
 from functools import partial
 from math import sqrt
 from typing import Tuple, Optional
+import random
 
 import jax
 import jax.numpy as jnp
diff --git a/tests/jax/test_misc.py b/tests/jax/test_misc.py
new file mode 100644
index 0000000000..67145daf63
--- /dev/null
+++ b/tests/jax/test_misc.py
@@ -0,0 +1,40 @@
+# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+
+import pytest
+from functools import partial
+import os
+
+from transformer_engine.jax.cpp_extensions.misc import get_xla_flag
+
+
+@pytest.fixture(autouse=True, scope="function")
+def preserve_xla_flags():
+    """Ensures the XLA flags environment variable is restored after any tests in this file run."""
+    old_flags = os.getenv("XLA_FLAGS")
+    yield
+    if old_flags is not None:
+        os.environ["XLA_FLAGS"] = old_flags
+
+
+def test_get_xla_flag(request):
+    os.environ["XLA_FLAGS"] = ""
+    assert get_xla_flag("") is None
+    assert get_xla_flag("--foo") is None
+    assert get_xla_flag("--bar=1") is None
+
+    os.environ["XLA_FLAGS"] = "--foo --bar=1 --baz=biz"
+    assert get_xla_flag("--foo") == True
+    assert get_xla_flag("--bar") == "1"
+    assert get_xla_flag("--bar", cast=int) == 1
+    assert get_xla_flag("--bar", cast=bool) == True
+    assert get_xla_flag("--baz") == "biz"
+    with pytest.raises(ValueError):
+        # cast will fail
+        assert get_xla_flag("--baz", cast=int)
+    assert get_xla_flag("--xla") is None
+
+    os.environ["XLA_FLAGS"] = "--xla_abc --xla_abb"
+    assert get_xla_flag("--xla_abc") == True
+    assert get_xla_flag("--xla_abb") == True
diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py
index 218fd174de..3ecc9bcd75 100644
--- a/transformer_engine/jax/attention.py
+++ b/transformer_engine/jax/attention.py
@@ -79,6 +79,19 @@ class QKVFormat(Enum):
     THD = NVTE_QKV_Format.NVTE_THD
 
 
+class CPStrategy(Enum):
+    """Defines the context parallel strategies of Jax fused attention.
+
+    DEFAULT: Default strategy will choose automatically if context parallel axis is sharded.
+    ALL_GATHER: All-gather/reduce scatter implementation.
+    RING: Ring attention implementation (https://arxiv.org/abs/2310.01889).
+    """
+
+    DEFAULT = 0
+    ALL_GATHER = 1
+    RING = 2
+
+
 def get_qkv_format(qkv_layout):
     """
     Get qkv_format from qkv_layout
@@ -260,6 +273,7 @@ def fused_attn(
     dropout_probability: float,
     is_training: bool,
     window_size: Optional[Tuple[int, int]] = None,
+    context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
     context_parallel_causal_load_balanced: bool = False,
     context_parallel_axis: str = "",
 ):
@@ -347,6 +361,7 @@ def fused_attn(
         is_training=is_training,
         max_segments_per_seq=1,
         window_size=window_size,
+        context_parallel_strategy=context_parallel_strategy,
         context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
         context_parallel_axis=context_parallel_axis,
     )
@@ -370,6 +385,7 @@ def fused_attn_thd(
     is_training: bool,
     max_segments_per_seq: int = 1,
     window_size: Optional[Tuple[int, int]] = None,
+    context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
     context_parallel_causal_load_balanced: bool = False,
     context_parallel_axis: str = "",
 ):
@@ -470,6 +486,7 @@ def fused_attn_thd(
         is_training=is_training,
         max_segments_per_seq=max_segments_per_seq,
         window_size=window_size,
+        context_parallel_strategy=context_parallel_strategy,
         context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
         context_parallel_axis=context_parallel_axis,
     )
@@ -477,7 +494,7 @@ def fused_attn_thd(
     return output
 
 
-@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16))
+@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17))
 def _fused_attn(
     qkv: Tuple[jnp.ndarray, ...],
     bias: Optional[jnp.ndarray],
@@ -494,6 +511,7 @@ def _fused_attn(
     is_training: bool,
     max_segments_per_seq: int,
     window_size: Optional[Tuple[int, int]],
+    context_parallel_strategy: CPStrategy,
     context_parallel_causal_load_balanced: bool,
     context_parallel_axis: str,
 ):
@@ -513,6 +531,7 @@ def _fused_attn(
         is_training,
         max_segments_per_seq,
         window_size,
+        context_parallel_strategy,
         context_parallel_causal_load_balanced,
         context_parallel_axis,
     )
@@ -535,6 +554,7 @@ def _fused_attn_fwd_rule(
     is_training,
     max_segments_per_seq,
     window_size,
+    context_parallel_strategy,
     context_parallel_causal_load_balanced,
     context_parallel_axis,
 ):
@@ -554,6 +574,7 @@ def _fused_attn_fwd_rule(
         is_training=is_training,
         max_segments_per_seq=max_segments_per_seq,
         window_size=window_size,
+        context_parallel_strategy=context_parallel_strategy,
         context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
         context_parallel_axis=context_parallel_axis,
     )
@@ -582,6 +603,7 @@ def _fused_attn_bwd_rule(
     is_training,
     max_segments_per_seq,
     window_size,
+    context_parallel_strategy,
     context_parallel_causal_load_balanced,
     context_parallel_axis,
     ctx,
@@ -617,6 +639,7 @@ def _fused_attn_bwd_rule(
         is_training=is_training,
         max_segments_per_seq=max_segments_per_seq,
         window_size=window_size,
+        context_parallel_strategy=context_parallel_strategy,
         context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
         context_parallel_axis=context_parallel_axis,
     )
diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py
index 2832a20ab7..6c4f518189 100644
--- a/transformer_engine/jax/cpp_extensions/attention.py
+++ b/transformer_engine/jax/cpp_extensions/attention.py
@@ -17,6 +17,8 @@
 from jax.sharding import PartitionSpec, NamedSharding
 from jax.extend import ffi
 
+from transformer_engine.jax.attention import CPStrategy
+
 from transformer_engine import transformer_engine_jax
 from transformer_engine.transformer_engine_jax import (
     NVTE_Bias_Type,
@@ -35,6 +37,7 @@
     get_padded_spec,
     get_cudnn_version,
     is_ffi_enabled,
+    get_xla_flag,
 )
 from ..sharding import (
     global_mesh_resource,
@@ -1032,7 +1035,7 @@ def check_supported(self):
         if self.config.qkv_layout not in allowed_layouts:
             raise ValueError(
                 f"{header} only supports layouts:"
-                f" {','.join([str(x) for x in allowed_layouts])} got: {self.config.qkv_layout}"
+                f" {','.join(map(str, allowed_layouts))} got: {self.config.qkv_layout}"
             )
 
         if self.config.attn_bias_type != NVTE_Bias_Type.NVTE_NO_BIAS:
@@ -1042,7 +1045,7 @@ def check_supported(self):
         if self.config.attn_mask_type not in allowed_masks:
             raise ValueError(
                 f"{header} only supports masking types: "
-                f" {','.join([str(x) for x in allowed_masks])} got: {self.config.attn_mask_type}"
+                f" {','.join(map(str, allowed_masks))} got: {self.config.attn_mask_type}"
             )
 
         if self.config.max_segments_per_seq != 1:
@@ -1411,6 +1414,503 @@ def _cross_attn_bwd(
 register_primitive(FusedAttnCPWithAllGatherBwdPrimitive)
 
 
+@dataclass(frozen=True)
+class _FusedAttnCPWithP2PHelper:
+    """Helper class to assist with running the P2P ring strategy for CP attention."""
+
+    mesh: jax.sharding.Mesh
+    config: _FusedAttnConfig
+
+    @staticmethod
+    def use_scanloop():
+        """Returns true if the implementation will use a scan loop for iteration."""
+        use_scan = bool(int(os.getenv("NVTE_FUSED_RING_ATTENTION_USE_SCAN", "1")))
+
+        # nvbug(4675071): Disable the HLO verifier for channel ID checks.
+        # A WAR was added to XLA: https://github.com/openxla/xla/pull/16779
+        def truthy(val):
+            return val.lower() in ["1", "true"]
+
+        x = use_scan and get_xla_flag(
+            "--xla_experimental_ignore_channel_id", default=False, cast=truthy
+        )
+        return x
+
+    def check_supported(self):
+        """Checks if the context parallel implementation is supported by the given arguments."""
+        header = "Context parallel fused ring attention"
+
+        allowed_layouts = [NVTE_QKV_Layout.NVTE_BSHD_BS2HD, NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD]
+        if self.config.qkv_layout not in allowed_layouts:
+            raise ValueError(
+                f"{header} only supports layouts:"
+                f" {','.join(map(str, allowed_layouts))} got: {self.config.qkv_layout}"
+            )
+
+        if self.config.attn_bias_type != NVTE_Bias_Type.NVTE_NO_BIAS:
+            raise ValueError(f"{header} does not support bias got: {self.config.attn_bias_type}")
+
+        allowed_masks = [NVTE_Mask_Type.NVTE_NO_MASK, NVTE_Mask_Type.NVTE_CAUSAL_MASK]
+        if self.config.attn_mask_type not in allowed_masks:
+            raise ValueError(
+                f"{header} only supports masking types: "
+                f" {','.join(map(str, allowed_masks))} got: {self.config.attn_mask_type}"
+            )
+
+        if self.config.max_segments_per_seq != 1:
+            raise ValueError(
+                f"{header} only supports max_segments_per_seq == 1 got:"
+                f" {self.config.max_segments_per_seq}"
+            )
+
+        if self.config.dropout_probability != 0.0:
+            raise ValueError(f"{header} does not support dropout")
+
+        # We want to encourage use of scan loop to minimize unrolling and ensure more
+        # predictable scheduling from XLA. The unrolled flavor will be supported but
+        # not the prefered implementation.
+        if not self.use_scanloop():
+            warnings.warn(
+                "Scan loop is disabled for fused ring attention. To enable set"
+                " NVTE_FUSED_RING_ATTENTION_USE_SCAN=1 in your environment and"
+                " add --xla_experimental_ignore_channel_id=true to XLA_FLAGS."
+            )
+
+    def get_step_config(self, attn_mask_type) -> _FusedAttnConfig:
+        """Returns a _FusedAttnConfig for single CP step call to fused attention."""
+        return _FusedAttnConfig(
+            attn_bias_type=self.config.attn_bias_type,
+            attn_mask_type=attn_mask_type,
+            qkv_layout=NVTE_QKV_Layout.NVTE_BSHD_BS2HD,
+            scaling_factor=self.config.scaling_factor,
+            dropout_probability=self.config.dropout_probability,
+            is_training=self.config.is_training,
+            max_segments_per_seq=self.config.max_segments_per_seq,
+            window_size=self.config.window_size,
+            context_parallel_load_balanced=self.config.context_parallel_load_balanced,
+            cp_axis=self.config.cp_axis,
+        )
+
+    def stack_kv(self, k, v):
+        """Stacks k and v tensors if not stacked."""
+        _not_used = jnp.zeros(0, dtype=k.dtype)
+        match self.config.qkv_layout:
+            case NVTE_QKV_Layout.NVTE_BSHD_BS2HD:
+                return k
+            case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD:
+                return jnp.stack([k, v], axis=2)
+        return _not_used
+
+    def unstack_kv(self, kv):
+        """Un-stacks k and v tensors if not stacked."""
+        _not_used = jnp.zeros(0, dtype=kv.dtype)
+        match self.config.qkv_layout:
+            case NVTE_QKV_Layout.NVTE_BSHD_BS2HD:
+                return kv, _not_used
+            case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD:
+                return jnp.unstack(kv, axis=2)
+        return _not_used, _not_used  # fall through
+
+    def permute_kv(self, kv, cp_perm):
+        """Permutes kv around the ring as described by cp_perm."""
+        return lax_paral_op(kv, lax.ppermute, self.config.cp_axis, mesh=self.mesh, perm=cp_perm)
+
+    def correct_softmax_aux(self, softmax_aux, softmax_aux_per_step):
+        """Apply soft max correction after an attention step."""
+        max_scale = jnp.maximum(softmax_aux, softmax_aux_per_step)
+        min_scale = jnp.minimum(softmax_aux, softmax_aux_per_step)
+        new_softmax_aux = max_scale + jnp.log(1 + jnp.exp(min_scale - max_scale))
+        return new_softmax_aux
+
+    def adjust_seqlen(self, seqlen, max_seqlen, idx):
+        """Adjust the sequence length per step."""
+        seqlen_of_curr_step = seqlen - max_seqlen * idx
+        seqlen_of_curr_step = jnp.where(seqlen_of_curr_step < 0, 0, seqlen_of_curr_step)
+        seqlen_per_step = jnp.where(
+            seqlen_of_curr_step < max_seqlen, seqlen_of_curr_step, max_seqlen
+        )
+        return seqlen_per_step
+
+
+class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
+    """
+    Fused Ring Attention Forward Primitive
+    """
+
+    @staticmethod
+    def partition(config, mesh, arg_infos, result_infos):
+        is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
+        assert (
+            not is_context_parallel or config.window_size[0] == -1
+        ), "Sliding window attention is not supported when context parallelism is enabled"
+        if not is_context_parallel:
+            return FusedAttnFwdPrimitive.partition(config, mesh, arg_infos, result_infos)
+
+        helper = _FusedAttnCPWithP2PHelper(mesh, config)
+        helper.check_supported()
+
+        out_sharding = result_infos[0].sharding
+        softmax_aux_sharding = result_infos[1].sharding
+        rng_state_sharding = seed_sharding = NamedSharding(
+            mesh, PartitionSpec(get_all_mesh_axes(), None)
+        )
+        arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [seed_sharding])
+        out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)
+
+        def ring_attn_fwd_impl(
+            q,
+            k,
+            v,
+            bias,
+            q_seqlen,
+            kv_seqlen,
+            q_seq_offsets,
+            k_seq_offsets,
+            seed,
+        ):
+            _not_used = jnp.zeros(0, dtype=v.dtype)
+
+            # Combine KV tensors if separate for better permute scheduling and performance.
+            # Eventually XLA should perform this automatically.
+            kv = helper.stack_kv(k, v)
+
+            batch, q_max_seqlen, head, _ = q.shape
+            kv_max_seqlen = k.shape[1]
+
+            cp_size = get_mesh_axis_size(config.cp_axis, mesh)
+            cp_rank = get_mesh_axis_rank(config.cp_axis, mesh)
+            cp_perm = [(i, (i + 1) % cp_size) for i in range(cp_size)]
+
+            output_per_steps = jnp.zeros((cp_size, *q.shape), dtype=q.dtype)
+            softmax_aux_per_steps = jnp.zeros(
+                (cp_size, batch, head, q_max_seqlen, 1), dtype=jnp.float32
+            )
+            softmax_aux = jnp.full((batch, head, q_max_seqlen, 1), -jnp.inf, dtype=jnp.float32)
+
+            # RNG shape should be the shared shape. This is unused for ring attention as we do not
+            # support dropout currently.
+            rng_state_shape = (result_infos[2].shape[0] // mesh.size, *result_infos[2].shape[1:])
+            rng_state = jnp.zeros(rng_state_shape).astype(result_infos[2].dtype)
+
+            def scan_kv_block(idx, carry):
+                kv, softmax_aux, output_per_steps, softmax_aux_per_steps = carry
+
+                # Send KV block to next step so we can overlap compute.
+                kv_next = helper.permute_kv(kv, cp_perm)
+
+                def mask_compute(attn_mask_type):
+                    q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx)
+                    kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx)
+                    output_per_step, softmax_aux_per_step, _ = FusedAttnFwdPrimitive.impl(
+                        q,
+                        kv,
+                        _not_used,
+                        bias,
+                        q_seqlen_per_step,
+                        kv_seqlen_per_step,
+                        q_seq_offsets,
+                        k_seq_offsets,
+                        seed,
+                        helper.get_step_config(attn_mask_type),
+                    )
+                    return output_per_step, softmax_aux_per_step
+
+                causal_mask_compute = partial(mask_compute, NVTE_Mask_Type.NVTE_CAUSAL_MASK)
+                no_mask_compute = partial(mask_compute, NVTE_Mask_Type.NVTE_NO_MASK)
+
+                def half_kv_no_mask_compute():
+                    q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx)
+                    kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx) // 2
+                    kv_part = lax.slice_in_dim(kv, 0, kv.shape[1] // 2, axis=1)
+                    output_per_step, softmax_aux_per_step, _ = FusedAttnFwdPrimitive.impl(
+                        q,
+                        kv_part,
+                        _not_used,
+                        bias,
+                        q_seqlen_per_step,
+                        kv_seqlen_per_step,
+                        q_seq_offsets,
+                        k_seq_offsets,
+                        seed,
+                        config=helper.get_step_config(NVTE_Mask_Type.NVTE_NO_MASK),
+                    )
+                    return output_per_step, softmax_aux_per_step
+
+                def half_q_no_mask_compute():
+                    q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx) // 2
+                    kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx)
+                    q_part = lax.slice_in_dim(q, q_max_seqlen // 2, q_max_seqlen, axis=1)
+                    output_per_step, softmax_aux_per_step, _ = FusedAttnFwdPrimitive.impl(
+                        q_part,
+                        kv,
+                        _not_used,
+                        bias,
+                        q_seqlen_per_step,
+                        kv_seqlen_per_step,
+                        q_seq_offsets,
+                        k_seq_offsets,
+                        seed,
+                        config=helper.get_step_config(NVTE_Mask_Type.NVTE_NO_MASK),
+                    )
+                    output_per_step = jnp.concat([jnp.zeros_like(q_part), output_per_step], axis=1)
+                    softmax_aux_per_step = jnp.concat(
+                        [
+                            jnp.full_like(softmax_aux_per_step, -jnp.inf),
+                            softmax_aux_per_step,
+                        ],
+                        axis=2,
+                    )
+                    return output_per_step, softmax_aux_per_step
+
+                def skip_compute():
+                    output_per_step = jnp.zeros_like(q)
+                    softmax_aux_per_step = jnp.full(
+                        (batch, head, q.shape[1], 1), -jnp.inf, dtype=jnp.float32
+                    )
+                    return output_per_step, softmax_aux_per_step
+
+                if config.attn_mask_type == NVTE_Mask_Type.NVTE_CAUSAL_MASK:
+                    # This is for nested jax.lax.cond
+                    def jax_cond_wrap():
+                        if config.context_parallel_load_balanced:
+                            return lax.cond(
+                                (idx <= cp_rank), half_kv_no_mask_compute, half_q_no_mask_compute
+                            )
+                        return lax.cond((idx <= cp_rank), no_mask_compute, skip_compute)
+
+                    output_per_step, softmax_aux_per_step = lax.cond(
+                        idx == 0, causal_mask_compute, jax_cond_wrap
+                    )
+                else:
+                    output_per_step, softmax_aux_per_step = no_mask_compute()
+
+                softmax_aux = helper.correct_softmax_aux(softmax_aux, softmax_aux_per_step)
+                output_per_steps = output_per_steps.at[idx].set(output_per_step)
+                softmax_aux_per_steps = softmax_aux_per_steps.at[idx].set(softmax_aux_per_step)
+
+                return (kv_next, softmax_aux, output_per_steps, softmax_aux_per_steps)
+
+            carry = (kv, softmax_aux, output_per_steps, softmax_aux_per_steps)
+            if helper.use_scanloop():
+                carry = lax.fori_loop(0, cp_size, scan_kv_block, carry)
+            else:
+                for i in range(0, cp_size):
+                    carry = scan_kv_block(i, carry)
+            (kv, softmax_aux, output_per_steps, softmax_aux_per_steps) = carry
+
+            output = jnp.zeros(q.shape).astype(jnp.float32)
+            for idx in range(cp_size):
+                output = output + output_per_steps[idx].astype(jnp.float32) * jnp.exp(
+                    softmax_aux_per_steps[idx] - softmax_aux
+                ).transpose(0, 2, 1, 3)
+            output = output.astype(q.dtype)
+            return output, softmax_aux, rng_state
+
+        return mesh, ring_attn_fwd_impl, out_shardings, arg_shardings
+
+
+register_primitive(FusedRingAttnFwdPrimitive)
+
+
+class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
+    """
+    Fused Ring Attention Backward Primitive
+    """
+
+    @staticmethod
+    def partition(config, mesh, arg_infos, result_infos):
+        is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
+        assert (
+            not is_context_parallel or config.window_size[0] == -1
+        ), "Sliding window attention is not supported when context parallelism is enabled"
+        if not is_context_parallel:
+            return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos)
+
+        del result_infos
+        q_spec = get_padded_spec(arg_infos[0])
+        k_spec = get_padded_spec(arg_infos[1])
+        v_spec = get_padded_spec(arg_infos[2])
+        bias_spec = get_padded_spec(arg_infos[3])
+        dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
+        dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
+        dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
+        dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
+        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
+        out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding)
+
+        helper = _FusedAttnCPWithP2PHelper(mesh, config)
+        helper.check_supported()
+
+        def ring_attn_bwd_impl(
+            q,
+            k,
+            v,
+            bias,
+            softmax_aux,
+            rng_state,
+            output,
+            doutput,
+            q_seqlen,
+            kv_seqlen,
+            q_seq_offsets,
+            k_seq_offsets,
+        ):
+            _not_used = jnp.zeros(0, dtype=output.dtype)
+
+            # Combine KV tensors if separate for better permute scheduling and performance.
+            # Eventually XLA should perform this automatically.
+            kv = helper.stack_kv(k, v)
+
+            q_max_seqlen = q.shape[1]
+            kv_max_seqlen = k.shape[1]
+
+            cp_size = get_mesh_axis_size(config.cp_axis, mesh)
+            cp_rank = get_mesh_axis_rank(config.cp_axis, mesh)
+            cp_perm = [(i, (i + 1) % cp_size) for i in range(cp_size)]
+
+            dq = jnp.zeros_like(q)
+            dk_dv = helper.stack_kv(jnp.zeros_like(k), jnp.zeros_like(v))
+            dbias = jnp.zeros_like(bias)
+
+            def scan_kv_block(idx, carry):
+
+                kv, dq, dk_dv, dbias = carry
+
+                # Start communication that feeds the next iteraton.
+                # We further combine the tensors to improve overlap.
+
+                kv_dk_dv = jnp.stack([kv, dk_dv])
+                kv_dk_dv = helper.permute_kv(kv_dk_dv, cp_perm)
+
+                def mask_compute(attn_mask_type):
+                    q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx)
+                    kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx)
+                    dq_per_step, dk_dv_per_step, _, dbias_per_step = FusedAttnBwdPrimitive.impl(
+                        q,
+                        kv,
+                        _not_used,
+                        bias,
+                        softmax_aux,
+                        rng_state,
+                        output,
+                        doutput,
+                        q_seqlen_per_step,
+                        kv_seqlen_per_step,
+                        q_seq_offsets,
+                        k_seq_offsets,
+                        config=helper.get_step_config(attn_mask_type),
+                    )
+                    return dq_per_step, dk_dv_per_step, dbias_per_step
+
+                causal_mask_compute = partial(mask_compute, NVTE_Mask_Type.NVTE_CAUSAL_MASK)
+                no_mask_compute = partial(mask_compute, NVTE_Mask_Type.NVTE_NO_MASK)
+
+                def half_kv_no_mask_compute():
+                    q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx)
+                    kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx) // 2
+                    kv_part = lax.slice_in_dim(kv, 0, kv_max_seqlen // 2, axis=1)
+                    dq_per_step, dk_dv_per_step, _, dbias_per_step = FusedAttnBwdPrimitive.impl(
+                        q,
+                        kv_part,
+                        _not_used,
+                        bias,
+                        softmax_aux,
+                        rng_state,
+                        output,
+                        doutput,
+                        q_seqlen_per_step,
+                        kv_seqlen_per_step,
+                        q_seq_offsets,
+                        k_seq_offsets,
+                        config=helper.get_step_config(NVTE_Mask_Type.NVTE_NO_MASK),
+                    )
+                    dk_dv_per_step = jnp.concat(
+                        [dk_dv_per_step, jnp.zeros_like(dk_dv_per_step)], axis=1
+                    )
+                    return dq_per_step, dk_dv_per_step, dbias_per_step
+
+                def half_q_no_mask_compute():
+                    q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx) // 2
+                    kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx)
+
+                    q_part = lax.slice_in_dim(q, q_max_seqlen // 2, q_max_seqlen, axis=1)
+                    doutput_part = lax.slice_in_dim(
+                        doutput, q_max_seqlen // 2, q_max_seqlen, axis=1
+                    )
+                    output_part = lax.slice_in_dim(output, q_max_seqlen // 2, q_max_seqlen, axis=1)
+
+                    softmax_aux_part = lax.slice_in_dim(
+                        softmax_aux, q_max_seqlen // 2, q_max_seqlen, axis=2
+                    )
+
+                    dq_per_step, dk_dv_per_step, _, dbias_per_step = FusedAttnBwdPrimitive.impl(
+                        q_part,
+                        kv,
+                        _not_used,
+                        bias,
+                        softmax_aux_part,
+                        rng_state,
+                        output_part,
+                        doutput_part,
+                        q_seqlen_per_step,
+                        kv_seqlen_per_step,
+                        q_seq_offsets,
+                        k_seq_offsets,
+                        config=helper.get_step_config(NVTE_Mask_Type.NVTE_NO_MASK),
+                    )
+                    dq_per_step = jnp.concat([jnp.zeros_like(dq_per_step), dq_per_step], axis=1)
+                    return dq_per_step, dk_dv_per_step, dbias_per_step
+
+                def skip_compute():
+                    return jnp.zeros_like(q), jnp.zeros_like(kv), jnp.zeros_like(bias)
+
+                if config.attn_mask_type == NVTE_Mask_Type.NVTE_CAUSAL_MASK:
+                    # This is for nested jax.lax.cond
+                    def jax_cond_wrap():
+                        if config.context_parallel_load_balanced:
+                            return lax.cond(
+                                (idx <= cp_rank), half_kv_no_mask_compute, half_q_no_mask_compute
+                            )
+                        return lax.cond((idx <= cp_rank), no_mask_compute, skip_compute)
+
+                    dq_per_step, dk_dv_per_step, dbias_per_step = lax.cond(
+                        idx == 0, causal_mask_compute, jax_cond_wrap
+                    )
+                else:
+                    dq_per_step, dk_dv_per_step, dbias_per_step = no_mask_compute()
+
+                kv_next, dk_dv = jnp.unstack(kv_dk_dv)
+                dq = dq + dq_per_step
+                dk_dv = dk_dv + dk_dv_per_step
+                if config.attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS:
+                    dbias = dbias + dbias_per_step
+
+                return (kv_next, dq, dk_dv, dbias)
+
+            carry = (kv, dq, dk_dv, dbias)
+            if helper.use_scanloop():
+                carry = lax.fori_loop(0, cp_size, scan_kv_block, carry)
+            else:
+                for i in range(0, cp_size):
+                    carry = scan_kv_block(i, carry)
+            (kv, dq, dk_dv, dbias) = carry
+
+            # Final permute to put gradients back to their final resting place.
+            dk_dv = helper.permute_kv(dk_dv, cp_perm)
+
+            global_dbias = dbias
+            if config.attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS:
+                global_dbias = all_reduce_sum_along_dp_fsdp(dbias, mesh)
+
+            dk, dv = helper.unstack_kv(dk_dv)
+            return dq, dk, dv, global_dbias
+
+        return mesh, ring_attn_bwd_impl, out_shardings, arg_shardings
+
+
+register_primitive(FusedRingAttnBwdPrimitive)
+
+
 def _maybe_context_parallel_axis(cp_axis: str):
     if not cp_axis:
         gmr = global_mesh_resource()
@@ -1437,6 +1937,7 @@ def fused_attn_fwd(
     is_training: bool,
     max_segments_per_seq: int,
     window_size: Optional[Tuple[int, int]] = None,
+    context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
     context_parallel_causal_load_balanced: bool = False,
     context_parallel_axis: str = "",
 ) -> jnp.ndarray:
@@ -1519,7 +2020,14 @@ def fused_attn_fwd(
         cp_axis=_maybe_context_parallel_axis(context_parallel_axis),
     )
 
-    return FusedAttnCPWithAllGatherFwdPrimitive.outer_primitive.bind(
+    primative = None
+    match context_parallel_strategy:
+        case CPStrategy.DEFAULT | CPStrategy.ALL_GATHER:
+            primative = FusedAttnCPWithAllGatherFwdPrimitive.outer_primitive
+        case CPStrategy.RING:
+            primative = FusedRingAttnFwdPrimitive.outer_primitive
+
+    return primative.bind(
         *qkv_for_primitive,
         bias,
         q_seqlen,
@@ -1550,6 +2058,7 @@ def fused_attn_bwd(
     is_training: bool,
     max_segments_per_seq: int,
     window_size: Optional[Tuple[int, int]] = None,
+    context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
     context_parallel_causal_load_balanced: bool = False,
     context_parallel_axis: str = "",
 ):
@@ -1636,7 +2145,14 @@ def fused_attn_bwd(
         cp_axis=_maybe_context_parallel_axis(context_parallel_axis),
     )
 
-    *qkv_grads, bias_grad = FusedAttnCPWithAllGatherBwdPrimitive.outer_primitive.bind(
+    primative = None
+    match context_parallel_strategy:
+        case CPStrategy.DEFAULT | CPStrategy.ALL_GATHER:
+            primative = FusedAttnCPWithAllGatherBwdPrimitive.outer_primitive
+        case CPStrategy.RING:
+            primative = FusedRingAttnBwdPrimitive.outer_primitive
+
+    *qkv_grads, bias_grad = primative.bind(
         *qkv_for_primitive,
         bias,
         softmax_aux,
diff --git a/transformer_engine/jax/cpp_extensions/misc.py b/transformer_engine/jax/cpp_extensions/misc.py
index d3df614ac9..1f13484b98 100644
--- a/transformer_engine/jax/cpp_extensions/misc.py
+++ b/transformer_engine/jax/cpp_extensions/misc.py
@@ -167,3 +167,25 @@ def is_ffi_enabled():
     is_enabled = int(os.getenv("NVTE_JAX_WITH_FFI", "1"))
     assert is_enabled in (0, 1), "Invalid NVTE_JAX_WITH_FFI value"
     return is_supported and is_enabled
+
+
+def get_xla_flag(flag: str, default=None, cast=str):
+    """
+    Returns the value of a flag/option in XLA_FLAGS environment variable if present or returns the default value.
+    """
+    xla_flags = []
+    if xla_flags_env := os.getenv("XLA_FLAGS"):
+        xla_flags.extend(xla_flags_env.split())
+
+    for flag_i in sorted(xla_flags):
+        if "=" in flag_i:
+            # option like --xla_abc=foo
+            name, val = flag_i.split("=", 2)
+            if name == flag:
+                return val if cast is None else cast(val)
+        else:
+            # flag like --xla_enable_foo
+            name, val = flag_i, None
+            if name == flag:
+                return True
+    return default
diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py
index a14a8384cf..f2da288be5 100644
--- a/transformer_engine/jax/sharding.py
+++ b/transformer_engine/jax/sharding.py
@@ -197,7 +197,7 @@ class MeshResource:
         The axis name in Mesh used to split the batch and weights along.
         If it is None, then full-sharded data parallelism is disabled.
     pp_resource : str, default = None
-        The axis name in Mesh used to split model layers. along.
+        The axis name in Mesh used to split model layers along.
         If it is None, then pipeline parallelism is disabled.
     cp_resource : str, default = None
         The axis name in Mesh used to split sequence (context) dimensions along