|
51 | 51 | }, |
52 | 52 | { |
53 | 53 | "cell_type": "code", |
54 | | - "execution_count": 1, |
| 54 | + "execution_count": null, |
55 | 55 | "id": "881fd001", |
56 | 56 | "metadata": {}, |
57 | 57 | "outputs": [], |
|
65 | 65 | }, |
66 | 66 | { |
67 | 67 | "cell_type": "code", |
68 | | - "execution_count": 2, |
| 68 | + "execution_count": 8, |
69 | 69 | "id": "d5284a38", |
70 | 70 | "metadata": {}, |
71 | | - "outputs": [ |
72 | | - { |
73 | | - "name": "stderr", |
74 | | - "output_type": "stream", |
75 | | - "text": [ |
76 | | - "WARNING:absl:Tensorflow library not found, tensorflow.io.gfile operations will use native shim calls. GCS paths (i.e. 'gs://...') cannot be accessed.\n" |
77 | | - ] |
78 | | - } |
79 | | - ], |
| 71 | + "outputs": [], |
80 | 72 | "source": [ |
81 | 73 | "import jax\n", |
82 | 74 | "import jax.numpy as jnp\n", |
|
87 | 79 | }, |
88 | 80 | { |
89 | 81 | "cell_type": "code", |
90 | | - "execution_count": 3, |
| 82 | + "execution_count": 9, |
91 | 83 | "id": "a4d1cfdc", |
92 | 84 | "metadata": {}, |
93 | 85 | "outputs": [], |
|
181 | 173 | }, |
182 | 174 | { |
183 | 175 | "cell_type": "code", |
184 | | - "execution_count": 4, |
| 176 | + "execution_count": 10, |
185 | 177 | "id": "8b44649d", |
186 | 178 | "metadata": {}, |
187 | 179 | "outputs": [], |
|
202 | 194 | }, |
203 | 195 | { |
204 | 196 | "cell_type": "code", |
205 | | - "execution_count": null, |
| 197 | + "execution_count": 11, |
206 | 198 | "id": "e44ed26d", |
207 | 199 | "metadata": {}, |
208 | 200 | "outputs": [ |
209 | 201 | { |
210 | 202 | "name": "stdout", |
211 | 203 | "output_type": "stream", |
212 | 204 | "text": [ |
213 | | - "BasicTransformerLayer initialized successfully!\n", |
| 205 | + "Pure Flax BasicTransformerLayer initialized successfully!\n", |
214 | 206 | "Parameter shapes: {'params': {'BasicMLP_0': {'Dense_0': {'bias': (16384,), 'kernel': (4096, 16384)}, 'Dense_1': {'bias': (4096,), 'kernel': (16384, 4096)}}, 'Dense_0': {'bias': (12288,), 'kernel': (4096, 12288)}, 'Dense_1': {'bias': (4096,), 'kernel': (4096, 4096)}, 'LayerNorm_0': {'bias': (4096,), 'scale': (4096,)}, 'LayerNorm_1': {'bias': (4096,), 'scale': (4096,)}}}\n" |
215 | 207 | ] |
216 | 208 | } |
|
232 | 224 | }, |
233 | 225 | { |
234 | 226 | "cell_type": "code", |
235 | | - "execution_count": 6, |
| 227 | + "execution_count": 12, |
236 | 228 | "id": "de91af7a", |
237 | 229 | "metadata": {}, |
238 | 230 | "outputs": [ |
|
258 | 250 | }, |
259 | 251 | { |
260 | 252 | "cell_type": "code", |
261 | | - "execution_count": 7, |
| 253 | + "execution_count": 13, |
262 | 254 | "id": "037bc8d9", |
263 | 255 | "metadata": {}, |
264 | 256 | "outputs": [ |
265 | 257 | { |
266 | 258 | "name": "stdout", |
267 | 259 | "output_type": "stream", |
268 | 260 | "text": [ |
269 | | - "Mean time: 27.949795722961426 ms\n" |
| 261 | + "Mean time: 27.269372940063477 ms\n" |
270 | 262 | ] |
271 | 263 | } |
272 | 264 | ], |
|
316 | 308 | }, |
317 | 309 | { |
318 | 310 | "cell_type": "code", |
319 | | - "execution_count": 8, |
| 311 | + "execution_count": 14, |
320 | 312 | "id": "bed20d6b", |
321 | 313 | "metadata": {}, |
322 | 314 | "outputs": [], |
|
336 | 328 | }, |
337 | 329 | { |
338 | 330 | "cell_type": "code", |
339 | | - "execution_count": 9, |
| 331 | + "execution_count": 15, |
340 | 332 | "id": "56105579", |
341 | 333 | "metadata": {}, |
342 | 334 | "outputs": [], |
|
357 | 349 | " hidden_size: int\n", |
358 | 350 | " ffn_hidden_size: int \n", |
359 | 351 | " num_attention_heads: int \n", |
360 | | - " layernorm_eps: int = 1e-5\n", |
| 352 | + " layernorm_eps: float = 1e-5\n", |
361 | 353 | " attention_dropout: float = 0.1 \n", |
362 | 354 | " hidden_dropout: float = 0.1\n", |
363 | 355 | "\n", |
|
433 | 425 | }, |
434 | 426 | { |
435 | 427 | "cell_type": "code", |
436 | | - "execution_count": null, |
| 428 | + "execution_count": 16, |
437 | 429 | "id": "5146cd99", |
438 | 430 | "metadata": {}, |
439 | 431 | "outputs": [ |
440 | 432 | { |
441 | 433 | "name": "stdout", |
442 | 434 | "output_type": "stream", |
443 | 435 | "text": [ |
444 | | - "Basic parameter shapes: {'BasicMLP_0': {'Dense_0': {'bias': (16384,), 'kernel': (4096, 16384)}, 'Dense_1': {'bias': (4096,), 'kernel': (16384, 4096)}}, 'Dense_0': {'bias': (12288,), 'kernel': (4096, 12288)}, 'Dense_1': {'bias': (4096,), 'kernel': (4096, 4096)}, 'LayerNorm_0': {'bias': (4096,), 'scale': (4096,)}, 'LayerNorm_1': {'bias': (4096,), 'scale': (4096,)}}\n", |
445 | | - "TE parameter shapes: {'BasicTEMLP_0': {'DenseGeneral_0': {'bias': LogicallyPartitioned(value=(16384,), names=(), mesh=None, rules=None), 'kernel': LogicallyPartitioned(value=(4096, 16384), names=(), mesh=None, rules=None)}, 'DenseGeneral_1': {'bias': LogicallyPartitioned(value=(4096,), names=(), mesh=None, rules=None), 'kernel': LogicallyPartitioned(value=(16384, 4096), names=(), mesh=None, rules=None)}}, 'DenseGeneral_0': {'bias': LogicallyPartitioned(value=(12288,), names=(), mesh=None, rules=None), 'kernel': LogicallyPartitioned(value=(4096, 12288), names=(), mesh=None, rules=None)}, 'DenseGeneral_1': {'bias': LogicallyPartitioned(value=(4096,), names=(), mesh=None, rules=None), 'kernel': LogicallyPartitioned(value=(4096, 4096), names=(), mesh=None, rules=None)}, 'LayerNorm_0': {'ln_bias': LogicallyPartitioned(value=(4096,), names=('embed',), mesh=None, rules=None), 'scale': LogicallyPartitioned(value=(4096,), names=('embed',), mesh=None, rules=None)}, 'LayerNorm_1': {'ln_bias': LogicallyPartitioned(value=(4096,), names=('embed',), mesh=None, rules=None), 'scale': LogicallyPartitioned(value=(4096,), names=('embed',), mesh=None, rules=None)}}\n", |
446 | | - "Input shape: (2048, 4, 4096)\n", |
447 | | - "Output shape: (2048, 4, 4096)\n", |
448 | | - "Output dtype: bfloat16\n", |
449 | | - "Forward pass completed successfully!\n", |
450 | | - "Mean time: 17.249975204467773 ms\n" |
| 436 | + "Basic TE parameter shapes: {'BasicTEMLP_0': {'DenseGeneral_0': {'bias': LogicallyPartitioned(value=(16384,), names=(), mesh=None, rules=None), 'kernel': LogicallyPartitioned(value=(4096, 16384), names=(), mesh=None, rules=None)}, 'DenseGeneral_1': {'bias': LogicallyPartitioned(value=(4096,), names=(), mesh=None, rules=None), 'kernel': LogicallyPartitioned(value=(16384, 4096), names=(), mesh=None, rules=None)}}, 'DenseGeneral_0': {'bias': LogicallyPartitioned(value=(12288,), names=(), mesh=None, rules=None), 'kernel': LogicallyPartitioned(value=(4096, 12288), names=(), mesh=None, rules=None)}, 'DenseGeneral_1': {'bias': LogicallyPartitioned(value=(4096,), names=(), mesh=None, rules=None), 'kernel': LogicallyPartitioned(value=(4096, 4096), names=(), mesh=None, rules=None)}, 'LayerNorm_0': {'ln_bias': LogicallyPartitioned(value=(4096,), names=('embed',), mesh=None, rules=None), 'scale': LogicallyPartitioned(value=(4096,), names=('embed',), mesh=None, rules=None)}, 'LayerNorm_1': {'ln_bias': LogicallyPartitioned(value=(4096,), names=('embed',), mesh=None, rules=None), 'scale': LogicallyPartitioned(value=(4096,), names=('embed',), mesh=None, rules=None)}}\n", |
| 437 | + "Mean time: 17.397570610046387 ms\n" |
451 | 438 | ] |
452 | 439 | } |
453 | 440 | ], |
|
516 | 503 | }, |
517 | 504 | { |
518 | 505 | "cell_type": "code", |
519 | | - "execution_count": null, |
| 506 | + "execution_count": 17, |
520 | 507 | "id": "11203785", |
521 | 508 | "metadata": {}, |
522 | 509 | "outputs": [], |
|
525 | 512 | " hidden_size: int\n", |
526 | 513 | " ffn_hidden_size: int \n", |
527 | 514 | " num_attention_heads: int \n", |
528 | | - " layernorm_eps: int = 1e-5\n", |
| 515 | + " layernorm_eps: float = 1e-5\n", |
529 | 516 | " attention_dropout: float = 0.1 \n", |
530 | 517 | " hidden_dropout: float = 0.1\n", |
531 | 518 | "\n", |
|
586 | 573 | }, |
587 | 574 | { |
588 | 575 | "cell_type": "code", |
589 | | - "execution_count": null, |
| 576 | + "execution_count": 18, |
590 | 577 | "id": "114de14f", |
591 | 578 | "metadata": {}, |
592 | 579 | "outputs": [ |
593 | 580 | { |
594 | 581 | "name": "stdout", |
595 | 582 | "output_type": "stream", |
596 | 583 | "text": [ |
597 | | - "Basic parameter shapes: {'BasicMLP_0': {'Dense_0': {'bias': (16384,), 'kernel': (4096, 16384)}, 'Dense_1': {'bias': (4096,), 'kernel': (16384, 4096)}}, 'Dense_0': {'bias': (12288,), 'kernel': (4096, 12288)}, 'Dense_1': {'bias': (4096,), 'kernel': (4096, 4096)}, 'LayerNorm_0': {'bias': (4096,), 'scale': (4096,)}, 'LayerNorm_1': {'bias': (4096,), 'scale': (4096,)}}\n", |
598 | 584 | "Fused TE parameter shapes: {'DenseGeneral_0': {'bias': LogicallyPartitioned(value=(4096,), names=(), mesh=None, rules=None), 'kernel': LogicallyPartitioned(value=(4096, 4096), names=(), mesh=None, rules=None)}, 'LayerNormDenseGeneral_0': {'bias': LogicallyPartitioned(value=(12288,), names=(), mesh=None, rules=None), 'kernel': LogicallyPartitioned(value=(4096, 12288), names=(), mesh=None, rules=None), 'ln_bias': LogicallyPartitioned(value=(4096,), names=('embed',), mesh=None, rules=None), 'scale': LogicallyPartitioned(value=(4096,), names=('embed',), mesh=None, rules=None)}, 'LayerNormMLP_0': {'ln_bias': LogicallyPartitioned(value=(4096,), names=('embed',), mesh=None, rules=None), 'scale': LogicallyPartitioned(value=(4096,), names=('embed',), mesh=None, rules=None), 'wi_bias': LogicallyPartitioned(value=(1, 16384), names=('act', 'mlp'), mesh=None, rules=None), 'wi_kernel': LogicallyPartitioned(value=(4096, 1, 16384), names=('embed', 'act', 'mlp'), mesh=None, rules=None), 'wo_bias': LogicallyPartitioned(value=(4096,), names=('embed',), mesh=None, rules=None), 'wo_kernel': LogicallyPartitioned(value=(16384, 4096), names=('mlp', 'embed'), mesh=None, rules=None)}}\n" |
599 | 585 | ] |
600 | 586 | } |
|
617 | 603 | }, |
618 | 604 | { |
619 | 605 | "cell_type": "code", |
620 | | - "execution_count": null, |
| 606 | + "execution_count": 19, |
621 | 607 | "id": "6b0c705e", |
622 | 608 | "metadata": {}, |
623 | 609 | "outputs": [ |
624 | 610 | { |
625 | 611 | "name": "stdout", |
626 | 612 | "output_type": "stream", |
627 | 613 | "text": [ |
628 | | - "Input shape: (2048, 4, 4096)\n", |
629 | | - "Output shape: (2048, 4, 4096)\n", |
630 | | - "Output dtype: bfloat16\n", |
631 | | - "Forward pass completed successfully!\n", |
632 | | - "Mean time: 17.9510498046875 ms\n" |
| 614 | + "Mean time: 18.0991792678833 ms\n" |
633 | 615 | ] |
634 | 616 | } |
635 | 617 | ], |
|
660 | 642 | }, |
661 | 643 | { |
662 | 644 | "cell_type": "code", |
663 | | - "execution_count": null, |
| 645 | + "execution_count": 20, |
664 | 646 | "id": "7496b159", |
665 | 647 | "metadata": {}, |
666 | 648 | "outputs": [ |
667 | 649 | { |
668 | 650 | "name": "stdout", |
669 | 651 | "output_type": "stream", |
670 | 652 | "text": [ |
671 | | - "Basic parameter shapes: {'BasicMLP_0': {'Dense_0': {'bias': (16384,), 'kernel': (4096, 16384)}, 'Dense_1': {'bias': (4096,), 'kernel': (16384, 4096)}}, 'Dense_0': {'bias': (12288,), 'kernel': (4096, 12288)}, 'Dense_1': {'bias': (4096,), 'kernel': (4096, 4096)}, 'LayerNorm_0': {'bias': (4096,), 'scale': (4096,)}, 'LayerNorm_1': {'bias': (4096,), 'scale': (4096,)}}\n", |
672 | 653 | "TE TransformerLayer parameter shapes: {'attention': {'out': {'bias': LogicallyPartitioned(value=(4096,), names=('nvte_w_no_shard',), mesh=None, rules=None), 'kernel': LogicallyPartitioned(value=(4096, 4096), names=('nvte_w_tp', 'nvte_w_fsdp'), mesh=None, rules=None)}, 'qkv': {'bias': LogicallyPartitioned(value=(3, 4096), names=('nvte_w_joined', 'nvte_w_tp'), mesh=None, rules=None), 'kernel': LogicallyPartitioned(value=(4096, 3, 4096), names=('nvte_w_fsdp', 'nvte_w_joined', 'nvte_w_tp'), mesh=None, rules=None), 'ln_bias': LogicallyPartitioned(value=(4096,), names=('nvte_w_no_shard',), mesh=None, rules=None), 'scale': LogicallyPartitioned(value=(4096,), names=('nvte_w_no_shard',), mesh=None, rules=None)}}, 'mlp': {'ln_bias': LogicallyPartitioned(value=(4096,), names=('nvte_w_no_shard',), mesh=None, rules=None), 'scale': LogicallyPartitioned(value=(4096,), names=('nvte_w_no_shard',), mesh=None, rules=None), 'wi_bias': LogicallyPartitioned(value=(1, 16384), names=('nvte_w_joined', 'nvte_w_tp'), mesh=None, rules=None), 'wi_kernel': LogicallyPartitioned(value=(4096, 1, 16384), names=('nvte_w_fsdp', 'nvte_w_joined', 'nvte_w_tp'), mesh=None, rules=None), 'wo_bias': LogicallyPartitioned(value=(4096,), names=('nvte_w_no_shard',), mesh=None, rules=None), 'wo_kernel': LogicallyPartitioned(value=(16384, 4096), names=('nvte_w_tp', 'nvte_w_fsdp'), mesh=None, rules=None)}, 'relpos_bias': {'rel_embedding': LogicallyPartitioned(value=(32, 32), names=('heads', 'relpos_buckets'), mesh=None, rules=None)}}\n" |
673 | 654 | ] |
674 | 655 | } |
|
692 | 673 | }, |
693 | 674 | { |
694 | 675 | "cell_type": "code", |
695 | | - "execution_count": null, |
| 676 | + "execution_count": 21, |
696 | 677 | "id": "6ec0f60e", |
697 | 678 | "metadata": {}, |
698 | 679 | "outputs": [ |
699 | 680 | { |
700 | 681 | "name": "stdout", |
701 | 682 | "output_type": "stream", |
702 | 683 | "text": [ |
703 | | - "Input shape: (2048, 4, 4096)\n", |
704 | | - "Output shape: (2048, 4, 4096)\n", |
705 | | - "Output dtype: bfloat16\n", |
706 | | - "Forward pass completed successfully!\n", |
707 | | - "Mean time: 11.953592300415039 ms\n" |
| 684 | + "Mean time: 11.84274673461914 ms\n" |
708 | 685 | ] |
709 | 686 | } |
710 | 687 | ], |
|
743 | 720 | "\n", |
744 | 721 | "</div>\n", |
745 | 722 | "\n", |
746 | | - "Enabling FP8 support is very simple in Transformer Engine. We just need to wrap the modules within an [fp8_autocast](../api/pytorch.rst#transformer_engine.pytorch.fp8_autocast) context manager. Note that fp8_autocast should only be used to wrap the forward pass and must exit before starting a backward pass. See the [FP8 tutorial](fp8_primer.ipynb) (currently only available in PyTorch) for a detailed explanation of FP8 recipes and the supported options.\n", |
| 723 | + "Enabling FP8 support is very simple in Transformer Engine. We just need to wrap the modules within an [fp8_autocast](.../api/jax.rst#transformer_engine.jax.fp8_autocast) context manager. Note that fp8_autocast should only be used to wrap the forward pass and must exit before starting a backward pass. See the [FP8 tutorial](fp8_primer.ipynb) (currently only available in PyTorch) for a detailed explanation of FP8 recipes and the supported options.\n", |
747 | 724 | "\n", |
748 | 725 | "<div class=\"alert alert-warning\">\n", |
749 | 726 | "\n", |
|
756 | 733 | }, |
757 | 734 | { |
758 | 735 | "cell_type": "code", |
759 | | - "execution_count": null, |
| 736 | + "execution_count": 22, |
760 | 737 | "id": "b2aaa8ef", |
761 | 738 | "metadata": {}, |
762 | 739 | "outputs": [ |
|
803 | 780 | }, |
804 | 781 | { |
805 | 782 | "cell_type": "code", |
806 | | - "execution_count": null, |
| 783 | + "execution_count": 23, |
807 | 784 | "id": "b9cdbf22", |
808 | 785 | "metadata": {}, |
809 | 786 | "outputs": [ |
810 | 787 | { |
811 | 788 | "name": "stdout", |
812 | 789 | "output_type": "stream", |
813 | 790 | "text": [ |
814 | | - "Mean time: 7.995896339416504 ms\n" |
| 791 | + "Mean time: 7.96757698059082 ms\n" |
815 | 792 | ] |
816 | 793 | } |
817 | 794 | ], |
|
834 | 811 | "display_name": "Python 3 (ipykernel)", |
835 | 812 | "language": "python", |
836 | 813 | "name": "python3" |
| 814 | + }, |
| 815 | + "language_info": { |
| 816 | + "codemirror_mode": { |
| 817 | + "name": "ipython", |
| 818 | + "version": 3 |
| 819 | + }, |
| 820 | + "file_extension": ".py", |
| 821 | + "mimetype": "text/x-python", |
| 822 | + "name": "python", |
| 823 | + "nbconvert_exporter": "python", |
| 824 | + "pygments_lexer": "ipython3", |
| 825 | + "version": "3.12.3" |
837 | 826 | } |
838 | 827 | }, |
839 | 828 | "nbformat": 4, |
|
0 commit comments