diff --git a/ci_scripts/check_api_parameters.py b/ci_scripts/check_api_parameters.py index 2e6d8b18e2e..fb36b54fdb8 100644 --- a/ci_scripts/check_api_parameters.py +++ b/ci_scripts/check_api_parameters.py @@ -20,6 +20,8 @@ import re import sys +import paddle # noqa: F401 + def add_path(path): if path not in sys.path: @@ -30,6 +32,7 @@ def add_path(path): # Add docs/api to PYTHONPATH add_path(osp.abspath(osp.join(this_dir, "..", "docs", "api"))) from extract_api_from_docs import extract_params_desc_from_rst_file +from gen_doc import gen_functions_args_str arguments = [ # flags, dest, type, default, help @@ -60,37 +63,13 @@ def _check_params_in_description(rstfilename, paramstr): params_in_title = [] if paramstr: fake_func = ast.parse(f"def fake_func({paramstr}): pass") - # Iterate over all in_title parameters - num_defaults = len(fake_func.body[0].args.defaults) - num_args = len(fake_func.body[0].args.args) - # args & defaults - for i, arg in enumerate(fake_func.body[0].args.args): - if i >= num_args - num_defaults: - default_value = fake_func.body[0].args.defaults[ - i - (num_args - num_defaults) - ] - params_in_title.append(f"{arg.arg}={default_value}") - else: - params_in_title.append(arg.arg) - # posonlyargs - for arg in fake_func.body[0].args.posonlyargs: - params_in_title.append(arg.arg) - # vararg(*args) - if fake_func.body[0].args.vararg: - params_in_title.append(fake_func.body[0].args.vararg.arg) - # kwonlyargs & kw_defaults - for i, arg in enumerate(fake_func.body[0].args.kwonlyargs): - if ( - i < len(fake_func.body[0].args.kw_defaults) - and fake_func.body[0].args.kw_defaults[i] is not None - ): - default_value = fake_func.body[0].args.kw_defaults[i] - params_in_title.append(f"{arg.arg}={default_value}") - else: - params_in_title.append(arg.arg) - # **kwargs - if fake_func.body[0].args.kwarg: - params_in_title.append(fake_func.body[0].args.kwarg.arg) + func_node = fake_func.body[0] + func_args_str = gen_functions_args_str(func_node) + params_in_title = func_args_str.split(", ") + if "/" in params_in_title: + params_in_title.remove("/") + if "*" in params_in_title: + params_in_title.remove("*") funcdescnode = extract_params_desc_from_rst_file(rstfilename) if funcdescnode: @@ -107,7 +86,9 @@ def _check_params_in_description(rstfilename, paramstr): ) else: info = f"The number of params in title does not match the params in description: {len(params_in_title)} != {len(items)}." - print(f"check failed (parammeters description): {rstfilename}") + print( + f"check failed with different nums (parammeters description): {rstfilename}" + ) else: for i in range(len(items)): pname_in_title = params_in_title[i].split("=")[0].strip() @@ -141,11 +122,25 @@ def _check_params_in_description(rstfilename, paramstr): def _check_params_in_description_with_fullargspec(rstfilename, funcname): flag = True info = "" - funcspec = inspect.getfullargspec(eval(funcname)) + try: + func = eval(funcname) + except AttributeError: + flag = False + info = f"function {funcname} in rst file {rstfilename} not found in paddle module, please check it." + return flag, info + source = inspect.getsource(func) + + tree = ast.parse(source) + func_node = tree.body[0] + params_inspec = gen_functions_args_str(func_node).split(", ") + if "/" in params_inspec: + params_inspec.remove("/") + if "*" in params_inspec: + params_inspec.remove("*") + funcdescnode = extract_params_desc_from_rst_file(rstfilename) if funcdescnode: items = funcdescnode.children[1].children[0].children - params_inspec = funcspec.args if len(items) != len(params_inspec): flag = False info = f"check_with_fullargspec failed (parammeters description): {rstfilename}" @@ -171,10 +166,10 @@ def _check_params_in_description_with_fullargspec(rstfilename, funcname): f"check failed (parammeters description): {rstfilename}, param name not found in {i} paragraph." ) else: - if funcspec.args: + if params_inspec: info = "params section not found in description, check it please." print( - f"check failed (parameters description not found): {rstfilename}, {funcspec.args}." + f"check failed (parameters description not found): {rstfilename}, {params_inspec}." ) flag = False return flag, info @@ -200,16 +195,45 @@ def check_api_parameters(rstfiles, apiinfo): print(f"checking : {rstfile}") with open(rstfilename, "r") as rst_fobj: func_found = False + is_first_line = True + api_label = None for line in rst_fobj: + if is_first_line: + api_label = ( + line.strip() + .removeprefix(".. _cn_api_") + .removesuffix(":") + .removesuffix("__upper") + ) + is_first_line = False mo = pat.match(line) if mo: func_found = True functype = mo.group(1) if functype not in ("function", "method"): + # TODO: check class method check_passed.append(rstfile) continue funcname = mo.group(2) paramstr = mo.group(3) + func_to_label = funcname.replace(".", "_") + + # check same as the api_label + if func_to_label != api_label: + # if funcname is a function, try to back to class + try: + obj = eval(funcname) + except AttributeError: + obj = None + if obj is not None and inspect.isfunction(obj): + class_name = ".".join(funcname.split(".")[:-1]) + class_to_label = class_name.replace(".", "_") + if class_to_label != api_label: + flag = False + info = f"funcname in title is not same as the label name: {funcname} != {api_label}." + check_failed[rstfile] = info + continue + flag = False func_found_in_json = False for apiobj in apiinfo.values(): diff --git a/ci_scripts/check_api_parameters.sh b/ci_scripts/check_api_parameters.sh index 8cb6c75d1bc..f1567526c08 100644 --- a/ci_scripts/check_api_parameters.sh +++ b/ci_scripts/check_api_parameters.sh @@ -35,7 +35,7 @@ if [ "$need_check_files" = "" ] then echo "need check files is empty, skip api parameters check" else - python check_api_parameters.py --rst-files "${need_check_files}" --api-info $2 + python check_api_parameters_patch.py --rst-files "${need_check_files}" --api-info $2 if [ $? -ne 0 ];then set +x echo "************************************************************************************" diff --git a/ci_scripts/check_api_parameters_patch.py b/ci_scripts/check_api_parameters_patch.py new file mode 100644 index 00000000000..4425ec76d6d --- /dev/null +++ b/ci_scripts/check_api_parameters_patch.py @@ -0,0 +1,336 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import argparse +import ast +import inspect +import json +import logging +import re +import sys +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +import paddle # noqa: F401 + +# Set up logging +logging.basicConfig( + level=logging.INFO, + format="%(levelname)s: %(message)s", +) +logger = logging.getLogger(__name__) + + +# Add Project Paths +THIS_DIR = Path(__file__).resolve().parent +API_DOC_TOOLS_PATH = THIS_DIR.parent / "docs" / "api" + + +def add_path(path: str): + if path not in sys.path: + sys.path.insert(0, path) + + +add_path(str(API_DOC_TOOLS_PATH)) + +from extract_api_from_docs import extract_params_desc_from_rst_file +from gen_doc import gen_functions_args_str + +# Custom Exception + + +class APINotFoundError(Exception): + pass + + +class APICheckError(Exception): + pass + + +@dataclass +class CheckResults: + passed: list[str] = field(default_factory=list) + failed: dict[str] = field(default_factory=dict) + not_found: dict[str] = field(default_factory=dict) + + +class ParamChecker: + def __init__(self, api_info: dict[str, Any]): + self.api_info = api_info + self.api_info_by_name = {} + for apiobj in api_info.values(): + if "all_names" in apiobj: + for name in apiobj["all_names"]: + self.api_info_by_name[name] = apiobj + + def check_files(self, rst_files: list[Path]) -> CheckResults: + results = CheckResults() + for rst_file in rst_files: + logger.info(f"Checking: {rst_file}") + try: + self.check_file(rst_file) + results.passed.append(str(rst_file)) + except APINotFoundError as e: + results.not_found[str(rst_file)] = str(e) + logger.warning(f"API not found in {rst_file} - {e}") + except APICheckError as e: + results.failed[str(rst_file)] = str(e) + logger.error(f"API check failed in {rst_file} - {e}") + except Exception as e: + results.failed[str(rst_file)] = str(e) + logger.error(f"Unexpected error in {rst_file} - {e}") + return results + + def check_file(self, rst_file: Path): + pat = re.compile( + r"^\.\.\s+py:(method|function|class)::\s+([^\s(]+)\s*(?:\(\s*(.*)\s*\))?\s*$" + ) + with open(rst_file, "r", encoding="utf-8") as f: + lines = f.readlines() + func_found = False + api_label = None + + for idx, line in enumerate(lines): + line_strip = line.strip() + if idx == 0: + api_label = ( + line_strip.removeprefix(".. _cn_api_") + .removesuffix(":") + .removesuffix("__upper") + ) + mo = pat.match(line_strip) + if mo: + func_found = True + functype, funcname, paramstr = mo.groups() + if functype not in ("function", "method"): + return # + func_to_label = funcname.replace(".", "_") + if api_label and func_to_label != api_label: + class_name = ".".join(funcname.split(".")[:-1]) + class_to_label = class_name.replace(".", "_") + if class_to_label != api_label: + raise APICheckError( + f"Function name in title does not match the api label name: {funcname} != {api_label}" + ) + apiobj = self.api_info_by_name.get(funcname) + if apiobj and "args" in apiobj: + if paramstr == apiobj["args"]: + self._check_params_in_description(rst_file, paramstr) + else: + logger.warning( + f"Parameter string mismatch for {funcname}: RST='{paramstr}', JSON='{apiobj['args']}'" + ) + self._check_params_in_description(rst_file, paramstr) + else: + self._check_params_in_description_with_fullargspec( + rst_file, funcname + ) + return # + if not func_found: + raise APINotFoundError( + "Function name in title not found, please check the format of '.. py:function::func()'" + ) + + def _check_params_in_description( + self, rst_file: Path, paramstr: str | None + ): + params_in_title = [] + if paramstr: + try: + fake_func = ast.parse(f"def fake_func({paramstr}): pass") + func_node = fake_func.body[0] + func_args_str: list[str] = gen_functions_args_str( + func_node, False, False + ) + params_in_title = [ + p.split("=")[0].strip() + for p in func_args_str + if p not in ("/", "*") + ] + params_in_title = [ + p.removeprefix("*").removeprefix("*") + for p in params_in_title + ] + except Exception as e: + raise APICheckError(f"Failed to parse parameters: {e}") + funcdescnode = extract_params_desc_from_rst_file(str(rst_file), True) + if funcdescnode: + try: + items = funcdescnode.children[1].children[0].children + except Exception: + raise APICheckError( + "Params section format error in description." + ) + if not items: + if params_in_title: + raise APICheckError( + "Params section in description is empty, check it please." + ) + elif len(items) != len(params_in_title): + raise APICheckError( + f"The number of params in title does not match the params in description: {len(params_in_title)} != {len(items)}." + ) + else: + for i, item in enumerate(items): + pname_in_title = params_in_title[i] + mo = re.match( + r"\*{0,2}(\w+)\b.*", item.children[0].astext() + ) + if mo: + pname_indesc = mo.group(1) + if pname_indesc != pname_in_title: + raise APICheckError( + f"Param mismatch: {pname_in_title} != {pname_indesc}." + ) + else: + raise APICheckError( + f"Param name '{pname_in_title}' not matched in description line {i + 1}, check it please." + ) + elif params_in_title: + raise APICheckError( + "Params section not found in description, check it please." + ) + + def _check_params_in_description_with_fullargspec( + self, rst_file: Path, funcname: str + ): + try: + obj = self._import_object(funcname) + except Exception: + raise APICheckError( + f"Function {funcname} not found in paddle module, please check it." + ) + try: + source = inspect.getsource(obj) + tree = ast.parse(source) + func_node = tree.body[0] + params_inspec = [ + p.split("=")[0].strip() + for p in gen_functions_args_str(func_node, True, False) + if p not in ("/", "*") + ] + # for *args and **kwargs, remove * and ** + params_inspec = [ + p.removeprefix("*").removeprefix("*") for p in params_inspec + ] + + except Exception as e: + raise APICheckError(f"Failed to inspect function {funcname}: {e}") + funcdescnode = extract_params_desc_from_rst_file(str(rst_file), True) + if funcdescnode: + try: + items = funcdescnode.children[1].children[0].children + except Exception: + raise APICheckError( + "Params section format error in description." + ) + if len(items) != len(params_inspec): + raise APICheckError( + f"Param count mismatch: {len(params_inspec)} != {len(items)}." + ) + else: + for i, item in enumerate(items): + pname_in_title = params_inspec[i] + mo = re.match( + r"\*{0,2}(\w+)\b.*", item.children[0].astext() + ) + if mo: + pname_indesc = mo.group(1) + if pname_indesc != pname_in_title: + raise APICheckError( + f"Param mismatch: {pname_in_title} != {pname_indesc}." + ) + else: + raise APICheckError( + f"Param name '{pname_in_title}' not matched in description line {i + 1}." + ) + else: + if params_inspec: + raise APICheckError( + "Params section not found in description, check it please." + ) + + def _import_object(self, dotted_path: str): + import importlib + + parts = dotted_path.split(".") + module = importlib.import_module(parts[0]) + obj = module + for part in parts[1:]: + obj = getattr(obj, part) + return obj + + +def parse_args(): + parser = argparse.ArgumentParser(description="check api parameters") + parser.add_argument( + "--rst-files", + dest="rst_files", + required=True, + help="api rst files, separated by space", + type=str, + ) + parser.add_argument( + "--api-info", + dest="api_info_file", + required=True, + help="api_info_all.json filename", + type=str, + ) + parser.add_argument("--debug", dest="debug", action="store_true") + return parser.parse_args() + + +def main(): + args = parse_args() + + if args.debug: + logger.setLevel(logging.DEBUG) + try: + with open(args.api_info_file, "r", encoding="utf-8") as f: + api_info = json.load(f) + except Exception as e: + logger.error(f"Failed to load API info file: {e}") + sys.exit(1) + rst_files = [fn for fn in args.rst_files.split(" ") if fn] + if not rst_files: + logger.error("No RST files provided.") + sys.exit(1) + + checker = ParamChecker(api_info) + results = checker.check_files(rst_files) + + logger.warning( + f"API parameter checking completed. Pass: {len(results.passed)}, Fail: {len(results.failed)}, Not Found: {len(results.not_found)}" + ) + + if results.failed: + logger.warning("Following files failed the check:") + for file_path, error in results.failed.items(): + logger.error(f" - {file_path}: {error}") + if results.not_found: + logger.warning("Following files had API not found:") + for file_path, error in results.not_found.items(): + logger.error(f" - {file_path}: {error}") + if results.failed or results.not_found: + sys.exit(1) + else: + logger.info("All API parameter checks passed.") + sys.exit(0) + + +if __name__ == "__main__": + main() diff --git a/ci_scripts/ci_start.sh b/ci_scripts/ci_start.sh index 839efc0009a..2f369191cf3 100644 --- a/ci_scripts/ci_start.sh +++ b/ci_scripts/ci_start.sh @@ -83,6 +83,17 @@ if [ "${BUILD_DOC}" = "true" ] && [ -x /usr/local/bin/sphinx-build ] ; then fi fi +git merge --no-edit upstream/${BRANCH} +need_check_cn_doc_files=$(find_all_cn_api_files_modified_by_pr) +echo $need_check_cn_doc_files + +# Check for existing stock issues. +find_cn_rst_files() { + local search_dir="$SCRIPT_DIR/../docs/api/paddle" + find "$search_dir" -type f -name "*_cn.rst" +} +all_api_cn_files=$(find_cn_rst_files) + check_parameters=ON if [ "${check_parameters}" = "OFF" ] ; then #echo "chinese api doc fileslist is empty, skip check." @@ -91,7 +102,7 @@ else jsonfn=${OUTPUTDIR}/en/${VERSIONSTR}/gen_doc_output/api_info_all.json if [ -f $jsonfn ] ; then echo "$jsonfn exists." - /bin/bash ${DIR_PATH}/check_api_parameters.sh "${need_check_cn_doc_files}" ${jsonfn} + /bin/bash ${DIR_PATH}/check_api_parameters.sh "${all_api_cn_files}" ${jsonfn} if [ $? -ne 0 ];then exit 1 fi @@ -109,9 +120,6 @@ if [ $? -ne 0 ];then EXIT_CODE=1 fi -git merge --no-edit upstream/${BRANCH} -need_check_cn_doc_files=$(find_all_cn_api_files_modified_by_pr) -echo $need_check_cn_doc_files # 4 Chinese api docs check if [ "${need_check_cn_doc_files}" = "" ] ; then echo "chinese api doc fileslist is empty, skip check." diff --git a/docs/api/extract_api_from_docs.py b/docs/api/extract_api_from_docs.py index 390ea979fd3..f6d43aa0092 100644 --- a/docs/api/extract_api_from_docs.py +++ b/docs/api/extract_api_from_docs.py @@ -20,7 +20,7 @@ import logging import os import re -from contextlib import contextmanager +from contextlib import contextmanager, redirect_stderr import docutils import docutils.core @@ -235,17 +235,26 @@ def extract_rst_title(filename): return None -def extract_params_desc_from_rst_file(filename, section_title="参数"): +def extract_params_desc_from_rst_file( + filename, need_redirect_stderr=False, section_title="参数" +): overrides = { # Disable the promotion of a lone top-level section title to document # title (and subsequent section title to document subtitle promotion). "docinfo_xform": 0, "initial_header_level": 2, } - with open(filename, "r") as fileobj: - doctree = docutils.core.publish_doctree( - fileobj.read(), settings_overrides=overrides - ) + with open(filename, "r", encoding="utf-8") as fileobj: + if need_redirect_stderr: + with open(os.devnull, "w") as fnull: + with redirect_stderr(fnull): + doctree = docutils.core.publish_doctree( + fileobj.read(), settings_overrides=overrides + ) + else: + doctree = docutils.core.publish_doctree( + fileobj.read(), settings_overrides=overrides + ) found = False for child in doctree.children: if isinstance(child, docutils.nodes.section) and isinstance( diff --git a/docs/api/gen_doc.py b/docs/api/gen_doc.py index 566cde3bcd9..33b491b7783 100755 --- a/docs/api/gen_doc.py +++ b/docs/api/gen_doc.py @@ -328,7 +328,9 @@ def parse_module_file(mod): and n.name == "__init__" ): api_info_dict[obj_id]["args"] = ( - gen_functions_args_str(n) + gen_functions_args_str( + n, skip_self=True + ) ) break else: @@ -361,39 +363,68 @@ def parse_module_file(mod): logger.debug("%s omitted", obj_full_name) -def gen_functions_args_str(node): +def gen_functions_args_str(node, skip_self=False, return_str=True): + def _process_positional_args(args, params): + positional_args = args.posonlyargs + args.args + num_defaults = len(args.defaults) + + total_positional = len(positional_args) + first_default_pos = total_positional - num_defaults + if args.posonlyargs: + for idx, arg in enumerate(args.posonlyargs): + if skip_self and arg.arg == "self": + continue + param = _format_arg_with_default( + arg, idx, first_default_pos, args.defaults + ) + params.append(param) + params.append("/") + + for idx, arg in enumerate(args.args): + if skip_self and arg.arg == "self": + continue + global_idx = idx + len(args.posonlyargs) + param = _format_arg_with_default( + arg, global_idx, first_default_pos, args.defaults + ) + params.append(param) + + def _format_arg_with_default(arg, index, first_default_pos, defaults): + if first_default_pos is not None and index >= first_default_pos: + default_index = index - first_default_pos + defarg_value = ast.unparse(defaults[default_index]).strip() + return f"{arg.arg}={defarg_value}" + return arg.arg + + def _process_var_args(args, params): + if args.vararg: + params.append(f"*{args.vararg.arg}") + elif args.kwonlyargs: + params.append("*") + + def _process_kwonly_args(args, params): + for idx, arg in enumerate(args.kwonlyargs): + if default := args.kw_defaults[idx]: + default_str = ast.unparse(default).strip() + params.append(f"{arg.arg}={default_str}") + else: + params.append(arg.arg) + + def _process_kwargs(args, params): + if args.kwarg: + params.append(f"**{args.kwarg.arg}") + str_args_list = [] if isinstance(node, ast.FunctionDef): - # 'args', 'defaults', 'kw_defaults', 'kwarg', 'kwonlyargs', 'posonlyargs', 'vararg' - for arg in node.args.args: - if not arg.arg == "self": - str_args_list.append(arg.arg) - - defarg_ind_start = len(str_args_list) - len(node.args.defaults) - for defarg_ind in range(len(node.args.defaults)): - if isinstance(node.args.defaults[defarg_ind], ast.Name): - str_args_list[defarg_ind_start + defarg_ind] += "=" + str( - node.args.defaults[defarg_ind].id - ) - elif isinstance(node.args.defaults[defarg_ind], ast.Constant): - defarg_val = str(node.args.defaults[defarg_ind].value) - if isinstance(node.args.defaults[defarg_ind].value, str): - defarg_val = f"'{defarg_val}'" - str_args_list[defarg_ind_start + defarg_ind] += "=" + defarg_val - if node.args.vararg is not None: - str_args_list.append("*" + node.args.vararg.arg) - if len(node.args.kwonlyargs) > 0: - if node.args.vararg is None: - str_args_list.append("*") - for kwoarg, d in zip(node.args.kwonlyargs, node.args.kw_defaults): - if isinstance(d, ast.Constant): - str_args_list.append(f"{kwoarg.arg}={d.value}") - elif isinstance(d, ast.Name): - str_args_list.append(f"{kwoarg.arg}={d.id}") - if node.args.kwarg is not None: - str_args_list.append("**" + node.args.kwarg.arg) - - return ", ".join(str_args_list) + func_args = node.args + _process_positional_args(func_args, str_args_list) + _process_var_args(func_args, str_args_list) + _process_kwonly_args(func_args, str_args_list) + _process_kwargs(func_args, str_args_list) + if return_str: + return ", ".join(str_args_list) + else: + return str_args_list # step 2 fill field : `display` diff --git a/docs/api/paddle/nn/functional/max_pool3d_cn.rst b/docs/api/paddle/nn/functional/max_pool3d_cn.rst index 4634fad4fae..fa0b0e11e4c 100644 --- a/docs/api/paddle/nn/functional/max_pool3d_cn.rst +++ b/docs/api/paddle/nn/functional/max_pool3d_cn.rst @@ -3,7 +3,8 @@ max_pool3d ------------------------------- -.. py:function:: paddle.nn.functional.max_pool3d(x, kernel_size, stride=None, padding=0, ceil_mode=False, return_mask=False, data_format="NCDHW", name=None)) +.. py:function:: paddle.nn.functional.max_pool3d(x, kernel_size, stride=None, padding=0, ceil_mode=False, return_mask=False, data_format="NCDHW", name=None) + 该函数是一个三维最大池化函数,根据输入参数 `kernel_size`, `stride`, `padding` 等参数对输入 `x` 做最大池化操作。