Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Merge branch 'main' into datasets_feature
Browse files Browse the repository at this point in the history
  • Loading branch information
pab-vmware authored Apr 4, 2021
2 parents 08d3012 + decb875 commit aa3aefa
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 20 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- Sanity checks in the `GradientDescentTrainer` can now be turned off by setting the `run_sanity_checks` parameter to `False`.
- Allow the order of examples in the task cards to be specified explicitly
- `histogram_interval` parameter is now deprecated in `TensorboardWriter`, please use `distribution_interval` instead.
- Memory usage is not logged in tensorboard during training now. `ConsoleLoggerCallback` should be used instead.
Expand Down
5 changes: 3 additions & 2 deletions allennlp/sanity_checks/normalization_bias_verification.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Code based almost entirely on
https://github.com/awaelchli/pytorch-lightning-snippets/commit/7db53f774715d635c59ef56f21a17634d246b2c5
Code based almost entirely from the [pytorch-lightning-snippets]
(https://github.com/awaelchli/pytorch-lightning-snippets/commit/7db53f774715d635c59ef56f21a17634d246b2c5)
repository.
"""

import torch
Expand Down
46 changes: 30 additions & 16 deletions allennlp/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import time
import traceback
from contextlib import contextmanager
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Type

from allennlp.common.util import int_to_device

Expand Down Expand Up @@ -118,10 +118,10 @@ class GradientDescentTrainer(Trainer):
stopping. There are many other bells and whistles as well.
Registered as a `Trainer` with the name "gradient_descent" (and is also the default `Trainer`).
The constructor that is registered is `from_partial_objects` - see the arguments to that
function for the exact keys that should be used, if you are using a configuration file. They
largely match the arguments to `__init__`, and we don't repeat their docstrings in
`from_partial_objects`.
The constructor that is registered is [`from_partial_objects`](#from_partial_objects) -
see the arguments to that function for the exact keys that should be used, if you are using
a configuration file. They largely match the arguments to `__init__`, and we don't repeat their
docstrings in `from_partial_objects`.
[0]: https://tinyurl.com/y5mv44fw
Expand Down Expand Up @@ -248,6 +248,16 @@ class GradientDescentTrainer(Trainer):
use_amp : `bool`, optional, (default = `False`)
If `True`, we'll train using [Automatic Mixed Precision](https://pytorch.org/docs/stable/amp.html).
enable_default_callbacks : `bool`, optional (default = `True`)
When `True`, the [`DEFAULT_CALLBACKS`](#default_callbacks) will be used in
addition to any other callbacks listed in the `callbacks` parameter.
When set to `False`, `DEFAULT_CALLBACKS` are not used.
run_sanity_checks : `bool`, optional (default = `True`)
Determines whether model sanity checks, such as
[`NormalizationBiasVerification`](../../sanity_checks/normalization_bias_verification/),
are ran.
"""

def __init__(
Expand All @@ -273,6 +283,8 @@ def __init__(
world_size: int = 1,
num_gradient_accumulation_steps: int = 1,
use_amp: bool = False,
enable_default_callbacks: bool = True,
run_sanity_checks: bool = True,
) -> None:
super().__init__(serialization_dir, cuda_device, distributed, local_rank, world_size)

Expand Down Expand Up @@ -316,6 +328,15 @@ def __init__(
self._moving_average = moving_average

self._callbacks = callbacks or []
default_callbacks = list(DEFAULT_CALLBACKS) if enable_default_callbacks else []
if run_sanity_checks:
default_callbacks.append(SanityChecksCallback)
for callback_cls in default_callbacks:
for callback in self._callbacks:
if callback.__class__ == callback_cls:
break
else:
self._callbacks.append(callback_cls(self._serialization_dir))

self._batch_num_total = 0
self._last_log = 0.0 # time of last logging
Expand Down Expand Up @@ -970,6 +991,7 @@ def from_partial_objects(
checkpointer: Lazy[Checkpointer] = Lazy(Checkpointer),
callbacks: List[Lazy[TrainerCallback]] = None,
enable_default_callbacks: bool = True,
run_sanity_checks: bool = True,
) -> "Trainer":
"""
This method exists so that we can have a documented method to construct this class using
Expand Down Expand Up @@ -1037,13 +1059,6 @@ def from_partial_objects(
callbacks_: List[TrainerCallback] = []
for callback_ in callbacks or []:
callbacks_.append(callback_.construct(serialization_dir=serialization_dir))
if enable_default_callbacks:
for callback_cls in DEFAULT_CALLBACKS:
for callback in callbacks_:
if callback.__class__ == callback_cls:
break
else:
callbacks_.append(callback_cls(serialization_dir))

return cls(
model,
Expand All @@ -1067,13 +1082,12 @@ def from_partial_objects(
world_size=world_size,
num_gradient_accumulation_steps=num_gradient_accumulation_steps,
use_amp=use_amp,
enable_default_callbacks=enable_default_callbacks,
run_sanity_checks=run_sanity_checks,
)


DEFAULT_CALLBACKS = (
SanityChecksCallback,
ConsoleLoggerCallback,
)
DEFAULT_CALLBACKS: Tuple[Type[TrainerCallback]] = (ConsoleLoggerCallback,)
"""
The default callbacks used by `GradientDescentTrainer`.
"""
2 changes: 1 addition & 1 deletion dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ nr.databind.core<0.0.17
nr.interface<0.0.4

mkdocs==1.1.2
mkdocs-material>=5.5.0,<7.1.0
mkdocs-material>=5.5.0,<7.2.0
markdown-include==0.6.0

#### PACKAGE-UPLOAD PACKAGES ####
Expand Down
2 changes: 1 addition & 1 deletion tests/training/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,7 +847,7 @@ def test_sanity_check_default(self):
serialization_dir=self.TEST_DIR,
data_loader=data_loader,
num_epochs=1,
enable_default_callbacks=False,
run_sanity_checks=False,
)

# Check is not run, so no failure.
Expand Down

0 comments on commit aa3aefa

Please sign in to comment.