Skip to content

[DO NOT MERGE] Pathways v6e large scale runs #288

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/xpk/commands/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,16 @@
template:
spec:
terminationGracePeriodSeconds: {args.termination_grace_period_seconds}
initContainers: # Add this section
- name: network-init
image: {args.docker_image}
command: ["bash", "-c", "echo '4096 41943040 314572800' > /proc/sys/net/ipv4/tcp_rmem"]
securityContext:
privileged: true
containers:
- args:
{pathways_worker_args}
env:
image: {args.server_image}
imagePullPolicy: Always
name: pathways-worker
Expand Down Expand Up @@ -263,6 +270,10 @@
volumeMounts:
- mountPath: /tmp
name: shared-tmp
resources:
limits:
cpu: "30"
memory: 120G
nodeSelector:
cloud.google.com/gke-nodepool: cpu-rm-np
hostNetwork: true
Expand All @@ -287,11 +298,18 @@
containers:
- args:
{pathways_proxy_args}
env:
- name: XLA_FLAGS
value: "--xla_dump_to={args.pathways_gcs_location}/xla_dump"
image: {args.proxy_server_image}
imagePullPolicy: Always
name: pathways-proxy
ports:
- containerPort: 29000
resources:
limits:
cpu: "30"
memory: 120G
hostNetwork: true
dnsPolicy: ClusterFirstWithHostNet
nodeSelector:
Expand Down
4 changes: 2 additions & 2 deletions src/xpk/core/cluster_private.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def authorize_private_cluster_access_if_necessary(args) -> int:
if add_current_machine_to_networks_return_code != 0:
return add_current_machine_to_networks_return_code

if new_authorized_networks_needed or not is_current_machine_in_network:
return update_cluster_new_authorized_networks(args, authorized_networks)
# if new_authorized_networks_needed or not is_current_machine_in_network:
# return update_cluster_new_authorized_networks(args, authorized_networks)

xpk_print("Current machine's IP adrress is already authorized.")
return 0
Expand Down
9 changes: 7 additions & 2 deletions src/xpk/core/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2125,7 +2125,11 @@ def get_volume_mounts(args, system: SystemCharacteristics) -> str:

if args.use_pathways:
volume_mount_yaml = """- mountPath: /tmp
name: shared-tmp"""
name: shared-tmp
- name: gcs-fuse-csi-ephemeral
mountPath: /training-data
- name: dshm
mountPath: /dev/shm"""
elif (
system.accelerator_type == AcceleratorType['TPU']
and args.deploy_stacktrace_sidecar
Expand Down Expand Up @@ -2281,7 +2285,8 @@ def get_main_container_resources(
resources_yaml = """cpu: "24"
memory: 100G"""
if args.use_pathways:
return resources_yaml
return ""
# return resources_yaml

gpu_resources_yaml = """nvidia.com/gpu: {system.chips_per_vm}"""
if system.accelerator_type == AcceleratorType['GPU']:
Expand Down
57 changes: 50 additions & 7 deletions src/xpk/core/pathways.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,12 @@ def get_pathways_worker_args(args) -> str:
"""
yaml = """- --server_port=29001
- --resource_manager_address={rm_address}
- --gcs_scratch_location={args.pathways_gcs_location}"""
- --temporary_flags_for_debugging=temporary_flag_for_debugging_megascale_address_derive_from_megascale_grpc=true
- --megascale_grpc_premap_memory_bytes=17179869184
- --gcs_scratch_location={args.pathways_gcs_location}
- --megascale_graph_within_launch_hang_threshold=5m
- --deepsea_chip_config_name=megachip_tccontrol""" # More flags we can adjust here: https://source.corp.google.com/piper///depot/google3/platforms/xla/megascale/runtime/executor/executor.cc;l=53-81;rcl=705575633

if args.use_pathways:
return yaml.format(args=args, rm_address=get_rm_address(args))
else:
Expand All @@ -59,7 +64,22 @@ def get_pathways_proxy_args(args) -> str:
"""
yaml = """- --server_port=29000
- --resource_manager_address={rm_address}
- --gcs_scratch_location={args.pathways_gcs_location}"""
- --xla_tpu_scoped_vmem_limit_kib=98304
- --xla_tpu_enable_async_collective_fusion=true
- --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true
- --xla_tpu_enable_async_collective_fusion_multiple_steps=true
- --xla_tpu_overlap_compute_collective_tc=true
- --xla_enable_async_all_gather=true
- --xla_tpu_spmd_rng_bit_generator_unsafe=true
- --xla_tpu_enable_sunk_dcn_allreduce_done_with_host_reduction=true
- --gcs_scratch_location={args.pathways_gcs_location}
- --xla_sc_disable_megacore_partitioning=true
- --xla_tpu_enable_all_reduce_offload_tracing=true
- --xla_tpu_use_tc_device_shape_on_sc=true
- --xla_tpu_enable_sparse_core_collective_offload_all_reduce=true
- --xla_sc_enable_instruction_fusion=false
- --xla_sc_disjoint_spmem=false
- --deepsea_chip_config_name=megachip_tccontrol"""

if args.use_pathways:
return yaml.format(args=args, rm_address=get_rm_address(args))
Expand Down Expand Up @@ -105,9 +125,9 @@ def add_pw_resources_to_kueue(args):
- name: cpu-rm
resources:
- name: "cpu"
nominalQuota: 80
nominalQuota: 480
- name: "memory"
nominalQuota: 160G
nominalQuota: 2000G
- name: cpu-proxy
resources:
- name: "cpu"
Expand Down Expand Up @@ -201,7 +221,9 @@ def get_pathways_rm_args(args, system: SystemCharacteristics) -> str:
- --gcs_scratch_location={args.pathways_gcs_location}
- --node_type=resource_manager
- --instance_count={instance_count}
- --instance_type={instance_type}"""
- --temporary_flags_for_debugging=temporary_flag_for_debugging_worker_expected_tpu_chip_config=megachip_tccontrol;;;temporary_flag_for_debugging_megascale_address_derive_from_megascale_grpc=true
- --instance_type={instance_type}
- --deepsea_chip_config_name=megachip_tccontrol"""
if args.use_pathways:
return yaml.format(
args=args,
Expand Down Expand Up @@ -237,17 +259,38 @@ def get_user_workload_for_pathways(args, system: SystemCharacteristics) -> str:
completions: 1
parallelism: 1
template:
metadata:
annotations:
gke-gcsfuse/volumes: "true"
gke-gcsfuse/cpu-limit: "0"
gke-gcsfuse/memory-limit: "0"
gke-gcsfuse/ephemeral-storage-limit: "0"
spec:
containers:
{container}
nodeSelector:
cloud.google.com/gke-nodepool: cpu-user-np
cloud.google.com/gke-nodepool: high-mem-pool
hostNetwork: true
dnsPolicy: ClusterFirstWithHostNet
restartPolicy: OnFailure
volumes:
- hostPath:
path: /tmp
type: DirectoryOrCreate
name: shared-tmp"""
name: shared-tmp
- name: gke-gcsfuse-cache
emptyDir:
medium: Memory
- name: dshm
emptyDir:
medium: Memory
- name: gcs-fuse-csi-ephemeral
csi:
driver: gcsfuse.csi.storage.gke.io
volumeAttributes:
bucketName: trillium-storage-datasets-sr
mountOptions: "debug_fuse,implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:-1"
"""
if args.headless:
return ''
else:
Expand Down
Loading