diff --git a/megatron/neox_arguments/arguments.py b/megatron/neox_arguments/arguments.py index 18de92296..040dc12d6 100644 --- a/megatron/neox_arguments/arguments.py +++ b/megatron/neox_arguments/arguments.py @@ -1135,56 +1135,63 @@ def validate_types(self): At runtime, checks types are actually the type specified. """ for field_name, field_def in self.__dataclass_fields__.items(): - actual_value = getattr(self, field_name) - if actual_value is None: - continue # we allow for some values not to be configured - - if self.autotuning is not None and actual_value == "auto": - continue + try: + actual_value = getattr(self, field_name) + if actual_value is None: + continue # we allow for some values not to be configured - actual_type = type(actual_value) - if actual_type != field_def.type: - if ( - actual_type == int and field_def.type == float - ): # floats should be able to be configured as ints + if self.autotuning is not None and actual_value == "auto": continue - # for typing.Literal (i.e a list of choices) - checks that actual value is in accepted values - elif field_def.type.__origin__ == Literal: - accepted_values = field_def.type.__args__ - if actual_value in accepted_values: + actual_type = type(actual_value) + if actual_type != field_def.type: + if ( + actual_type == int and field_def.type == float + ): # floats should be able to be configured as ints continue - elif type(actual_value) == str: - # case insensitive checking - lowercase_accepted_values = [ - i.lower() for i in accepted_values if isinstance(i, str) - ] - if actual_value.lower() in lowercase_accepted_values: + + # for typing.Literal (i.e a list of choices) - checks that actual value is in accepted values + elif field_def.type.__origin__ == Literal: + accepted_values = field_def.type.__args__ + if actual_value in accepted_values: continue - logging.error( - self.__class__.__name__ - + ".validate_types() " - + f"{field_name}: '{actual_value}' Not in accepted values: '{accepted_values}'" - ) - return False - elif field_def.type.__origin__ == Union: - accepted_types = field_def.type.__args__ - if actual_type in accepted_types: - continue - else: + elif type(actual_value) == str: + # case insensitive checking + lowercase_accepted_values = [ + i.lower() for i in accepted_values if isinstance(i, str) + ] + if actual_value.lower() in lowercase_accepted_values: + continue logging.error( self.__class__.__name__ + ".validate_types() " - + f"{field_name}: '{actual_type}' not in {accepted_types}" + + f"{field_name}: '{actual_value}' Not in accepted values: '{accepted_values}'" ) return False + elif field_def.type.__origin__ == Union: + accepted_types = field_def.type.__args__ + if actual_type in accepted_types: + continue + else: + logging.error( + self.__class__.__name__ + + ".validate_types() " + + f"{field_name}: '{actual_type}' not in {accepted_types}" + ) + return False + + logging.error( + self.__class__.__name__ + + ".validate_types() " + + f"{field_name}: '{actual_type}' instead of '{field_def.type}'" + ) + return False + except Exception as e: logging.error( - self.__class__.__name__ - + ".validate_types() " - + f"{field_name}: '{actual_type}' instead of '{field_def.type}'" + f"Found an error for configuration {field_name=} with {field_def=}, see below for more details;" ) - return False + raise e # validate deepspeed dicts for field_name in ["optimizer", "scheduler"]: