Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion config/default_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ masking_strategy_config: {"strategies": ["random", "healpix", "channel"],
"same_strategy_per_batch": false
}

num_epochs: 32
num_epochs: 128
samples_per_epoch: 4096
samples_per_validation: 512
shuffle: True
Expand Down
3 changes: 2 additions & 1 deletion config/streams/era5_1deg/era5.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

ERA5 :
type : anemoi
filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2022-6h-v6.zarr']
filenames : ['era5-o96-1979-2023-6h-v8.zarr']
#filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2022-6h-v6.zarr']
source_exclude : ['w_', 'skt', 'tcw', 'cp', 'tp']
target_exclude : ['w_', 'slor', 'sdor', 'tcw', 'cp', 'tp']
loss_weight : 1.
Expand Down
49 changes: 33 additions & 16 deletions src/weathergen/train/trainer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from weathergen.common.config import Config
from weathergen.train.utils import str_to_tensor, tensor_to_str
from weathergen.utils.distributed import is_root
from weathergen.utils.distributed import is_root, get_rank_from_env, get_size_from_env

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -54,15 +54,20 @@ def init_torch(use_cuda=True, num_accs_per_task=1, multiprocessing_method="fork"
if not use_cuda:
return torch.device("cpu")

local_id_node = os.environ.get("SLURM_LOCALID", "-1")
if local_id_node == "-1":
#local_id_node = os.environ.get("SLURM_LOCALID", "-1")
rank = get_rank_from_env(default="-1")
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "")
num_gpus = len(cuda_visible_devices.split(',')) if cuda_visible_devices else 0
rank_local = rank % num_gpus

if rank == "-1":
devices = ["cuda"]
else:
devices = [
f"cuda:{int(local_id_node) * num_accs_per_task + i}"
f"cuda:{int(rank_local) * num_accs_per_task + i}"
for i in range(num_accs_per_task)
]
torch.cuda.set_device(int(local_id_node) * num_accs_per_task)
torch.cuda.set_device(int(rank_local) * num_accs_per_task)

return devices

Expand All @@ -82,31 +87,39 @@ def init_ddp(cf):
_logger.info(f"rank: {rank} has run_id: {cf.run_id}")
return

local_rank = int(os.environ.get("SLURM_LOCALID"))
ranks_per_node = int(os.environ.get("SLURM_TASKS_PER_NODE", "1")[0])
rank = int(os.environ.get("SLURM_NODEID")) * ranks_per_node + local_rank
num_ranks = int(os.environ.get("SLURM_NTASKS"))
# local_rank = int(os.environ.get("SLURM_LOCALID"))
# ranks_per_node = int(os.environ.get("SLURM_TASKS_PER_NODE", "1")[0])
# rank = int(os.environ.get("SLURM_NODEID")) * ranks_per_node + local_rank
# num_ranks = int(os.environ.get("SLURM_NTASKS"))
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "")
num_gpus = len(cuda_visible_devices.split(',')) if cuda_visible_devices else 0

rank = get_rank_from_env(default=0)
num_ranks = get_size_from_env(default=1)
rank_local = rank % num_gpus

_logger.info(
f"DDP initialization: local_rank={local_rank}, ranks_per_node={ranks_per_node}, "
f"DDP initialization: local_rank={rank_local}, ranks_per_node={num_gpus}, "
f"rank={rank}, num_ranks={num_ranks}"
)

master_port=os.getenv("MASTER_PORT")
if rank == 0:
# Check that port 1345 is available, raise an error if not
# Check that port master_port is available, raise an error if not
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
try:
s.bind((master_node, 1345))
s.bind((master_node, master_port))
except OSError as e:
if e.errno == errno.EADDRINUSE:
_logger.error(
(
f"Port 1345 is already in use on {master_node}.",
f"Port {master_port} is already in use on {master_node}.",
" Please check your network configuration.",
)
)
raise
else:
_logger.error(f"Error while binding to port 1345 on {master_node}: {e}")
_logger.error(f"Error while binding to port {master_port} on {master_node}: {e}")
raise

_logger.info(
Expand All @@ -115,17 +128,21 @@ def init_ddp(cf):

dist.init_process_group(
backend="nccl",
init_method="tcp://" + master_node + ":1345",
init_method="tcp://" + master_node + f":{master_port}",
timeout=datetime.timedelta(seconds=240),
world_size=num_ranks,
rank=rank,
device_id=torch.device("cuda", local_rank),
#device_id=torch.device("cuda", rank_local),
)
if is_root():
_logger.info("DDP initialized: root.")
# Wait for all ranks to reach this point
dist.barrier()

_logger.info(
f"TORCH DISTRIBUTED INFO: rank:{dist.get_rank()} world_size:{dist.get_world_size()}"
)

# communicate run id to all nodes
len_run_id = len(cf.run_id)
run_id_int = torch.zeros(len_run_id, dtype=torch.int32).cuda()
Expand Down
52 changes: 52 additions & 0 deletions src/weathergen/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import torch
import torch.distributed as dist
import os

SYNC_TIMEOUT_SEC = 60 * 60 # 1 hour

Expand Down Expand Up @@ -112,3 +113,54 @@ def all_gather_vdim(tensor: torch.Tensor, group=None) -> list[torch.Tensor]:
dist.all_to_all(outputs, inputs, group=group)

return outputs

# --------------------------------------------------------------

def get_rank_from_env(default=0):
"""
Attempts to determine the rank (i.e., the process ID in a parallel job)
by checking common environment variables used by various HPC schedulers.

Parameters:
default (int): Default rank value to return if no environment variable is found.

Returns:
int: The detected rank or the default value.
"""
var_list = [
"SLURM_PROCID", # SLURM
"PMI_RANK", # Intel MPI
"OMPI_COMM_WORLD_RANK", # Open MPI
#"MP_CHILD"
#"RANK",
]
for var in var_list:
value = os.getenv(var)
if value is not None:
return int(value)
return int(default)

def get_size_from_env(default=1):
"""
Attempts to determine the total number of processes (world size)
by checking common environment variables used by various HPC schedulers.

Parameters:
default (int): Default size value to return if no environment variable is found.

Returns:
int: The detected world size or the default value.
"""
var_list = [
"SLURM_NTASKS", # SLURM
"PMI_SIZE", # Intel MPI
"OMPI_COMM_WORLD_SIZE", # Open MPI
#"MP_PROCS"
#"SIZE",
#"WORLD_SIZE",
]
for var in var_list:
value = os.getenv(var)
if value is not None:
return int(value)
return int(default)