From 9d483f82b1ec51dfe51d9e12b6c60bc2bfd0ab14 Mon Sep 17 00:00:00 2001 From: Benjamin Bolte Date: Sun, 28 Jan 2024 00:32:35 -0800 Subject: [PATCH] fix tests, add some core stuff --- pyproject.toml | 2 +- setup.cfg | 2 + setup.py | 6 + xax/core/__init__.py | 0 xax/core/conf.py | 206 +++++++++++++++++++++++++ xax/core/state.py | 67 ++++++++ xax/requirements.txt | 1 + xax/utils/__init__.py | 0 xax/utils/text.py | 350 ++++++++++++++++++++++++++++++++++++++++++ 9 files changed, 633 insertions(+), 1 deletion(-) create mode 100644 xax/core/__init__.py create mode 100644 xax/core/conf.py create mode 100644 xax/core/state.py create mode 100644 xax/utils/__init__.py create mode 100644 xax/utils/text.py diff --git a/pyproject.toml b/pyproject.toml index 3aee4b2..65db9ea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ warn_unused_ignores = true warn_redundant_casts = true incremental = true -namespace_packages = false +explicit_package_bases = true [[tool.mypy.overrides]] diff --git a/setup.cfg b/setup.cfg index 05d396d..b66faab 100644 --- a/setup.cfg +++ b/setup.cfg @@ -5,4 +5,6 @@ packages = find: [options.packages.find] exclude = + .vscode + .github tests diff --git a/setup.py b/setup.py index b0ac591..11126da 100644 --- a/setup.py +++ b/setup.py @@ -35,4 +35,10 @@ install_requires=requirements, tests_require=requirements_dev, extras_require={"dev": requirements_dev}, + package_data={ + "mlfab": [ + "py.typed", + "requirements*.txt", + ], + }, ) diff --git a/xax/core/__init__.py b/xax/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/xax/core/conf.py b/xax/core/conf.py new file mode 100644 index 0000000..c0d9946 --- /dev/null +++ b/xax/core/conf.py @@ -0,0 +1,206 @@ +"""Defines base configuration functions and utilities.""" + +import functools +import os +from dataclasses import dataclass, field as field_base +from pathlib import Path +from typing import Any, cast + +import jax.numpy as jnp +from omegaconf import II, MISSING, Container as OmegaConfContainer, OmegaConf + +from xax.utils.text import show_error + +FieldType = Any + + +def field(value: FieldType, **kwargs: str) -> FieldType: + """Short-hand function for getting a config field. + + Args: + value: The current field's default value. + kwargs: Additional metadata fields to supply. + + Returns: + The dataclass field. + """ + metadata: dict[str, Any] = {} + metadata.update(kwargs) + + if hasattr(value, "__call__"): + return field_base(default_factory=value, metadata=metadata) + if value.__class__.__hash__ is None: + return field_base(default_factory=lambda: value, metadata=metadata) + return field_base(default=value, metadata=metadata) + + +def is_missing(cfg: Any, key: str) -> bool: # noqa: ANN401 + """Utility function for checking if a config key is missing. + + This is for cases when you are using a raw dataclass rather than an + OmegaConf container but want to treat them the same way. + + Args: + cfg: The config to check + key: The key to check + + Returns: + Whether or not the key is missing a value in the config + """ + if isinstance(cfg, OmegaConfContainer): + if OmegaConf.is_missing(cfg, key): + return True + if OmegaConf.is_interpolation(cfg, key): + try: + getattr(cfg, key) + return False + except Exception: + return True + if getattr(cfg, key) is MISSING: + return True + return False + + +@dataclass +class ErrorHandling: + enabled: bool = field(True, help="Is error handling enabled?") + maximum_exceptions: int = field(10, help="Maximum number of errors to encounter") + backoff_after: int = field(5, help="Start to do a sleeping backoff after this many exceptions") + sleep_backoff: float = field(0.1, help="Sleep backoff amount") + sleep_backoff_power: float = field(2.0, help="How much to multiply backoff for each successive exception") + log_full_exception: bool = field(False, help="Log the full exception message for each exception") + flush_exception_summary_every: int = field(500, help="How often to flush exception summary") + report_top_n_exception_types: int = field(5, help="Number of exceptions to summarize") + exception_location_traceback_depth: int = field(3, help="Traceback length for the exception location") + + +@dataclass +class Logging: + hide_third_party_logs: bool = field(True, help="If set, hide third-party logs") + log_level: str = field("INFO", help="The logging level to use") + + +@dataclass +class Device: + cpu: bool = field(True, help="Whether to use the CPU") + gpu: bool = field(II("oc.env:USE_GPU,1"), help="Whether to use the GPU") + metal: bool = field(II("oc.env:USE_METAL,1"), help="Whether to use the Apple Silicon accelerator") + use_fp64: bool = field(False, help="Always use the 64-bit floating point type") + use_fp32: bool = field(False, help="Always use the 32-bit floating point type") + use_bf16: bool = field(False, help="Always use the 16-bit bfloat type") + use_fp16: bool = field(False, help="Always use the 16-bit floating point type") + + +def parse_dtype(cfg: Device) -> jnp.dtype | None: + if cfg.use_fp64: + return jnp.float64 + if cfg.use_fp32: + return jnp.float32 + if cfg.use_bf16: + return jnp.bfloat16 + if cfg.use_fp16: + return jnp.float16 + return None + + +@dataclass +class Triton: + use_triton_if_available: bool = field(True, help="Use Triton if available") + + +@dataclass +class Experiment: + default_random_seed: int = field(1337, help="The default random seed to use") + + +@dataclass +class Directories: + run: str = field(II("oc.env:RUN_DIR"), help="The run directory") + data: str = field(II("oc.env:DATA_DIR"), help="The data directory") + pretrained_models: str = field(II("oc.env:MODEL_DIR"), help="The models directory") + + +@dataclass +class SlurmPartition: + partition: str = field(MISSING, help="The partition name") + num_nodes: int = field(1, help="The number of nodes to use") + + +@dataclass +class Slurm: + launch: dict[str, SlurmPartition] = field({}, help="The available launch configurations") + + +@dataclass +class UserConfig: + error_handling: ErrorHandling = field(ErrorHandling) + logging: Logging = field(Logging) + device: Device = field(Device) + triton: Triton = field(Triton) + experiment: Experiment = field(Experiment) + directories: Directories = field(Directories) + slurm: Slurm = field(Slurm) + + +def user_config_path() -> Path: + xaxrc_path_raw = os.environ.get("XAXRC_PATH", "~/.xax.yml") + xaxrc_path = Path(xaxrc_path_raw).expanduser() + return xaxrc_path + + +@functools.lru_cache(maxsize=None) +def _load_user_config_cached() -> UserConfig: + xaxrc_path = user_config_path() + base_cfg = OmegaConf.structured(UserConfig) + + # Writes the config file. + if xaxrc_path.exists(): + cfg = OmegaConf.merge(base_cfg, OmegaConf.load(xaxrc_path)) + else: + show_error(f"No config file was found in {xaxrc_path}; writing one...", important=True) + OmegaConf.save(base_cfg, xaxrc_path) + cfg = base_cfg + + # Looks in the current directory for a config file. + local_cfg_path = Path("xax.yml") + if local_cfg_path.exists(): + cfg = OmegaConf.merge(cfg, OmegaConf.load(local_cfg_path)) + + return cast(UserConfig, cfg) + + +def load_user_config() -> UserConfig: + """Loads the ``~/.xax.yml`` configuration file. + + Returns: + The loaded configuration. + """ + return _load_user_config_cached() + + +def get_run_dir() -> Path | None: + config = load_user_config().directories + if is_missing(config, "run"): + return None + (run_dir := Path(config.run)).mkdir(parents=True, exist_ok=True) + return run_dir + + +def get_data_dir() -> Path: + config = load_user_config().directories + if is_missing(config, "data"): + raise RuntimeError( + "The data directory has not been set! You should set it in your config file " + f"in {user_config_path()} or set the DATA_DIR environment variable." + ) + return Path(config.data) + + +def get_pretrained_models_dir() -> Path: + config = load_user_config().directories + if is_missing(config, "pretrained_models"): + raise RuntimeError( + "The data directory has not been set! You should set it in your config file " + f"in {user_config_path()} or set the MODEL_DIR environment variable." + ) + return Path(config.pretrained_models) diff --git a/xax/core/state.py b/xax/core/state.py new file mode 100644 index 0000000..457a53f --- /dev/null +++ b/xax/core/state.py @@ -0,0 +1,67 @@ +"""Defines a dataclass for keeping track of the current training state.""" + +import time +from dataclasses import dataclass +from typing import Literal, cast, get_args + +from omegaconf import MISSING + +from xax.core.conf import field + +Phase = Literal["train", "valid", "test"] + + +def cast_phase(raw_phase: str) -> Phase: + args = get_args(Phase) + assert raw_phase in args, f"Invalid phase: '{raw_phase}' Valid options are {args}" + return cast(Phase, raw_phase) + + +@dataclass +class State: + num_epochs: int = field(MISSING, help="Number of epochs so far") + num_steps: int = field(MISSING, help="Number of steps so far") + num_epoch_steps: int = field(MISSING, help="Number of steps in the current epoch") + num_samples: int = field(MISSING, help="Number of sample so far") + num_epoch_samples: int = field(MISSING, help="Number of samples in the current epoch") + num_valid_steps: int = field(MISSING, help="Number of validation steps so far") + num_test_steps: int = field(MISSING, help="Number of test steps so far") + start_time_s: float = field(MISSING, help="Start time of training") + elapsed_time_s: float = field(MISSING, help="Total elapsed time so far") + raw_phase: str = field(MISSING, help="Current training phase") + + @property + def phase(self) -> Phase: + return cast_phase(self.raw_phase) + + @phase.setter + def phase(self, new_phase: Phase) -> None: + self.raw_phase = new_phase + + @classmethod + def init_state(cls) -> "State": + return cls( + num_epochs=0, + num_steps=0, + num_epoch_steps=0, + num_samples=0, + num_epoch_samples=0, + num_valid_steps=0, + num_test_steps=0, + start_time_s=time.time(), + elapsed_time_s=0.0, + raw_phase="train", + ) + + @property + def training(self) -> bool: + return self.phase == "train" + + def num_phase_steps(self, phase: Phase) -> int: + match phase: + case "train": + return self.num_steps + case "valid": + return self.num_valid_steps + case "test": + return self.num_test_steps diff --git a/xax/requirements.txt b/xax/requirements.txt index 58d75d8..c365670 100644 --- a/xax/requirements.txt +++ b/xax/requirements.txt @@ -2,3 +2,4 @@ jax jaxtyping +omegaconf diff --git a/xax/utils/__init__.py b/xax/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/xax/utils/text.py b/xax/utils/text.py new file mode 100644 index 0000000..9d81bf3 --- /dev/null +++ b/xax/utils/text.py @@ -0,0 +1,350 @@ +"""Defines helper functions for displaying text in the terminal.""" + +import datetime +import itertools +import re +import sys +from typing import Literal + +RESET_SEQ = "\033[0m" +REG_COLOR_SEQ = "\033[%dm" +BOLD_COLOR_SEQ = "\033[1;%dm" +BOLD_SEQ = "\033[1m" + +Color = Literal[ + "black", + "red", + "green", + "yellow", + "blue", + "magenta", + "cyan", + "white", + "grey", + "light-red", + "light-green", + "light-yellow", + "light-blue", + "light-magenta", + "light-cyan", +] + +COLOR_INDEX: dict[Color, int] = { + "black": 30, + "red": 31, + "green": 32, + "yellow": 33, + "blue": 34, + "magenta": 35, + "cyan": 36, + "white": 37, + "grey": 90, + "light-red": 91, + "light-green": 92, + "light-yellow": 93, + "light-blue": 94, + "light-magenta": 95, + "light-cyan": 96, +} + + +def color_parts(color: Color, bold: bool = False) -> tuple[str, str]: + if bold: + return BOLD_COLOR_SEQ % COLOR_INDEX[color], RESET_SEQ + return REG_COLOR_SEQ % COLOR_INDEX[color], RESET_SEQ + + +def uncolored(s: str) -> str: + return re.sub(r"\033\[[\d;]+m", "", s) + + +def colored(s: str, color: Color | None = None, bold: bool = False) -> str: + if color is None: + return s + start, end = color_parts(color, bold=bold) + return start + s + end + + +def wrapped( + s: str, + length: int | None = None, + space: str = " ", + spaces: str | re.Pattern = r" ", + newlines: str | re.Pattern = r"[\n\r]", + too_long_suffix: str = "...", +) -> list[str]: + strings = [] + lines = re.split(newlines, s.strip(), flags=re.MULTILINE | re.UNICODE) + for line in lines: + cur_string = [] + cur_length = 0 + for part in re.split(spaces, line.strip(), flags=re.MULTILINE | re.UNICODE): + if length is None: + cur_string.append(part) + cur_length += len(space) + len(part) + else: + if len(part) > length: + part = part[: length - len(too_long_suffix)] + too_long_suffix + if cur_length + len(part) > length: + strings.append(space.join(cur_string)) + cur_string = [part] + cur_length = len(part) + else: + cur_string.append(part) + cur_length += len(space) + len(part) + if cur_length > 0: + strings.append(space.join(cur_string)) + return strings + + +def outlined( + s: str, + inner: Color | None = None, + side: Color | None = None, + bold: bool = False, + max_length: int | None = None, + space: str = " ", + spaces: str | re.Pattern = r" ", + newlines: str | re.Pattern = r"[\n\r]", +) -> str: + strs = wrapped(uncolored(s), max_length, space, spaces, newlines) + max_len = max(len(s) for s in strs) + strs = [f"{s}{' ' * (max_len - len(s))}" for s in strs] + strs = [colored(s, inner, bold=bold) for s in strs] + strs_with_sides = [f"{colored('│', side)} {s} {colored('│', side)}" for s in strs] + top = colored("┌─" + "─" * max_len + "─┐", side) + bottom = colored("└─" + "─" * max_len + "─┘", side) + return "\n".join([top] + strs_with_sides + [bottom]) + + +def show_info(s: str, important: bool = False) -> None: + if important: + s = outlined(s, inner="light-cyan", side="cyan", bold=True) + else: + s = colored(s, "light-cyan", bold=False) + sys.stdout.write(s) + sys.stdout.write("\n") + sys.stdout.flush() + + +def show_error(s: str, important: bool = False) -> None: + if important: + s = outlined(s, inner="light-red", side="red", bold=True) + else: + s = colored(s, "light-red", bold=False) + sys.stdout.write(s) + sys.stdout.write("\n") + sys.stdout.flush() + + +def show_warning(s: str, important: bool = False) -> None: + if important: + s = outlined(s, inner="light-yellow", side="yellow", bold=True) + else: + s = colored(s, "light-yellow", bold=False) + sys.stdout.write(s) + sys.stdout.write("\n") + sys.stdout.flush() + + +class TextBlock: + def __init__( + self, + text: str, + color: Color | None = None, + bold: bool = False, + width: int | None = None, + space: str = " ", + spaces: str | re.Pattern = r" ", + newlines: str | re.Pattern = r"[\n\r]", + too_long_suffix: str = "...", + no_sep: bool = False, + center: bool = False, + ) -> None: + super().__init__() + + self.width = width + self.lines = wrapped(uncolored(text), width, space, spaces, newlines, too_long_suffix) + self.color = color + self.bold = bold + self.no_sep = no_sep + self.center = center + + +def render_text_blocks( + blocks: list[list[TextBlock]], + newline: str = "\n", + align_all_blocks: bool = False, + padding: int = 0, +) -> str: + """Renders a collection of blocks into a single string. + + Args: + blocks: The blocks to render. + newline: The string to use as a newline separator. + align_all_blocks: If set, aligns the widths for all blocks. + padding: The amount of padding to add to each block. + + Returns: + The rendered blocks. + """ + if align_all_blocks: + if any(len(row) != len(blocks[0]) for row in blocks): + raise ValueError("All rows must have the same number of blocks in order to align them") + widths = [[max(len(line) for line in i.lines) if i.width is None else i.width for i in r] for r in blocks] + row_widths = [max(i) for i in zip(*widths)] + for row in blocks: + for i, block in enumerate(row): + block.width = row_widths[i] + + def get_widths(row: list[TextBlock], n: int = 0) -> list[int]: + return [ + (max(len(line) for line in block.lines) if block.width is None else block.width) + n + padding + for block in row + ] + + def get_acc_widths(row: list[TextBlock], n: int = 0) -> list[int]: + return list(itertools.accumulate(get_widths(row, n))) + + def get_height(row: list[TextBlock]) -> int: + return max(len(block.lines) for block in row) + + def pad(s: str, width: int, center: bool) -> str: + swidth = len(s) + if center: + lpad, rpad = (width - swidth) // 2, (width - swidth + 1) // 2 + else: + lpad, rpad = 0, width - swidth + return " " * lpad + s + " " * rpad + + lines = [] + prev_row: list[TextBlock] | None = None + for row in blocks: + if prev_row is None: + lines += ["┌─" + "─┬─".join(["─" * width for width in get_widths(row)]) + "─┐"] + elif not all(block.no_sep for block in row): + ins, outs = get_acc_widths(prev_row, 3), get_acc_widths(row, 3) + segs = sorted([(i, False) for i in ins] + [(i, True) for i in outs]) + line = ["├"] + + c = 1 + for i, (s, is_out) in enumerate(segs): + if i > 0 and segs[i - 1][0] == s: + continue + is_in_out = i < len(segs) - 1 and segs[i + 1][0] == s + is_last = i == len(segs) - 2 if is_in_out else i == len(segs) - 1 + + line += "─" * (s - c) + if is_last: + if is_in_out: + line += "┤" + elif is_out: + line += "┐" + else: + line += "┘" + else: # noqa: PLR5501 + if is_in_out: + line += "┼" + elif is_out: + line += "┬" + else: + line += "┴" + c = s + 1 + + lines += ["".join(line)] + + for i in range(get_height(row)): + lines += [ + "│ " + + " │ ".join( + [ + ( + " " * width + if i >= len(block.lines) + else colored(pad(block.lines[i], width, block.center), block.color, bold=block.bold) + ) + for block, width in zip(row, get_widths(row)) + ] + ) + + " │" + ] + + prev_row = row + if prev_row is not None: + lines += ["└─" + "─┴─".join(["─" * width for width in get_widths(prev_row)]) + "─┘"] + + return newline.join(lines) + + +def format_timedelta(timedelta: datetime.timedelta, short: bool = False) -> str: + """Formats a delta time to human-readable format. + + Args: + timedelta: The delta to format + short: If set, uses a shorter format + + Returns: + The human-readable time delta + """ + parts = [] + if timedelta.days > 0: + if short: + parts += [f"{timedelta.days}d"] + else: + parts += [f"{timedelta.days} day" if timedelta.days == 1 else f"{timedelta.days} days"] + + seconds = timedelta.seconds + + if seconds > 60 * 60: + hours, seconds = seconds // (60 * 60), seconds % (60 * 60) + if short: + parts += [f"{hours}h"] + else: + parts += [f"{hours} hour" if hours == 1 else f"{hours} hours"] + + if seconds > 60: + minutes, seconds = seconds // 60, seconds % 60 + if short: + parts += [f"{minutes}m"] + else: + parts += [f"{minutes} minute" if minutes == 1 else f"{minutes} minutes"] + + if short: + parts += [f"{seconds}s"] + else: + parts += [f"{seconds} second" if seconds == 1 else f"{seconds} seconds"] + + return ", ".join(parts) + + +def format_datetime(dt: datetime.datetime) -> str: + """Formats a datetime to human-readable format. + + Args: + dt: The datetime to format + + Returns: + The human-readable datetime + """ + return dt.strftime("%Y-%m-%d %H:%M:%S") + + +def camelcase_to_snakecase(s: str) -> str: + return re.sub(r"([a-z0-9])([A-Z])", r"\1_\2", s).lower() + + +def snakecase_to_camelcase(s: str) -> str: + return "".join(word.title() for word in s.split("_")) + + +def highlight_exception_message(s: str) -> str: + s = re.sub(r"(\w+Error)", r"\033[1;31m\1\033[0m", s) + s = re.sub(r"(\w+Exception)", r"\033[1;31m\1\033[0m", s) + s = re.sub(r"(\w+Warning)", r"\033[1;33m\1\033[0m", s) + s = re.sub(r"\^+", r"\033[1;35m\g<0>\033[0m", s) + s = re.sub(r"File \"(.+?)\"", r'File "\033[36m\1\033[0m"', s) + return s + + +def is_interactive_session() -> bool: + return hasattr(sys, "ps1") or hasattr(sys, "ps2") or sys.stdout.isatty() or sys.stderr.isatty()