Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions jasmine/train_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from dataclasses import dataclass, field
import itertools
from typing import cast, Optional
from typing import cast, Optional, Literal

import einops
from jax.sharding import Mesh, PartitionSpec, NamedSharding
Expand All @@ -31,6 +31,10 @@
print_compiled_cost_analysis,
)

DTYPE_MAP = {
"float32": jnp.float32,
"bfloat16": jnp.bfloat16,
}

@dataclass
class Args:
Expand All @@ -53,7 +57,7 @@ class Args:
20_000 # NOTE: wsd_decay_steps will only be used when using a wsd-schedule
)
warmup_steps: int = 5000
lr_schedule: str = "wsd" # supported options: wsd, cos
lr_schedule: Literal["wsd", "cos"] = "wsd"
# Tokenizer
tokenizer_dim: int = 512
tokenizer_ffn_dim: int = 2048
Expand Down Expand Up @@ -81,8 +85,8 @@ class Args:
dropout: float = 0.0
mask_limit: float = 0.5
z_loss_weight: float = 0.0
param_dtype = jnp.float32
dtype = jnp.bfloat16
param_dtype: Literal["float32", "bfloat16"] = "float32"
dtype: Literal["float32", "bfloat16"] = "bfloat16"
use_flash_attention: bool = True
use_gt_actions: bool = False
# Logging
Expand Down Expand Up @@ -443,7 +447,7 @@ def main(args: Args) -> None:
)
wandb.init(**wandb_init_kwargs)

wandb.config.update({"model_param_count": param_counts})
wandb.config.update({"model_param_count": param_counts, "param_dtype": args.param_dtype, "dtype": args.dtype})

print("Parameter counts:")
print(param_counts)
Expand Down Expand Up @@ -818,4 +822,12 @@ def calculate_validation_metrics(val_dataloader, genie, rng):

if __name__ == "__main__":
args = tyro.cli(Args)

args.param_dtype = DTYPE_MAP[args.param_dtype]
args.dtype = DTYPE_MAP[args.dtype]

if args.dtype == jnp.float32:
args.use_flash_attention = False
print("Using float32, disabling flash attention")

main(args)
22 changes: 17 additions & 5 deletions jasmine/train_lam.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from dataclasses import dataclass, field
import itertools
from typing import cast, Optional
from typing import cast, Optional, Literal

import einops
from jax.sharding import Mesh, PartitionSpec, NamedSharding
Expand All @@ -30,6 +30,10 @@
print_compiled_cost_analysis,
)

DTYPE_MAP = {
"float32": jnp.float32,
"bfloat16": jnp.bfloat16,
}

@dataclass
class Args:
Expand All @@ -53,7 +57,7 @@ class Args:
20_000 # NOTE: wsd_decay_steps will only be used when using a wsd-schedule
)
warmup_steps: int = 5000
lr_schedule: str = "wsd" # supported options: wsd, cos
lr_schedule: Literal["wsd", "cos"] = "wsd"
vq_reset_thresh: int = 50
# LAM
model_dim: int = 512
Expand All @@ -65,8 +69,8 @@ class Args:
num_heads: int = 8
dropout: float = 0.0
codebook_dropout: float = 0.0
param_dtype = jnp.float32
dtype = jnp.bfloat16
param_dtype: Literal["float32", "bfloat16"] = "float32"
dtype: Literal["float32", "bfloat16"] = "bfloat16"
use_flash_attention: bool = True
# Logging
log: bool = True
Expand Down Expand Up @@ -317,7 +321,7 @@ def main(args: Args) -> None:
)
wandb.init(**wandb_init_kwargs)

wandb.config.update({"model_param_count": param_counts})
wandb.config.update({"model_param_count": param_counts, "param_dtype": args.param_dtype, "dtype": args.dtype})

print("Parameter counts:")
print(param_counts)
Expand Down Expand Up @@ -595,4 +599,12 @@ def calculate_validation_metrics(val_dataloader, lam):

if __name__ == "__main__":
args = tyro.cli(Args)

args.param_dtype = DTYPE_MAP[args.param_dtype]
args.dtype = DTYPE_MAP[args.dtype]

if args.dtype == jnp.float32:
args.use_flash_attention = False
print("Using float32, disabling flash attention")

main(args)
22 changes: 17 additions & 5 deletions jasmine/train_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
os.environ.setdefault("XLA_PYTHON_CLIENT_MEM_FRACTION", "0.98")

from dataclasses import dataclass, field
from typing import cast, Optional
from typing import cast, Optional, Literal

import einops
import itertools
Expand All @@ -30,6 +30,10 @@
print_compiled_cost_analysis,
)

DTYPE_MAP = {
"float32": jnp.float32,
"bfloat16": jnp.bfloat16,
}

@dataclass
class Args:
Expand All @@ -52,7 +56,7 @@ class Args:
wsd_decay_steps: int = (
30_000 # NOTE: wsd_decay_steps will only be used when using a wsd-schedule
)
lr_schedule: str = "wsd" # supported options: wsd, cos
lr_schedule: Literal["wsd", "cos"] = "wsd"
warmup_steps: int = 10000
# Tokenizer
model_dim: int = 512
Expand All @@ -64,8 +68,8 @@ class Args:
num_heads: int = 8
dropout: float = 0.0
codebook_dropout: float = 0.01
param_dtype = jnp.float32
dtype = jnp.bfloat16
param_dtype: Literal["float32", "bfloat16"] = "float32"
dtype: Literal["float32", "bfloat16"] = "bfloat16"
use_flash_attention: bool = True
# Logging
log: bool = True
Expand Down Expand Up @@ -308,7 +312,7 @@ def main(args: Args) -> None:
)
wandb.init(**wandb_init_kwargs)

wandb.config.update({"model_param_count": param_counts})
wandb.config.update({"model_param_count": param_counts, "param_dtype": args.param_dtype, "dtype": args.dtype})

print("Parameter counts:")
print(param_counts)
Expand Down Expand Up @@ -571,4 +575,12 @@ def calculate_validation_metrics(val_dataloader, tokenizer):

if __name__ == "__main__":
args = tyro.cli(Args)

args.param_dtype = DTYPE_MAP[args.param_dtype]
args.dtype = DTYPE_MAP[args.dtype]

if args.dtype == jnp.float32:
args.use_flash_attention = False
print("Using float32, disabling flash attention")

main(args)
Loading