Skip to content

Commit b742244

Browse files
committed
Add Copyright to utils and fix some more greptiles complaints
Signed-off-by: tdophung <[email protected]>
1 parent 733d61b commit b742244

File tree

2 files changed

+47
-54
lines changed

2 files changed

+47
-54
lines changed

docs/examples/quickstart_jax.ipynb

Lines changed: 40 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
},
5252
{
5353
"cell_type": "code",
54-
"execution_count": 1,
54+
"execution_count": null,
5555
"id": "881fd001",
5656
"metadata": {},
5757
"outputs": [],
@@ -65,18 +65,10 @@
6565
},
6666
{
6767
"cell_type": "code",
68-
"execution_count": 2,
68+
"execution_count": 8,
6969
"id": "d5284a38",
7070
"metadata": {},
71-
"outputs": [
72-
{
73-
"name": "stderr",
74-
"output_type": "stream",
75-
"text": [
76-
"WARNING:absl:Tensorflow library not found, tensorflow.io.gfile operations will use native shim calls. GCS paths (i.e. 'gs://...') cannot be accessed.\n"
77-
]
78-
}
79-
],
71+
"outputs": [],
8072
"source": [
8173
"import jax\n",
8274
"import jax.numpy as jnp\n",
@@ -87,7 +79,7 @@
8779
},
8880
{
8981
"cell_type": "code",
90-
"execution_count": 3,
82+
"execution_count": 9,
9183
"id": "a4d1cfdc",
9284
"metadata": {},
9385
"outputs": [],
@@ -181,7 +173,7 @@
181173
},
182174
{
183175
"cell_type": "code",
184-
"execution_count": 4,
176+
"execution_count": 10,
185177
"id": "8b44649d",
186178
"metadata": {},
187179
"outputs": [],
@@ -202,15 +194,15 @@
202194
},
203195
{
204196
"cell_type": "code",
205-
"execution_count": null,
197+
"execution_count": 11,
206198
"id": "e44ed26d",
207199
"metadata": {},
208200
"outputs": [
209201
{
210202
"name": "stdout",
211203
"output_type": "stream",
212204
"text": [
213-
"BasicTransformerLayer initialized successfully!\n",
205+
"Pure Flax BasicTransformerLayer initialized successfully!\n",
214206
"Parameter shapes: {'params': {'BasicMLP_0': {'Dense_0': {'bias': (16384,), 'kernel': (4096, 16384)}, 'Dense_1': {'bias': (4096,), 'kernel': (16384, 4096)}}, 'Dense_0': {'bias': (12288,), 'kernel': (4096, 12288)}, 'Dense_1': {'bias': (4096,), 'kernel': (4096, 4096)}, 'LayerNorm_0': {'bias': (4096,), 'scale': (4096,)}, 'LayerNorm_1': {'bias': (4096,), 'scale': (4096,)}}}\n"
215207
]
216208
}
@@ -232,7 +224,7 @@
232224
},
233225
{
234226
"cell_type": "code",
235-
"execution_count": 6,
227+
"execution_count": 12,
236228
"id": "de91af7a",
237229
"metadata": {},
238230
"outputs": [
@@ -258,15 +250,15 @@
258250
},
259251
{
260252
"cell_type": "code",
261-
"execution_count": 7,
253+
"execution_count": 13,
262254
"id": "037bc8d9",
263255
"metadata": {},
264256
"outputs": [
265257
{
266258
"name": "stdout",
267259
"output_type": "stream",
268260
"text": [
269-
"Mean time: 27.949795722961426 ms\n"
261+
"Mean time: 27.269372940063477 ms\n"
270262
]
271263
}
272264
],
@@ -316,7 +308,7 @@
316308
},
317309
{
318310
"cell_type": "code",
319-
"execution_count": 8,
311+
"execution_count": 14,
320312
"id": "bed20d6b",
321313
"metadata": {},
322314
"outputs": [],
@@ -336,7 +328,7 @@
336328
},
337329
{
338330
"cell_type": "code",
339-
"execution_count": 9,
331+
"execution_count": 15,
340332
"id": "56105579",
341333
"metadata": {},
342334
"outputs": [],
@@ -357,7 +349,7 @@
357349
" hidden_size: int\n",
358350
" ffn_hidden_size: int \n",
359351
" num_attention_heads: int \n",
360-
" layernorm_eps: int = 1e-5\n",
352+
" layernorm_eps: float = 1e-5\n",
361353
" attention_dropout: float = 0.1 \n",
362354
" hidden_dropout: float = 0.1\n",
363355
"\n",
@@ -433,21 +425,16 @@
433425
},
434426
{
435427
"cell_type": "code",
436-
"execution_count": null,
428+
"execution_count": 16,
437429
"id": "5146cd99",
438430
"metadata": {},
439431
"outputs": [
440432
{
441433
"name": "stdout",
442434
"output_type": "stream",
443435
"text": [
444-
"Basic parameter shapes: {'BasicMLP_0': {'Dense_0': {'bias': (16384,), 'kernel': (4096, 16384)}, 'Dense_1': {'bias': (4096,), 'kernel': (16384, 4096)}}, 'Dense_0': {'bias': (12288,), 'kernel': (4096, 12288)}, 'Dense_1': {'bias': (4096,), 'kernel': (4096, 4096)}, 'LayerNorm_0': {'bias': (4096,), 'scale': (4096,)}, 'LayerNorm_1': {'bias': (4096,), 'scale': (4096,)}}\n",
445-
"TE parameter shapes: {'BasicTEMLP_0': {'DenseGeneral_0': {'bias': LogicallyPartitioned(value=(16384,), names=(), mesh=None, rules=None), 'kernel': LogicallyPartitioned(value=(4096, 16384), names=(), mesh=None, rules=None)}, 'DenseGeneral_1': {'bias': LogicallyPartitioned(value=(4096,), names=(), mesh=None, rules=None), 'kernel': LogicallyPartitioned(value=(16384, 4096), names=(), mesh=None, rules=None)}}, 'DenseGeneral_0': {'bias': LogicallyPartitioned(value=(12288,), names=(), mesh=None, rules=None), 'kernel': LogicallyPartitioned(value=(4096, 12288), names=(), mesh=None, rules=None)}, 'DenseGeneral_1': {'bias': LogicallyPartitioned(value=(4096,), names=(), mesh=None, rules=None), 'kernel': LogicallyPartitioned(value=(4096, 4096), names=(), mesh=None, rules=None)}, 'LayerNorm_0': {'ln_bias': LogicallyPartitioned(value=(4096,), names=('embed',), mesh=None, rules=None), 'scale': LogicallyPartitioned(value=(4096,), names=('embed',), mesh=None, rules=None)}, 'LayerNorm_1': {'ln_bias': LogicallyPartitioned(value=(4096,), names=('embed',), mesh=None, rules=None), 'scale': LogicallyPartitioned(value=(4096,), names=('embed',), mesh=None, rules=None)}}\n",
446-
"Input shape: (2048, 4, 4096)\n",
447-
"Output shape: (2048, 4, 4096)\n",
448-
"Output dtype: bfloat16\n",
449-
"Forward pass completed successfully!\n",
450-
"Mean time: 17.249975204467773 ms\n"
436+
"Basic TE parameter shapes: {'BasicTEMLP_0': {'DenseGeneral_0': {'bias': LogicallyPartitioned(value=(16384,), names=(), mesh=None, rules=None), 'kernel': LogicallyPartitioned(value=(4096, 16384), names=(), mesh=None, rules=None)}, 'DenseGeneral_1': {'bias': LogicallyPartitioned(value=(4096,), names=(), mesh=None, rules=None), 'kernel': LogicallyPartitioned(value=(16384, 4096), names=(), mesh=None, rules=None)}}, 'DenseGeneral_0': {'bias': LogicallyPartitioned(value=(12288,), names=(), mesh=None, rules=None), 'kernel': LogicallyPartitioned(value=(4096, 12288), names=(), mesh=None, rules=None)}, 'DenseGeneral_1': {'bias': LogicallyPartitioned(value=(4096,), names=(), mesh=None, rules=None), 'kernel': LogicallyPartitioned(value=(4096, 4096), names=(), mesh=None, rules=None)}, 'LayerNorm_0': {'ln_bias': LogicallyPartitioned(value=(4096,), names=('embed',), mesh=None, rules=None), 'scale': LogicallyPartitioned(value=(4096,), names=('embed',), mesh=None, rules=None)}, 'LayerNorm_1': {'ln_bias': LogicallyPartitioned(value=(4096,), names=('embed',), mesh=None, rules=None), 'scale': LogicallyPartitioned(value=(4096,), names=('embed',), mesh=None, rules=None)}}\n",
437+
"Mean time: 17.397570610046387 ms\n"
451438
]
452439
}
453440
],
@@ -516,7 +503,7 @@
516503
},
517504
{
518505
"cell_type": "code",
519-
"execution_count": null,
506+
"execution_count": 17,
520507
"id": "11203785",
521508
"metadata": {},
522509
"outputs": [],
@@ -525,7 +512,7 @@
525512
" hidden_size: int\n",
526513
" ffn_hidden_size: int \n",
527514
" num_attention_heads: int \n",
528-
" layernorm_eps: int = 1e-5\n",
515+
" layernorm_eps: float = 1e-5\n",
529516
" attention_dropout: float = 0.1 \n",
530517
" hidden_dropout: float = 0.1\n",
531518
"\n",
@@ -586,15 +573,14 @@
586573
},
587574
{
588575
"cell_type": "code",
589-
"execution_count": null,
576+
"execution_count": 18,
590577
"id": "114de14f",
591578
"metadata": {},
592579
"outputs": [
593580
{
594581
"name": "stdout",
595582
"output_type": "stream",
596583
"text": [
597-
"Basic parameter shapes: {'BasicMLP_0': {'Dense_0': {'bias': (16384,), 'kernel': (4096, 16384)}, 'Dense_1': {'bias': (4096,), 'kernel': (16384, 4096)}}, 'Dense_0': {'bias': (12288,), 'kernel': (4096, 12288)}, 'Dense_1': {'bias': (4096,), 'kernel': (4096, 4096)}, 'LayerNorm_0': {'bias': (4096,), 'scale': (4096,)}, 'LayerNorm_1': {'bias': (4096,), 'scale': (4096,)}}\n",
598584
"Fused TE parameter shapes: {'DenseGeneral_0': {'bias': LogicallyPartitioned(value=(4096,), names=(), mesh=None, rules=None), 'kernel': LogicallyPartitioned(value=(4096, 4096), names=(), mesh=None, rules=None)}, 'LayerNormDenseGeneral_0': {'bias': LogicallyPartitioned(value=(12288,), names=(), mesh=None, rules=None), 'kernel': LogicallyPartitioned(value=(4096, 12288), names=(), mesh=None, rules=None), 'ln_bias': LogicallyPartitioned(value=(4096,), names=('embed',), mesh=None, rules=None), 'scale': LogicallyPartitioned(value=(4096,), names=('embed',), mesh=None, rules=None)}, 'LayerNormMLP_0': {'ln_bias': LogicallyPartitioned(value=(4096,), names=('embed',), mesh=None, rules=None), 'scale': LogicallyPartitioned(value=(4096,), names=('embed',), mesh=None, rules=None), 'wi_bias': LogicallyPartitioned(value=(1, 16384), names=('act', 'mlp'), mesh=None, rules=None), 'wi_kernel': LogicallyPartitioned(value=(4096, 1, 16384), names=('embed', 'act', 'mlp'), mesh=None, rules=None), 'wo_bias': LogicallyPartitioned(value=(4096,), names=('embed',), mesh=None, rules=None), 'wo_kernel': LogicallyPartitioned(value=(16384, 4096), names=('mlp', 'embed'), mesh=None, rules=None)}}\n"
599585
]
600586
}
@@ -617,19 +603,15 @@
617603
},
618604
{
619605
"cell_type": "code",
620-
"execution_count": null,
606+
"execution_count": 19,
621607
"id": "6b0c705e",
622608
"metadata": {},
623609
"outputs": [
624610
{
625611
"name": "stdout",
626612
"output_type": "stream",
627613
"text": [
628-
"Input shape: (2048, 4, 4096)\n",
629-
"Output shape: (2048, 4, 4096)\n",
630-
"Output dtype: bfloat16\n",
631-
"Forward pass completed successfully!\n",
632-
"Mean time: 17.9510498046875 ms\n"
614+
"Mean time: 18.0991792678833 ms\n"
633615
]
634616
}
635617
],
@@ -660,15 +642,14 @@
660642
},
661643
{
662644
"cell_type": "code",
663-
"execution_count": null,
645+
"execution_count": 20,
664646
"id": "7496b159",
665647
"metadata": {},
666648
"outputs": [
667649
{
668650
"name": "stdout",
669651
"output_type": "stream",
670652
"text": [
671-
"Basic parameter shapes: {'BasicMLP_0': {'Dense_0': {'bias': (16384,), 'kernel': (4096, 16384)}, 'Dense_1': {'bias': (4096,), 'kernel': (16384, 4096)}}, 'Dense_0': {'bias': (12288,), 'kernel': (4096, 12288)}, 'Dense_1': {'bias': (4096,), 'kernel': (4096, 4096)}, 'LayerNorm_0': {'bias': (4096,), 'scale': (4096,)}, 'LayerNorm_1': {'bias': (4096,), 'scale': (4096,)}}\n",
672653
"TE TransformerLayer parameter shapes: {'attention': {'out': {'bias': LogicallyPartitioned(value=(4096,), names=('nvte_w_no_shard',), mesh=None, rules=None), 'kernel': LogicallyPartitioned(value=(4096, 4096), names=('nvte_w_tp', 'nvte_w_fsdp'), mesh=None, rules=None)}, 'qkv': {'bias': LogicallyPartitioned(value=(3, 4096), names=('nvte_w_joined', 'nvte_w_tp'), mesh=None, rules=None), 'kernel': LogicallyPartitioned(value=(4096, 3, 4096), names=('nvte_w_fsdp', 'nvte_w_joined', 'nvte_w_tp'), mesh=None, rules=None), 'ln_bias': LogicallyPartitioned(value=(4096,), names=('nvte_w_no_shard',), mesh=None, rules=None), 'scale': LogicallyPartitioned(value=(4096,), names=('nvte_w_no_shard',), mesh=None, rules=None)}}, 'mlp': {'ln_bias': LogicallyPartitioned(value=(4096,), names=('nvte_w_no_shard',), mesh=None, rules=None), 'scale': LogicallyPartitioned(value=(4096,), names=('nvte_w_no_shard',), mesh=None, rules=None), 'wi_bias': LogicallyPartitioned(value=(1, 16384), names=('nvte_w_joined', 'nvte_w_tp'), mesh=None, rules=None), 'wi_kernel': LogicallyPartitioned(value=(4096, 1, 16384), names=('nvte_w_fsdp', 'nvte_w_joined', 'nvte_w_tp'), mesh=None, rules=None), 'wo_bias': LogicallyPartitioned(value=(4096,), names=('nvte_w_no_shard',), mesh=None, rules=None), 'wo_kernel': LogicallyPartitioned(value=(16384, 4096), names=('nvte_w_tp', 'nvte_w_fsdp'), mesh=None, rules=None)}, 'relpos_bias': {'rel_embedding': LogicallyPartitioned(value=(32, 32), names=('heads', 'relpos_buckets'), mesh=None, rules=None)}}\n"
673654
]
674655
}
@@ -692,19 +673,15 @@
692673
},
693674
{
694675
"cell_type": "code",
695-
"execution_count": null,
676+
"execution_count": 21,
696677
"id": "6ec0f60e",
697678
"metadata": {},
698679
"outputs": [
699680
{
700681
"name": "stdout",
701682
"output_type": "stream",
702683
"text": [
703-
"Input shape: (2048, 4, 4096)\n",
704-
"Output shape: (2048, 4, 4096)\n",
705-
"Output dtype: bfloat16\n",
706-
"Forward pass completed successfully!\n",
707-
"Mean time: 11.953592300415039 ms\n"
684+
"Mean time: 11.84274673461914 ms\n"
708685
]
709686
}
710687
],
@@ -743,7 +720,7 @@
743720
"\n",
744721
"</div>\n",
745722
"\n",
746-
"Enabling FP8 support is very simple in Transformer Engine. We just need to wrap the modules within an [fp8_autocast](../api/pytorch.rst#transformer_engine.pytorch.fp8_autocast) context manager. Note that fp8_autocast should only be used to wrap the forward pass and must exit before starting a backward pass. See the [FP8 tutorial](fp8_primer.ipynb) (currently only available in PyTorch) for a detailed explanation of FP8 recipes and the supported options.\n",
723+
"Enabling FP8 support is very simple in Transformer Engine. We just need to wrap the modules within an [fp8_autocast](.../api/jax.rst#transformer_engine.jax.fp8_autocast) context manager. Note that fp8_autocast should only be used to wrap the forward pass and must exit before starting a backward pass. See the [FP8 tutorial](fp8_primer.ipynb) (currently only available in PyTorch) for a detailed explanation of FP8 recipes and the supported options.\n",
747724
"\n",
748725
"<div class=\"alert alert-warning\">\n",
749726
"\n",
@@ -756,7 +733,7 @@
756733
},
757734
{
758735
"cell_type": "code",
759-
"execution_count": null,
736+
"execution_count": 22,
760737
"id": "b2aaa8ef",
761738
"metadata": {},
762739
"outputs": [
@@ -803,15 +780,15 @@
803780
},
804781
{
805782
"cell_type": "code",
806-
"execution_count": null,
783+
"execution_count": 23,
807784
"id": "b9cdbf22",
808785
"metadata": {},
809786
"outputs": [
810787
{
811788
"name": "stdout",
812789
"output_type": "stream",
813790
"text": [
814-
"Mean time: 7.995896339416504 ms\n"
791+
"Mean time: 7.96757698059082 ms\n"
815792
]
816793
}
817794
],
@@ -834,6 +811,18 @@
834811
"display_name": "Python 3 (ipykernel)",
835812
"language": "python",
836813
"name": "python3"
814+
},
815+
"language_info": {
816+
"codemirror_mode": {
817+
"name": "ipython",
818+
"version": 3
819+
},
820+
"file_extension": ".py",
821+
"mimetype": "text/x-python",
822+
"name": "python",
823+
"nbconvert_exporter": "python",
824+
"pygments_lexer": "ipython3",
825+
"version": "3.12.3"
837826
}
838827
},
839828
"nbformat": 4,

docs/examples/quickstart_jax_utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# See LICENSE for license information.
4+
15
import jax
26
import jax.numpy as jnp
37
import time
@@ -35,15 +39,15 @@ def speedometer(
3539
for _ in range(warmup_iters):
3640
key, step_key = jax.random.split(key)
3741
loss, (param_grads, other_grads) = train_step_fn(
38-
variables, input, output_grad, step_key, key
42+
variables, input, output_grad, step_key
3943
)
4044

4145
# Timing runs
4246
start = time.time()
4347
for _ in range(timing_iters):
4448
key, step_key = jax.random.split(key)
4549
loss, (param_grads, other_grads) = train_step_fn(
46-
variables, input, output_grad, step_key, key
50+
variables, input, output_grad, step_key
4751
)
4852
end = time.time()
4953

@@ -98,7 +102,7 @@ def create_train_step_fn_vjp(
98102
def train_step_fn(variables: Any, inp: jnp.ndarray, grad_target: jnp.ndarray, dropout_key):
99103
"""Compute forward pass and VJP in one step"""
100104

101-
# Define forward function that closes over grad_target and dropout_key
105+
# Define forward function that closes over dropout_key
102106
def forward_fn(variables: Any, inp: jnp.ndarray):
103107
"""Pure forward function for VJP computation"""
104108
rngs = {"dropout": dropout_key}

0 commit comments

Comments
 (0)