Skip to content

Commit

Permalink
enable ruff formatter (#853)
Browse files Browse the repository at this point in the history
Former-commit-id: 999c51bb1c4e4842c2c27e3e15e6f1a46e650b64
  • Loading branch information
misko authored Sep 17, 2024
1 parent 623a6e6 commit f41b231
Show file tree
Hide file tree
Showing 20 changed files with 188 additions and 114 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,5 @@ jobs:
- name: ruff
run: |
ruff --version
ruff check src
ruff check src # tests has a lot of issues , TODO
ruff format --check src # tests
2 changes: 1 addition & 1 deletion ruff.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
include = ["src/fairchem/core/**/*.py", "src/fairchem/data/oc/**/*.py"]
include = ["src/fairchem/core/**/*.py", "src/fairchem/data/oc/**/*.py", "tests/**/*.py"]
line-length = 88

[lint]
Expand Down
20 changes: 13 additions & 7 deletions src/fairchem/core/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,19 @@ def checkpoint(self, *args, **kwargs):
self.config["timestamp_id"] = self.trainer.timestamp_id
if self.trainer.logger is not None:
self.trainer.logger.mark_preempting()
logging.info(f'Checkpointing callback is triggered, checkpoint saved to: {self.config["checkpoint"]}, timestamp_id: {self.config["timestamp_id"]}')
logging.info(
f'Checkpointing callback is triggered, checkpoint saved to: {self.config["checkpoint"]}, timestamp_id: {self.config["timestamp_id"]}'
)
return DelayedSubmission(new_runner, self.config)


def runner_wrapper(config: dict):
Runner()(config)


def main(args: argparse.Namespace | None = None, override_args: list[str] | None = None):
def main(
args: argparse.Namespace | None = None, override_args: list[str] | None = None
):
"""Run the main fairchem program."""
setup_logging()

Expand All @@ -66,7 +70,9 @@ def main(args: argparse.Namespace | None = None, override_args: list[str] | None
args, override_args = parser.parse_known_args()

# TODO: rename num_gpus -> num_ranks everywhere
assert args.num_gpus > 0, "num_gpus is used to determine number ranks, so it must be at least 1"
assert (
args.num_gpus > 0
), "num_gpus is used to determine number ranks, so it must be at least 1"
config = build_config(args, override_args)

if args.submit: # Run on cluster
Expand Down Expand Up @@ -98,9 +104,7 @@ def main(args: argparse.Namespace | None = None, override_args: list[str] | None

else: # Run locally on a single node, n-processes
if args.num_gpus > 1:
logging.info(
f"Running in local mode with {args.num_gpus} ranks"
)
logging.info(f"Running in local mode with {args.num_gpus} ranks")
# HACK to disable multiprocess dataloading in local mode
# there is an open issue where LMDB's environment cannot be pickled and used
# during torch multiprocessing https://github.com/pytorch/examples/issues/526
Expand All @@ -119,7 +123,9 @@ def main(args: argparse.Namespace | None = None, override_args: list[str] | None
)
elastic_launch(launch_config, runner_wrapper)(config)
else:
logging.info("Running in local mode without elastic launch (single gpu only)")
logging.info(
"Running in local mode without elastic launch (single gpu only)"
)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["LOCAL_RANK"] = "0"
os.environ["RANK"] = "0"
Expand Down
16 changes: 13 additions & 3 deletions src/fairchem/core/common/distutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
T = TypeVar("T")
DISTRIBUTED_PORT = 13356


def os_environ_get_or_throw(x: str) -> str:
if x not in os.environ:
raise RuntimeError(f"Could not find {x} in ENV variables")
Expand Down Expand Up @@ -68,7 +69,9 @@ def setup(config) -> None:
)

# ensures GPU0 does not have extra context/higher peak memory
logging.info(f"local rank: {config['local_rank']}, visible devices: {os.environ['CUDA_VISIBLE_DEVICES']}")
logging.info(
f"local rank: {config['local_rank']}, visible devices: {os.environ['CUDA_VISIBLE_DEVICES']}"
)
torch.cuda.set_device(config["local_rank"])

dist.init_process_group(
Expand Down Expand Up @@ -104,13 +107,20 @@ def setup(config) -> None:
)
else:
if not os.environ.get("MASTER_ADDR"):
assert config["world_size"] == 1, "Can only setup master address and port at this point for a single rank, otherwise we assume the processes and the comm addr/port have already been setup"
assert (
config["world_size"] == 1
), "Can only setup master address and port at this point for a single rank, otherwise we assume the processes and the comm addr/port have already been setup"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(get_free_port())
os.environ["LOCAL_RANK"] = "0"
os.environ["RANK"] = "0"
config["local_rank"] = int(os.environ.get("LOCAL_RANK"))
dist.init_process_group(backend=config["distributed_backend"], rank=int(os.environ.get("RANK")), world_size=config["world_size"], timeout=timeout)
dist.init_process_group(
backend=config["distributed_backend"],
rank=int(os.environ.get("RANK")),
world_size=config["world_size"],
timeout=timeout,
)


def cleanup() -> None:
Expand Down
2 changes: 2 additions & 0 deletions src/fairchem/core/common/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def log_summary(self, summary_dict: dict[str, Any]) -> None:
def log_artifact(self, name: str, type: str, file_location: str) -> None:
pass


@registry.register_logger("wandb")
class WandBLogger(Logger):
def __init__(self, config) -> None:
Expand Down Expand Up @@ -115,6 +116,7 @@ def log_artifact(self, name: str, type: str, file_location: str) -> None:
art.add_file(file_location)
art.save()


@registry.register_logger("tensorboard")
class TensorboardLogger(Logger):
def __init__(self, config) -> None:
Expand Down
7 changes: 6 additions & 1 deletion src/fairchem/core/common/profiler_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
if TYPE_CHECKING:
from fairchem.core.common.logger import Logger


def get_default_profiler_handler(run_id: str, output_dir: str, logger: Logger):
"""Get a standard callback handle for the pytorch profiler"""

Expand All @@ -20,9 +21,13 @@ def trace_handler(p):
print(f"Saving trace in {output_path}")
p.export_chrome_trace(output_path)
if logger:
logger.log_artifact(name=trace_name, type="profile", file_location=output_path)
logger.log_artifact(
name=trace_name, type="profile", file_location=output_path
)

return trace_handler


def get_profile_schedule(wait: int = 5, warmup: int = 5, active: int = 2):
"""Get a profile schedule and total number of steps to run
check pytorch docs on the meaning of these paramters:
Expand Down
8 changes: 6 additions & 2 deletions src/fairchem/core/common/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,17 @@
from submitit.core.utils import JobPaths


def add_timestamp_id_to_submission_pickle(slurm_folder: str, slurm_job_id: str, timestamp_id: str):
def add_timestamp_id_to_submission_pickle(
slurm_folder: str, slurm_job_id: str, timestamp_id: str
):
# Try to put the timestamp-id into the original submission pickle's config
# so that if the node crashes, it can be pick up the correct run to resume
#
# we need to do this after the job has started because the timestamp-id is generated at runtime
# instead a-priori before the submission starts (ie: if we had a db to store a global job unique job)
submission_pickle_path = JobPaths(folder=slurm_folder, job_id=slurm_job_id).submitted_pickle
submission_pickle_path = JobPaths(
folder=slurm_folder, job_id=slurm_job_id
).submitted_pickle
try:
with open(str(submission_pickle_path), "rb") as f:
pkl = pickle.load(f)
Expand Down
1 change: 1 addition & 0 deletions src/fairchem/core/common/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def spawn_multi_process(

return [mp_output_dict[i] for i in range(config.world_size)]


def init_local_distributed_process_group(backend="nccl"):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(get_free_port())
Expand Down
4 changes: 3 additions & 1 deletion src/fairchem/core/datasets/lmdb_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,9 @@ def sample_property_metadata(self, num_samples: int = 100):
}


def data_list_collater(data_list: list[BaseData], otf_graph: bool = False, to_dict: bool = False) -> BaseData | dict[str, torch.Tensor]:
def data_list_collater(
data_list: list[BaseData], otf_graph: bool = False, to_dict: bool = False
) -> BaseData | dict[str, torch.Tensor]:
batch = Batch.from_data_list(data_list)

if not otf_graph:
Expand Down
42 changes: 30 additions & 12 deletions src/fairchem/core/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,9 +254,15 @@ def __init__(
# if finetune_config is provided, then attempt to load the model from the given finetune checkpoint
starting_model = None
if finetune_config is not None:
starting_model: HydraModel = load_model_and_weights_from_checkpoint(finetune_config["starting_checkpoint"])
logging.info(f"Found and loaded fine-tuning checkpoint: {finetune_config['starting_checkpoint']} (Note we are NOT loading the training state from this checkpoint, only parts of the model and weights)")
assert isinstance(starting_model, HydraModel), "Can only finetune starting from other hydra models!"
starting_model: HydraModel = load_model_and_weights_from_checkpoint(
finetune_config["starting_checkpoint"]
)
logging.info(
f"Found and loaded fine-tuning checkpoint: {finetune_config['starting_checkpoint']} (Note we are NOT loading the training state from this checkpoint, only parts of the model and weights)"
)
assert isinstance(
starting_model, HydraModel
), "Can only finetune starting from other hydra models!"

if backbone is not None:
backbone = copy.deepcopy(backbone)
Expand All @@ -268,17 +274,23 @@ def __init__(
)
elif starting_model is not None:
self.backbone = starting_model.backbone
logging.info(f"User did not specify a backbone, using the backbone from the starting checkpoint {self.backbone}")
logging.info(
f"User did not specify a backbone, using the backbone from the starting checkpoint {self.backbone}"
)
else:
raise RuntimeError("Backbone not specified and not found in the starting checkpoint")
raise RuntimeError(
"Backbone not specified and not found in the starting checkpoint"
)

if heads is not None:
heads = copy.deepcopy(heads)
# Iterate through outputs_cfg and create heads
self.output_heads: dict[str, HeadInterface] = {}

head_names_sorted = sorted(heads.keys())
assert len(set(head_names_sorted)) == len(head_names_sorted), "Head names must be unique!"
assert len(set(head_names_sorted)) == len(
head_names_sorted
), "Head names must be unique!"
for head_name in head_names_sorted:
head_config = heads[head_name]
if "module" not in head_config:
Expand All @@ -295,15 +307,23 @@ def __init__(
self.output_heads = torch.nn.ModuleDict(self.output_heads)
elif starting_model is not None:
self.output_heads = starting_model.output_heads
logging.info(f"User did not specify heads, using the output heads from the starting checkpoint {self.output_heads}")
logging.info(
f"User did not specify heads, using the output heads from the starting checkpoint {self.output_heads}"
)
else:
raise RuntimeError("Heads not specified and not found in the starting checkpoint")
raise RuntimeError(
"Heads not specified and not found in the starting checkpoint"
)

def forward(self, data: Batch):
# lazily get device from input to use with amp, at least one input must be a tensor to figure out it's device
if not self.device:
device_from_tensors = {x.device.type for x in data.values() if isinstance(x, torch.Tensor)}
assert len(device_from_tensors) == 1, f"all inputs must be on the same device, found the following devices {device_from_tensors}"
device_from_tensors = {
x.device.type for x in data.values() if isinstance(x, torch.Tensor)
}
assert (
len(device_from_tensors) == 1
), f"all inputs must be on the same device, found the following devices {device_from_tensors}"
self.device = device_from_tensors.pop()

emb = self.backbone(data)
Expand All @@ -319,5 +339,3 @@ def forward(self, data: Batch):
out[k] = self.output_heads[k](data, emb)

return out


34 changes: 20 additions & 14 deletions src/fairchem/core/models/dimenet_plus_plus.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,13 +352,16 @@ def forward(
)
}
if self.regress_forces:
outputs["forces"] = -1 * (
torch.autograd.grad(
outputs["energy"],
data.pos,
grad_outputs=torch.ones_like(outputs["energy"]),
create_graph=True,
)[0]
outputs["forces"] = (
-1
* (
torch.autograd.grad(
outputs["energy"],
data.pos,
grad_outputs=torch.ones_like(outputs["energy"]),
create_graph=True,
)[0]
)
)
return outputs

Expand Down Expand Up @@ -465,13 +468,16 @@ def forward(self, data):
outputs = {"energy": energy}

if self.regress_forces:
forces = -1 * (
torch.autograd.grad(
energy,
data.pos,
grad_outputs=torch.ones_like(energy),
create_graph=True,
)[0]
forces = (
-1
* (
torch.autograd.grad(
energy,
data.pos,
grad_outputs=torch.ones_like(energy),
create_graph=True,
)[0]
)
)
outputs["forces"] = forces

Expand Down
7 changes: 4 additions & 3 deletions src/fairchem/core/models/equiformer_v2/equiformer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ def no_weight_decay(self) -> set:

@registry.register_model("equiformer_v2_energy_head")
class EquiformerV2EnergyHead(nn.Module, HeadInterface):
def __init__(self, backbone, reduce: str="sum"):
def __init__(self, backbone, reduce: str = "sum"):
super().__init__()
self.reduce = reduce
self.avg_num_nodes = backbone.avg_num_nodes
Expand Down Expand Up @@ -645,8 +645,9 @@ def forward(self, data: Batch, emb: dict[str, torch.Tensor | GraphData]):
elif self.reduce == "mean":
return {"energy": energy / data.natoms}
else:
raise ValueError(f"reduce can only be sum or mean, user provided: {self.reduce}")

raise ValueError(
f"reduce can only be sum or mean, user provided: {self.reduce}"
)


@registry.register_model("equiformer_v2_force_head")
Expand Down
6 changes: 4 additions & 2 deletions src/fairchem/core/models/escn/escn.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ def forward(self, data: Batch) -> dict[str, torch.Tensor]:

@registry.register_model("escn_energy_head")
class eSCNEnergyHead(nn.Module, HeadInterface):
def __init__(self, backbone, reduce = "sum"):
def __init__(self, backbone, reduce="sum"):
super().__init__()
backbone.energy_block = None
self.reduce = reduce
Expand All @@ -558,7 +558,9 @@ def forward(
elif self.reduce == "mean":
return {"energy": energy / data.natoms}
else:
raise ValueError(f"reduce can only be sum or mean, user provided: {self.reduce}")
raise ValueError(
f"reduce can only be sum or mean, user provided: {self.reduce}"
)


@registry.register_model("escn_force_head")
Expand Down
Loading

0 comments on commit f41b231

Please sign in to comment.