Skip to content

Commit

Permalink
fix tests, add some core stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
codekansas committed Jan 28, 2024
1 parent bec3a19 commit 9d483f8
Show file tree
Hide file tree
Showing 9 changed files with 633 additions and 1 deletion.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ warn_unused_ignores = true
warn_redundant_casts = true

incremental = true
namespace_packages = false
explicit_package_bases = true

[[tool.mypy.overrides]]

Expand Down
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,6 @@ packages = find:
[options.packages.find]

exclude =
.vscode
.github
tests
6 changes: 6 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,10 @@
install_requires=requirements,
tests_require=requirements_dev,
extras_require={"dev": requirements_dev},
package_data={
"mlfab": [
"py.typed",
"requirements*.txt",
],
},
)
Empty file added xax/core/__init__.py
Empty file.
206 changes: 206 additions & 0 deletions xax/core/conf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
"""Defines base configuration functions and utilities."""

import functools
import os
from dataclasses import dataclass, field as field_base
from pathlib import Path
from typing import Any, cast

import jax.numpy as jnp
from omegaconf import II, MISSING, Container as OmegaConfContainer, OmegaConf

from xax.utils.text import show_error

FieldType = Any


def field(value: FieldType, **kwargs: str) -> FieldType:
"""Short-hand function for getting a config field.
Args:
value: The current field's default value.
kwargs: Additional metadata fields to supply.
Returns:
The dataclass field.
"""
metadata: dict[str, Any] = {}
metadata.update(kwargs)

if hasattr(value, "__call__"):
return field_base(default_factory=value, metadata=metadata)
if value.__class__.__hash__ is None:
return field_base(default_factory=lambda: value, metadata=metadata)
return field_base(default=value, metadata=metadata)


def is_missing(cfg: Any, key: str) -> bool: # noqa: ANN401
"""Utility function for checking if a config key is missing.
This is for cases when you are using a raw dataclass rather than an
OmegaConf container but want to treat them the same way.
Args:
cfg: The config to check
key: The key to check
Returns:
Whether or not the key is missing a value in the config
"""
if isinstance(cfg, OmegaConfContainer):
if OmegaConf.is_missing(cfg, key):
return True
if OmegaConf.is_interpolation(cfg, key):
try:
getattr(cfg, key)
return False
except Exception:
return True
if getattr(cfg, key) is MISSING:
return True
return False


@dataclass
class ErrorHandling:
enabled: bool = field(True, help="Is error handling enabled?")
maximum_exceptions: int = field(10, help="Maximum number of errors to encounter")
backoff_after: int = field(5, help="Start to do a sleeping backoff after this many exceptions")
sleep_backoff: float = field(0.1, help="Sleep backoff amount")
sleep_backoff_power: float = field(2.0, help="How much to multiply backoff for each successive exception")
log_full_exception: bool = field(False, help="Log the full exception message for each exception")
flush_exception_summary_every: int = field(500, help="How often to flush exception summary")
report_top_n_exception_types: int = field(5, help="Number of exceptions to summarize")
exception_location_traceback_depth: int = field(3, help="Traceback length for the exception location")


@dataclass
class Logging:
hide_third_party_logs: bool = field(True, help="If set, hide third-party logs")
log_level: str = field("INFO", help="The logging level to use")


@dataclass
class Device:
cpu: bool = field(True, help="Whether to use the CPU")
gpu: bool = field(II("oc.env:USE_GPU,1"), help="Whether to use the GPU")
metal: bool = field(II("oc.env:USE_METAL,1"), help="Whether to use the Apple Silicon accelerator")
use_fp64: bool = field(False, help="Always use the 64-bit floating point type")
use_fp32: bool = field(False, help="Always use the 32-bit floating point type")
use_bf16: bool = field(False, help="Always use the 16-bit bfloat type")
use_fp16: bool = field(False, help="Always use the 16-bit floating point type")


def parse_dtype(cfg: Device) -> jnp.dtype | None:
if cfg.use_fp64:
return jnp.float64
if cfg.use_fp32:
return jnp.float32
if cfg.use_bf16:
return jnp.bfloat16
if cfg.use_fp16:
return jnp.float16
return None


@dataclass
class Triton:
use_triton_if_available: bool = field(True, help="Use Triton if available")


@dataclass
class Experiment:
default_random_seed: int = field(1337, help="The default random seed to use")


@dataclass
class Directories:
run: str = field(II("oc.env:RUN_DIR"), help="The run directory")
data: str = field(II("oc.env:DATA_DIR"), help="The data directory")
pretrained_models: str = field(II("oc.env:MODEL_DIR"), help="The models directory")


@dataclass
class SlurmPartition:
partition: str = field(MISSING, help="The partition name")
num_nodes: int = field(1, help="The number of nodes to use")


@dataclass
class Slurm:
launch: dict[str, SlurmPartition] = field({}, help="The available launch configurations")


@dataclass
class UserConfig:
error_handling: ErrorHandling = field(ErrorHandling)
logging: Logging = field(Logging)
device: Device = field(Device)
triton: Triton = field(Triton)
experiment: Experiment = field(Experiment)
directories: Directories = field(Directories)
slurm: Slurm = field(Slurm)


def user_config_path() -> Path:
xaxrc_path_raw = os.environ.get("XAXRC_PATH", "~/.xax.yml")
xaxrc_path = Path(xaxrc_path_raw).expanduser()
return xaxrc_path


@functools.lru_cache(maxsize=None)
def _load_user_config_cached() -> UserConfig:
xaxrc_path = user_config_path()
base_cfg = OmegaConf.structured(UserConfig)

# Writes the config file.
if xaxrc_path.exists():
cfg = OmegaConf.merge(base_cfg, OmegaConf.load(xaxrc_path))
else:
show_error(f"No config file was found in {xaxrc_path}; writing one...", important=True)
OmegaConf.save(base_cfg, xaxrc_path)
cfg = base_cfg

# Looks in the current directory for a config file.
local_cfg_path = Path("xax.yml")
if local_cfg_path.exists():
cfg = OmegaConf.merge(cfg, OmegaConf.load(local_cfg_path))

return cast(UserConfig, cfg)


def load_user_config() -> UserConfig:
"""Loads the ``~/.xax.yml`` configuration file.
Returns:
The loaded configuration.
"""
return _load_user_config_cached()


def get_run_dir() -> Path | None:
config = load_user_config().directories
if is_missing(config, "run"):
return None
(run_dir := Path(config.run)).mkdir(parents=True, exist_ok=True)
return run_dir


def get_data_dir() -> Path:
config = load_user_config().directories
if is_missing(config, "data"):
raise RuntimeError(
"The data directory has not been set! You should set it in your config file "
f"in {user_config_path()} or set the DATA_DIR environment variable."
)
return Path(config.data)


def get_pretrained_models_dir() -> Path:
config = load_user_config().directories
if is_missing(config, "pretrained_models"):
raise RuntimeError(
"The data directory has not been set! You should set it in your config file "
f"in {user_config_path()} or set the MODEL_DIR environment variable."
)
return Path(config.pretrained_models)
67 changes: 67 additions & 0 deletions xax/core/state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""Defines a dataclass for keeping track of the current training state."""

import time
from dataclasses import dataclass
from typing import Literal, cast, get_args

from omegaconf import MISSING

from xax.core.conf import field

Phase = Literal["train", "valid", "test"]


def cast_phase(raw_phase: str) -> Phase:
args = get_args(Phase)
assert raw_phase in args, f"Invalid phase: '{raw_phase}' Valid options are {args}"
return cast(Phase, raw_phase)


@dataclass
class State:
num_epochs: int = field(MISSING, help="Number of epochs so far")
num_steps: int = field(MISSING, help="Number of steps so far")
num_epoch_steps: int = field(MISSING, help="Number of steps in the current epoch")
num_samples: int = field(MISSING, help="Number of sample so far")
num_epoch_samples: int = field(MISSING, help="Number of samples in the current epoch")
num_valid_steps: int = field(MISSING, help="Number of validation steps so far")
num_test_steps: int = field(MISSING, help="Number of test steps so far")
start_time_s: float = field(MISSING, help="Start time of training")
elapsed_time_s: float = field(MISSING, help="Total elapsed time so far")
raw_phase: str = field(MISSING, help="Current training phase")

@property
def phase(self) -> Phase:
return cast_phase(self.raw_phase)

@phase.setter
def phase(self, new_phase: Phase) -> None:
self.raw_phase = new_phase

@classmethod
def init_state(cls) -> "State":
return cls(
num_epochs=0,
num_steps=0,
num_epoch_steps=0,
num_samples=0,
num_epoch_samples=0,
num_valid_steps=0,
num_test_steps=0,
start_time_s=time.time(),
elapsed_time_s=0.0,
raw_phase="train",
)

@property
def training(self) -> bool:
return self.phase == "train"

def num_phase_steps(self, phase: Phase) -> int:
match phase:
case "train":
return self.num_steps
case "valid":
return self.num_valid_steps
case "test":
return self.num_test_steps
1 change: 1 addition & 0 deletions xax/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@

jax
jaxtyping
omegaconf
Empty file added xax/utils/__init__.py
Empty file.
Loading

0 comments on commit 9d483f8

Please sign in to comment.