From ef431a1c408bbbe46c62b2860f6b529cd5ed725b Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 20 Sep 2025 13:26:35 +0000 Subject: [PATCH 1/7] Initial plan From 9eb1beac599e4b17c86a15d4740e4040dda61ccc Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 20 Sep 2025 13:42:24 +0000 Subject: [PATCH 2/7] feat(training): add comprehensive NaN detection with tests and validation Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com> --- deepmd/pd/train/training.py | 17 + deepmd/pt/train/training.py | 17 + deepmd/tf/train/trainer.py | 10 + deepmd/utils/nan_detector.py | 119 ++ source/3rdparty/implib/implib-gen.py | 1093 ++++++++++--------- source/tests/common/test_nan_detector.py | 167 +++ source/tests/common/test_nan_integration.py | 131 +++ 7 files changed, 1046 insertions(+), 508 deletions(-) create mode 100644 deepmd/utils/nan_detector.py create mode 100644 source/tests/common/test_nan_detector.py create mode 100644 source/tests/common/test_nan_integration.py diff --git a/deepmd/pd/train/training.py b/deepmd/pd/train/training.py index 4e5fea081f..7c084f1084 100644 --- a/deepmd/pd/train/training.py +++ b/deepmd/pd/train/training.py @@ -75,6 +75,9 @@ from deepmd.utils.data import ( DataRequirementItem, ) +from deepmd.utils.nan_detector import ( + check_loss_nan, +) from deepmd.utils.path import ( DPH5Path, ) @@ -951,6 +954,20 @@ def log_loss_valid(_task_key="Default"): fout, display_step_id, cur_lr, train_results, valid_results ) + # Check for NaN in loss values before saving checkpoint + # Loss values are already on CPU at this point for display/logging + if self.rank == 0: + if not self.multi_task: + check_loss_nan(display_step_id, train_results) + if valid_results: + check_loss_nan(display_step_id, valid_results) + else: + for task_key in train_results: + if train_results[task_key]: + check_loss_nan(display_step_id, train_results[task_key]) + if valid_results[task_key]: + check_loss_nan(display_step_id, valid_results[task_key]) + if ( ((_step_id + 1) % self.save_freq == 0 and _step_id != self.start_step) or (_step_id + 1) == self.num_steps diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 52d2888081..d9ef2fbd41 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -75,6 +75,9 @@ from deepmd.utils.data import ( DataRequirementItem, ) +from deepmd.utils.nan_detector import ( + check_loss_nan, +) if torch.__version__.startswith("2"): import torch._dynamo @@ -1070,6 +1073,20 @@ def log_loss_valid(_task_key: str = "Default") -> dict: fout, display_step_id, cur_lr, train_results, valid_results ) + # Check for NaN in loss values before saving checkpoint + # Loss values are already on CPU at this point for display/logging + if self.rank == 0: + if not self.multi_task: + check_loss_nan(display_step_id, train_results) + if valid_results: + check_loss_nan(display_step_id, valid_results) + else: + for task_key in train_results: + if train_results[task_key]: + check_loss_nan(display_step_id, train_results[task_key]) + if valid_results[task_key]: + check_loss_nan(display_step_id, valid_results[task_key]) + if ( ( (display_step_id) % self.save_freq == 0 diff --git a/deepmd/tf/train/trainer.py b/deepmd/tf/train/trainer.py index f70c919301..4b6e884a25 100644 --- a/deepmd/tf/train/trainer.py +++ b/deepmd/tf/train/trainer.py @@ -60,6 +60,9 @@ from deepmd.utils.data import ( DataRequirementItem, ) +from deepmd.utils.nan_detector import ( + check_loss_nan, +) log = logging.getLogger(__name__) @@ -684,6 +687,13 @@ def valid_on_the_fly( cur_batch = self.cur_batch current_lr = run_sess(self.sess, self.learning_rate) + + # Check for NaN in loss values before writing to file and saving checkpoint + # Loss values are already on CPU at this point + check_loss_nan(cur_batch, train_results) + if valid_results is not None: + check_loss_nan(cur_batch, valid_results) + if print_header: self.print_header(fp, train_results, valid_results) self.print_on_training( diff --git a/deepmd/utils/nan_detector.py b/deepmd/utils/nan_detector.py new file mode 100644 index 0000000000..6a2641e327 --- /dev/null +++ b/deepmd/utils/nan_detector.py @@ -0,0 +1,119 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Utilities for detecting NaN values in loss during training.""" + +import logging +import math +from typing import ( + Any, +) + +import numpy as np + +log = logging.getLogger(__name__) + + +class LossNaNError(Exception): + """Exception raised when NaN is detected in loss during training.""" + + def __init__(self, step: int, loss_dict: dict[str, Any]) -> None: + """Initialize the exception. + + Parameters + ---------- + step : int + The training step where NaN was detected + loss_dict : dict[str, Any] + Dictionary containing the loss values where NaN was found + """ + self.step = step + self.loss_dict = loss_dict + super().__init__(self._format_message()) + + def _format_message(self) -> str: + """Format the error message.""" + nan_losses = [] + for key, value in self.loss_dict.items(): + if self._is_nan(value): + nan_losses.append(f"{key}={value}") + + message = ( + f"NaN detected in loss at training step {self.step}. " + f"Training stopped to prevent wasting time with corrupted parameters. " + f"NaN values found in: {', '.join(nan_losses)}. " + f"This typically indicates unstable training conditions such as " + f"learning rate too high, poor data quality, or numerical instability." + ) + return message + + @staticmethod + def _is_nan(value: Any) -> bool: + """Check if a value is NaN.""" + if value is None: + return False + try: + # Handle various tensor types and Python scalars + if hasattr(value, "item"): + # PyTorch/TensorFlow/PaddlePaddle tensor + return math.isnan(value.item()) + elif isinstance(value, (int, float)): + # Python scalar + return math.isnan(value) + elif isinstance(value, np.ndarray): + # NumPy array + return np.isnan(value).any() + else: + # Try to convert to float and check + return math.isnan(float(value)) + except (TypeError, ValueError): + # If we can't convert to float, assume it's not NaN + return False + + +def check_loss_nan(step: int, loss_dict: dict[str, Any]) -> None: + """Check if any loss values contain NaN and raise an exception if found. + + This function is designed to be called during training after loss values + are computed and available on CPU, typically during the logging/display phase. + + Parameters + ---------- + step : int + Current training step + loss_dict : dict[str, Any] + Dictionary containing loss values to check for NaN + + Raises + ------ + LossNaNError + If any loss value contains NaN + """ + nan_found = False + for key, value in loss_dict.items(): + if LossNaNError._is_nan(value): + nan_found = True + log.error(f"NaN detected in {key} at step {step}: {value}") + + if nan_found: + raise LossNaNError(step, loss_dict) + + +def check_single_loss_nan(step: int, loss_name: str, loss_value: Any) -> None: + """Check if a single loss value contains NaN and raise an exception if found. + + Parameters + ---------- + step : int + Current training step + loss_name : str + Name/identifier of the loss + loss_value : Any + Loss value to check for NaN + + Raises + ------ + LossNaNError + If the loss value contains NaN + """ + if LossNaNError._is_nan(loss_value): + log.error(f"NaN detected in {loss_name} at step {step}: {loss_value}") + raise LossNaNError(step, {loss_name: loss_value}) diff --git a/source/3rdparty/implib/implib-gen.py b/source/3rdparty/implib/implib-gen.py index 86cfa77378..3a51be271d 100755 --- a/source/3rdparty/implib/implib-gen.py +++ b/source/3rdparty/implib/implib-gen.py @@ -22,577 +22,654 @@ me = os.path.basename(__file__) root = os.path.dirname(__file__) + def warn(msg): - """Emits a nicely-decorated warning.""" - sys.stderr.write(f'{me}: warning: {msg}\n') + """Emits a nicely-decorated warning.""" + sys.stderr.write(f"{me}: warning: {msg}\n") + def error(msg): - """Emits a nicely-decorated error and exits.""" - sys.stderr.write(f'{me}: error: {msg}\n') - sys.exit(1) - -def run(args, stdin=''): - """Runs external program and aborts on error.""" - env = os.environ.copy() - # Force English language - env['LC_ALL'] = 'c' - try: - del env["LANG"] - except KeyError: - pass - with subprocess.Popen(args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, - stderr=subprocess.PIPE, env=env) as p: - out, err = p.communicate(input=stdin.encode('utf-8')) - out = out.decode('utf-8') - err = err.decode('utf-8') - if p.returncode != 0 or err: - error(f"{args[0]} failed with retcode {p.returncode}:\n{err}") - return out, err + """Emits a nicely-decorated error and exits.""" + sys.stderr.write(f"{me}: error: {msg}\n") + sys.exit(1) + + +def run(args, stdin=""): + """Runs external program and aborts on error.""" + env = os.environ.copy() + # Force English language + env["LC_ALL"] = "c" + try: + del env["LANG"] + except KeyError: + pass + with subprocess.Popen( + args, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env=env, + ) as p: + out, err = p.communicate(input=stdin.encode("utf-8")) + out = out.decode("utf-8") + err = err.decode("utf-8") + if p.returncode != 0 or err: + error(f"{args[0]} failed with retcode {p.returncode}:\n{err}") + return out, err + def make_toc(words, renames=None): - "Make an mapping of words to their indices in list" - renames = renames or {} - toc = {} - for i, n in enumerate(words): - name = renames.get(n, n) - toc[i] = name - return toc + "Make an mapping of words to their indices in list" + renames = renames or {} + toc = {} + for i, n in enumerate(words): + name = renames.get(n, n) + toc[i] = name + return toc + def parse_row(words, toc, hex_keys): - "Make a mapping from column names to values" - vals = {k: (words[i] if i < len(words) else '') for i, k in toc.items()} - for k in hex_keys: - if vals[k]: - vals[k] = int(vals[k], 16) - return vals + "Make a mapping from column names to values" + vals = {k: (words[i] if i < len(words) else "") for i, k in toc.items()} + for k in hex_keys: + if vals[k]: + vals[k] = int(vals[k], 16) + return vals + def collect_syms(f): - """Collect ELF dynamic symtab.""" - - # --dyn-syms does not always work for some reason so dump all symtabs - out, _ = run(['readelf', '-sW', f]) - - toc = None - syms = [] - syms_set = set() - for line in out.splitlines(): - line = line.strip() - if not line: - # Next symtab - toc = None - continue - words = re.split(r' +', line) - if line.startswith('Num'): # Header? - if toc is not None: - error("multiple headers in output of readelf") - # Colons are different across readelf versions so get rid of them. - toc = make_toc(map(lambda n: n.replace(':', ''), words)) - elif toc is not None: - sym = parse_row(words, toc, ['Value']) - name = sym['Name'] - if not name: - continue - if name in syms_set: - continue - syms_set.add(name) - sym['Size'] = int(sym['Size'], 0) # Readelf is inconistent on Size format - if '@' in name: - sym['Default'] = '@@' in name - name, ver = re.split(r'@+', name) - sym['Name'] = name - sym['Version'] = ver - else: - sym['Default'] = True - sym['Version'] = None - syms.append(sym) - - if toc is None: - error(f"failed to analyze symbols in {f}") - - # Also collected demangled names - if syms: - out, _ = run(['c++filt'], '\n'.join((sym['Name'] for sym in syms))) - out = out.rstrip("\n") # Some c++filts append newlines at the end - for i, name in enumerate(out.split("\n")): - syms[i]['Demangled Name'] = name - - return syms + """Collect ELF dynamic symtab.""" + + # --dyn-syms does not always work for some reason so dump all symtabs + out, _ = run(["readelf", "-sW", f]) + + toc = None + syms = [] + syms_set = set() + for line in out.splitlines(): + line = line.strip() + if not line: + # Next symtab + toc = None + continue + words = re.split(r" +", line) + if line.startswith("Num"): # Header? + if toc is not None: + error("multiple headers in output of readelf") + # Colons are different across readelf versions so get rid of them. + toc = make_toc(map(lambda n: n.replace(":", ""), words)) + elif toc is not None: + sym = parse_row(words, toc, ["Value"]) + name = sym["Name"] + if not name: + continue + if name in syms_set: + continue + syms_set.add(name) + sym["Size"] = int(sym["Size"], 0) # Readelf is inconistent on Size format + if "@" in name: + sym["Default"] = "@@" in name + name, ver = re.split(r"@+", name) + sym["Name"] = name + sym["Version"] = ver + else: + sym["Default"] = True + sym["Version"] = None + syms.append(sym) + + if toc is None: + error(f"failed to analyze symbols in {f}") + + # Also collected demangled names + if syms: + out, _ = run(["c++filt"], "\n".join((sym["Name"] for sym in syms))) + out = out.rstrip("\n") # Some c++filts append newlines at the end + for i, name in enumerate(out.split("\n")): + syms[i]["Demangled Name"] = name + + return syms + def collect_relocs(f): - """Collect ELF dynamic relocs.""" - - out, _ = run(['readelf', '-rW', f]) - - toc = None - rels = [] - for line in out.splitlines(): - line = line.strip() - if not line: - toc = None - continue - if line == 'There are no relocations in this file.': - return [] - if re.match(r'^\s*Type[0-9]:', line): # Spurious lines for MIPS - continue - if re.match(r'^\s*Offset', line): # Header? - if toc is not None: - error("multiple headers in output of readelf") - words = re.split(r'\s\s+', line) # "Symbol's Name + Addend" - toc = make_toc(words) - elif toc is not None: - line = re.sub(r' \+ ', '+', line) - words = re.split(r'\s+', line) - rel = parse_row(words, toc, ['Offset', 'Info']) - rels.append(rel) - # Split symbolic representation - sym_name = 'Symbol\'s Name + Addend' - if sym_name not in rel and 'Symbol\'s Name' in rel: - # Adapt to different versions of readelf - rel[sym_name] = rel['Symbol\'s Name'] + '+0' - if rel[sym_name]: - p = rel[sym_name].split('+') - if len(p) == 1: - p = ['', p[0]] - rel[sym_name] = (p[0], int(p[1], 16)) - - if toc is None: - error(f"failed to analyze relocations in {f}") - - return rels + """Collect ELF dynamic relocs.""" + + out, _ = run(["readelf", "-rW", f]) + + toc = None + rels = [] + for line in out.splitlines(): + line = line.strip() + if not line: + toc = None + continue + if line == "There are no relocations in this file.": + return [] + if re.match(r"^\s*Type[0-9]:", line): # Spurious lines for MIPS + continue + if re.match(r"^\s*Offset", line): # Header? + if toc is not None: + error("multiple headers in output of readelf") + words = re.split(r"\s\s+", line) # "Symbol's Name + Addend" + toc = make_toc(words) + elif toc is not None: + line = re.sub(r" \+ ", "+", line) + words = re.split(r"\s+", line) + rel = parse_row(words, toc, ["Offset", "Info"]) + rels.append(rel) + # Split symbolic representation + sym_name = "Symbol's Name + Addend" + if sym_name not in rel and "Symbol's Name" in rel: + # Adapt to different versions of readelf + rel[sym_name] = rel["Symbol's Name"] + "+0" + if rel[sym_name]: + p = rel[sym_name].split("+") + if len(p) == 1: + p = ["", p[0]] + rel[sym_name] = (p[0], int(p[1], 16)) + + if toc is None: + error(f"failed to analyze relocations in {f}") + + return rels + def collect_sections(f): - """Collect section info from ELF.""" - - out, _ = run(['readelf', '-SW', f]) - - toc = None - sections = [] - for line in out.splitlines(): - line = line.strip() - if not line: - continue - line = re.sub(r'\[\s+', '[', line) - words = re.split(r' +', line) - if line.startswith('[Nr]'): # Header? - if toc is not None: - error("multiple headers in output of readelf") - toc = make_toc(words, {'Addr' : 'Address'}) - elif line.startswith('[') and toc is not None: - sec = parse_row(words, toc, ['Address', 'Off', 'Size']) - if 'A' in sec['Flg']: # Allocatable section? - sections.append(sec) - - if toc is None: - error(f"failed to analyze sections in {f}") - - return sections + """Collect section info from ELF.""" + + out, _ = run(["readelf", "-SW", f]) + + toc = None + sections = [] + for line in out.splitlines(): + line = line.strip() + if not line: + continue + line = re.sub(r"\[\s+", "[", line) + words = re.split(r" +", line) + if line.startswith("[Nr]"): # Header? + if toc is not None: + error("multiple headers in output of readelf") + toc = make_toc(words, {"Addr": "Address"}) + elif line.startswith("[") and toc is not None: + sec = parse_row(words, toc, ["Address", "Off", "Size"]) + if "A" in sec["Flg"]: # Allocatable section? + sections.append(sec) + + if toc is None: + error(f"failed to analyze sections in {f}") + + return sections + def read_unrelocated_data(input_name, syms, secs): - """Collect unrelocated data from ELF.""" - data = {} - with open(input_name, 'rb') as f: - def is_symbol_in_section(sym, sec): - sec_end = sec['Address'] + sec['Size'] - is_start_in_section = sec['Address'] <= sym['Value'] < sec_end - is_end_in_section = sym['Value'] + sym['Size'] <= sec_end - return is_start_in_section and is_end_in_section - for name, s in sorted(syms.items(), key=lambda s: s[1]['Value']): - # TODO: binary search (bisect) - sec = [sec for sec in secs if is_symbol_in_section(s, sec)] - if len(sec) != 1: - error(f"failed to locate section for interval [{s['Value']:x}, {s['Value'] + s['Size']:x})") - sec = sec[0] - f.seek(sec['Off']) - data[name] = f.read(s['Size']) - return data + """Collect unrelocated data from ELF.""" + data = {} + with open(input_name, "rb") as f: + + def is_symbol_in_section(sym, sec): + sec_end = sec["Address"] + sec["Size"] + is_start_in_section = sec["Address"] <= sym["Value"] < sec_end + is_end_in_section = sym["Value"] + sym["Size"] <= sec_end + return is_start_in_section and is_end_in_section + + for name, s in sorted(syms.items(), key=lambda s: s[1]["Value"]): + # TODO: binary search (bisect) + sec = [sec for sec in secs if is_symbol_in_section(s, sec)] + if len(sec) != 1: + error( + f"failed to locate section for interval [{s['Value']:x}, {s['Value'] + s['Size']:x})" + ) + sec = sec[0] + f.seek(sec["Off"]) + data[name] = f.read(s["Size"]) + return data + def collect_relocated_data(syms, bites, rels, ptr_size, reloc_types): - """Identify relocations for each symbol""" - data = {} - for name, s in sorted(syms.items()): - b = bites.get(name) - assert b is not None - if s['Demangled Name'].startswith('typeinfo name'): - data[name] = [('byte', int(x)) for x in b] - continue - data[name] = [] - for i in range(0, len(b), ptr_size): - val = int.from_bytes(b[i*ptr_size:(i + 1)*ptr_size], byteorder='little') - data[name].append(('offset', val)) - start = s['Value'] - finish = start + s['Size'] - # TODO: binary search (bisect) - for rel in rels: - if rel['Type'] in reloc_types and start <= rel['Offset'] < finish: - i = (rel['Offset'] - start) // ptr_size - assert i < len(data[name]) - data[name][i] = 'reloc', rel - return data + """Identify relocations for each symbol""" + data = {} + for name, s in sorted(syms.items()): + b = bites.get(name) + assert b is not None + if s["Demangled Name"].startswith("typeinfo name"): + data[name] = [("byte", int(x)) for x in b] + continue + data[name] = [] + for i in range(0, len(b), ptr_size): + val = int.from_bytes( + b[i * ptr_size : (i + 1) * ptr_size], byteorder="little" + ) + data[name].append(("offset", val)) + start = s["Value"] + finish = start + s["Size"] + # TODO: binary search (bisect) + for rel in rels: + if rel["Type"] in reloc_types and start <= rel["Offset"] < finish: + i = (rel["Offset"] - start) // ptr_size + assert i < len(data[name]) + data[name][i] = "reloc", rel + return data + def generate_vtables(cls_tables, cls_syms, cls_data): - """Generate code for vtables""" - c_types = { - 'reloc' : 'const void *', - 'byte' : 'unsigned char', - 'offset' : 'size_t' - } - - ss = [] - ss.append('''\ + """Generate code for vtables""" + c_types = {"reloc": "const void *", "byte": "unsigned char", "offset": "size_t"} + + ss = [] + ss.append("""\ #ifdef __cplusplus extern "C" { #endif -''') +""") - # Print externs + # Print externs - printed = set() - for name, data in sorted(cls_data.items()): - for typ, val in data: - if typ != 'reloc': - continue - sym_name, addend = val['Symbol\'s Name + Addend'] - sym_name = re.sub(r'@.*', '', sym_name) # Can we pin version in C? - if sym_name not in cls_syms and sym_name not in printed: - ss.append(f'''\ + printed = set() + for name, data in sorted(cls_data.items()): + for typ, val in data: + if typ != "reloc": + continue + sym_name, addend = val["Symbol's Name + Addend"] + sym_name = re.sub(r"@.*", "", sym_name) # Can we pin version in C? + if sym_name not in cls_syms and sym_name not in printed: + ss.append(f"""\ extern const char {sym_name}[]; -''') +""") - # Collect variable infos + # Collect variable infos - code_info = {} + code_info = {} - for name, s in sorted(cls_syms.items()): - data = cls_data[name] - if s['Demangled Name'].startswith('typeinfo name'): - declarator = 'const unsigned char %s[]' - else: - field_types = (f'{c_types[typ]} field_{i};' for i, (typ, _) in enumerate(data)) - declarator = 'const struct { %s } %%s' % ' '.join(field_types) # pylint: disable=C0209 # consider-using-f-string - vals = [] - for typ, val in data: - if typ != 'reloc': - vals.append(str(val) + 'UL') - else: - sym_name, addend = val['Symbol\'s Name + Addend'] - sym_name = re.sub(r'@.*', '', sym_name) # Can we pin version in C? - vals.append(f'(const char *)&{sym_name} + {addend}') - code_info[name] = (declarator, '{ %s }' % ', '.join(vals)) # pylint: disable= C0209 # consider-using-f-string - - # Print declarations - - for name, (decl, _) in sorted(code_info.items()): - type_name = name + '_type' - type_decl = decl % type_name - ss.append(f'''\ + for name, s in sorted(cls_syms.items()): + data = cls_data[name] + if s["Demangled Name"].startswith("typeinfo name"): + declarator = "const unsigned char %s[]" + else: + field_types = ( + f"{c_types[typ]} field_{i};" for i, (typ, _) in enumerate(data) + ) + declarator = "const struct { %s } %%s" % " ".join(field_types) # pylint: disable=C0209 # consider-using-f-string + vals = [] + for typ, val in data: + if typ != "reloc": + vals.append(str(val) + "UL") + else: + sym_name, addend = val["Symbol's Name + Addend"] + sym_name = re.sub(r"@.*", "", sym_name) # Can we pin version in C? + vals.append(f"(const char *)&{sym_name} + {addend}") + code_info[name] = (declarator, "{ %s }" % ", ".join(vals)) # pylint: disable= C0209 # consider-using-f-string + + # Print declarations + + for name, (decl, _) in sorted(code_info.items()): + type_name = name + "_type" + type_decl = decl % type_name + ss.append(f"""\ typedef {type_decl}; extern __attribute__((weak)) {type_name} {name}; -''') +""") - # Print definitions + # Print definitions - for name, (_, init) in sorted(code_info.items()): - type_name = name + '_type' - ss.append(f'''\ + for name, (_, init) in sorted(code_info.items()): + type_name = name + "_type" + ss.append(f"""\ const {type_name} {name} = {init}; -''') +""") - ss.append('''\ + ss.append("""\ #ifdef __cplusplus } // extern "C" #endif -''') +""") + + return "".join(ss) - return ''.join(ss) def read_soname(f): - """Read ELF's SONAME.""" + """Read ELF's SONAME.""" + + out, _ = run(["readelf", "-d", f]) - out, _ = run(['readelf', '-d', f]) + for line in out.splitlines(): + line = line.strip() + if not line: + continue + # 0x000000000000000e (SONAME) Library soname: [libndp.so.0] + soname_match = re.search(r"\(SONAME\).*\[(.+)\]", line) + if soname_match is not None: + return soname_match[1] - for line in out.splitlines(): - line = line.strip() - if not line: - continue - # 0x000000000000000e (SONAME) Library soname: [libndp.so.0] - soname_match = re.search(r'\(SONAME\).*\[(.+)\]', line) - if soname_match is not None: - return soname_match[1] + return None - return None def main(): - """Driver function""" - parser = argparse.ArgumentParser(description="Generate wrappers for shared library functions.", - formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=f"""\ + """Driver function""" + parser = argparse.ArgumentParser( + description="Generate wrappers for shared library functions.", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=f"""\ Examples: $ python3 {me} /usr/lib/x86_64-linux-gnu/libaccountsservice.so.0 Generating libaccountsservice.so.0.tramp.S... Generating libaccountsservice.so.0.init.c... -""") - - parser.add_argument('library', - metavar='LIB', - help="Library to be wrapped.") - parser.add_argument('--verbose', '-v', - help="Print diagnostic info", - action='count', - default=0) - parser.add_argument('--dlopen', - help="Emit dlopen call (default)", - dest='dlopen', action='store_true', default=True) - parser.add_argument('--no-dlopen', - help="Do not emit dlopen call (user must load/unload library himself)", - dest='dlopen', action='store_false') - parser.add_argument('--dlopen-callback', - help="Call user-provided custom callback to load library instead of dlopen", - default='') - parser.add_argument('--dlsym-callback', - help="Call user-provided custom callback to resolve a symbol, " - "instead of dlsym", - default='') - parser.add_argument('--library-load-name', - help="Use custom name for dlopened library (default is SONAME)") - parser.add_argument('--lazy-load', - help="Load library on first call to any of it's functions (default)", - dest='lazy_load', action='store_true', default=True) - parser.add_argument('--no-lazy-load', - help="Load library at program start", - dest='lazy_load', action='store_false') - parser.add_argument('--vtables', - help="Intercept virtual tables (EXPERIMENTAL)", - dest='vtables', action='store_true', default=False) - parser.add_argument('--no-vtables', - help="Do not intercept virtual tables (default)", - dest='vtables', action='store_false') - parser.add_argument('--no-weak-symbols', - help="Don't bind weak symbols", dest='no_weak_symbols', - action='store_true', default=False) - parser.add_argument('--target', - help="Target platform triple e.g. x86_64-unknown-linux-gnu or arm-none-eabi " - "(atm x86_64, i[0-9]86, arm/armhf/armeabi, aarch64/armv8, " - "mips/mipsel, mips64/mip64el and e2k are supported)", - default=os.uname()[-1]) - parser.add_argument('--symbol-list', - help="Path to file with symbols that should be present in wrapper " - "(all by default)") - parser.add_argument('--symbol-prefix', - metavar='PFX', - help="Prefix wrapper symbols with PFX", - default='') - parser.add_argument('-q', '--quiet', - help="Do not print progress info", - action='store_true') - parser.add_argument('--outdir', '-o', - help="Path to create wrapper at", - default='./') - - args = parser.parse_args() - - input_name = args.library - verbose = args.verbose - dlopen_callback = args.dlopen_callback - dlsym_callback = args.dlsym_callback - dlopen = args.dlopen - lazy_load = args.lazy_load - if args.target.startswith('arm'): - target = 'arm' # Handle armhf-..., armel-... - elif re.match(r'^i[0-9]86', args.target): - target = 'i386' - elif args.target.startswith('mips64'): - target = 'mips64' # Handle mips64-..., mips64el-..., mips64le-... - elif args.target.startswith('mips'): - target = 'mips' # Handle mips-..., mipsel-..., mipsle-... - else: - target = args.target.split('-')[0] - quiet = args.quiet - outdir = args.outdir - - if args.symbol_list is None: - funs = None - else: - with open(args.symbol_list, 'r') as f: - funs = [] - for line in re.split(r'\r?\n', f.read()): - line = re.sub(r'#.*', '', line) - line = line.strip() - if line: - funs.append(line) +""", + ) + + parser.add_argument("library", metavar="LIB", help="Library to be wrapped.") + parser.add_argument( + "--verbose", "-v", help="Print diagnostic info", action="count", default=0 + ) + parser.add_argument( + "--dlopen", + help="Emit dlopen call (default)", + dest="dlopen", + action="store_true", + default=True, + ) + parser.add_argument( + "--no-dlopen", + help="Do not emit dlopen call (user must load/unload library himself)", + dest="dlopen", + action="store_false", + ) + parser.add_argument( + "--dlopen-callback", + help="Call user-provided custom callback to load library instead of dlopen", + default="", + ) + parser.add_argument( + "--dlsym-callback", + help="Call user-provided custom callback to resolve a symbol, instead of dlsym", + default="", + ) + parser.add_argument( + "--library-load-name", + help="Use custom name for dlopened library (default is SONAME)", + ) + parser.add_argument( + "--lazy-load", + help="Load library on first call to any of it's functions (default)", + dest="lazy_load", + action="store_true", + default=True, + ) + parser.add_argument( + "--no-lazy-load", + help="Load library at program start", + dest="lazy_load", + action="store_false", + ) + parser.add_argument( + "--vtables", + help="Intercept virtual tables (EXPERIMENTAL)", + dest="vtables", + action="store_true", + default=False, + ) + parser.add_argument( + "--no-vtables", + help="Do not intercept virtual tables (default)", + dest="vtables", + action="store_false", + ) + parser.add_argument( + "--no-weak-symbols", + help="Don't bind weak symbols", + dest="no_weak_symbols", + action="store_true", + default=False, + ) + parser.add_argument( + "--target", + help="Target platform triple e.g. x86_64-unknown-linux-gnu or arm-none-eabi " + "(atm x86_64, i[0-9]86, arm/armhf/armeabi, aarch64/armv8, " + "mips/mipsel, mips64/mip64el and e2k are supported)", + default=os.uname()[-1], + ) + parser.add_argument( + "--symbol-list", + help="Path to file with symbols that should be present in wrapper " + "(all by default)", + ) + parser.add_argument( + "--symbol-prefix", + metavar="PFX", + help="Prefix wrapper symbols with PFX", + default="", + ) + parser.add_argument( + "-q", "--quiet", help="Do not print progress info", action="store_true" + ) + parser.add_argument( + "--outdir", "-o", help="Path to create wrapper at", default="./" + ) + + args = parser.parse_args() + + input_name = args.library + verbose = args.verbose + dlopen_callback = args.dlopen_callback + dlsym_callback = args.dlsym_callback + dlopen = args.dlopen + lazy_load = args.lazy_load + if args.target.startswith("arm"): + target = "arm" # Handle armhf-..., armel-... + elif re.match(r"^i[0-9]86", args.target): + target = "i386" + elif args.target.startswith("mips64"): + target = "mips64" # Handle mips64-..., mips64el-..., mips64le-... + elif args.target.startswith("mips"): + target = "mips" # Handle mips-..., mipsel-..., mipsle-... + else: + target = args.target.split("-")[0] + quiet = args.quiet + outdir = args.outdir - if args.library_load_name is not None: - load_name = args.library_load_name - else: - load_name = read_soname(input_name) - if load_name is None: - load_name = os.path.basename(input_name) + if args.symbol_list is None: + funs = None + else: + with open(args.symbol_list, "r") as f: + funs = [] + for line in re.split(r"\r?\n", f.read()): + line = re.sub(r"#.*", "", line) + line = line.strip() + if line: + funs.append(line) + + if args.library_load_name is not None: + load_name = args.library_load_name + else: + load_name = read_soname(input_name) + if load_name is None: + load_name = os.path.basename(input_name) - # Collect target info + # Collect target info - target_dir = os.path.join(root, 'arch', target) + target_dir = os.path.join(root, "arch", target) - if not os.path.exists(target_dir): - error(f"unknown architecture '{target}'") + if not os.path.exists(target_dir): + error(f"unknown architecture '{target}'") - cfg = configparser.ConfigParser(inline_comment_prefixes=';') - cfg.read(target_dir + '/config.ini') + cfg = configparser.ConfigParser(inline_comment_prefixes=";") + cfg.read(target_dir + "/config.ini") - ptr_size = int(cfg['Arch']['PointerSize']) - symbol_reloc_types = set(re.split(r'\s*,\s*', cfg['Arch']['SymbolReloc'])) + ptr_size = int(cfg["Arch"]["PointerSize"]) + symbol_reloc_types = set(re.split(r"\s*,\s*", cfg["Arch"]["SymbolReloc"])) - def is_exported(s): - conditions = [ - s['Bind'] != 'LOCAL', - s['Type'] != 'NOTYPE', - s['Ndx'] != 'UND', - s['Name'] not in ['', '_init', '_fini']] - if args.no_weak_symbols: - conditions.append(s['Bind'] != 'WEAK') - return all(conditions) + def is_exported(s): + conditions = [ + s["Bind"] != "LOCAL", + s["Type"] != "NOTYPE", + s["Ndx"] != "UND", + s["Name"] not in ["", "_init", "_fini"], + ] + if args.no_weak_symbols: + conditions.append(s["Bind"] != "WEAK") + return all(conditions) - syms = list(filter(is_exported, collect_syms(input_name))) + syms = list(filter(is_exported, collect_syms(input_name))) - def is_data_symbol(s): - return (s['Type'] == 'OBJECT' + def is_data_symbol(s): + return ( + s["Type"] == "OBJECT" # Allow vtables if --vtables is on - and not (' for ' in s['Demangled Name'] and args.vtables)) - - exported_data = [s['Name'] for s in syms if is_data_symbol(s)] - if exported_data: - # TODO: we can generate wrappers for const data without relocations (or only code relocations) - warn(f"library '{input_name}' contains data symbols which won't be intercepted: " - + ', '.join(exported_data)) - - # Collect functions - # TODO: warn if user-specified functions are missing - - orig_funs = filter(lambda s: s['Type'] == 'FUNC', syms) - - all_funs = set() - warn_versioned = False - for s in orig_funs: - if not s['Default']: - # TODO: support versions - if not warn_versioned: - warn(f"library {input_name} contains versioned symbols which are NYI") - warn_versioned = True - if verbose: - print(f"Skipping versioned symbol {s['Name']}") - continue - all_funs.add(s['Name']) - - if funs is None: - funs = sorted(list(all_funs)) - if not funs and not quiet: - warn(f"no public functions were found in {input_name}") - else: - missing_funs = [name for name in funs if name not in all_funs] - if missing_funs: - warn("some user-specified functions are not present in library: " + ', '.join(missing_funs)) - funs = [name for name in funs if name in all_funs] - - if verbose: - print("Exported functions:") - for i, fun in enumerate(funs): - print(f" {i}: {fun}") - - # Collect vtables - - if args.vtables: - cls_tables = {} - cls_syms = {} - - for s in syms: - m = re.match(r'^(vtable|typeinfo|typeinfo name) for (.*)', s['Demangled Name']) - if m is not None and is_exported(s): - typ, cls = m.groups() - name = s['Name'] - cls_tables.setdefault(cls, {})[typ] = name - cls_syms[name] = s - - if verbose: - print("Exported classes:") - for cls, _ in sorted(cls_tables.items()): - print(f" {cls}") + and not (" for " in s["Demangled Name"] and args.vtables) + ) + + exported_data = [s["Name"] for s in syms if is_data_symbol(s)] + if exported_data: + # TODO: we can generate wrappers for const data without relocations (or only code relocations) + warn( + f"library '{input_name}' contains data symbols which won't be intercepted: " + + ", ".join(exported_data) + ) + + # Collect functions + # TODO: warn if user-specified functions are missing + + orig_funs = filter(lambda s: s["Type"] == "FUNC", syms) + + all_funs = set() + warn_versioned = False + for s in orig_funs: + if not s["Default"]: + # TODO: support versions + if not warn_versioned: + warn(f"library {input_name} contains versioned symbols which are NYI") + warn_versioned = True + if verbose: + print(f"Skipping versioned symbol {s['Name']}") + continue + all_funs.add(s["Name"]) + + if funs is None: + funs = sorted(list(all_funs)) + if not funs and not quiet: + warn(f"no public functions were found in {input_name}") + else: + missing_funs = [name for name in funs if name not in all_funs] + if missing_funs: + warn( + "some user-specified functions are not present in library: " + + ", ".join(missing_funs) + ) + funs = [name for name in funs if name in all_funs] - secs = collect_sections(input_name) if verbose: - print("Sections:") - for sec in secs: - print(f" {sec['Name']}: [{sec['Address']:x}, {sec['Address'] + sec['Size']:x}), " - f"at {sec['Off']:x}") + print("Exported functions:") + for i, fun in enumerate(funs): + print(f" {i}: {fun}") - bites = read_unrelocated_data(input_name, cls_syms, secs) + # Collect vtables - rels = collect_relocs(input_name) - if verbose: - print("Relocs:") - for rel in rels: - sym_add = rel['Symbol\'s Name + Addend'] - print(f" {rel['Offset']}: {sym_add}") - - cls_data = collect_relocated_data(cls_syms, bites, rels, ptr_size, symbol_reloc_types) - if verbose: - print("Class data:") - for name, data in sorted(cls_data.items()): - demangled_name = cls_syms[name]['Demangled Name'] - print(f" {name} ({demangled_name}):") - for typ, val in data: - print(" " + str(val if typ != 'reloc' else val['Symbol\'s Name + Addend'])) - - # Generate assembly code - - suffix = os.path.basename(input_name) - lib_suffix = re.sub(r'[^a-zA-Z_0-9]+', '_', suffix) - - tramp_file = f'{suffix}.tramp.S' - with open(os.path.join(outdir, tramp_file), 'w') as f: - if not quiet: - print(f"Generating {tramp_file}...") - with open(target_dir + '/table.S.tpl', 'r') as t: - table_text = string.Template(t.read()).substitute( - lib_suffix=lib_suffix, - table_size=ptr_size*(len(funs) + 1)) - f.write(table_text) - - with open(target_dir + '/trampoline.S.tpl', 'r') as t: - tramp_tpl = string.Template(t.read()) - - for i, name in enumerate(funs): - tramp_text = tramp_tpl.substitute( - lib_suffix=lib_suffix, - sym=args.symbol_prefix + name, - offset=i*ptr_size, - number=i) - f.write(tramp_text) - - # Generate C code - - init_file = f'{suffix}.init.c' - with open(os.path.join(outdir, init_file), 'w') as f: - if not quiet: - print(f"Generating {init_file}...") - with open(os.path.join(root, 'arch/common/init.c.tpl'), 'r') as t: - if funs: - sym_names = ',\n '.join(f'"{name}"' for name in funs) + ',' - else: - sym_names = '' - init_text = string.Template(t.read()).substitute( - lib_suffix=lib_suffix, - load_name=load_name, - dlopen_callback=dlopen_callback, - dlsym_callback=dlsym_callback, - has_dlopen_callback=int(bool(dlopen_callback)), - has_dlsym_callback=int(bool(dlsym_callback)), - no_dlopen=int(not dlopen), - lazy_load=int(lazy_load), - sym_names=sym_names) - f.write(init_text) if args.vtables: - vtable_text = generate_vtables(cls_tables, cls_syms, cls_data) - f.write(vtable_text) - -if __name__ == '__main__': - main() + cls_tables = {} + cls_syms = {} + + for s in syms: + m = re.match( + r"^(vtable|typeinfo|typeinfo name) for (.*)", s["Demangled Name"] + ) + if m is not None and is_exported(s): + typ, cls = m.groups() + name = s["Name"] + cls_tables.setdefault(cls, {})[typ] = name + cls_syms[name] = s + + if verbose: + print("Exported classes:") + for cls, _ in sorted(cls_tables.items()): + print(f" {cls}") + + secs = collect_sections(input_name) + if verbose: + print("Sections:") + for sec in secs: + print( + f" {sec['Name']}: [{sec['Address']:x}, {sec['Address'] + sec['Size']:x}), " + f"at {sec['Off']:x}" + ) + + bites = read_unrelocated_data(input_name, cls_syms, secs) + + rels = collect_relocs(input_name) + if verbose: + print("Relocs:") + for rel in rels: + sym_add = rel["Symbol's Name + Addend"] + print(f" {rel['Offset']}: {sym_add}") + + cls_data = collect_relocated_data( + cls_syms, bites, rels, ptr_size, symbol_reloc_types + ) + if verbose: + print("Class data:") + for name, data in sorted(cls_data.items()): + demangled_name = cls_syms[name]["Demangled Name"] + print(f" {name} ({demangled_name}):") + for typ, val in data: + print( + " " + + str(val if typ != "reloc" else val["Symbol's Name + Addend"]) + ) + + # Generate assembly code + + suffix = os.path.basename(input_name) + lib_suffix = re.sub(r"[^a-zA-Z_0-9]+", "_", suffix) + + tramp_file = f"{suffix}.tramp.S" + with open(os.path.join(outdir, tramp_file), "w") as f: + if not quiet: + print(f"Generating {tramp_file}...") + with open(target_dir + "/table.S.tpl", "r") as t: + table_text = string.Template(t.read()).substitute( + lib_suffix=lib_suffix, table_size=ptr_size * (len(funs) + 1) + ) + f.write(table_text) + + with open(target_dir + "/trampoline.S.tpl", "r") as t: + tramp_tpl = string.Template(t.read()) + + for i, name in enumerate(funs): + tramp_text = tramp_tpl.substitute( + lib_suffix=lib_suffix, + sym=args.symbol_prefix + name, + offset=i * ptr_size, + number=i, + ) + f.write(tramp_text) + + # Generate C code + + init_file = f"{suffix}.init.c" + with open(os.path.join(outdir, init_file), "w") as f: + if not quiet: + print(f"Generating {init_file}...") + with open(os.path.join(root, "arch/common/init.c.tpl"), "r") as t: + if funs: + sym_names = ",\n ".join(f'"{name}"' for name in funs) + "," + else: + sym_names = "" + init_text = string.Template(t.read()).substitute( + lib_suffix=lib_suffix, + load_name=load_name, + dlopen_callback=dlopen_callback, + dlsym_callback=dlsym_callback, + has_dlopen_callback=int(bool(dlopen_callback)), + has_dlsym_callback=int(bool(dlsym_callback)), + no_dlopen=int(not dlopen), + lazy_load=int(lazy_load), + sym_names=sym_names, + ) + f.write(init_text) + if args.vtables: + vtable_text = generate_vtables(cls_tables, cls_syms, cls_data) + f.write(vtable_text) + + +if __name__ == "__main__": + main() diff --git a/source/tests/common/test_nan_detector.py b/source/tests/common/test_nan_detector.py new file mode 100644 index 0000000000..1ce719d944 --- /dev/null +++ b/source/tests/common/test_nan_detector.py @@ -0,0 +1,167 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Test cases for NaN detection utility.""" + +import math +import unittest + +import numpy as np + +from deepmd.utils.nan_detector import ( + LossNaNError, + check_loss_nan, + check_single_loss_nan, +) + + +class TestNaNDetector(unittest.TestCase): + """Test the NaN detection utility functions.""" + + def test_normal_values_pass(self): + """Test that normal loss values don't trigger NaN detection.""" + # Test with various normal values + normal_losses = { + "energy": 0.5, + "force": 1.0, + "virial": 0.001, + "zero": 0.0, + "negative": -0.5, + } + + # Should not raise any exception + try: + check_loss_nan(100, normal_losses) + except Exception as e: + self.fail(f"Normal values should not raise exception: {e}") + + def test_nan_detection_raises_exception(self): + """Test that NaN values trigger the proper exception.""" + # Test with NaN values + nan_losses = { + "energy": 0.5, + "force": float("nan"), + "virial": 1.0, + } + + with self.assertRaises(LossNaNError) as context: + check_loss_nan(200, nan_losses) + + exception = context.exception + self.assertEqual(exception.step, 200) + self.assertIn("force", str(exception)) + self.assertIn("NaN detected in loss at training step 200", str(exception)) + + def test_single_loss_nan_detection(self): + """Test single loss NaN detection.""" + # Normal value should pass + try: + check_single_loss_nan(50, "test_loss", 0.5) + except Exception as e: + self.fail(f"Normal single loss should not raise exception: {e}") + + # NaN value should raise + with self.assertRaises(LossNaNError) as context: + check_single_loss_nan(50, "test_loss", float("nan")) + + exception = context.exception + self.assertEqual(exception.step, 50) + self.assertIn("test_loss", str(exception)) + + def test_various_nan_representations(self): + """Test detection of various NaN representations.""" + nan_values = [ + float("nan"), + np.nan, + math.nan, + ] + + for i, nan_val in enumerate(nan_values): + with self.assertRaises(LossNaNError): + check_single_loss_nan(i, f"loss_{i}", nan_val) + + def test_tensor_like_objects(self): + """Test that tensor-like objects work with NaN detection.""" + + # Mock tensor-like object with item() method + class MockTensor: + def __init__(self, value): + self._value = value + + def item(self): + return self._value + + # Normal tensor should pass + normal_tensor = MockTensor(0.5) + try: + check_single_loss_nan(10, "tensor_loss", normal_tensor) + except Exception as e: + self.fail(f"Normal tensor should not raise exception: {e}") + + # NaN tensor should raise + nan_tensor = MockTensor(float("nan")) + with self.assertRaises(LossNaNError): + check_single_loss_nan(10, "tensor_loss", nan_tensor) + + def test_error_message_format(self): + """Test that error messages contain useful information.""" + nan_losses = { + "energy": 0.5, + "force": float("nan"), + "virial": float("nan"), + } + + with self.assertRaises(LossNaNError) as context: + check_loss_nan(123, nan_losses) + + error_msg = str(context.exception) + + # Check key information is in the message + self.assertIn("step 123", error_msg) + self.assertIn("force=nan", error_msg) + self.assertIn("virial=nan", error_msg) + self.assertIn("Training stopped", error_msg) + self.assertIn("learning rate too high", error_msg) + + def test_mixed_loss_dict(self): + """Test loss dictionary with mix of normal and NaN values.""" + mixed_losses = { + "energy": 0.5, + "force": float("nan"), + "virial": 1.0, + "dipole": float("nan"), + } + + with self.assertRaises(LossNaNError) as context: + check_loss_nan(99, mixed_losses) + + exception = context.exception + # Should detect both NaN values + error_msg = str(exception) + self.assertIn("force=nan", error_msg) + self.assertIn("dipole=nan", error_msg) + # Should not mention normal values + self.assertNotIn("energy=0.5", error_msg) + self.assertNotIn("virial=1.0", error_msg) + + def test_edge_cases(self): + """Test edge cases for NaN detection.""" + # Empty dict should pass + try: + check_loss_nan(1, {}) + except Exception as e: + self.fail(f"Empty dict should not raise exception: {e}") + + # None values should not trigger NaN detection + try: + check_loss_nan(1, {"test": None}) + except Exception as e: + self.fail(f"None values should not raise exception: {e}") + + # Infinity should not trigger NaN detection (separate issue) + try: + check_loss_nan(1, {"test": float("inf")}) + except Exception as e: + self.fail(f"Infinity should not raise NaN exception: {e}") + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/common/test_nan_integration.py b/source/tests/common/test_nan_integration.py new file mode 100644 index 0000000000..fcd4624f66 --- /dev/null +++ b/source/tests/common/test_nan_integration.py @@ -0,0 +1,131 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Integration test to verify NaN detection during training. + +This test creates a mock training scenario where loss becomes NaN +and verifies that the training stops with appropriate error message. +""" + +import unittest +from unittest.mock import ( + patch, +) + +from deepmd.utils.nan_detector import ( + LossNaNError, + check_loss_nan, +) + + +class TestNaNDetectionIntegration(unittest.TestCase): + """Integration tests for NaN detection during training.""" + + def test_training_stops_on_nan_loss(self): + """Test that training stops when NaN is detected in loss values.""" + # Simulate a training scenario where loss becomes NaN + train_results = { + "energy_loss": 0.1, + "force_loss": float("nan"), # This should trigger the detection + "virial_loss": 0.05, + } + + valid_results = { + "energy_loss": 0.12, + "force_loss": 0.08, + "virial_loss": 0.06, + } + + # The NaN in train_results should be detected + with self.assertRaises(LossNaNError) as context: + check_loss_nan(100, train_results) + + exception = context.exception + self.assertEqual(exception.step, 100) + self.assertIn("force_loss=nan", str(exception)) + + # Valid results without NaN should pass + try: + check_loss_nan(100, valid_results) + except Exception as e: + self.fail(f"Valid results should not raise exception: {e}") + + def test_multi_task_nan_detection(self): + """Test NaN detection in multi-task training scenario.""" + # Simulate multi-task training results + multi_task_results = { + "task1": { + "energy_loss": 0.1, + "force_loss": 0.05, + }, + "task2": { + "energy_loss": float("nan"), # NaN in task2 + "force_loss": 0.03, + }, + "task3": { + "energy_loss": 0.08, + "force_loss": 0.04, + }, + } + + # Check each task separately (as done in the actual training code) + # task1 and task3 should pass + try: + check_loss_nan(50, multi_task_results["task1"]) + check_loss_nan(50, multi_task_results["task3"]) + except Exception as e: + self.fail(f"Normal tasks should not raise exception: {e}") + + # task2 should fail due to NaN + with self.assertRaises(LossNaNError) as context: + check_loss_nan(50, multi_task_results["task2"]) + + exception = context.exception + self.assertEqual(exception.step, 50) + self.assertIn("energy_loss=nan", str(exception)) + + @patch("deepmd.utils.nan_detector.log") + def test_logging_on_nan_detection(self, mock_log): + """Test that NaN detection logs appropriate error messages.""" + nan_losses = { + "energy": 0.5, + "force": float("nan"), + } + + with self.assertRaises(LossNaNError): + check_loss_nan(200, nan_losses) + + # Verify that error was logged + mock_log.error.assert_called_once() + logged_message = mock_log.error.call_args[0][0] + self.assertIn("NaN detected in force at step 200", logged_message) + + def test_training_simulation_with_checkpoint_prevention(self): + """Simulate the training checkpoint scenario to ensure NaN prevents saving.""" + + def mock_save_checkpoint(): + """Mock function that should not be called when NaN is detected.""" + raise AssertionError("Checkpoint should not be saved when NaN is detected!") + + # Simulate the training flow: check loss, then save checkpoint + step_id = 1000 + loss_results = { + "total_loss": float("nan"), + "energy_loss": 0.1, + "force_loss": 0.05, + } + + # This should raise LossNaNError before checkpoint saving + with self.assertRaises(LossNaNError): + check_loss_nan(step_id, loss_results) + # This line should never be reached + mock_save_checkpoint() + + # Verify the error contains expected information + try: + check_loss_nan(step_id, loss_results) + except LossNaNError as e: + self.assertIn("Training stopped to prevent wasting time", str(e)) + self.assertIn("corrupted parameters", str(e)) + + +if __name__ == "__main__": + unittest.main() From 5a22dfc611b55b9d87b6c09cf6d881438a32dbc0 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 20 Sep 2025 16:38:11 +0000 Subject: [PATCH 3/7] fix(training): address PR feedback - simplify NaN detection API and inherit from RuntimeError Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com> --- deepmd/pd/train/training.py | 18 +- deepmd/pt/train/training.py | 22 +- deepmd/tf/train/trainer.py | 20 +- deepmd/utils/nan_detector.py | 101 +- source/3rdparty/implib/implib-gen.py | 1093 +++++++++---------- source/tests/common/test_nan_detector.py | 146 +-- source/tests/common/test_nan_integration.py | 115 +- 7 files changed, 634 insertions(+), 881 deletions(-) diff --git a/deepmd/pd/train/training.py b/deepmd/pd/train/training.py index 7c084f1084..d6ae30a94a 100644 --- a/deepmd/pd/train/training.py +++ b/deepmd/pd/train/training.py @@ -76,7 +76,7 @@ DataRequirementItem, ) from deepmd.utils.nan_detector import ( - check_loss_nan, + check_total_loss_nan, ) from deepmd.utils.path import ( DPH5Path, @@ -781,6 +781,8 @@ def step(_step_id, task_key="Default") -> None: label=label_dict, task_key=task_key, ) + # Check for NaN in total loss before backward pass to prevent corrupted training + check_total_loss_nan(_step_id + 1, loss.item()) with nvprof_context(enable_profiling, "Backward pass"): loss.backward() @@ -954,20 +956,6 @@ def log_loss_valid(_task_key="Default"): fout, display_step_id, cur_lr, train_results, valid_results ) - # Check for NaN in loss values before saving checkpoint - # Loss values are already on CPU at this point for display/logging - if self.rank == 0: - if not self.multi_task: - check_loss_nan(display_step_id, train_results) - if valid_results: - check_loss_nan(display_step_id, valid_results) - else: - for task_key in train_results: - if train_results[task_key]: - check_loss_nan(display_step_id, train_results[task_key]) - if valid_results[task_key]: - check_loss_nan(display_step_id, valid_results[task_key]) - if ( ((_step_id + 1) % self.save_freq == 0 and _step_id != self.start_step) or (_step_id + 1) == self.num_steps diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index d9ef2fbd41..cc1858b4b7 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -76,7 +76,7 @@ DataRequirementItem, ) from deepmd.utils.nan_detector import ( - check_loss_nan, + check_total_loss_nan, ) if torch.__version__.startswith("2"): @@ -764,6 +764,8 @@ def step(_step_id: int, task_key: str = "Default") -> None: model_pred, loss, more_loss = self.wrapper( **input_dict, cur_lr=pref_lr, label=label_dict, task_key=task_key ) + # Check for NaN in total loss before backward pass to prevent corrupted training + check_total_loss_nan(_step_id + 1, loss.item()) loss.backward() if self.gradient_max_norm > 0.0: torch.nn.utils.clip_grad_norm_( @@ -815,6 +817,8 @@ def fake_model() -> dict: int(input_dict["atype"].shape[-1]), learning_rate=pref_lr, ) + # Check for NaN in total loss before continuing training + check_total_loss_nan(_step_id + 1, loss.item()) elif isinstance(self.loss, DenoiseLoss): KFOptWrapper = KFOptimizerWrapper( self.wrapper, @@ -841,6 +845,8 @@ def fake_model() -> dict: input_dict["natoms"], learning_rate=pref_lr, ) + # Check for NaN in total loss before continuing training + check_total_loss_nan(_step_id + 1, loss.item()) else: raise ValueError(f"Not supported optimizer type '{self.opt_type}'") @@ -1073,20 +1079,6 @@ def log_loss_valid(_task_key: str = "Default") -> dict: fout, display_step_id, cur_lr, train_results, valid_results ) - # Check for NaN in loss values before saving checkpoint - # Loss values are already on CPU at this point for display/logging - if self.rank == 0: - if not self.multi_task: - check_loss_nan(display_step_id, train_results) - if valid_results: - check_loss_nan(display_step_id, valid_results) - else: - for task_key in train_results: - if train_results[task_key]: - check_loss_nan(display_step_id, train_results[task_key]) - if valid_results[task_key]: - check_loss_nan(display_step_id, valid_results[task_key]) - if ( ( (display_step_id) % self.save_freq == 0 diff --git a/deepmd/tf/train/trainer.py b/deepmd/tf/train/trainer.py index 4b6e884a25..957b77a274 100644 --- a/deepmd/tf/train/trainer.py +++ b/deepmd/tf/train/trainer.py @@ -61,7 +61,7 @@ DataRequirementItem, ) from deepmd.utils.nan_detector import ( - check_loss_nan, + check_total_loss_nan, ) log = logging.getLogger(__name__) @@ -688,11 +688,19 @@ def valid_on_the_fly( cur_batch = self.cur_batch current_lr = run_sess(self.sess, self.learning_rate) - # Check for NaN in loss values before writing to file and saving checkpoint - # Loss values are already on CPU at this point - check_loss_nan(cur_batch, train_results) - if valid_results is not None: - check_loss_nan(cur_batch, valid_results) + # Check for NaN in total loss before writing to file and saving checkpoint + # We check the main loss component that represents total training loss + if train_results: + # Look for the main loss key (typically the first loss component) + main_loss_key = next(iter(train_results.keys())) if train_results else None + if main_loss_key and main_loss_key in train_results: + check_total_loss_nan(cur_batch, train_results[main_loss_key]) + + if valid_results: + # Check validation loss as well for consistency + main_loss_key = next(iter(valid_results.keys())) if valid_results else None + if main_loss_key and main_loss_key in valid_results: + check_total_loss_nan(cur_batch, valid_results[main_loss_key]) if print_header: self.print_header(fp, train_results, valid_results) diff --git a/deepmd/utils/nan_detector.py b/deepmd/utils/nan_detector.py index 6a2641e327..7c5095322f 100644 --- a/deepmd/utils/nan_detector.py +++ b/deepmd/utils/nan_detector.py @@ -3,117 +3,52 @@ import logging import math -from typing import ( - Any, -) - -import numpy as np log = logging.getLogger(__name__) -class LossNaNError(Exception): - """Exception raised when NaN is detected in loss during training.""" +class LossNaNError(RuntimeError): + """Exception raised when NaN is detected in total loss during training.""" - def __init__(self, step: int, loss_dict: dict[str, Any]) -> None: + def __init__(self, step: int, total_loss: float) -> None: """Initialize the exception. Parameters ---------- step : int The training step where NaN was detected - loss_dict : dict[str, Any] - Dictionary containing the loss values where NaN was found + total_loss : float + The total loss value that contains NaN """ self.step = step - self.loss_dict = loss_dict - super().__init__(self._format_message()) - - def _format_message(self) -> str: - """Format the error message.""" - nan_losses = [] - for key, value in self.loss_dict.items(): - if self._is_nan(value): - nan_losses.append(f"{key}={value}") - + self.total_loss = total_loss message = ( - f"NaN detected in loss at training step {self.step}. " + f"NaN detected in total loss at training step {step}: {total_loss}. " f"Training stopped to prevent wasting time with corrupted parameters. " - f"NaN values found in: {', '.join(nan_losses)}. " f"This typically indicates unstable training conditions such as " f"learning rate too high, poor data quality, or numerical instability." ) - return message - - @staticmethod - def _is_nan(value: Any) -> bool: - """Check if a value is NaN.""" - if value is None: - return False - try: - # Handle various tensor types and Python scalars - if hasattr(value, "item"): - # PyTorch/TensorFlow/PaddlePaddle tensor - return math.isnan(value.item()) - elif isinstance(value, (int, float)): - # Python scalar - return math.isnan(value) - elif isinstance(value, np.ndarray): - # NumPy array - return np.isnan(value).any() - else: - # Try to convert to float and check - return math.isnan(float(value)) - except (TypeError, ValueError): - # If we can't convert to float, assume it's not NaN - return False - - -def check_loss_nan(step: int, loss_dict: dict[str, Any]) -> None: - """Check if any loss values contain NaN and raise an exception if found. - - This function is designed to be called during training after loss values - are computed and available on CPU, typically during the logging/display phase. - - Parameters - ---------- - step : int - Current training step - loss_dict : dict[str, Any] - Dictionary containing loss values to check for NaN - - Raises - ------ - LossNaNError - If any loss value contains NaN - """ - nan_found = False - for key, value in loss_dict.items(): - if LossNaNError._is_nan(value): - nan_found = True - log.error(f"NaN detected in {key} at step {step}: {value}") + super().__init__(message) - if nan_found: - raise LossNaNError(step, loss_dict) +def check_total_loss_nan(step: int, total_loss: float) -> None: + """Check if the total loss contains NaN and raise an exception if found. -def check_single_loss_nan(step: int, loss_name: str, loss_value: Any) -> None: - """Check if a single loss value contains NaN and raise an exception if found. + This function is designed to be called during training after the total loss + is computed and converted to a CPU float value. Parameters ---------- step : int Current training step - loss_name : str - Name/identifier of the loss - loss_value : Any - Loss value to check for NaN + total_loss : float + Total loss value to check for NaN Raises ------ LossNaNError - If the loss value contains NaN + If the total loss contains NaN """ - if LossNaNError._is_nan(loss_value): - log.error(f"NaN detected in {loss_name} at step {step}: {loss_value}") - raise LossNaNError(step, {loss_name: loss_value}) + if math.isnan(total_loss): + log.error(f"NaN detected in total loss at step {step}: {total_loss}") + raise LossNaNError(step, total_loss) diff --git a/source/3rdparty/implib/implib-gen.py b/source/3rdparty/implib/implib-gen.py index 3a51be271d..86cfa77378 100755 --- a/source/3rdparty/implib/implib-gen.py +++ b/source/3rdparty/implib/implib-gen.py @@ -22,654 +22,577 @@ me = os.path.basename(__file__) root = os.path.dirname(__file__) - def warn(msg): - """Emits a nicely-decorated warning.""" - sys.stderr.write(f"{me}: warning: {msg}\n") - + """Emits a nicely-decorated warning.""" + sys.stderr.write(f'{me}: warning: {msg}\n') def error(msg): - """Emits a nicely-decorated error and exits.""" - sys.stderr.write(f"{me}: error: {msg}\n") - sys.exit(1) - - -def run(args, stdin=""): - """Runs external program and aborts on error.""" - env = os.environ.copy() - # Force English language - env["LC_ALL"] = "c" - try: - del env["LANG"] - except KeyError: - pass - with subprocess.Popen( - args, - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - env=env, - ) as p: - out, err = p.communicate(input=stdin.encode("utf-8")) - out = out.decode("utf-8") - err = err.decode("utf-8") - if p.returncode != 0 or err: - error(f"{args[0]} failed with retcode {p.returncode}:\n{err}") - return out, err - + """Emits a nicely-decorated error and exits.""" + sys.stderr.write(f'{me}: error: {msg}\n') + sys.exit(1) + +def run(args, stdin=''): + """Runs external program and aborts on error.""" + env = os.environ.copy() + # Force English language + env['LC_ALL'] = 'c' + try: + del env["LANG"] + except KeyError: + pass + with subprocess.Popen(args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, + stderr=subprocess.PIPE, env=env) as p: + out, err = p.communicate(input=stdin.encode('utf-8')) + out = out.decode('utf-8') + err = err.decode('utf-8') + if p.returncode != 0 or err: + error(f"{args[0]} failed with retcode {p.returncode}:\n{err}") + return out, err def make_toc(words, renames=None): - "Make an mapping of words to their indices in list" - renames = renames or {} - toc = {} - for i, n in enumerate(words): - name = renames.get(n, n) - toc[i] = name - return toc - + "Make an mapping of words to their indices in list" + renames = renames or {} + toc = {} + for i, n in enumerate(words): + name = renames.get(n, n) + toc[i] = name + return toc def parse_row(words, toc, hex_keys): - "Make a mapping from column names to values" - vals = {k: (words[i] if i < len(words) else "") for i, k in toc.items()} - for k in hex_keys: - if vals[k]: - vals[k] = int(vals[k], 16) - return vals - + "Make a mapping from column names to values" + vals = {k: (words[i] if i < len(words) else '') for i, k in toc.items()} + for k in hex_keys: + if vals[k]: + vals[k] = int(vals[k], 16) + return vals def collect_syms(f): - """Collect ELF dynamic symtab.""" - - # --dyn-syms does not always work for some reason so dump all symtabs - out, _ = run(["readelf", "-sW", f]) - - toc = None - syms = [] - syms_set = set() - for line in out.splitlines(): - line = line.strip() - if not line: - # Next symtab - toc = None - continue - words = re.split(r" +", line) - if line.startswith("Num"): # Header? - if toc is not None: - error("multiple headers in output of readelf") - # Colons are different across readelf versions so get rid of them. - toc = make_toc(map(lambda n: n.replace(":", ""), words)) - elif toc is not None: - sym = parse_row(words, toc, ["Value"]) - name = sym["Name"] - if not name: - continue - if name in syms_set: - continue - syms_set.add(name) - sym["Size"] = int(sym["Size"], 0) # Readelf is inconistent on Size format - if "@" in name: - sym["Default"] = "@@" in name - name, ver = re.split(r"@+", name) - sym["Name"] = name - sym["Version"] = ver - else: - sym["Default"] = True - sym["Version"] = None - syms.append(sym) - - if toc is None: - error(f"failed to analyze symbols in {f}") - - # Also collected demangled names - if syms: - out, _ = run(["c++filt"], "\n".join((sym["Name"] for sym in syms))) - out = out.rstrip("\n") # Some c++filts append newlines at the end - for i, name in enumerate(out.split("\n")): - syms[i]["Demangled Name"] = name - - return syms - + """Collect ELF dynamic symtab.""" + + # --dyn-syms does not always work for some reason so dump all symtabs + out, _ = run(['readelf', '-sW', f]) + + toc = None + syms = [] + syms_set = set() + for line in out.splitlines(): + line = line.strip() + if not line: + # Next symtab + toc = None + continue + words = re.split(r' +', line) + if line.startswith('Num'): # Header? + if toc is not None: + error("multiple headers in output of readelf") + # Colons are different across readelf versions so get rid of them. + toc = make_toc(map(lambda n: n.replace(':', ''), words)) + elif toc is not None: + sym = parse_row(words, toc, ['Value']) + name = sym['Name'] + if not name: + continue + if name in syms_set: + continue + syms_set.add(name) + sym['Size'] = int(sym['Size'], 0) # Readelf is inconistent on Size format + if '@' in name: + sym['Default'] = '@@' in name + name, ver = re.split(r'@+', name) + sym['Name'] = name + sym['Version'] = ver + else: + sym['Default'] = True + sym['Version'] = None + syms.append(sym) + + if toc is None: + error(f"failed to analyze symbols in {f}") + + # Also collected demangled names + if syms: + out, _ = run(['c++filt'], '\n'.join((sym['Name'] for sym in syms))) + out = out.rstrip("\n") # Some c++filts append newlines at the end + for i, name in enumerate(out.split("\n")): + syms[i]['Demangled Name'] = name + + return syms def collect_relocs(f): - """Collect ELF dynamic relocs.""" - - out, _ = run(["readelf", "-rW", f]) - - toc = None - rels = [] - for line in out.splitlines(): - line = line.strip() - if not line: - toc = None - continue - if line == "There are no relocations in this file.": - return [] - if re.match(r"^\s*Type[0-9]:", line): # Spurious lines for MIPS - continue - if re.match(r"^\s*Offset", line): # Header? - if toc is not None: - error("multiple headers in output of readelf") - words = re.split(r"\s\s+", line) # "Symbol's Name + Addend" - toc = make_toc(words) - elif toc is not None: - line = re.sub(r" \+ ", "+", line) - words = re.split(r"\s+", line) - rel = parse_row(words, toc, ["Offset", "Info"]) - rels.append(rel) - # Split symbolic representation - sym_name = "Symbol's Name + Addend" - if sym_name not in rel and "Symbol's Name" in rel: - # Adapt to different versions of readelf - rel[sym_name] = rel["Symbol's Name"] + "+0" - if rel[sym_name]: - p = rel[sym_name].split("+") - if len(p) == 1: - p = ["", p[0]] - rel[sym_name] = (p[0], int(p[1], 16)) - - if toc is None: - error(f"failed to analyze relocations in {f}") - - return rels - + """Collect ELF dynamic relocs.""" + + out, _ = run(['readelf', '-rW', f]) + + toc = None + rels = [] + for line in out.splitlines(): + line = line.strip() + if not line: + toc = None + continue + if line == 'There are no relocations in this file.': + return [] + if re.match(r'^\s*Type[0-9]:', line): # Spurious lines for MIPS + continue + if re.match(r'^\s*Offset', line): # Header? + if toc is not None: + error("multiple headers in output of readelf") + words = re.split(r'\s\s+', line) # "Symbol's Name + Addend" + toc = make_toc(words) + elif toc is not None: + line = re.sub(r' \+ ', '+', line) + words = re.split(r'\s+', line) + rel = parse_row(words, toc, ['Offset', 'Info']) + rels.append(rel) + # Split symbolic representation + sym_name = 'Symbol\'s Name + Addend' + if sym_name not in rel and 'Symbol\'s Name' in rel: + # Adapt to different versions of readelf + rel[sym_name] = rel['Symbol\'s Name'] + '+0' + if rel[sym_name]: + p = rel[sym_name].split('+') + if len(p) == 1: + p = ['', p[0]] + rel[sym_name] = (p[0], int(p[1], 16)) + + if toc is None: + error(f"failed to analyze relocations in {f}") + + return rels def collect_sections(f): - """Collect section info from ELF.""" - - out, _ = run(["readelf", "-SW", f]) - - toc = None - sections = [] - for line in out.splitlines(): - line = line.strip() - if not line: - continue - line = re.sub(r"\[\s+", "[", line) - words = re.split(r" +", line) - if line.startswith("[Nr]"): # Header? - if toc is not None: - error("multiple headers in output of readelf") - toc = make_toc(words, {"Addr": "Address"}) - elif line.startswith("[") and toc is not None: - sec = parse_row(words, toc, ["Address", "Off", "Size"]) - if "A" in sec["Flg"]: # Allocatable section? - sections.append(sec) - - if toc is None: - error(f"failed to analyze sections in {f}") - - return sections - + """Collect section info from ELF.""" + + out, _ = run(['readelf', '-SW', f]) + + toc = None + sections = [] + for line in out.splitlines(): + line = line.strip() + if not line: + continue + line = re.sub(r'\[\s+', '[', line) + words = re.split(r' +', line) + if line.startswith('[Nr]'): # Header? + if toc is not None: + error("multiple headers in output of readelf") + toc = make_toc(words, {'Addr' : 'Address'}) + elif line.startswith('[') and toc is not None: + sec = parse_row(words, toc, ['Address', 'Off', 'Size']) + if 'A' in sec['Flg']: # Allocatable section? + sections.append(sec) + + if toc is None: + error(f"failed to analyze sections in {f}") + + return sections def read_unrelocated_data(input_name, syms, secs): - """Collect unrelocated data from ELF.""" - data = {} - with open(input_name, "rb") as f: - - def is_symbol_in_section(sym, sec): - sec_end = sec["Address"] + sec["Size"] - is_start_in_section = sec["Address"] <= sym["Value"] < sec_end - is_end_in_section = sym["Value"] + sym["Size"] <= sec_end - return is_start_in_section and is_end_in_section - - for name, s in sorted(syms.items(), key=lambda s: s[1]["Value"]): - # TODO: binary search (bisect) - sec = [sec for sec in secs if is_symbol_in_section(s, sec)] - if len(sec) != 1: - error( - f"failed to locate section for interval [{s['Value']:x}, {s['Value'] + s['Size']:x})" - ) - sec = sec[0] - f.seek(sec["Off"]) - data[name] = f.read(s["Size"]) - return data - + """Collect unrelocated data from ELF.""" + data = {} + with open(input_name, 'rb') as f: + def is_symbol_in_section(sym, sec): + sec_end = sec['Address'] + sec['Size'] + is_start_in_section = sec['Address'] <= sym['Value'] < sec_end + is_end_in_section = sym['Value'] + sym['Size'] <= sec_end + return is_start_in_section and is_end_in_section + for name, s in sorted(syms.items(), key=lambda s: s[1]['Value']): + # TODO: binary search (bisect) + sec = [sec for sec in secs if is_symbol_in_section(s, sec)] + if len(sec) != 1: + error(f"failed to locate section for interval [{s['Value']:x}, {s['Value'] + s['Size']:x})") + sec = sec[0] + f.seek(sec['Off']) + data[name] = f.read(s['Size']) + return data def collect_relocated_data(syms, bites, rels, ptr_size, reloc_types): - """Identify relocations for each symbol""" - data = {} - for name, s in sorted(syms.items()): - b = bites.get(name) - assert b is not None - if s["Demangled Name"].startswith("typeinfo name"): - data[name] = [("byte", int(x)) for x in b] - continue - data[name] = [] - for i in range(0, len(b), ptr_size): - val = int.from_bytes( - b[i * ptr_size : (i + 1) * ptr_size], byteorder="little" - ) - data[name].append(("offset", val)) - start = s["Value"] - finish = start + s["Size"] - # TODO: binary search (bisect) - for rel in rels: - if rel["Type"] in reloc_types and start <= rel["Offset"] < finish: - i = (rel["Offset"] - start) // ptr_size - assert i < len(data[name]) - data[name][i] = "reloc", rel - return data - + """Identify relocations for each symbol""" + data = {} + for name, s in sorted(syms.items()): + b = bites.get(name) + assert b is not None + if s['Demangled Name'].startswith('typeinfo name'): + data[name] = [('byte', int(x)) for x in b] + continue + data[name] = [] + for i in range(0, len(b), ptr_size): + val = int.from_bytes(b[i*ptr_size:(i + 1)*ptr_size], byteorder='little') + data[name].append(('offset', val)) + start = s['Value'] + finish = start + s['Size'] + # TODO: binary search (bisect) + for rel in rels: + if rel['Type'] in reloc_types and start <= rel['Offset'] < finish: + i = (rel['Offset'] - start) // ptr_size + assert i < len(data[name]) + data[name][i] = 'reloc', rel + return data def generate_vtables(cls_tables, cls_syms, cls_data): - """Generate code for vtables""" - c_types = {"reloc": "const void *", "byte": "unsigned char", "offset": "size_t"} - - ss = [] - ss.append("""\ + """Generate code for vtables""" + c_types = { + 'reloc' : 'const void *', + 'byte' : 'unsigned char', + 'offset' : 'size_t' + } + + ss = [] + ss.append('''\ #ifdef __cplusplus extern "C" { #endif -""") +''') - # Print externs + # Print externs - printed = set() - for name, data in sorted(cls_data.items()): - for typ, val in data: - if typ != "reloc": - continue - sym_name, addend = val["Symbol's Name + Addend"] - sym_name = re.sub(r"@.*", "", sym_name) # Can we pin version in C? - if sym_name not in cls_syms and sym_name not in printed: - ss.append(f"""\ + printed = set() + for name, data in sorted(cls_data.items()): + for typ, val in data: + if typ != 'reloc': + continue + sym_name, addend = val['Symbol\'s Name + Addend'] + sym_name = re.sub(r'@.*', '', sym_name) # Can we pin version in C? + if sym_name not in cls_syms and sym_name not in printed: + ss.append(f'''\ extern const char {sym_name}[]; -""") +''') - # Collect variable infos + # Collect variable infos - code_info = {} + code_info = {} - for name, s in sorted(cls_syms.items()): - data = cls_data[name] - if s["Demangled Name"].startswith("typeinfo name"): - declarator = "const unsigned char %s[]" - else: - field_types = ( - f"{c_types[typ]} field_{i};" for i, (typ, _) in enumerate(data) - ) - declarator = "const struct { %s } %%s" % " ".join(field_types) # pylint: disable=C0209 # consider-using-f-string - vals = [] - for typ, val in data: - if typ != "reloc": - vals.append(str(val) + "UL") - else: - sym_name, addend = val["Symbol's Name + Addend"] - sym_name = re.sub(r"@.*", "", sym_name) # Can we pin version in C? - vals.append(f"(const char *)&{sym_name} + {addend}") - code_info[name] = (declarator, "{ %s }" % ", ".join(vals)) # pylint: disable= C0209 # consider-using-f-string - - # Print declarations - - for name, (decl, _) in sorted(code_info.items()): - type_name = name + "_type" - type_decl = decl % type_name - ss.append(f"""\ + for name, s in sorted(cls_syms.items()): + data = cls_data[name] + if s['Demangled Name'].startswith('typeinfo name'): + declarator = 'const unsigned char %s[]' + else: + field_types = (f'{c_types[typ]} field_{i};' for i, (typ, _) in enumerate(data)) + declarator = 'const struct { %s } %%s' % ' '.join(field_types) # pylint: disable=C0209 # consider-using-f-string + vals = [] + for typ, val in data: + if typ != 'reloc': + vals.append(str(val) + 'UL') + else: + sym_name, addend = val['Symbol\'s Name + Addend'] + sym_name = re.sub(r'@.*', '', sym_name) # Can we pin version in C? + vals.append(f'(const char *)&{sym_name} + {addend}') + code_info[name] = (declarator, '{ %s }' % ', '.join(vals)) # pylint: disable= C0209 # consider-using-f-string + + # Print declarations + + for name, (decl, _) in sorted(code_info.items()): + type_name = name + '_type' + type_decl = decl % type_name + ss.append(f'''\ typedef {type_decl}; extern __attribute__((weak)) {type_name} {name}; -""") +''') - # Print definitions + # Print definitions - for name, (_, init) in sorted(code_info.items()): - type_name = name + "_type" - ss.append(f"""\ + for name, (_, init) in sorted(code_info.items()): + type_name = name + '_type' + ss.append(f'''\ const {type_name} {name} = {init}; -""") +''') - ss.append("""\ + ss.append('''\ #ifdef __cplusplus } // extern "C" #endif -""") - - return "".join(ss) +''') + return ''.join(ss) def read_soname(f): - """Read ELF's SONAME.""" - - out, _ = run(["readelf", "-d", f]) + """Read ELF's SONAME.""" - for line in out.splitlines(): - line = line.strip() - if not line: - continue - # 0x000000000000000e (SONAME) Library soname: [libndp.so.0] - soname_match = re.search(r"\(SONAME\).*\[(.+)\]", line) - if soname_match is not None: - return soname_match[1] + out, _ = run(['readelf', '-d', f]) - return None + for line in out.splitlines(): + line = line.strip() + if not line: + continue + # 0x000000000000000e (SONAME) Library soname: [libndp.so.0] + soname_match = re.search(r'\(SONAME\).*\[(.+)\]', line) + if soname_match is not None: + return soname_match[1] + return None def main(): - """Driver function""" - parser = argparse.ArgumentParser( - description="Generate wrappers for shared library functions.", - formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=f"""\ + """Driver function""" + parser = argparse.ArgumentParser(description="Generate wrappers for shared library functions.", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=f"""\ Examples: $ python3 {me} /usr/lib/x86_64-linux-gnu/libaccountsservice.so.0 Generating libaccountsservice.so.0.tramp.S... Generating libaccountsservice.so.0.init.c... -""", - ) - - parser.add_argument("library", metavar="LIB", help="Library to be wrapped.") - parser.add_argument( - "--verbose", "-v", help="Print diagnostic info", action="count", default=0 - ) - parser.add_argument( - "--dlopen", - help="Emit dlopen call (default)", - dest="dlopen", - action="store_true", - default=True, - ) - parser.add_argument( - "--no-dlopen", - help="Do not emit dlopen call (user must load/unload library himself)", - dest="dlopen", - action="store_false", - ) - parser.add_argument( - "--dlopen-callback", - help="Call user-provided custom callback to load library instead of dlopen", - default="", - ) - parser.add_argument( - "--dlsym-callback", - help="Call user-provided custom callback to resolve a symbol, instead of dlsym", - default="", - ) - parser.add_argument( - "--library-load-name", - help="Use custom name for dlopened library (default is SONAME)", - ) - parser.add_argument( - "--lazy-load", - help="Load library on first call to any of it's functions (default)", - dest="lazy_load", - action="store_true", - default=True, - ) - parser.add_argument( - "--no-lazy-load", - help="Load library at program start", - dest="lazy_load", - action="store_false", - ) - parser.add_argument( - "--vtables", - help="Intercept virtual tables (EXPERIMENTAL)", - dest="vtables", - action="store_true", - default=False, - ) - parser.add_argument( - "--no-vtables", - help="Do not intercept virtual tables (default)", - dest="vtables", - action="store_false", - ) - parser.add_argument( - "--no-weak-symbols", - help="Don't bind weak symbols", - dest="no_weak_symbols", - action="store_true", - default=False, - ) - parser.add_argument( - "--target", - help="Target platform triple e.g. x86_64-unknown-linux-gnu or arm-none-eabi " - "(atm x86_64, i[0-9]86, arm/armhf/armeabi, aarch64/armv8, " - "mips/mipsel, mips64/mip64el and e2k are supported)", - default=os.uname()[-1], - ) - parser.add_argument( - "--symbol-list", - help="Path to file with symbols that should be present in wrapper " - "(all by default)", - ) - parser.add_argument( - "--symbol-prefix", - metavar="PFX", - help="Prefix wrapper symbols with PFX", - default="", - ) - parser.add_argument( - "-q", "--quiet", help="Do not print progress info", action="store_true" - ) - parser.add_argument( - "--outdir", "-o", help="Path to create wrapper at", default="./" - ) - - args = parser.parse_args() - - input_name = args.library - verbose = args.verbose - dlopen_callback = args.dlopen_callback - dlsym_callback = args.dlsym_callback - dlopen = args.dlopen - lazy_load = args.lazy_load - if args.target.startswith("arm"): - target = "arm" # Handle armhf-..., armel-... - elif re.match(r"^i[0-9]86", args.target): - target = "i386" - elif args.target.startswith("mips64"): - target = "mips64" # Handle mips64-..., mips64el-..., mips64le-... - elif args.target.startswith("mips"): - target = "mips" # Handle mips-..., mipsel-..., mipsle-... - else: - target = args.target.split("-")[0] - quiet = args.quiet - outdir = args.outdir +""") - if args.symbol_list is None: - funs = None - else: - with open(args.symbol_list, "r") as f: - funs = [] - for line in re.split(r"\r?\n", f.read()): - line = re.sub(r"#.*", "", line) - line = line.strip() - if line: - funs.append(line) - - if args.library_load_name is not None: - load_name = args.library_load_name - else: - load_name = read_soname(input_name) - if load_name is None: - load_name = os.path.basename(input_name) + parser.add_argument('library', + metavar='LIB', + help="Library to be wrapped.") + parser.add_argument('--verbose', '-v', + help="Print diagnostic info", + action='count', + default=0) + parser.add_argument('--dlopen', + help="Emit dlopen call (default)", + dest='dlopen', action='store_true', default=True) + parser.add_argument('--no-dlopen', + help="Do not emit dlopen call (user must load/unload library himself)", + dest='dlopen', action='store_false') + parser.add_argument('--dlopen-callback', + help="Call user-provided custom callback to load library instead of dlopen", + default='') + parser.add_argument('--dlsym-callback', + help="Call user-provided custom callback to resolve a symbol, " + "instead of dlsym", + default='') + parser.add_argument('--library-load-name', + help="Use custom name for dlopened library (default is SONAME)") + parser.add_argument('--lazy-load', + help="Load library on first call to any of it's functions (default)", + dest='lazy_load', action='store_true', default=True) + parser.add_argument('--no-lazy-load', + help="Load library at program start", + dest='lazy_load', action='store_false') + parser.add_argument('--vtables', + help="Intercept virtual tables (EXPERIMENTAL)", + dest='vtables', action='store_true', default=False) + parser.add_argument('--no-vtables', + help="Do not intercept virtual tables (default)", + dest='vtables', action='store_false') + parser.add_argument('--no-weak-symbols', + help="Don't bind weak symbols", dest='no_weak_symbols', + action='store_true', default=False) + parser.add_argument('--target', + help="Target platform triple e.g. x86_64-unknown-linux-gnu or arm-none-eabi " + "(atm x86_64, i[0-9]86, arm/armhf/armeabi, aarch64/armv8, " + "mips/mipsel, mips64/mip64el and e2k are supported)", + default=os.uname()[-1]) + parser.add_argument('--symbol-list', + help="Path to file with symbols that should be present in wrapper " + "(all by default)") + parser.add_argument('--symbol-prefix', + metavar='PFX', + help="Prefix wrapper symbols with PFX", + default='') + parser.add_argument('-q', '--quiet', + help="Do not print progress info", + action='store_true') + parser.add_argument('--outdir', '-o', + help="Path to create wrapper at", + default='./') + + args = parser.parse_args() + + input_name = args.library + verbose = args.verbose + dlopen_callback = args.dlopen_callback + dlsym_callback = args.dlsym_callback + dlopen = args.dlopen + lazy_load = args.lazy_load + if args.target.startswith('arm'): + target = 'arm' # Handle armhf-..., armel-... + elif re.match(r'^i[0-9]86', args.target): + target = 'i386' + elif args.target.startswith('mips64'): + target = 'mips64' # Handle mips64-..., mips64el-..., mips64le-... + elif args.target.startswith('mips'): + target = 'mips' # Handle mips-..., mipsel-..., mipsle-... + else: + target = args.target.split('-')[0] + quiet = args.quiet + outdir = args.outdir + + if args.symbol_list is None: + funs = None + else: + with open(args.symbol_list, 'r') as f: + funs = [] + for line in re.split(r'\r?\n', f.read()): + line = re.sub(r'#.*', '', line) + line = line.strip() + if line: + funs.append(line) + + if args.library_load_name is not None: + load_name = args.library_load_name + else: + load_name = read_soname(input_name) + if load_name is None: + load_name = os.path.basename(input_name) - # Collect target info + # Collect target info - target_dir = os.path.join(root, "arch", target) + target_dir = os.path.join(root, 'arch', target) - if not os.path.exists(target_dir): - error(f"unknown architecture '{target}'") + if not os.path.exists(target_dir): + error(f"unknown architecture '{target}'") - cfg = configparser.ConfigParser(inline_comment_prefixes=";") - cfg.read(target_dir + "/config.ini") + cfg = configparser.ConfigParser(inline_comment_prefixes=';') + cfg.read(target_dir + '/config.ini') - ptr_size = int(cfg["Arch"]["PointerSize"]) - symbol_reloc_types = set(re.split(r"\s*,\s*", cfg["Arch"]["SymbolReloc"])) + ptr_size = int(cfg['Arch']['PointerSize']) + symbol_reloc_types = set(re.split(r'\s*,\s*', cfg['Arch']['SymbolReloc'])) - def is_exported(s): - conditions = [ - s["Bind"] != "LOCAL", - s["Type"] != "NOTYPE", - s["Ndx"] != "UND", - s["Name"] not in ["", "_init", "_fini"], - ] - if args.no_weak_symbols: - conditions.append(s["Bind"] != "WEAK") - return all(conditions) + def is_exported(s): + conditions = [ + s['Bind'] != 'LOCAL', + s['Type'] != 'NOTYPE', + s['Ndx'] != 'UND', + s['Name'] not in ['', '_init', '_fini']] + if args.no_weak_symbols: + conditions.append(s['Bind'] != 'WEAK') + return all(conditions) - syms = list(filter(is_exported, collect_syms(input_name))) + syms = list(filter(is_exported, collect_syms(input_name))) - def is_data_symbol(s): - return ( - s["Type"] == "OBJECT" + def is_data_symbol(s): + return (s['Type'] == 'OBJECT' # Allow vtables if --vtables is on - and not (" for " in s["Demangled Name"] and args.vtables) - ) - - exported_data = [s["Name"] for s in syms if is_data_symbol(s)] - if exported_data: - # TODO: we can generate wrappers for const data without relocations (or only code relocations) - warn( - f"library '{input_name}' contains data symbols which won't be intercepted: " - + ", ".join(exported_data) - ) - - # Collect functions - # TODO: warn if user-specified functions are missing - - orig_funs = filter(lambda s: s["Type"] == "FUNC", syms) - - all_funs = set() - warn_versioned = False - for s in orig_funs: - if not s["Default"]: - # TODO: support versions - if not warn_versioned: - warn(f"library {input_name} contains versioned symbols which are NYI") - warn_versioned = True - if verbose: - print(f"Skipping versioned symbol {s['Name']}") - continue - all_funs.add(s["Name"]) - - if funs is None: - funs = sorted(list(all_funs)) - if not funs and not quiet: - warn(f"no public functions were found in {input_name}") - else: - missing_funs = [name for name in funs if name not in all_funs] - if missing_funs: - warn( - "some user-specified functions are not present in library: " - + ", ".join(missing_funs) - ) - funs = [name for name in funs if name in all_funs] + and not (' for ' in s['Demangled Name'] and args.vtables)) + + exported_data = [s['Name'] for s in syms if is_data_symbol(s)] + if exported_data: + # TODO: we can generate wrappers for const data without relocations (or only code relocations) + warn(f"library '{input_name}' contains data symbols which won't be intercepted: " + + ', '.join(exported_data)) + + # Collect functions + # TODO: warn if user-specified functions are missing + + orig_funs = filter(lambda s: s['Type'] == 'FUNC', syms) + + all_funs = set() + warn_versioned = False + for s in orig_funs: + if not s['Default']: + # TODO: support versions + if not warn_versioned: + warn(f"library {input_name} contains versioned symbols which are NYI") + warn_versioned = True + if verbose: + print(f"Skipping versioned symbol {s['Name']}") + continue + all_funs.add(s['Name']) + + if funs is None: + funs = sorted(list(all_funs)) + if not funs and not quiet: + warn(f"no public functions were found in {input_name}") + else: + missing_funs = [name for name in funs if name not in all_funs] + if missing_funs: + warn("some user-specified functions are not present in library: " + ', '.join(missing_funs)) + funs = [name for name in funs if name in all_funs] + + if verbose: + print("Exported functions:") + for i, fun in enumerate(funs): + print(f" {i}: {fun}") + + # Collect vtables + + if args.vtables: + cls_tables = {} + cls_syms = {} + + for s in syms: + m = re.match(r'^(vtable|typeinfo|typeinfo name) for (.*)', s['Demangled Name']) + if m is not None and is_exported(s): + typ, cls = m.groups() + name = s['Name'] + cls_tables.setdefault(cls, {})[typ] = name + cls_syms[name] = s + + if verbose: + print("Exported classes:") + for cls, _ in sorted(cls_tables.items()): + print(f" {cls}") + secs = collect_sections(input_name) if verbose: - print("Exported functions:") - for i, fun in enumerate(funs): - print(f" {i}: {fun}") + print("Sections:") + for sec in secs: + print(f" {sec['Name']}: [{sec['Address']:x}, {sec['Address'] + sec['Size']:x}), " + f"at {sec['Off']:x}") - # Collect vtables + bites = read_unrelocated_data(input_name, cls_syms, secs) + rels = collect_relocs(input_name) + if verbose: + print("Relocs:") + for rel in rels: + sym_add = rel['Symbol\'s Name + Addend'] + print(f" {rel['Offset']}: {sym_add}") + + cls_data = collect_relocated_data(cls_syms, bites, rels, ptr_size, symbol_reloc_types) + if verbose: + print("Class data:") + for name, data in sorted(cls_data.items()): + demangled_name = cls_syms[name]['Demangled Name'] + print(f" {name} ({demangled_name}):") + for typ, val in data: + print(" " + str(val if typ != 'reloc' else val['Symbol\'s Name + Addend'])) + + # Generate assembly code + + suffix = os.path.basename(input_name) + lib_suffix = re.sub(r'[^a-zA-Z_0-9]+', '_', suffix) + + tramp_file = f'{suffix}.tramp.S' + with open(os.path.join(outdir, tramp_file), 'w') as f: + if not quiet: + print(f"Generating {tramp_file}...") + with open(target_dir + '/table.S.tpl', 'r') as t: + table_text = string.Template(t.read()).substitute( + lib_suffix=lib_suffix, + table_size=ptr_size*(len(funs) + 1)) + f.write(table_text) + + with open(target_dir + '/trampoline.S.tpl', 'r') as t: + tramp_tpl = string.Template(t.read()) + + for i, name in enumerate(funs): + tramp_text = tramp_tpl.substitute( + lib_suffix=lib_suffix, + sym=args.symbol_prefix + name, + offset=i*ptr_size, + number=i) + f.write(tramp_text) + + # Generate C code + + init_file = f'{suffix}.init.c' + with open(os.path.join(outdir, init_file), 'w') as f: + if not quiet: + print(f"Generating {init_file}...") + with open(os.path.join(root, 'arch/common/init.c.tpl'), 'r') as t: + if funs: + sym_names = ',\n '.join(f'"{name}"' for name in funs) + ',' + else: + sym_names = '' + init_text = string.Template(t.read()).substitute( + lib_suffix=lib_suffix, + load_name=load_name, + dlopen_callback=dlopen_callback, + dlsym_callback=dlsym_callback, + has_dlopen_callback=int(bool(dlopen_callback)), + has_dlsym_callback=int(bool(dlsym_callback)), + no_dlopen=int(not dlopen), + lazy_load=int(lazy_load), + sym_names=sym_names) + f.write(init_text) if args.vtables: - cls_tables = {} - cls_syms = {} - - for s in syms: - m = re.match( - r"^(vtable|typeinfo|typeinfo name) for (.*)", s["Demangled Name"] - ) - if m is not None and is_exported(s): - typ, cls = m.groups() - name = s["Name"] - cls_tables.setdefault(cls, {})[typ] = name - cls_syms[name] = s - - if verbose: - print("Exported classes:") - for cls, _ in sorted(cls_tables.items()): - print(f" {cls}") - - secs = collect_sections(input_name) - if verbose: - print("Sections:") - for sec in secs: - print( - f" {sec['Name']}: [{sec['Address']:x}, {sec['Address'] + sec['Size']:x}), " - f"at {sec['Off']:x}" - ) - - bites = read_unrelocated_data(input_name, cls_syms, secs) - - rels = collect_relocs(input_name) - if verbose: - print("Relocs:") - for rel in rels: - sym_add = rel["Symbol's Name + Addend"] - print(f" {rel['Offset']}: {sym_add}") - - cls_data = collect_relocated_data( - cls_syms, bites, rels, ptr_size, symbol_reloc_types - ) - if verbose: - print("Class data:") - for name, data in sorted(cls_data.items()): - demangled_name = cls_syms[name]["Demangled Name"] - print(f" {name} ({demangled_name}):") - for typ, val in data: - print( - " " - + str(val if typ != "reloc" else val["Symbol's Name + Addend"]) - ) - - # Generate assembly code - - suffix = os.path.basename(input_name) - lib_suffix = re.sub(r"[^a-zA-Z_0-9]+", "_", suffix) - - tramp_file = f"{suffix}.tramp.S" - with open(os.path.join(outdir, tramp_file), "w") as f: - if not quiet: - print(f"Generating {tramp_file}...") - with open(target_dir + "/table.S.tpl", "r") as t: - table_text = string.Template(t.read()).substitute( - lib_suffix=lib_suffix, table_size=ptr_size * (len(funs) + 1) - ) - f.write(table_text) - - with open(target_dir + "/trampoline.S.tpl", "r") as t: - tramp_tpl = string.Template(t.read()) - - for i, name in enumerate(funs): - tramp_text = tramp_tpl.substitute( - lib_suffix=lib_suffix, - sym=args.symbol_prefix + name, - offset=i * ptr_size, - number=i, - ) - f.write(tramp_text) - - # Generate C code - - init_file = f"{suffix}.init.c" - with open(os.path.join(outdir, init_file), "w") as f: - if not quiet: - print(f"Generating {init_file}...") - with open(os.path.join(root, "arch/common/init.c.tpl"), "r") as t: - if funs: - sym_names = ",\n ".join(f'"{name}"' for name in funs) + "," - else: - sym_names = "" - init_text = string.Template(t.read()).substitute( - lib_suffix=lib_suffix, - load_name=load_name, - dlopen_callback=dlopen_callback, - dlsym_callback=dlsym_callback, - has_dlopen_callback=int(bool(dlopen_callback)), - has_dlsym_callback=int(bool(dlsym_callback)), - no_dlopen=int(not dlopen), - lazy_load=int(lazy_load), - sym_names=sym_names, - ) - f.write(init_text) - if args.vtables: - vtable_text = generate_vtables(cls_tables, cls_syms, cls_data) - f.write(vtable_text) - - -if __name__ == "__main__": - main() + vtable_text = generate_vtables(cls_tables, cls_syms, cls_data) + f.write(vtable_text) + +if __name__ == '__main__': + main() diff --git a/source/tests/common/test_nan_detector.py b/source/tests/common/test_nan_detector.py index 1ce719d944..9d9f0bf2f3 100644 --- a/source/tests/common/test_nan_detector.py +++ b/source/tests/common/test_nan_detector.py @@ -8,8 +8,7 @@ from deepmd.utils.nan_detector import ( LossNaNError, - check_loss_nan, - check_single_loss_nan, + check_total_loss_nan, ) @@ -19,52 +18,25 @@ class TestNaNDetector(unittest.TestCase): def test_normal_values_pass(self): """Test that normal loss values don't trigger NaN detection.""" # Test with various normal values - normal_losses = { - "energy": 0.5, - "force": 1.0, - "virial": 0.001, - "zero": 0.0, - "negative": -0.5, - } + normal_losses = [0.5, 1.0, 0.001, 0.0, -0.5] # Should not raise any exception - try: - check_loss_nan(100, normal_losses) - except Exception as e: - self.fail(f"Normal values should not raise exception: {e}") + for i, loss_val in enumerate(normal_losses): + try: + check_total_loss_nan(100 + i, loss_val) + except Exception as e: + self.fail(f"Normal values should not raise exception: {e}") def test_nan_detection_raises_exception(self): """Test that NaN values trigger the proper exception.""" - # Test with NaN values - nan_losses = { - "energy": 0.5, - "force": float("nan"), - "virial": 1.0, - } - + # Test with NaN value with self.assertRaises(LossNaNError) as context: - check_loss_nan(200, nan_losses) + check_total_loss_nan(200, float("nan")) exception = context.exception self.assertEqual(exception.step, 200) - self.assertIn("force", str(exception)) - self.assertIn("NaN detected in loss at training step 200", str(exception)) - - def test_single_loss_nan_detection(self): - """Test single loss NaN detection.""" - # Normal value should pass - try: - check_single_loss_nan(50, "test_loss", 0.5) - except Exception as e: - self.fail(f"Normal single loss should not raise exception: {e}") - - # NaN value should raise - with self.assertRaises(LossNaNError) as context: - check_single_loss_nan(50, "test_loss", float("nan")) - - exception = context.exception - self.assertEqual(exception.step, 50) - self.assertIn("test_loss", str(exception)) + self.assertTrue(math.isnan(exception.total_loss)) + self.assertIn("NaN detected in total loss at training step 200", str(exception)) def test_various_nan_representations(self): """Test detection of various NaN representations.""" @@ -76,91 +48,55 @@ def test_various_nan_representations(self): for i, nan_val in enumerate(nan_values): with self.assertRaises(LossNaNError): - check_single_loss_nan(i, f"loss_{i}", nan_val) - - def test_tensor_like_objects(self): - """Test that tensor-like objects work with NaN detection.""" - - # Mock tensor-like object with item() method - class MockTensor: - def __init__(self, value): - self._value = value - - def item(self): - return self._value - - # Normal tensor should pass - normal_tensor = MockTensor(0.5) - try: - check_single_loss_nan(10, "tensor_loss", normal_tensor) - except Exception as e: - self.fail(f"Normal tensor should not raise exception: {e}") - - # NaN tensor should raise - nan_tensor = MockTensor(float("nan")) - with self.assertRaises(LossNaNError): - check_single_loss_nan(10, "tensor_loss", nan_tensor) + check_total_loss_nan(i, nan_val) def test_error_message_format(self): """Test that error messages contain useful information.""" - nan_losses = { - "energy": 0.5, - "force": float("nan"), - "virial": float("nan"), - } - with self.assertRaises(LossNaNError) as context: - check_loss_nan(123, nan_losses) + check_total_loss_nan(123, float("nan")) error_msg = str(context.exception) # Check key information is in the message self.assertIn("step 123", error_msg) - self.assertIn("force=nan", error_msg) - self.assertIn("virial=nan", error_msg) self.assertIn("Training stopped", error_msg) self.assertIn("learning rate too high", error_msg) - def test_mixed_loss_dict(self): - """Test loss dictionary with mix of normal and NaN values.""" - mixed_losses = { - "energy": 0.5, - "force": float("nan"), - "virial": 1.0, - "dipole": float("nan"), - } - - with self.assertRaises(LossNaNError) as context: - check_loss_nan(99, mixed_losses) - - exception = context.exception - # Should detect both NaN values - error_msg = str(exception) - self.assertIn("force=nan", error_msg) - self.assertIn("dipole=nan", error_msg) - # Should not mention normal values - self.assertNotIn("energy=0.5", error_msg) - self.assertNotIn("virial=1.0", error_msg) - def test_edge_cases(self): """Test edge cases for NaN detection.""" - # Empty dict should pass + # Infinity should not trigger NaN detection (separate issue) try: - check_loss_nan(1, {}) + check_total_loss_nan(1, float("inf")) + check_total_loss_nan(2, float("-inf")) except Exception as e: - self.fail(f"Empty dict should not raise exception: {e}") + self.fail(f"Infinity should not raise NaN exception: {e}") - # None values should not trigger NaN detection - try: - check_loss_nan(1, {"test": None}) - except Exception as e: - self.fail(f"None values should not raise exception: {e}") + def test_numeric_types(self): + """Test that various numeric types work correctly.""" + # Various numeric types that should pass + test_values = [ + 0.5, # float + 1, # int + np.float32(0.3), # NumPy float32 + np.float64(0.7), # NumPy float64 + ] + + for i, val in enumerate(test_values): + try: + check_total_loss_nan(10 + i, float(val)) + except Exception as e: + self.fail(f"Numeric type {type(val)} should not raise exception: {e}") + + def test_inheritance_from_runtime_error(self): + """Test that LossNaNError inherits from RuntimeError.""" + self.assertTrue(issubclass(LossNaNError, RuntimeError)) - # Infinity should not trigger NaN detection (separate issue) try: - check_loss_nan(1, {"test": float("inf")}) - except Exception as e: - self.fail(f"Infinity should not raise NaN exception: {e}") + check_total_loss_nan(999, float("nan")) + except LossNaNError as e: + self.assertIsInstance(e, RuntimeError) + except Exception: + self.fail("Should raise LossNaNError which inherits from RuntimeError") if __name__ == "__main__": diff --git a/source/tests/common/test_nan_integration.py b/source/tests/common/test_nan_integration.py index fcd4624f66..1b3ae1ff61 100644 --- a/source/tests/common/test_nan_integration.py +++ b/source/tests/common/test_nan_integration.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later """Integration test to verify NaN detection during training. -This test creates a mock training scenario where loss becomes NaN +This test creates a mock training scenario where total loss becomes NaN and verifies that the training stops with appropriate error message. """ @@ -12,7 +12,7 @@ from deepmd.utils.nan_detector import ( LossNaNError, - check_loss_nan, + check_total_loss_nan, ) @@ -20,83 +20,31 @@ class TestNaNDetectionIntegration(unittest.TestCase): """Integration tests for NaN detection during training.""" def test_training_stops_on_nan_loss(self): - """Test that training stops when NaN is detected in loss values.""" - # Simulate a training scenario where loss becomes NaN - train_results = { - "energy_loss": 0.1, - "force_loss": float("nan"), # This should trigger the detection - "virial_loss": 0.05, - } - - valid_results = { - "energy_loss": 0.12, - "force_loss": 0.08, - "virial_loss": 0.06, - } - - # The NaN in train_results should be detected - with self.assertRaises(LossNaNError) as context: - check_loss_nan(100, train_results) - - exception = context.exception - self.assertEqual(exception.step, 100) - self.assertIn("force_loss=nan", str(exception)) - - # Valid results without NaN should pass - try: - check_loss_nan(100, valid_results) - except Exception as e: - self.fail(f"Valid results should not raise exception: {e}") - - def test_multi_task_nan_detection(self): - """Test NaN detection in multi-task training scenario.""" - # Simulate multi-task training results - multi_task_results = { - "task1": { - "energy_loss": 0.1, - "force_loss": 0.05, - }, - "task2": { - "energy_loss": float("nan"), # NaN in task2 - "force_loss": 0.03, - }, - "task3": { - "energy_loss": 0.08, - "force_loss": 0.04, - }, - } - - # Check each task separately (as done in the actual training code) - # task1 and task3 should pass + """Test that training stops when NaN is detected in total loss.""" + # Normal total loss should pass try: - check_loss_nan(50, multi_task_results["task1"]) - check_loss_nan(50, multi_task_results["task3"]) + check_total_loss_nan(100, 0.1) except Exception as e: - self.fail(f"Normal tasks should not raise exception: {e}") + self.fail(f"Normal total loss should not raise exception: {e}") - # task2 should fail due to NaN + # NaN total loss should raise with self.assertRaises(LossNaNError) as context: - check_loss_nan(50, multi_task_results["task2"]) + check_total_loss_nan(100, float("nan")) exception = context.exception - self.assertEqual(exception.step, 50) - self.assertIn("energy_loss=nan", str(exception)) + self.assertEqual(exception.step, 100) + self.assertIn("NaN detected in total loss", str(exception)) @patch("deepmd.utils.nan_detector.log") def test_logging_on_nan_detection(self, mock_log): """Test that NaN detection logs appropriate error messages.""" - nan_losses = { - "energy": 0.5, - "force": float("nan"), - } - with self.assertRaises(LossNaNError): - check_loss_nan(200, nan_losses) + check_total_loss_nan(200, float("nan")) # Verify that error was logged mock_log.error.assert_called_once() logged_message = mock_log.error.call_args[0][0] - self.assertIn("NaN detected in force at step 200", logged_message) + self.assertIn("NaN detected in total loss at step 200", logged_message) def test_training_simulation_with_checkpoint_prevention(self): """Simulate the training checkpoint scenario to ensure NaN prevents saving.""" @@ -105,27 +53,50 @@ def mock_save_checkpoint(): """Mock function that should not be called when NaN is detected.""" raise AssertionError("Checkpoint should not be saved when NaN is detected!") - # Simulate the training flow: check loss, then save checkpoint + # Simulate the training flow: check total loss, then save checkpoint step_id = 1000 - loss_results = { - "total_loss": float("nan"), - "energy_loss": 0.1, - "force_loss": 0.05, - } + total_loss = float("nan") # This should raise LossNaNError before checkpoint saving with self.assertRaises(LossNaNError): - check_loss_nan(step_id, loss_results) + check_total_loss_nan(step_id, total_loss) # This line should never be reached mock_save_checkpoint() # Verify the error contains expected information try: - check_loss_nan(step_id, loss_results) + check_total_loss_nan(step_id, total_loss) except LossNaNError as e: self.assertIn("Training stopped to prevent wasting time", str(e)) self.assertIn("corrupted parameters", str(e)) + def test_realistic_training_scenario(self): + """Test a more realistic training scenario with decreasing then NaN loss.""" + # Simulate normal training progression + normal_steps = [ + (1, 1.0), # Initial high loss + (10, 0.5), # Loss decreasing + (20, 0.25), # Loss continuing to decrease + (50, 0.1), # Good progress + ] + + # All normal steps should pass + for step, loss_val in normal_steps: + try: + check_total_loss_nan(step, loss_val) + except Exception as e: + self.fail( + f"Normal training step {step} should not raise exception: {e}" + ) + + # But when loss becomes NaN, training should stop + with self.assertRaises(LossNaNError) as context: + check_total_loss_nan(100, float("nan")) + + exception = context.exception + self.assertEqual(exception.step, 100) + self.assertIn("Training stopped", str(exception)) + if __name__ == "__main__": unittest.main() From 0852b7c460cff028e560d3764f2ac034ffe3efa2 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 20 Sep 2025 17:27:21 +0000 Subject: [PATCH 4/7] fix(training): optimize NaN detection based on feedback - use lcurve CPU values and fixed loss keys Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com> --- deepmd/pd/train/training.py | 10 +- deepmd/pt/train/training.py | 14 +- deepmd/tf/train/trainer.py | 14 +- source/3rdparty/implib/implib-gen.py | 1093 ++++++++++++++------------ 4 files changed, 603 insertions(+), 528 deletions(-) diff --git a/deepmd/pd/train/training.py b/deepmd/pd/train/training.py index d6ae30a94a..7720671df8 100644 --- a/deepmd/pd/train/training.py +++ b/deepmd/pd/train/training.py @@ -781,8 +781,6 @@ def step(_step_id, task_key="Default") -> None: label=label_dict, task_key=task_key, ) - # Check for NaN in total loss before backward pass to prevent corrupted training - check_total_loss_nan(_step_id + 1, loss.item()) with nvprof_context(enable_profiling, "Backward pass"): loss.backward() @@ -864,6 +862,9 @@ def log_loss_valid(_task_key="Default"): if not self.multi_task: train_results = log_loss_train(loss, more_loss) + # Check for NaN in total loss using CPU values from lcurve computation + if self.rank == 0 and "rmse_e" in train_results: + check_total_loss_nan(display_step_id, train_results["rmse_e"]) valid_results = log_loss_valid() if self.rank == 0: log.info( @@ -905,6 +906,11 @@ def log_loss_valid(_task_key="Default"): loss, more_loss, _task_key=_key ) valid_results[_key] = log_loss_valid(_task_key=_key) + # Check for NaN in total loss using CPU values from lcurve computation + if self.rank == 0 and "rmse_e" in train_results[_key]: + check_total_loss_nan( + display_step_id, train_results[_key]["rmse_e"] + ) if self.rank == 0: log.info( format_training_message_per_task( diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index cc1858b4b7..5713582a94 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -764,8 +764,6 @@ def step(_step_id: int, task_key: str = "Default") -> None: model_pred, loss, more_loss = self.wrapper( **input_dict, cur_lr=pref_lr, label=label_dict, task_key=task_key ) - # Check for NaN in total loss before backward pass to prevent corrupted training - check_total_loss_nan(_step_id + 1, loss.item()) loss.backward() if self.gradient_max_norm > 0.0: torch.nn.utils.clip_grad_norm_( @@ -817,8 +815,6 @@ def fake_model() -> dict: int(input_dict["atype"].shape[-1]), learning_rate=pref_lr, ) - # Check for NaN in total loss before continuing training - check_total_loss_nan(_step_id + 1, loss.item()) elif isinstance(self.loss, DenoiseLoss): KFOptWrapper = KFOptimizerWrapper( self.wrapper, @@ -845,8 +841,6 @@ def fake_model() -> dict: input_dict["natoms"], learning_rate=pref_lr, ) - # Check for NaN in total loss before continuing training - check_total_loss_nan(_step_id + 1, loss.item()) else: raise ValueError(f"Not supported optimizer type '{self.opt_type}'") @@ -958,6 +952,9 @@ def log_loss_valid(_task_key: str = "Default") -> dict: if not self.multi_task: train_results = log_loss_train(loss, more_loss) + # Check for NaN in total loss using CPU values from lcurve computation + if self.rank == 0 and "rmse_e" in train_results: + check_total_loss_nan(display_step_id, train_results["rmse_e"]) valid_results = log_loss_valid() if self.rank == 0: log.info( @@ -1006,6 +1003,11 @@ def log_loss_valid(_task_key: str = "Default") -> dict: loss, more_loss, _task_key=_key ) valid_results[_key] = log_loss_valid(_task_key=_key) + # Check for NaN in total loss using CPU values from lcurve computation + if self.rank == 0 and "rmse_e" in train_results[_key]: + check_total_loss_nan( + display_step_id, train_results[_key]["rmse_e"] + ) if self.rank == 0: log.info( format_training_message_per_task( diff --git a/deepmd/tf/train/trainer.py b/deepmd/tf/train/trainer.py index 957b77a274..9fa3b4e323 100644 --- a/deepmd/tf/train/trainer.py +++ b/deepmd/tf/train/trainer.py @@ -689,18 +689,8 @@ def valid_on_the_fly( current_lr = run_sess(self.sess, self.learning_rate) # Check for NaN in total loss before writing to file and saving checkpoint - # We check the main loss component that represents total training loss - if train_results: - # Look for the main loss key (typically the first loss component) - main_loss_key = next(iter(train_results.keys())) if train_results else None - if main_loss_key and main_loss_key in train_results: - check_total_loss_nan(cur_batch, train_results[main_loss_key]) - - if valid_results: - # Check validation loss as well for consistency - main_loss_key = next(iter(valid_results.keys())) if valid_results else None - if main_loss_key and main_loss_key in valid_results: - check_total_loss_nan(cur_batch, valid_results[main_loss_key]) + # We check the main energy loss component that represents total training loss + check_total_loss_nan(cur_batch, train_results["rmse_e"]) if print_header: self.print_header(fp, train_results, valid_results) diff --git a/source/3rdparty/implib/implib-gen.py b/source/3rdparty/implib/implib-gen.py index 86cfa77378..3a51be271d 100755 --- a/source/3rdparty/implib/implib-gen.py +++ b/source/3rdparty/implib/implib-gen.py @@ -22,577 +22,654 @@ me = os.path.basename(__file__) root = os.path.dirname(__file__) + def warn(msg): - """Emits a nicely-decorated warning.""" - sys.stderr.write(f'{me}: warning: {msg}\n') + """Emits a nicely-decorated warning.""" + sys.stderr.write(f"{me}: warning: {msg}\n") + def error(msg): - """Emits a nicely-decorated error and exits.""" - sys.stderr.write(f'{me}: error: {msg}\n') - sys.exit(1) - -def run(args, stdin=''): - """Runs external program and aborts on error.""" - env = os.environ.copy() - # Force English language - env['LC_ALL'] = 'c' - try: - del env["LANG"] - except KeyError: - pass - with subprocess.Popen(args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, - stderr=subprocess.PIPE, env=env) as p: - out, err = p.communicate(input=stdin.encode('utf-8')) - out = out.decode('utf-8') - err = err.decode('utf-8') - if p.returncode != 0 or err: - error(f"{args[0]} failed with retcode {p.returncode}:\n{err}") - return out, err + """Emits a nicely-decorated error and exits.""" + sys.stderr.write(f"{me}: error: {msg}\n") + sys.exit(1) + + +def run(args, stdin=""): + """Runs external program and aborts on error.""" + env = os.environ.copy() + # Force English language + env["LC_ALL"] = "c" + try: + del env["LANG"] + except KeyError: + pass + with subprocess.Popen( + args, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env=env, + ) as p: + out, err = p.communicate(input=stdin.encode("utf-8")) + out = out.decode("utf-8") + err = err.decode("utf-8") + if p.returncode != 0 or err: + error(f"{args[0]} failed with retcode {p.returncode}:\n{err}") + return out, err + def make_toc(words, renames=None): - "Make an mapping of words to their indices in list" - renames = renames or {} - toc = {} - for i, n in enumerate(words): - name = renames.get(n, n) - toc[i] = name - return toc + "Make an mapping of words to their indices in list" + renames = renames or {} + toc = {} + for i, n in enumerate(words): + name = renames.get(n, n) + toc[i] = name + return toc + def parse_row(words, toc, hex_keys): - "Make a mapping from column names to values" - vals = {k: (words[i] if i < len(words) else '') for i, k in toc.items()} - for k in hex_keys: - if vals[k]: - vals[k] = int(vals[k], 16) - return vals + "Make a mapping from column names to values" + vals = {k: (words[i] if i < len(words) else "") for i, k in toc.items()} + for k in hex_keys: + if vals[k]: + vals[k] = int(vals[k], 16) + return vals + def collect_syms(f): - """Collect ELF dynamic symtab.""" - - # --dyn-syms does not always work for some reason so dump all symtabs - out, _ = run(['readelf', '-sW', f]) - - toc = None - syms = [] - syms_set = set() - for line in out.splitlines(): - line = line.strip() - if not line: - # Next symtab - toc = None - continue - words = re.split(r' +', line) - if line.startswith('Num'): # Header? - if toc is not None: - error("multiple headers in output of readelf") - # Colons are different across readelf versions so get rid of them. - toc = make_toc(map(lambda n: n.replace(':', ''), words)) - elif toc is not None: - sym = parse_row(words, toc, ['Value']) - name = sym['Name'] - if not name: - continue - if name in syms_set: - continue - syms_set.add(name) - sym['Size'] = int(sym['Size'], 0) # Readelf is inconistent on Size format - if '@' in name: - sym['Default'] = '@@' in name - name, ver = re.split(r'@+', name) - sym['Name'] = name - sym['Version'] = ver - else: - sym['Default'] = True - sym['Version'] = None - syms.append(sym) - - if toc is None: - error(f"failed to analyze symbols in {f}") - - # Also collected demangled names - if syms: - out, _ = run(['c++filt'], '\n'.join((sym['Name'] for sym in syms))) - out = out.rstrip("\n") # Some c++filts append newlines at the end - for i, name in enumerate(out.split("\n")): - syms[i]['Demangled Name'] = name - - return syms + """Collect ELF dynamic symtab.""" + + # --dyn-syms does not always work for some reason so dump all symtabs + out, _ = run(["readelf", "-sW", f]) + + toc = None + syms = [] + syms_set = set() + for line in out.splitlines(): + line = line.strip() + if not line: + # Next symtab + toc = None + continue + words = re.split(r" +", line) + if line.startswith("Num"): # Header? + if toc is not None: + error("multiple headers in output of readelf") + # Colons are different across readelf versions so get rid of them. + toc = make_toc(map(lambda n: n.replace(":", ""), words)) + elif toc is not None: + sym = parse_row(words, toc, ["Value"]) + name = sym["Name"] + if not name: + continue + if name in syms_set: + continue + syms_set.add(name) + sym["Size"] = int(sym["Size"], 0) # Readelf is inconistent on Size format + if "@" in name: + sym["Default"] = "@@" in name + name, ver = re.split(r"@+", name) + sym["Name"] = name + sym["Version"] = ver + else: + sym["Default"] = True + sym["Version"] = None + syms.append(sym) + + if toc is None: + error(f"failed to analyze symbols in {f}") + + # Also collected demangled names + if syms: + out, _ = run(["c++filt"], "\n".join((sym["Name"] for sym in syms))) + out = out.rstrip("\n") # Some c++filts append newlines at the end + for i, name in enumerate(out.split("\n")): + syms[i]["Demangled Name"] = name + + return syms + def collect_relocs(f): - """Collect ELF dynamic relocs.""" - - out, _ = run(['readelf', '-rW', f]) - - toc = None - rels = [] - for line in out.splitlines(): - line = line.strip() - if not line: - toc = None - continue - if line == 'There are no relocations in this file.': - return [] - if re.match(r'^\s*Type[0-9]:', line): # Spurious lines for MIPS - continue - if re.match(r'^\s*Offset', line): # Header? - if toc is not None: - error("multiple headers in output of readelf") - words = re.split(r'\s\s+', line) # "Symbol's Name + Addend" - toc = make_toc(words) - elif toc is not None: - line = re.sub(r' \+ ', '+', line) - words = re.split(r'\s+', line) - rel = parse_row(words, toc, ['Offset', 'Info']) - rels.append(rel) - # Split symbolic representation - sym_name = 'Symbol\'s Name + Addend' - if sym_name not in rel and 'Symbol\'s Name' in rel: - # Adapt to different versions of readelf - rel[sym_name] = rel['Symbol\'s Name'] + '+0' - if rel[sym_name]: - p = rel[sym_name].split('+') - if len(p) == 1: - p = ['', p[0]] - rel[sym_name] = (p[0], int(p[1], 16)) - - if toc is None: - error(f"failed to analyze relocations in {f}") - - return rels + """Collect ELF dynamic relocs.""" + + out, _ = run(["readelf", "-rW", f]) + + toc = None + rels = [] + for line in out.splitlines(): + line = line.strip() + if not line: + toc = None + continue + if line == "There are no relocations in this file.": + return [] + if re.match(r"^\s*Type[0-9]:", line): # Spurious lines for MIPS + continue + if re.match(r"^\s*Offset", line): # Header? + if toc is not None: + error("multiple headers in output of readelf") + words = re.split(r"\s\s+", line) # "Symbol's Name + Addend" + toc = make_toc(words) + elif toc is not None: + line = re.sub(r" \+ ", "+", line) + words = re.split(r"\s+", line) + rel = parse_row(words, toc, ["Offset", "Info"]) + rels.append(rel) + # Split symbolic representation + sym_name = "Symbol's Name + Addend" + if sym_name not in rel and "Symbol's Name" in rel: + # Adapt to different versions of readelf + rel[sym_name] = rel["Symbol's Name"] + "+0" + if rel[sym_name]: + p = rel[sym_name].split("+") + if len(p) == 1: + p = ["", p[0]] + rel[sym_name] = (p[0], int(p[1], 16)) + + if toc is None: + error(f"failed to analyze relocations in {f}") + + return rels + def collect_sections(f): - """Collect section info from ELF.""" - - out, _ = run(['readelf', '-SW', f]) - - toc = None - sections = [] - for line in out.splitlines(): - line = line.strip() - if not line: - continue - line = re.sub(r'\[\s+', '[', line) - words = re.split(r' +', line) - if line.startswith('[Nr]'): # Header? - if toc is not None: - error("multiple headers in output of readelf") - toc = make_toc(words, {'Addr' : 'Address'}) - elif line.startswith('[') and toc is not None: - sec = parse_row(words, toc, ['Address', 'Off', 'Size']) - if 'A' in sec['Flg']: # Allocatable section? - sections.append(sec) - - if toc is None: - error(f"failed to analyze sections in {f}") - - return sections + """Collect section info from ELF.""" + + out, _ = run(["readelf", "-SW", f]) + + toc = None + sections = [] + for line in out.splitlines(): + line = line.strip() + if not line: + continue + line = re.sub(r"\[\s+", "[", line) + words = re.split(r" +", line) + if line.startswith("[Nr]"): # Header? + if toc is not None: + error("multiple headers in output of readelf") + toc = make_toc(words, {"Addr": "Address"}) + elif line.startswith("[") and toc is not None: + sec = parse_row(words, toc, ["Address", "Off", "Size"]) + if "A" in sec["Flg"]: # Allocatable section? + sections.append(sec) + + if toc is None: + error(f"failed to analyze sections in {f}") + + return sections + def read_unrelocated_data(input_name, syms, secs): - """Collect unrelocated data from ELF.""" - data = {} - with open(input_name, 'rb') as f: - def is_symbol_in_section(sym, sec): - sec_end = sec['Address'] + sec['Size'] - is_start_in_section = sec['Address'] <= sym['Value'] < sec_end - is_end_in_section = sym['Value'] + sym['Size'] <= sec_end - return is_start_in_section and is_end_in_section - for name, s in sorted(syms.items(), key=lambda s: s[1]['Value']): - # TODO: binary search (bisect) - sec = [sec for sec in secs if is_symbol_in_section(s, sec)] - if len(sec) != 1: - error(f"failed to locate section for interval [{s['Value']:x}, {s['Value'] + s['Size']:x})") - sec = sec[0] - f.seek(sec['Off']) - data[name] = f.read(s['Size']) - return data + """Collect unrelocated data from ELF.""" + data = {} + with open(input_name, "rb") as f: + + def is_symbol_in_section(sym, sec): + sec_end = sec["Address"] + sec["Size"] + is_start_in_section = sec["Address"] <= sym["Value"] < sec_end + is_end_in_section = sym["Value"] + sym["Size"] <= sec_end + return is_start_in_section and is_end_in_section + + for name, s in sorted(syms.items(), key=lambda s: s[1]["Value"]): + # TODO: binary search (bisect) + sec = [sec for sec in secs if is_symbol_in_section(s, sec)] + if len(sec) != 1: + error( + f"failed to locate section for interval [{s['Value']:x}, {s['Value'] + s['Size']:x})" + ) + sec = sec[0] + f.seek(sec["Off"]) + data[name] = f.read(s["Size"]) + return data + def collect_relocated_data(syms, bites, rels, ptr_size, reloc_types): - """Identify relocations for each symbol""" - data = {} - for name, s in sorted(syms.items()): - b = bites.get(name) - assert b is not None - if s['Demangled Name'].startswith('typeinfo name'): - data[name] = [('byte', int(x)) for x in b] - continue - data[name] = [] - for i in range(0, len(b), ptr_size): - val = int.from_bytes(b[i*ptr_size:(i + 1)*ptr_size], byteorder='little') - data[name].append(('offset', val)) - start = s['Value'] - finish = start + s['Size'] - # TODO: binary search (bisect) - for rel in rels: - if rel['Type'] in reloc_types and start <= rel['Offset'] < finish: - i = (rel['Offset'] - start) // ptr_size - assert i < len(data[name]) - data[name][i] = 'reloc', rel - return data + """Identify relocations for each symbol""" + data = {} + for name, s in sorted(syms.items()): + b = bites.get(name) + assert b is not None + if s["Demangled Name"].startswith("typeinfo name"): + data[name] = [("byte", int(x)) for x in b] + continue + data[name] = [] + for i in range(0, len(b), ptr_size): + val = int.from_bytes( + b[i * ptr_size : (i + 1) * ptr_size], byteorder="little" + ) + data[name].append(("offset", val)) + start = s["Value"] + finish = start + s["Size"] + # TODO: binary search (bisect) + for rel in rels: + if rel["Type"] in reloc_types and start <= rel["Offset"] < finish: + i = (rel["Offset"] - start) // ptr_size + assert i < len(data[name]) + data[name][i] = "reloc", rel + return data + def generate_vtables(cls_tables, cls_syms, cls_data): - """Generate code for vtables""" - c_types = { - 'reloc' : 'const void *', - 'byte' : 'unsigned char', - 'offset' : 'size_t' - } - - ss = [] - ss.append('''\ + """Generate code for vtables""" + c_types = {"reloc": "const void *", "byte": "unsigned char", "offset": "size_t"} + + ss = [] + ss.append("""\ #ifdef __cplusplus extern "C" { #endif -''') +""") - # Print externs + # Print externs - printed = set() - for name, data in sorted(cls_data.items()): - for typ, val in data: - if typ != 'reloc': - continue - sym_name, addend = val['Symbol\'s Name + Addend'] - sym_name = re.sub(r'@.*', '', sym_name) # Can we pin version in C? - if sym_name not in cls_syms and sym_name not in printed: - ss.append(f'''\ + printed = set() + for name, data in sorted(cls_data.items()): + for typ, val in data: + if typ != "reloc": + continue + sym_name, addend = val["Symbol's Name + Addend"] + sym_name = re.sub(r"@.*", "", sym_name) # Can we pin version in C? + if sym_name not in cls_syms and sym_name not in printed: + ss.append(f"""\ extern const char {sym_name}[]; -''') +""") - # Collect variable infos + # Collect variable infos - code_info = {} + code_info = {} - for name, s in sorted(cls_syms.items()): - data = cls_data[name] - if s['Demangled Name'].startswith('typeinfo name'): - declarator = 'const unsigned char %s[]' - else: - field_types = (f'{c_types[typ]} field_{i};' for i, (typ, _) in enumerate(data)) - declarator = 'const struct { %s } %%s' % ' '.join(field_types) # pylint: disable=C0209 # consider-using-f-string - vals = [] - for typ, val in data: - if typ != 'reloc': - vals.append(str(val) + 'UL') - else: - sym_name, addend = val['Symbol\'s Name + Addend'] - sym_name = re.sub(r'@.*', '', sym_name) # Can we pin version in C? - vals.append(f'(const char *)&{sym_name} + {addend}') - code_info[name] = (declarator, '{ %s }' % ', '.join(vals)) # pylint: disable= C0209 # consider-using-f-string - - # Print declarations - - for name, (decl, _) in sorted(code_info.items()): - type_name = name + '_type' - type_decl = decl % type_name - ss.append(f'''\ + for name, s in sorted(cls_syms.items()): + data = cls_data[name] + if s["Demangled Name"].startswith("typeinfo name"): + declarator = "const unsigned char %s[]" + else: + field_types = ( + f"{c_types[typ]} field_{i};" for i, (typ, _) in enumerate(data) + ) + declarator = "const struct { %s } %%s" % " ".join(field_types) # pylint: disable=C0209 # consider-using-f-string + vals = [] + for typ, val in data: + if typ != "reloc": + vals.append(str(val) + "UL") + else: + sym_name, addend = val["Symbol's Name + Addend"] + sym_name = re.sub(r"@.*", "", sym_name) # Can we pin version in C? + vals.append(f"(const char *)&{sym_name} + {addend}") + code_info[name] = (declarator, "{ %s }" % ", ".join(vals)) # pylint: disable= C0209 # consider-using-f-string + + # Print declarations + + for name, (decl, _) in sorted(code_info.items()): + type_name = name + "_type" + type_decl = decl % type_name + ss.append(f"""\ typedef {type_decl}; extern __attribute__((weak)) {type_name} {name}; -''') +""") - # Print definitions + # Print definitions - for name, (_, init) in sorted(code_info.items()): - type_name = name + '_type' - ss.append(f'''\ + for name, (_, init) in sorted(code_info.items()): + type_name = name + "_type" + ss.append(f"""\ const {type_name} {name} = {init}; -''') +""") - ss.append('''\ + ss.append("""\ #ifdef __cplusplus } // extern "C" #endif -''') +""") + + return "".join(ss) - return ''.join(ss) def read_soname(f): - """Read ELF's SONAME.""" + """Read ELF's SONAME.""" + + out, _ = run(["readelf", "-d", f]) - out, _ = run(['readelf', '-d', f]) + for line in out.splitlines(): + line = line.strip() + if not line: + continue + # 0x000000000000000e (SONAME) Library soname: [libndp.so.0] + soname_match = re.search(r"\(SONAME\).*\[(.+)\]", line) + if soname_match is not None: + return soname_match[1] - for line in out.splitlines(): - line = line.strip() - if not line: - continue - # 0x000000000000000e (SONAME) Library soname: [libndp.so.0] - soname_match = re.search(r'\(SONAME\).*\[(.+)\]', line) - if soname_match is not None: - return soname_match[1] + return None - return None def main(): - """Driver function""" - parser = argparse.ArgumentParser(description="Generate wrappers for shared library functions.", - formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=f"""\ + """Driver function""" + parser = argparse.ArgumentParser( + description="Generate wrappers for shared library functions.", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=f"""\ Examples: $ python3 {me} /usr/lib/x86_64-linux-gnu/libaccountsservice.so.0 Generating libaccountsservice.so.0.tramp.S... Generating libaccountsservice.so.0.init.c... -""") - - parser.add_argument('library', - metavar='LIB', - help="Library to be wrapped.") - parser.add_argument('--verbose', '-v', - help="Print diagnostic info", - action='count', - default=0) - parser.add_argument('--dlopen', - help="Emit dlopen call (default)", - dest='dlopen', action='store_true', default=True) - parser.add_argument('--no-dlopen', - help="Do not emit dlopen call (user must load/unload library himself)", - dest='dlopen', action='store_false') - parser.add_argument('--dlopen-callback', - help="Call user-provided custom callback to load library instead of dlopen", - default='') - parser.add_argument('--dlsym-callback', - help="Call user-provided custom callback to resolve a symbol, " - "instead of dlsym", - default='') - parser.add_argument('--library-load-name', - help="Use custom name for dlopened library (default is SONAME)") - parser.add_argument('--lazy-load', - help="Load library on first call to any of it's functions (default)", - dest='lazy_load', action='store_true', default=True) - parser.add_argument('--no-lazy-load', - help="Load library at program start", - dest='lazy_load', action='store_false') - parser.add_argument('--vtables', - help="Intercept virtual tables (EXPERIMENTAL)", - dest='vtables', action='store_true', default=False) - parser.add_argument('--no-vtables', - help="Do not intercept virtual tables (default)", - dest='vtables', action='store_false') - parser.add_argument('--no-weak-symbols', - help="Don't bind weak symbols", dest='no_weak_symbols', - action='store_true', default=False) - parser.add_argument('--target', - help="Target platform triple e.g. x86_64-unknown-linux-gnu or arm-none-eabi " - "(atm x86_64, i[0-9]86, arm/armhf/armeabi, aarch64/armv8, " - "mips/mipsel, mips64/mip64el and e2k are supported)", - default=os.uname()[-1]) - parser.add_argument('--symbol-list', - help="Path to file with symbols that should be present in wrapper " - "(all by default)") - parser.add_argument('--symbol-prefix', - metavar='PFX', - help="Prefix wrapper symbols with PFX", - default='') - parser.add_argument('-q', '--quiet', - help="Do not print progress info", - action='store_true') - parser.add_argument('--outdir', '-o', - help="Path to create wrapper at", - default='./') - - args = parser.parse_args() - - input_name = args.library - verbose = args.verbose - dlopen_callback = args.dlopen_callback - dlsym_callback = args.dlsym_callback - dlopen = args.dlopen - lazy_load = args.lazy_load - if args.target.startswith('arm'): - target = 'arm' # Handle armhf-..., armel-... - elif re.match(r'^i[0-9]86', args.target): - target = 'i386' - elif args.target.startswith('mips64'): - target = 'mips64' # Handle mips64-..., mips64el-..., mips64le-... - elif args.target.startswith('mips'): - target = 'mips' # Handle mips-..., mipsel-..., mipsle-... - else: - target = args.target.split('-')[0] - quiet = args.quiet - outdir = args.outdir - - if args.symbol_list is None: - funs = None - else: - with open(args.symbol_list, 'r') as f: - funs = [] - for line in re.split(r'\r?\n', f.read()): - line = re.sub(r'#.*', '', line) - line = line.strip() - if line: - funs.append(line) +""", + ) + + parser.add_argument("library", metavar="LIB", help="Library to be wrapped.") + parser.add_argument( + "--verbose", "-v", help="Print diagnostic info", action="count", default=0 + ) + parser.add_argument( + "--dlopen", + help="Emit dlopen call (default)", + dest="dlopen", + action="store_true", + default=True, + ) + parser.add_argument( + "--no-dlopen", + help="Do not emit dlopen call (user must load/unload library himself)", + dest="dlopen", + action="store_false", + ) + parser.add_argument( + "--dlopen-callback", + help="Call user-provided custom callback to load library instead of dlopen", + default="", + ) + parser.add_argument( + "--dlsym-callback", + help="Call user-provided custom callback to resolve a symbol, instead of dlsym", + default="", + ) + parser.add_argument( + "--library-load-name", + help="Use custom name for dlopened library (default is SONAME)", + ) + parser.add_argument( + "--lazy-load", + help="Load library on first call to any of it's functions (default)", + dest="lazy_load", + action="store_true", + default=True, + ) + parser.add_argument( + "--no-lazy-load", + help="Load library at program start", + dest="lazy_load", + action="store_false", + ) + parser.add_argument( + "--vtables", + help="Intercept virtual tables (EXPERIMENTAL)", + dest="vtables", + action="store_true", + default=False, + ) + parser.add_argument( + "--no-vtables", + help="Do not intercept virtual tables (default)", + dest="vtables", + action="store_false", + ) + parser.add_argument( + "--no-weak-symbols", + help="Don't bind weak symbols", + dest="no_weak_symbols", + action="store_true", + default=False, + ) + parser.add_argument( + "--target", + help="Target platform triple e.g. x86_64-unknown-linux-gnu or arm-none-eabi " + "(atm x86_64, i[0-9]86, arm/armhf/armeabi, aarch64/armv8, " + "mips/mipsel, mips64/mip64el and e2k are supported)", + default=os.uname()[-1], + ) + parser.add_argument( + "--symbol-list", + help="Path to file with symbols that should be present in wrapper " + "(all by default)", + ) + parser.add_argument( + "--symbol-prefix", + metavar="PFX", + help="Prefix wrapper symbols with PFX", + default="", + ) + parser.add_argument( + "-q", "--quiet", help="Do not print progress info", action="store_true" + ) + parser.add_argument( + "--outdir", "-o", help="Path to create wrapper at", default="./" + ) + + args = parser.parse_args() + + input_name = args.library + verbose = args.verbose + dlopen_callback = args.dlopen_callback + dlsym_callback = args.dlsym_callback + dlopen = args.dlopen + lazy_load = args.lazy_load + if args.target.startswith("arm"): + target = "arm" # Handle armhf-..., armel-... + elif re.match(r"^i[0-9]86", args.target): + target = "i386" + elif args.target.startswith("mips64"): + target = "mips64" # Handle mips64-..., mips64el-..., mips64le-... + elif args.target.startswith("mips"): + target = "mips" # Handle mips-..., mipsel-..., mipsle-... + else: + target = args.target.split("-")[0] + quiet = args.quiet + outdir = args.outdir - if args.library_load_name is not None: - load_name = args.library_load_name - else: - load_name = read_soname(input_name) - if load_name is None: - load_name = os.path.basename(input_name) + if args.symbol_list is None: + funs = None + else: + with open(args.symbol_list, "r") as f: + funs = [] + for line in re.split(r"\r?\n", f.read()): + line = re.sub(r"#.*", "", line) + line = line.strip() + if line: + funs.append(line) + + if args.library_load_name is not None: + load_name = args.library_load_name + else: + load_name = read_soname(input_name) + if load_name is None: + load_name = os.path.basename(input_name) - # Collect target info + # Collect target info - target_dir = os.path.join(root, 'arch', target) + target_dir = os.path.join(root, "arch", target) - if not os.path.exists(target_dir): - error(f"unknown architecture '{target}'") + if not os.path.exists(target_dir): + error(f"unknown architecture '{target}'") - cfg = configparser.ConfigParser(inline_comment_prefixes=';') - cfg.read(target_dir + '/config.ini') + cfg = configparser.ConfigParser(inline_comment_prefixes=";") + cfg.read(target_dir + "/config.ini") - ptr_size = int(cfg['Arch']['PointerSize']) - symbol_reloc_types = set(re.split(r'\s*,\s*', cfg['Arch']['SymbolReloc'])) + ptr_size = int(cfg["Arch"]["PointerSize"]) + symbol_reloc_types = set(re.split(r"\s*,\s*", cfg["Arch"]["SymbolReloc"])) - def is_exported(s): - conditions = [ - s['Bind'] != 'LOCAL', - s['Type'] != 'NOTYPE', - s['Ndx'] != 'UND', - s['Name'] not in ['', '_init', '_fini']] - if args.no_weak_symbols: - conditions.append(s['Bind'] != 'WEAK') - return all(conditions) + def is_exported(s): + conditions = [ + s["Bind"] != "LOCAL", + s["Type"] != "NOTYPE", + s["Ndx"] != "UND", + s["Name"] not in ["", "_init", "_fini"], + ] + if args.no_weak_symbols: + conditions.append(s["Bind"] != "WEAK") + return all(conditions) - syms = list(filter(is_exported, collect_syms(input_name))) + syms = list(filter(is_exported, collect_syms(input_name))) - def is_data_symbol(s): - return (s['Type'] == 'OBJECT' + def is_data_symbol(s): + return ( + s["Type"] == "OBJECT" # Allow vtables if --vtables is on - and not (' for ' in s['Demangled Name'] and args.vtables)) - - exported_data = [s['Name'] for s in syms if is_data_symbol(s)] - if exported_data: - # TODO: we can generate wrappers for const data without relocations (or only code relocations) - warn(f"library '{input_name}' contains data symbols which won't be intercepted: " - + ', '.join(exported_data)) - - # Collect functions - # TODO: warn if user-specified functions are missing - - orig_funs = filter(lambda s: s['Type'] == 'FUNC', syms) - - all_funs = set() - warn_versioned = False - for s in orig_funs: - if not s['Default']: - # TODO: support versions - if not warn_versioned: - warn(f"library {input_name} contains versioned symbols which are NYI") - warn_versioned = True - if verbose: - print(f"Skipping versioned symbol {s['Name']}") - continue - all_funs.add(s['Name']) - - if funs is None: - funs = sorted(list(all_funs)) - if not funs and not quiet: - warn(f"no public functions were found in {input_name}") - else: - missing_funs = [name for name in funs if name not in all_funs] - if missing_funs: - warn("some user-specified functions are not present in library: " + ', '.join(missing_funs)) - funs = [name for name in funs if name in all_funs] - - if verbose: - print("Exported functions:") - for i, fun in enumerate(funs): - print(f" {i}: {fun}") - - # Collect vtables - - if args.vtables: - cls_tables = {} - cls_syms = {} - - for s in syms: - m = re.match(r'^(vtable|typeinfo|typeinfo name) for (.*)', s['Demangled Name']) - if m is not None and is_exported(s): - typ, cls = m.groups() - name = s['Name'] - cls_tables.setdefault(cls, {})[typ] = name - cls_syms[name] = s - - if verbose: - print("Exported classes:") - for cls, _ in sorted(cls_tables.items()): - print(f" {cls}") + and not (" for " in s["Demangled Name"] and args.vtables) + ) + + exported_data = [s["Name"] for s in syms if is_data_symbol(s)] + if exported_data: + # TODO: we can generate wrappers for const data without relocations (or only code relocations) + warn( + f"library '{input_name}' contains data symbols which won't be intercepted: " + + ", ".join(exported_data) + ) + + # Collect functions + # TODO: warn if user-specified functions are missing + + orig_funs = filter(lambda s: s["Type"] == "FUNC", syms) + + all_funs = set() + warn_versioned = False + for s in orig_funs: + if not s["Default"]: + # TODO: support versions + if not warn_versioned: + warn(f"library {input_name} contains versioned symbols which are NYI") + warn_versioned = True + if verbose: + print(f"Skipping versioned symbol {s['Name']}") + continue + all_funs.add(s["Name"]) + + if funs is None: + funs = sorted(list(all_funs)) + if not funs and not quiet: + warn(f"no public functions were found in {input_name}") + else: + missing_funs = [name for name in funs if name not in all_funs] + if missing_funs: + warn( + "some user-specified functions are not present in library: " + + ", ".join(missing_funs) + ) + funs = [name for name in funs if name in all_funs] - secs = collect_sections(input_name) if verbose: - print("Sections:") - for sec in secs: - print(f" {sec['Name']}: [{sec['Address']:x}, {sec['Address'] + sec['Size']:x}), " - f"at {sec['Off']:x}") + print("Exported functions:") + for i, fun in enumerate(funs): + print(f" {i}: {fun}") - bites = read_unrelocated_data(input_name, cls_syms, secs) + # Collect vtables - rels = collect_relocs(input_name) - if verbose: - print("Relocs:") - for rel in rels: - sym_add = rel['Symbol\'s Name + Addend'] - print(f" {rel['Offset']}: {sym_add}") - - cls_data = collect_relocated_data(cls_syms, bites, rels, ptr_size, symbol_reloc_types) - if verbose: - print("Class data:") - for name, data in sorted(cls_data.items()): - demangled_name = cls_syms[name]['Demangled Name'] - print(f" {name} ({demangled_name}):") - for typ, val in data: - print(" " + str(val if typ != 'reloc' else val['Symbol\'s Name + Addend'])) - - # Generate assembly code - - suffix = os.path.basename(input_name) - lib_suffix = re.sub(r'[^a-zA-Z_0-9]+', '_', suffix) - - tramp_file = f'{suffix}.tramp.S' - with open(os.path.join(outdir, tramp_file), 'w') as f: - if not quiet: - print(f"Generating {tramp_file}...") - with open(target_dir + '/table.S.tpl', 'r') as t: - table_text = string.Template(t.read()).substitute( - lib_suffix=lib_suffix, - table_size=ptr_size*(len(funs) + 1)) - f.write(table_text) - - with open(target_dir + '/trampoline.S.tpl', 'r') as t: - tramp_tpl = string.Template(t.read()) - - for i, name in enumerate(funs): - tramp_text = tramp_tpl.substitute( - lib_suffix=lib_suffix, - sym=args.symbol_prefix + name, - offset=i*ptr_size, - number=i) - f.write(tramp_text) - - # Generate C code - - init_file = f'{suffix}.init.c' - with open(os.path.join(outdir, init_file), 'w') as f: - if not quiet: - print(f"Generating {init_file}...") - with open(os.path.join(root, 'arch/common/init.c.tpl'), 'r') as t: - if funs: - sym_names = ',\n '.join(f'"{name}"' for name in funs) + ',' - else: - sym_names = '' - init_text = string.Template(t.read()).substitute( - lib_suffix=lib_suffix, - load_name=load_name, - dlopen_callback=dlopen_callback, - dlsym_callback=dlsym_callback, - has_dlopen_callback=int(bool(dlopen_callback)), - has_dlsym_callback=int(bool(dlsym_callback)), - no_dlopen=int(not dlopen), - lazy_load=int(lazy_load), - sym_names=sym_names) - f.write(init_text) if args.vtables: - vtable_text = generate_vtables(cls_tables, cls_syms, cls_data) - f.write(vtable_text) - -if __name__ == '__main__': - main() + cls_tables = {} + cls_syms = {} + + for s in syms: + m = re.match( + r"^(vtable|typeinfo|typeinfo name) for (.*)", s["Demangled Name"] + ) + if m is not None and is_exported(s): + typ, cls = m.groups() + name = s["Name"] + cls_tables.setdefault(cls, {})[typ] = name + cls_syms[name] = s + + if verbose: + print("Exported classes:") + for cls, _ in sorted(cls_tables.items()): + print(f" {cls}") + + secs = collect_sections(input_name) + if verbose: + print("Sections:") + for sec in secs: + print( + f" {sec['Name']}: [{sec['Address']:x}, {sec['Address'] + sec['Size']:x}), " + f"at {sec['Off']:x}" + ) + + bites = read_unrelocated_data(input_name, cls_syms, secs) + + rels = collect_relocs(input_name) + if verbose: + print("Relocs:") + for rel in rels: + sym_add = rel["Symbol's Name + Addend"] + print(f" {rel['Offset']}: {sym_add}") + + cls_data = collect_relocated_data( + cls_syms, bites, rels, ptr_size, symbol_reloc_types + ) + if verbose: + print("Class data:") + for name, data in sorted(cls_data.items()): + demangled_name = cls_syms[name]["Demangled Name"] + print(f" {name} ({demangled_name}):") + for typ, val in data: + print( + " " + + str(val if typ != "reloc" else val["Symbol's Name + Addend"]) + ) + + # Generate assembly code + + suffix = os.path.basename(input_name) + lib_suffix = re.sub(r"[^a-zA-Z_0-9]+", "_", suffix) + + tramp_file = f"{suffix}.tramp.S" + with open(os.path.join(outdir, tramp_file), "w") as f: + if not quiet: + print(f"Generating {tramp_file}...") + with open(target_dir + "/table.S.tpl", "r") as t: + table_text = string.Template(t.read()).substitute( + lib_suffix=lib_suffix, table_size=ptr_size * (len(funs) + 1) + ) + f.write(table_text) + + with open(target_dir + "/trampoline.S.tpl", "r") as t: + tramp_tpl = string.Template(t.read()) + + for i, name in enumerate(funs): + tramp_text = tramp_tpl.substitute( + lib_suffix=lib_suffix, + sym=args.symbol_prefix + name, + offset=i * ptr_size, + number=i, + ) + f.write(tramp_text) + + # Generate C code + + init_file = f"{suffix}.init.c" + with open(os.path.join(outdir, init_file), "w") as f: + if not quiet: + print(f"Generating {init_file}...") + with open(os.path.join(root, "arch/common/init.c.tpl"), "r") as t: + if funs: + sym_names = ",\n ".join(f'"{name}"' for name in funs) + "," + else: + sym_names = "" + init_text = string.Template(t.read()).substitute( + lib_suffix=lib_suffix, + load_name=load_name, + dlopen_callback=dlopen_callback, + dlsym_callback=dlsym_callback, + has_dlopen_callback=int(bool(dlopen_callback)), + has_dlsym_callback=int(bool(dlsym_callback)), + no_dlopen=int(not dlopen), + lazy_load=int(lazy_load), + sym_names=sym_names, + ) + f.write(init_text) + if args.vtables: + vtable_text = generate_vtables(cls_tables, cls_syms, cls_data) + f.write(vtable_text) + + +if __name__ == "__main__": + main() From 7a2b41edd3c21b71e4ee649ec2b7f37afcf5c8cf Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 20 Sep 2025 18:07:16 +0000 Subject: [PATCH 5/7] fix(training): use 'rmse' key for total loss instead of 'rmse_e' for energy loss Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com> --- deepmd/pd/train/training.py | 8 ++++---- deepmd/pt/train/training.py | 8 ++++---- deepmd/tf/train/trainer.py | 4 ++-- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/deepmd/pd/train/training.py b/deepmd/pd/train/training.py index 7720671df8..348f231575 100644 --- a/deepmd/pd/train/training.py +++ b/deepmd/pd/train/training.py @@ -863,8 +863,8 @@ def log_loss_valid(_task_key="Default"): if not self.multi_task: train_results = log_loss_train(loss, more_loss) # Check for NaN in total loss using CPU values from lcurve computation - if self.rank == 0 and "rmse_e" in train_results: - check_total_loss_nan(display_step_id, train_results["rmse_e"]) + if self.rank == 0 and "rmse" in train_results: + check_total_loss_nan(display_step_id, train_results["rmse"]) valid_results = log_loss_valid() if self.rank == 0: log.info( @@ -907,9 +907,9 @@ def log_loss_valid(_task_key="Default"): ) valid_results[_key] = log_loss_valid(_task_key=_key) # Check for NaN in total loss using CPU values from lcurve computation - if self.rank == 0 and "rmse_e" in train_results[_key]: + if self.rank == 0 and "rmse" in train_results[_key]: check_total_loss_nan( - display_step_id, train_results[_key]["rmse_e"] + display_step_id, train_results[_key]["rmse"] ) if self.rank == 0: log.info( diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 5713582a94..a63be6c57d 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -953,8 +953,8 @@ def log_loss_valid(_task_key: str = "Default") -> dict: if not self.multi_task: train_results = log_loss_train(loss, more_loss) # Check for NaN in total loss using CPU values from lcurve computation - if self.rank == 0 and "rmse_e" in train_results: - check_total_loss_nan(display_step_id, train_results["rmse_e"]) + if self.rank == 0 and "rmse" in train_results: + check_total_loss_nan(display_step_id, train_results["rmse"]) valid_results = log_loss_valid() if self.rank == 0: log.info( @@ -1004,9 +1004,9 @@ def log_loss_valid(_task_key: str = "Default") -> dict: ) valid_results[_key] = log_loss_valid(_task_key=_key) # Check for NaN in total loss using CPU values from lcurve computation - if self.rank == 0 and "rmse_e" in train_results[_key]: + if self.rank == 0 and "rmse" in train_results[_key]: check_total_loss_nan( - display_step_id, train_results[_key]["rmse_e"] + display_step_id, train_results[_key]["rmse"] ) if self.rank == 0: log.info( diff --git a/deepmd/tf/train/trainer.py b/deepmd/tf/train/trainer.py index 9fa3b4e323..0f26c00171 100644 --- a/deepmd/tf/train/trainer.py +++ b/deepmd/tf/train/trainer.py @@ -689,8 +689,8 @@ def valid_on_the_fly( current_lr = run_sess(self.sess, self.learning_rate) # Check for NaN in total loss before writing to file and saving checkpoint - # We check the main energy loss component that represents total training loss - check_total_loss_nan(cur_batch, train_results["rmse_e"]) + # We check the main total loss component that represents training loss + check_total_loss_nan(cur_batch, train_results["rmse"]) if print_header: self.print_header(fp, train_results, valid_results) From 22cb9ef4f4de5337bf0986298c63b4daa613c5d6 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 21 Sep 2025 07:51:44 +0000 Subject: [PATCH 6/7] fix: revert implib file and clean up redundant test code Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com> --- source/tests/common/test_nan_detector.py | 5 +---- source/tests/common/test_nan_integration.py | 19 +++++-------------- 2 files changed, 6 insertions(+), 18 deletions(-) diff --git a/source/tests/common/test_nan_detector.py b/source/tests/common/test_nan_detector.py index 9d9f0bf2f3..250f9205fb 100644 --- a/source/tests/common/test_nan_detector.py +++ b/source/tests/common/test_nan_detector.py @@ -22,10 +22,7 @@ def test_normal_values_pass(self): # Should not raise any exception for i, loss_val in enumerate(normal_losses): - try: - check_total_loss_nan(100 + i, loss_val) - except Exception as e: - self.fail(f"Normal values should not raise exception: {e}") + check_total_loss_nan(100 + i, loss_val) def test_nan_detection_raises_exception(self): """Test that NaN values trigger the proper exception.""" diff --git a/source/tests/common/test_nan_integration.py b/source/tests/common/test_nan_integration.py index 1b3ae1ff61..6a754d93f4 100644 --- a/source/tests/common/test_nan_integration.py +++ b/source/tests/common/test_nan_integration.py @@ -48,27 +48,18 @@ def test_logging_on_nan_detection(self, mock_log): def test_training_simulation_with_checkpoint_prevention(self): """Simulate the training checkpoint scenario to ensure NaN prevents saving.""" - - def mock_save_checkpoint(): - """Mock function that should not be called when NaN is detected.""" - raise AssertionError("Checkpoint should not be saved when NaN is detected!") - # Simulate the training flow: check total loss, then save checkpoint step_id = 1000 total_loss = float("nan") - # This should raise LossNaNError before checkpoint saving - with self.assertRaises(LossNaNError): + # This should raise LossNaNError, preventing any subsequent checkpoint saving + with self.assertRaises(LossNaNError) as context: check_total_loss_nan(step_id, total_loss) - # This line should never be reached - mock_save_checkpoint() # Verify the error contains expected information - try: - check_total_loss_nan(step_id, total_loss) - except LossNaNError as e: - self.assertIn("Training stopped to prevent wasting time", str(e)) - self.assertIn("corrupted parameters", str(e)) + exception = context.exception + self.assertIn("Training stopped to prevent wasting time", str(exception)) + self.assertIn("corrupted parameters", str(exception)) def test_realistic_training_scenario(self): """Test a more realistic training scenario with decreasing then NaN loss.""" From 0bebb06d374a378906be6df78277a96776f4d7ff Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sun, 21 Sep 2025 08:13:38 +0000 Subject: [PATCH 7/7] fix: properly revert implib file to exact original state Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com> --- source/3rdparty/implib/implib-gen.py | 1093 ++++++++++++-------------- 1 file changed, 508 insertions(+), 585 deletions(-) diff --git a/source/3rdparty/implib/implib-gen.py b/source/3rdparty/implib/implib-gen.py index 3a51be271d..86cfa77378 100755 --- a/source/3rdparty/implib/implib-gen.py +++ b/source/3rdparty/implib/implib-gen.py @@ -22,654 +22,577 @@ me = os.path.basename(__file__) root = os.path.dirname(__file__) - def warn(msg): - """Emits a nicely-decorated warning.""" - sys.stderr.write(f"{me}: warning: {msg}\n") - + """Emits a nicely-decorated warning.""" + sys.stderr.write(f'{me}: warning: {msg}\n') def error(msg): - """Emits a nicely-decorated error and exits.""" - sys.stderr.write(f"{me}: error: {msg}\n") - sys.exit(1) - - -def run(args, stdin=""): - """Runs external program and aborts on error.""" - env = os.environ.copy() - # Force English language - env["LC_ALL"] = "c" - try: - del env["LANG"] - except KeyError: - pass - with subprocess.Popen( - args, - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - env=env, - ) as p: - out, err = p.communicate(input=stdin.encode("utf-8")) - out = out.decode("utf-8") - err = err.decode("utf-8") - if p.returncode != 0 or err: - error(f"{args[0]} failed with retcode {p.returncode}:\n{err}") - return out, err - + """Emits a nicely-decorated error and exits.""" + sys.stderr.write(f'{me}: error: {msg}\n') + sys.exit(1) + +def run(args, stdin=''): + """Runs external program and aborts on error.""" + env = os.environ.copy() + # Force English language + env['LC_ALL'] = 'c' + try: + del env["LANG"] + except KeyError: + pass + with subprocess.Popen(args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, + stderr=subprocess.PIPE, env=env) as p: + out, err = p.communicate(input=stdin.encode('utf-8')) + out = out.decode('utf-8') + err = err.decode('utf-8') + if p.returncode != 0 or err: + error(f"{args[0]} failed with retcode {p.returncode}:\n{err}") + return out, err def make_toc(words, renames=None): - "Make an mapping of words to their indices in list" - renames = renames or {} - toc = {} - for i, n in enumerate(words): - name = renames.get(n, n) - toc[i] = name - return toc - + "Make an mapping of words to their indices in list" + renames = renames or {} + toc = {} + for i, n in enumerate(words): + name = renames.get(n, n) + toc[i] = name + return toc def parse_row(words, toc, hex_keys): - "Make a mapping from column names to values" - vals = {k: (words[i] if i < len(words) else "") for i, k in toc.items()} - for k in hex_keys: - if vals[k]: - vals[k] = int(vals[k], 16) - return vals - + "Make a mapping from column names to values" + vals = {k: (words[i] if i < len(words) else '') for i, k in toc.items()} + for k in hex_keys: + if vals[k]: + vals[k] = int(vals[k], 16) + return vals def collect_syms(f): - """Collect ELF dynamic symtab.""" - - # --dyn-syms does not always work for some reason so dump all symtabs - out, _ = run(["readelf", "-sW", f]) - - toc = None - syms = [] - syms_set = set() - for line in out.splitlines(): - line = line.strip() - if not line: - # Next symtab - toc = None - continue - words = re.split(r" +", line) - if line.startswith("Num"): # Header? - if toc is not None: - error("multiple headers in output of readelf") - # Colons are different across readelf versions so get rid of them. - toc = make_toc(map(lambda n: n.replace(":", ""), words)) - elif toc is not None: - sym = parse_row(words, toc, ["Value"]) - name = sym["Name"] - if not name: - continue - if name in syms_set: - continue - syms_set.add(name) - sym["Size"] = int(sym["Size"], 0) # Readelf is inconistent on Size format - if "@" in name: - sym["Default"] = "@@" in name - name, ver = re.split(r"@+", name) - sym["Name"] = name - sym["Version"] = ver - else: - sym["Default"] = True - sym["Version"] = None - syms.append(sym) - - if toc is None: - error(f"failed to analyze symbols in {f}") - - # Also collected demangled names - if syms: - out, _ = run(["c++filt"], "\n".join((sym["Name"] for sym in syms))) - out = out.rstrip("\n") # Some c++filts append newlines at the end - for i, name in enumerate(out.split("\n")): - syms[i]["Demangled Name"] = name - - return syms - + """Collect ELF dynamic symtab.""" + + # --dyn-syms does not always work for some reason so dump all symtabs + out, _ = run(['readelf', '-sW', f]) + + toc = None + syms = [] + syms_set = set() + for line in out.splitlines(): + line = line.strip() + if not line: + # Next symtab + toc = None + continue + words = re.split(r' +', line) + if line.startswith('Num'): # Header? + if toc is not None: + error("multiple headers in output of readelf") + # Colons are different across readelf versions so get rid of them. + toc = make_toc(map(lambda n: n.replace(':', ''), words)) + elif toc is not None: + sym = parse_row(words, toc, ['Value']) + name = sym['Name'] + if not name: + continue + if name in syms_set: + continue + syms_set.add(name) + sym['Size'] = int(sym['Size'], 0) # Readelf is inconistent on Size format + if '@' in name: + sym['Default'] = '@@' in name + name, ver = re.split(r'@+', name) + sym['Name'] = name + sym['Version'] = ver + else: + sym['Default'] = True + sym['Version'] = None + syms.append(sym) + + if toc is None: + error(f"failed to analyze symbols in {f}") + + # Also collected demangled names + if syms: + out, _ = run(['c++filt'], '\n'.join((sym['Name'] for sym in syms))) + out = out.rstrip("\n") # Some c++filts append newlines at the end + for i, name in enumerate(out.split("\n")): + syms[i]['Demangled Name'] = name + + return syms def collect_relocs(f): - """Collect ELF dynamic relocs.""" - - out, _ = run(["readelf", "-rW", f]) - - toc = None - rels = [] - for line in out.splitlines(): - line = line.strip() - if not line: - toc = None - continue - if line == "There are no relocations in this file.": - return [] - if re.match(r"^\s*Type[0-9]:", line): # Spurious lines for MIPS - continue - if re.match(r"^\s*Offset", line): # Header? - if toc is not None: - error("multiple headers in output of readelf") - words = re.split(r"\s\s+", line) # "Symbol's Name + Addend" - toc = make_toc(words) - elif toc is not None: - line = re.sub(r" \+ ", "+", line) - words = re.split(r"\s+", line) - rel = parse_row(words, toc, ["Offset", "Info"]) - rels.append(rel) - # Split symbolic representation - sym_name = "Symbol's Name + Addend" - if sym_name not in rel and "Symbol's Name" in rel: - # Adapt to different versions of readelf - rel[sym_name] = rel["Symbol's Name"] + "+0" - if rel[sym_name]: - p = rel[sym_name].split("+") - if len(p) == 1: - p = ["", p[0]] - rel[sym_name] = (p[0], int(p[1], 16)) - - if toc is None: - error(f"failed to analyze relocations in {f}") - - return rels - + """Collect ELF dynamic relocs.""" + + out, _ = run(['readelf', '-rW', f]) + + toc = None + rels = [] + for line in out.splitlines(): + line = line.strip() + if not line: + toc = None + continue + if line == 'There are no relocations in this file.': + return [] + if re.match(r'^\s*Type[0-9]:', line): # Spurious lines for MIPS + continue + if re.match(r'^\s*Offset', line): # Header? + if toc is not None: + error("multiple headers in output of readelf") + words = re.split(r'\s\s+', line) # "Symbol's Name + Addend" + toc = make_toc(words) + elif toc is not None: + line = re.sub(r' \+ ', '+', line) + words = re.split(r'\s+', line) + rel = parse_row(words, toc, ['Offset', 'Info']) + rels.append(rel) + # Split symbolic representation + sym_name = 'Symbol\'s Name + Addend' + if sym_name not in rel and 'Symbol\'s Name' in rel: + # Adapt to different versions of readelf + rel[sym_name] = rel['Symbol\'s Name'] + '+0' + if rel[sym_name]: + p = rel[sym_name].split('+') + if len(p) == 1: + p = ['', p[0]] + rel[sym_name] = (p[0], int(p[1], 16)) + + if toc is None: + error(f"failed to analyze relocations in {f}") + + return rels def collect_sections(f): - """Collect section info from ELF.""" - - out, _ = run(["readelf", "-SW", f]) - - toc = None - sections = [] - for line in out.splitlines(): - line = line.strip() - if not line: - continue - line = re.sub(r"\[\s+", "[", line) - words = re.split(r" +", line) - if line.startswith("[Nr]"): # Header? - if toc is not None: - error("multiple headers in output of readelf") - toc = make_toc(words, {"Addr": "Address"}) - elif line.startswith("[") and toc is not None: - sec = parse_row(words, toc, ["Address", "Off", "Size"]) - if "A" in sec["Flg"]: # Allocatable section? - sections.append(sec) - - if toc is None: - error(f"failed to analyze sections in {f}") - - return sections - + """Collect section info from ELF.""" + + out, _ = run(['readelf', '-SW', f]) + + toc = None + sections = [] + for line in out.splitlines(): + line = line.strip() + if not line: + continue + line = re.sub(r'\[\s+', '[', line) + words = re.split(r' +', line) + if line.startswith('[Nr]'): # Header? + if toc is not None: + error("multiple headers in output of readelf") + toc = make_toc(words, {'Addr' : 'Address'}) + elif line.startswith('[') and toc is not None: + sec = parse_row(words, toc, ['Address', 'Off', 'Size']) + if 'A' in sec['Flg']: # Allocatable section? + sections.append(sec) + + if toc is None: + error(f"failed to analyze sections in {f}") + + return sections def read_unrelocated_data(input_name, syms, secs): - """Collect unrelocated data from ELF.""" - data = {} - with open(input_name, "rb") as f: - - def is_symbol_in_section(sym, sec): - sec_end = sec["Address"] + sec["Size"] - is_start_in_section = sec["Address"] <= sym["Value"] < sec_end - is_end_in_section = sym["Value"] + sym["Size"] <= sec_end - return is_start_in_section and is_end_in_section - - for name, s in sorted(syms.items(), key=lambda s: s[1]["Value"]): - # TODO: binary search (bisect) - sec = [sec for sec in secs if is_symbol_in_section(s, sec)] - if len(sec) != 1: - error( - f"failed to locate section for interval [{s['Value']:x}, {s['Value'] + s['Size']:x})" - ) - sec = sec[0] - f.seek(sec["Off"]) - data[name] = f.read(s["Size"]) - return data - + """Collect unrelocated data from ELF.""" + data = {} + with open(input_name, 'rb') as f: + def is_symbol_in_section(sym, sec): + sec_end = sec['Address'] + sec['Size'] + is_start_in_section = sec['Address'] <= sym['Value'] < sec_end + is_end_in_section = sym['Value'] + sym['Size'] <= sec_end + return is_start_in_section and is_end_in_section + for name, s in sorted(syms.items(), key=lambda s: s[1]['Value']): + # TODO: binary search (bisect) + sec = [sec for sec in secs if is_symbol_in_section(s, sec)] + if len(sec) != 1: + error(f"failed to locate section for interval [{s['Value']:x}, {s['Value'] + s['Size']:x})") + sec = sec[0] + f.seek(sec['Off']) + data[name] = f.read(s['Size']) + return data def collect_relocated_data(syms, bites, rels, ptr_size, reloc_types): - """Identify relocations for each symbol""" - data = {} - for name, s in sorted(syms.items()): - b = bites.get(name) - assert b is not None - if s["Demangled Name"].startswith("typeinfo name"): - data[name] = [("byte", int(x)) for x in b] - continue - data[name] = [] - for i in range(0, len(b), ptr_size): - val = int.from_bytes( - b[i * ptr_size : (i + 1) * ptr_size], byteorder="little" - ) - data[name].append(("offset", val)) - start = s["Value"] - finish = start + s["Size"] - # TODO: binary search (bisect) - for rel in rels: - if rel["Type"] in reloc_types and start <= rel["Offset"] < finish: - i = (rel["Offset"] - start) // ptr_size - assert i < len(data[name]) - data[name][i] = "reloc", rel - return data - + """Identify relocations for each symbol""" + data = {} + for name, s in sorted(syms.items()): + b = bites.get(name) + assert b is not None + if s['Demangled Name'].startswith('typeinfo name'): + data[name] = [('byte', int(x)) for x in b] + continue + data[name] = [] + for i in range(0, len(b), ptr_size): + val = int.from_bytes(b[i*ptr_size:(i + 1)*ptr_size], byteorder='little') + data[name].append(('offset', val)) + start = s['Value'] + finish = start + s['Size'] + # TODO: binary search (bisect) + for rel in rels: + if rel['Type'] in reloc_types and start <= rel['Offset'] < finish: + i = (rel['Offset'] - start) // ptr_size + assert i < len(data[name]) + data[name][i] = 'reloc', rel + return data def generate_vtables(cls_tables, cls_syms, cls_data): - """Generate code for vtables""" - c_types = {"reloc": "const void *", "byte": "unsigned char", "offset": "size_t"} - - ss = [] - ss.append("""\ + """Generate code for vtables""" + c_types = { + 'reloc' : 'const void *', + 'byte' : 'unsigned char', + 'offset' : 'size_t' + } + + ss = [] + ss.append('''\ #ifdef __cplusplus extern "C" { #endif -""") +''') - # Print externs + # Print externs - printed = set() - for name, data in sorted(cls_data.items()): - for typ, val in data: - if typ != "reloc": - continue - sym_name, addend = val["Symbol's Name + Addend"] - sym_name = re.sub(r"@.*", "", sym_name) # Can we pin version in C? - if sym_name not in cls_syms and sym_name not in printed: - ss.append(f"""\ + printed = set() + for name, data in sorted(cls_data.items()): + for typ, val in data: + if typ != 'reloc': + continue + sym_name, addend = val['Symbol\'s Name + Addend'] + sym_name = re.sub(r'@.*', '', sym_name) # Can we pin version in C? + if sym_name not in cls_syms and sym_name not in printed: + ss.append(f'''\ extern const char {sym_name}[]; -""") +''') - # Collect variable infos + # Collect variable infos - code_info = {} + code_info = {} - for name, s in sorted(cls_syms.items()): - data = cls_data[name] - if s["Demangled Name"].startswith("typeinfo name"): - declarator = "const unsigned char %s[]" - else: - field_types = ( - f"{c_types[typ]} field_{i};" for i, (typ, _) in enumerate(data) - ) - declarator = "const struct { %s } %%s" % " ".join(field_types) # pylint: disable=C0209 # consider-using-f-string - vals = [] - for typ, val in data: - if typ != "reloc": - vals.append(str(val) + "UL") - else: - sym_name, addend = val["Symbol's Name + Addend"] - sym_name = re.sub(r"@.*", "", sym_name) # Can we pin version in C? - vals.append(f"(const char *)&{sym_name} + {addend}") - code_info[name] = (declarator, "{ %s }" % ", ".join(vals)) # pylint: disable= C0209 # consider-using-f-string - - # Print declarations - - for name, (decl, _) in sorted(code_info.items()): - type_name = name + "_type" - type_decl = decl % type_name - ss.append(f"""\ + for name, s in sorted(cls_syms.items()): + data = cls_data[name] + if s['Demangled Name'].startswith('typeinfo name'): + declarator = 'const unsigned char %s[]' + else: + field_types = (f'{c_types[typ]} field_{i};' for i, (typ, _) in enumerate(data)) + declarator = 'const struct { %s } %%s' % ' '.join(field_types) # pylint: disable=C0209 # consider-using-f-string + vals = [] + for typ, val in data: + if typ != 'reloc': + vals.append(str(val) + 'UL') + else: + sym_name, addend = val['Symbol\'s Name + Addend'] + sym_name = re.sub(r'@.*', '', sym_name) # Can we pin version in C? + vals.append(f'(const char *)&{sym_name} + {addend}') + code_info[name] = (declarator, '{ %s }' % ', '.join(vals)) # pylint: disable= C0209 # consider-using-f-string + + # Print declarations + + for name, (decl, _) in sorted(code_info.items()): + type_name = name + '_type' + type_decl = decl % type_name + ss.append(f'''\ typedef {type_decl}; extern __attribute__((weak)) {type_name} {name}; -""") +''') - # Print definitions + # Print definitions - for name, (_, init) in sorted(code_info.items()): - type_name = name + "_type" - ss.append(f"""\ + for name, (_, init) in sorted(code_info.items()): + type_name = name + '_type' + ss.append(f'''\ const {type_name} {name} = {init}; -""") +''') - ss.append("""\ + ss.append('''\ #ifdef __cplusplus } // extern "C" #endif -""") - - return "".join(ss) +''') + return ''.join(ss) def read_soname(f): - """Read ELF's SONAME.""" - - out, _ = run(["readelf", "-d", f]) + """Read ELF's SONAME.""" - for line in out.splitlines(): - line = line.strip() - if not line: - continue - # 0x000000000000000e (SONAME) Library soname: [libndp.so.0] - soname_match = re.search(r"\(SONAME\).*\[(.+)\]", line) - if soname_match is not None: - return soname_match[1] + out, _ = run(['readelf', '-d', f]) - return None + for line in out.splitlines(): + line = line.strip() + if not line: + continue + # 0x000000000000000e (SONAME) Library soname: [libndp.so.0] + soname_match = re.search(r'\(SONAME\).*\[(.+)\]', line) + if soname_match is not None: + return soname_match[1] + return None def main(): - """Driver function""" - parser = argparse.ArgumentParser( - description="Generate wrappers for shared library functions.", - formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=f"""\ + """Driver function""" + parser = argparse.ArgumentParser(description="Generate wrappers for shared library functions.", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=f"""\ Examples: $ python3 {me} /usr/lib/x86_64-linux-gnu/libaccountsservice.so.0 Generating libaccountsservice.so.0.tramp.S... Generating libaccountsservice.so.0.init.c... -""", - ) - - parser.add_argument("library", metavar="LIB", help="Library to be wrapped.") - parser.add_argument( - "--verbose", "-v", help="Print diagnostic info", action="count", default=0 - ) - parser.add_argument( - "--dlopen", - help="Emit dlopen call (default)", - dest="dlopen", - action="store_true", - default=True, - ) - parser.add_argument( - "--no-dlopen", - help="Do not emit dlopen call (user must load/unload library himself)", - dest="dlopen", - action="store_false", - ) - parser.add_argument( - "--dlopen-callback", - help="Call user-provided custom callback to load library instead of dlopen", - default="", - ) - parser.add_argument( - "--dlsym-callback", - help="Call user-provided custom callback to resolve a symbol, instead of dlsym", - default="", - ) - parser.add_argument( - "--library-load-name", - help="Use custom name for dlopened library (default is SONAME)", - ) - parser.add_argument( - "--lazy-load", - help="Load library on first call to any of it's functions (default)", - dest="lazy_load", - action="store_true", - default=True, - ) - parser.add_argument( - "--no-lazy-load", - help="Load library at program start", - dest="lazy_load", - action="store_false", - ) - parser.add_argument( - "--vtables", - help="Intercept virtual tables (EXPERIMENTAL)", - dest="vtables", - action="store_true", - default=False, - ) - parser.add_argument( - "--no-vtables", - help="Do not intercept virtual tables (default)", - dest="vtables", - action="store_false", - ) - parser.add_argument( - "--no-weak-symbols", - help="Don't bind weak symbols", - dest="no_weak_symbols", - action="store_true", - default=False, - ) - parser.add_argument( - "--target", - help="Target platform triple e.g. x86_64-unknown-linux-gnu or arm-none-eabi " - "(atm x86_64, i[0-9]86, arm/armhf/armeabi, aarch64/armv8, " - "mips/mipsel, mips64/mip64el and e2k are supported)", - default=os.uname()[-1], - ) - parser.add_argument( - "--symbol-list", - help="Path to file with symbols that should be present in wrapper " - "(all by default)", - ) - parser.add_argument( - "--symbol-prefix", - metavar="PFX", - help="Prefix wrapper symbols with PFX", - default="", - ) - parser.add_argument( - "-q", "--quiet", help="Do not print progress info", action="store_true" - ) - parser.add_argument( - "--outdir", "-o", help="Path to create wrapper at", default="./" - ) - - args = parser.parse_args() - - input_name = args.library - verbose = args.verbose - dlopen_callback = args.dlopen_callback - dlsym_callback = args.dlsym_callback - dlopen = args.dlopen - lazy_load = args.lazy_load - if args.target.startswith("arm"): - target = "arm" # Handle armhf-..., armel-... - elif re.match(r"^i[0-9]86", args.target): - target = "i386" - elif args.target.startswith("mips64"): - target = "mips64" # Handle mips64-..., mips64el-..., mips64le-... - elif args.target.startswith("mips"): - target = "mips" # Handle mips-..., mipsel-..., mipsle-... - else: - target = args.target.split("-")[0] - quiet = args.quiet - outdir = args.outdir +""") - if args.symbol_list is None: - funs = None - else: - with open(args.symbol_list, "r") as f: - funs = [] - for line in re.split(r"\r?\n", f.read()): - line = re.sub(r"#.*", "", line) - line = line.strip() - if line: - funs.append(line) - - if args.library_load_name is not None: - load_name = args.library_load_name - else: - load_name = read_soname(input_name) - if load_name is None: - load_name = os.path.basename(input_name) + parser.add_argument('library', + metavar='LIB', + help="Library to be wrapped.") + parser.add_argument('--verbose', '-v', + help="Print diagnostic info", + action='count', + default=0) + parser.add_argument('--dlopen', + help="Emit dlopen call (default)", + dest='dlopen', action='store_true', default=True) + parser.add_argument('--no-dlopen', + help="Do not emit dlopen call (user must load/unload library himself)", + dest='dlopen', action='store_false') + parser.add_argument('--dlopen-callback', + help="Call user-provided custom callback to load library instead of dlopen", + default='') + parser.add_argument('--dlsym-callback', + help="Call user-provided custom callback to resolve a symbol, " + "instead of dlsym", + default='') + parser.add_argument('--library-load-name', + help="Use custom name for dlopened library (default is SONAME)") + parser.add_argument('--lazy-load', + help="Load library on first call to any of it's functions (default)", + dest='lazy_load', action='store_true', default=True) + parser.add_argument('--no-lazy-load', + help="Load library at program start", + dest='lazy_load', action='store_false') + parser.add_argument('--vtables', + help="Intercept virtual tables (EXPERIMENTAL)", + dest='vtables', action='store_true', default=False) + parser.add_argument('--no-vtables', + help="Do not intercept virtual tables (default)", + dest='vtables', action='store_false') + parser.add_argument('--no-weak-symbols', + help="Don't bind weak symbols", dest='no_weak_symbols', + action='store_true', default=False) + parser.add_argument('--target', + help="Target platform triple e.g. x86_64-unknown-linux-gnu or arm-none-eabi " + "(atm x86_64, i[0-9]86, arm/armhf/armeabi, aarch64/armv8, " + "mips/mipsel, mips64/mip64el and e2k are supported)", + default=os.uname()[-1]) + parser.add_argument('--symbol-list', + help="Path to file with symbols that should be present in wrapper " + "(all by default)") + parser.add_argument('--symbol-prefix', + metavar='PFX', + help="Prefix wrapper symbols with PFX", + default='') + parser.add_argument('-q', '--quiet', + help="Do not print progress info", + action='store_true') + parser.add_argument('--outdir', '-o', + help="Path to create wrapper at", + default='./') + + args = parser.parse_args() + + input_name = args.library + verbose = args.verbose + dlopen_callback = args.dlopen_callback + dlsym_callback = args.dlsym_callback + dlopen = args.dlopen + lazy_load = args.lazy_load + if args.target.startswith('arm'): + target = 'arm' # Handle armhf-..., armel-... + elif re.match(r'^i[0-9]86', args.target): + target = 'i386' + elif args.target.startswith('mips64'): + target = 'mips64' # Handle mips64-..., mips64el-..., mips64le-... + elif args.target.startswith('mips'): + target = 'mips' # Handle mips-..., mipsel-..., mipsle-... + else: + target = args.target.split('-')[0] + quiet = args.quiet + outdir = args.outdir + + if args.symbol_list is None: + funs = None + else: + with open(args.symbol_list, 'r') as f: + funs = [] + for line in re.split(r'\r?\n', f.read()): + line = re.sub(r'#.*', '', line) + line = line.strip() + if line: + funs.append(line) + + if args.library_load_name is not None: + load_name = args.library_load_name + else: + load_name = read_soname(input_name) + if load_name is None: + load_name = os.path.basename(input_name) - # Collect target info + # Collect target info - target_dir = os.path.join(root, "arch", target) + target_dir = os.path.join(root, 'arch', target) - if not os.path.exists(target_dir): - error(f"unknown architecture '{target}'") + if not os.path.exists(target_dir): + error(f"unknown architecture '{target}'") - cfg = configparser.ConfigParser(inline_comment_prefixes=";") - cfg.read(target_dir + "/config.ini") + cfg = configparser.ConfigParser(inline_comment_prefixes=';') + cfg.read(target_dir + '/config.ini') - ptr_size = int(cfg["Arch"]["PointerSize"]) - symbol_reloc_types = set(re.split(r"\s*,\s*", cfg["Arch"]["SymbolReloc"])) + ptr_size = int(cfg['Arch']['PointerSize']) + symbol_reloc_types = set(re.split(r'\s*,\s*', cfg['Arch']['SymbolReloc'])) - def is_exported(s): - conditions = [ - s["Bind"] != "LOCAL", - s["Type"] != "NOTYPE", - s["Ndx"] != "UND", - s["Name"] not in ["", "_init", "_fini"], - ] - if args.no_weak_symbols: - conditions.append(s["Bind"] != "WEAK") - return all(conditions) + def is_exported(s): + conditions = [ + s['Bind'] != 'LOCAL', + s['Type'] != 'NOTYPE', + s['Ndx'] != 'UND', + s['Name'] not in ['', '_init', '_fini']] + if args.no_weak_symbols: + conditions.append(s['Bind'] != 'WEAK') + return all(conditions) - syms = list(filter(is_exported, collect_syms(input_name))) + syms = list(filter(is_exported, collect_syms(input_name))) - def is_data_symbol(s): - return ( - s["Type"] == "OBJECT" + def is_data_symbol(s): + return (s['Type'] == 'OBJECT' # Allow vtables if --vtables is on - and not (" for " in s["Demangled Name"] and args.vtables) - ) - - exported_data = [s["Name"] for s in syms if is_data_symbol(s)] - if exported_data: - # TODO: we can generate wrappers for const data without relocations (or only code relocations) - warn( - f"library '{input_name}' contains data symbols which won't be intercepted: " - + ", ".join(exported_data) - ) - - # Collect functions - # TODO: warn if user-specified functions are missing - - orig_funs = filter(lambda s: s["Type"] == "FUNC", syms) - - all_funs = set() - warn_versioned = False - for s in orig_funs: - if not s["Default"]: - # TODO: support versions - if not warn_versioned: - warn(f"library {input_name} contains versioned symbols which are NYI") - warn_versioned = True - if verbose: - print(f"Skipping versioned symbol {s['Name']}") - continue - all_funs.add(s["Name"]) - - if funs is None: - funs = sorted(list(all_funs)) - if not funs and not quiet: - warn(f"no public functions were found in {input_name}") - else: - missing_funs = [name for name in funs if name not in all_funs] - if missing_funs: - warn( - "some user-specified functions are not present in library: " - + ", ".join(missing_funs) - ) - funs = [name for name in funs if name in all_funs] + and not (' for ' in s['Demangled Name'] and args.vtables)) + + exported_data = [s['Name'] for s in syms if is_data_symbol(s)] + if exported_data: + # TODO: we can generate wrappers for const data without relocations (or only code relocations) + warn(f"library '{input_name}' contains data symbols which won't be intercepted: " + + ', '.join(exported_data)) + + # Collect functions + # TODO: warn if user-specified functions are missing + + orig_funs = filter(lambda s: s['Type'] == 'FUNC', syms) + + all_funs = set() + warn_versioned = False + for s in orig_funs: + if not s['Default']: + # TODO: support versions + if not warn_versioned: + warn(f"library {input_name} contains versioned symbols which are NYI") + warn_versioned = True + if verbose: + print(f"Skipping versioned symbol {s['Name']}") + continue + all_funs.add(s['Name']) + + if funs is None: + funs = sorted(list(all_funs)) + if not funs and not quiet: + warn(f"no public functions were found in {input_name}") + else: + missing_funs = [name for name in funs if name not in all_funs] + if missing_funs: + warn("some user-specified functions are not present in library: " + ', '.join(missing_funs)) + funs = [name for name in funs if name in all_funs] + + if verbose: + print("Exported functions:") + for i, fun in enumerate(funs): + print(f" {i}: {fun}") + + # Collect vtables + + if args.vtables: + cls_tables = {} + cls_syms = {} + + for s in syms: + m = re.match(r'^(vtable|typeinfo|typeinfo name) for (.*)', s['Demangled Name']) + if m is not None and is_exported(s): + typ, cls = m.groups() + name = s['Name'] + cls_tables.setdefault(cls, {})[typ] = name + cls_syms[name] = s + + if verbose: + print("Exported classes:") + for cls, _ in sorted(cls_tables.items()): + print(f" {cls}") + secs = collect_sections(input_name) if verbose: - print("Exported functions:") - for i, fun in enumerate(funs): - print(f" {i}: {fun}") + print("Sections:") + for sec in secs: + print(f" {sec['Name']}: [{sec['Address']:x}, {sec['Address'] + sec['Size']:x}), " + f"at {sec['Off']:x}") - # Collect vtables + bites = read_unrelocated_data(input_name, cls_syms, secs) + rels = collect_relocs(input_name) + if verbose: + print("Relocs:") + for rel in rels: + sym_add = rel['Symbol\'s Name + Addend'] + print(f" {rel['Offset']}: {sym_add}") + + cls_data = collect_relocated_data(cls_syms, bites, rels, ptr_size, symbol_reloc_types) + if verbose: + print("Class data:") + for name, data in sorted(cls_data.items()): + demangled_name = cls_syms[name]['Demangled Name'] + print(f" {name} ({demangled_name}):") + for typ, val in data: + print(" " + str(val if typ != 'reloc' else val['Symbol\'s Name + Addend'])) + + # Generate assembly code + + suffix = os.path.basename(input_name) + lib_suffix = re.sub(r'[^a-zA-Z_0-9]+', '_', suffix) + + tramp_file = f'{suffix}.tramp.S' + with open(os.path.join(outdir, tramp_file), 'w') as f: + if not quiet: + print(f"Generating {tramp_file}...") + with open(target_dir + '/table.S.tpl', 'r') as t: + table_text = string.Template(t.read()).substitute( + lib_suffix=lib_suffix, + table_size=ptr_size*(len(funs) + 1)) + f.write(table_text) + + with open(target_dir + '/trampoline.S.tpl', 'r') as t: + tramp_tpl = string.Template(t.read()) + + for i, name in enumerate(funs): + tramp_text = tramp_tpl.substitute( + lib_suffix=lib_suffix, + sym=args.symbol_prefix + name, + offset=i*ptr_size, + number=i) + f.write(tramp_text) + + # Generate C code + + init_file = f'{suffix}.init.c' + with open(os.path.join(outdir, init_file), 'w') as f: + if not quiet: + print(f"Generating {init_file}...") + with open(os.path.join(root, 'arch/common/init.c.tpl'), 'r') as t: + if funs: + sym_names = ',\n '.join(f'"{name}"' for name in funs) + ',' + else: + sym_names = '' + init_text = string.Template(t.read()).substitute( + lib_suffix=lib_suffix, + load_name=load_name, + dlopen_callback=dlopen_callback, + dlsym_callback=dlsym_callback, + has_dlopen_callback=int(bool(dlopen_callback)), + has_dlsym_callback=int(bool(dlsym_callback)), + no_dlopen=int(not dlopen), + lazy_load=int(lazy_load), + sym_names=sym_names) + f.write(init_text) if args.vtables: - cls_tables = {} - cls_syms = {} - - for s in syms: - m = re.match( - r"^(vtable|typeinfo|typeinfo name) for (.*)", s["Demangled Name"] - ) - if m is not None and is_exported(s): - typ, cls = m.groups() - name = s["Name"] - cls_tables.setdefault(cls, {})[typ] = name - cls_syms[name] = s - - if verbose: - print("Exported classes:") - for cls, _ in sorted(cls_tables.items()): - print(f" {cls}") - - secs = collect_sections(input_name) - if verbose: - print("Sections:") - for sec in secs: - print( - f" {sec['Name']}: [{sec['Address']:x}, {sec['Address'] + sec['Size']:x}), " - f"at {sec['Off']:x}" - ) - - bites = read_unrelocated_data(input_name, cls_syms, secs) - - rels = collect_relocs(input_name) - if verbose: - print("Relocs:") - for rel in rels: - sym_add = rel["Symbol's Name + Addend"] - print(f" {rel['Offset']}: {sym_add}") - - cls_data = collect_relocated_data( - cls_syms, bites, rels, ptr_size, symbol_reloc_types - ) - if verbose: - print("Class data:") - for name, data in sorted(cls_data.items()): - demangled_name = cls_syms[name]["Demangled Name"] - print(f" {name} ({demangled_name}):") - for typ, val in data: - print( - " " - + str(val if typ != "reloc" else val["Symbol's Name + Addend"]) - ) - - # Generate assembly code - - suffix = os.path.basename(input_name) - lib_suffix = re.sub(r"[^a-zA-Z_0-9]+", "_", suffix) - - tramp_file = f"{suffix}.tramp.S" - with open(os.path.join(outdir, tramp_file), "w") as f: - if not quiet: - print(f"Generating {tramp_file}...") - with open(target_dir + "/table.S.tpl", "r") as t: - table_text = string.Template(t.read()).substitute( - lib_suffix=lib_suffix, table_size=ptr_size * (len(funs) + 1) - ) - f.write(table_text) - - with open(target_dir + "/trampoline.S.tpl", "r") as t: - tramp_tpl = string.Template(t.read()) - - for i, name in enumerate(funs): - tramp_text = tramp_tpl.substitute( - lib_suffix=lib_suffix, - sym=args.symbol_prefix + name, - offset=i * ptr_size, - number=i, - ) - f.write(tramp_text) - - # Generate C code - - init_file = f"{suffix}.init.c" - with open(os.path.join(outdir, init_file), "w") as f: - if not quiet: - print(f"Generating {init_file}...") - with open(os.path.join(root, "arch/common/init.c.tpl"), "r") as t: - if funs: - sym_names = ",\n ".join(f'"{name}"' for name in funs) + "," - else: - sym_names = "" - init_text = string.Template(t.read()).substitute( - lib_suffix=lib_suffix, - load_name=load_name, - dlopen_callback=dlopen_callback, - dlsym_callback=dlsym_callback, - has_dlopen_callback=int(bool(dlopen_callback)), - has_dlsym_callback=int(bool(dlsym_callback)), - no_dlopen=int(not dlopen), - lazy_load=int(lazy_load), - sym_names=sym_names, - ) - f.write(init_text) - if args.vtables: - vtable_text = generate_vtables(cls_tables, cls_syms, cls_data) - f.write(vtable_text) - - -if __name__ == "__main__": - main() + vtable_text = generate_vtables(cls_tables, cls_syms, cls_data) + f.write(vtable_text) + +if __name__ == '__main__': + main()