From 7aff219acdffa12f65bfde66f3f577a20c76587b Mon Sep 17 00:00:00 2001 From: ooooo <3164076421@qq.com> Date: Sat, 12 Apr 2025 22:57:23 +0800 Subject: [PATCH 01/14] refine_func_args_parse --- ci_scripts/check_api_parameters.py | 73 +++++++++++++------------- ci_scripts/ci_start.sh | 16 ++++-- docs/api/gen_doc.py | 83 ++++++++++++++++++++---------- 3 files changed, 104 insertions(+), 68 deletions(-) diff --git a/ci_scripts/check_api_parameters.py b/ci_scripts/check_api_parameters.py index 2e6d8b18e2e..1839e7252cd 100644 --- a/ci_scripts/check_api_parameters.py +++ b/ci_scripts/check_api_parameters.py @@ -29,7 +29,10 @@ def add_path(path): this_dir = osp.dirname(__file__) # Add docs/api to PYTHONPATH add_path(osp.abspath(osp.join(this_dir, "..", "docs", "api"))) -from extract_api_from_docs import extract_params_desc_from_rst_file +from extract_api_from_docs import ( + extract_params_desc_from_rst_file, + gen_functions_args_str, +) arguments = [ # flags, dest, type, default, help @@ -60,37 +63,11 @@ def _check_params_in_description(rstfilename, paramstr): params_in_title = [] if paramstr: fake_func = ast.parse(f"def fake_func({paramstr}): pass") - # Iterate over all in_title parameters - num_defaults = len(fake_func.body[0].args.defaults) - num_args = len(fake_func.body[0].args.args) - # args & defaults - for i, arg in enumerate(fake_func.body[0].args.args): - if i >= num_args - num_defaults: - default_value = fake_func.body[0].args.defaults[ - i - (num_args - num_defaults) - ] - params_in_title.append(f"{arg.arg}={default_value}") - else: - params_in_title.append(arg.arg) - # posonlyargs - for arg in fake_func.body[0].args.posonlyargs: - params_in_title.append(arg.arg) - # vararg(*args) - if fake_func.body[0].args.vararg: - params_in_title.append(fake_func.body[0].args.vararg.arg) - # kwonlyargs & kw_defaults - for i, arg in enumerate(fake_func.body[0].args.kwonlyargs): - if ( - i < len(fake_func.body[0].args.kw_defaults) - and fake_func.body[0].args.kw_defaults[i] is not None - ): - default_value = fake_func.body[0].args.kw_defaults[i] - params_in_title.append(f"{arg.arg}={default_value}") - else: - params_in_title.append(arg.arg) - # **kwargs - if fake_func.body[0].args.kwarg: - params_in_title.append(fake_func.body[0].args.kwarg.arg) + func_node = fake_func.body[0] + func_args_str = gen_functions_args_str(func_node) + params_in_title = func_args_str.split(", ") + params_in_title.remove("/") + params_in_title.remove("*") funcdescnode = extract_params_desc_from_rst_file(rstfilename) if funcdescnode: @@ -141,11 +118,35 @@ def _check_params_in_description(rstfilename, paramstr): def _check_params_in_description_with_fullargspec(rstfilename, funcname): flag = True info = "" - funcspec = inspect.getfullargspec(eval(funcname)) + try: + func = eval(funcname) + except NameError: + func = eval(funcname) + source = inspect.getsource(func) + + class FunctionDefExtractor(ast.NodeTransformer): + target_name = func.__name__ + + def visit_FunctionDef(self, node): + if node.name == self.target_name: + node.decorator_list = [] + node.body = [ast.Pass()] + return node + return None + + tree = ast.parse(source) + modified_tree = FunctionDefExtractor().visit(tree) + modified_tree.body = [ + node for node in modified_tree.body if node is not None + ] + + func_node = modified_tree.body[0] + params_inspec = gen_functions_args_str(func_node).split(", ") + params_inspec.remove("/") + params_inspec.remove("*") funcdescnode = extract_params_desc_from_rst_file(rstfilename) if funcdescnode: items = funcdescnode.children[1].children[0].children - params_inspec = funcspec.args if len(items) != len(params_inspec): flag = False info = f"check_with_fullargspec failed (parammeters description): {rstfilename}" @@ -171,10 +172,10 @@ def _check_params_in_description_with_fullargspec(rstfilename, funcname): f"check failed (parammeters description): {rstfilename}, param name not found in {i} paragraph." ) else: - if funcspec.args: + if params_inspec: info = "params section not found in description, check it please." print( - f"check failed (parameters description not found): {rstfilename}, {funcspec.args}." + f"check failed (parameters description not found): {rstfilename}, {params_inspec}." ) flag = False return flag, info diff --git a/ci_scripts/ci_start.sh b/ci_scripts/ci_start.sh index 839efc0009a..2f369191cf3 100644 --- a/ci_scripts/ci_start.sh +++ b/ci_scripts/ci_start.sh @@ -83,6 +83,17 @@ if [ "${BUILD_DOC}" = "true" ] && [ -x /usr/local/bin/sphinx-build ] ; then fi fi +git merge --no-edit upstream/${BRANCH} +need_check_cn_doc_files=$(find_all_cn_api_files_modified_by_pr) +echo $need_check_cn_doc_files + +# Check for existing stock issues. +find_cn_rst_files() { + local search_dir="$SCRIPT_DIR/../docs/api/paddle" + find "$search_dir" -type f -name "*_cn.rst" +} +all_api_cn_files=$(find_cn_rst_files) + check_parameters=ON if [ "${check_parameters}" = "OFF" ] ; then #echo "chinese api doc fileslist is empty, skip check." @@ -91,7 +102,7 @@ else jsonfn=${OUTPUTDIR}/en/${VERSIONSTR}/gen_doc_output/api_info_all.json if [ -f $jsonfn ] ; then echo "$jsonfn exists." - /bin/bash ${DIR_PATH}/check_api_parameters.sh "${need_check_cn_doc_files}" ${jsonfn} + /bin/bash ${DIR_PATH}/check_api_parameters.sh "${all_api_cn_files}" ${jsonfn} if [ $? -ne 0 ];then exit 1 fi @@ -109,9 +120,6 @@ if [ $? -ne 0 ];then EXIT_CODE=1 fi -git merge --no-edit upstream/${BRANCH} -need_check_cn_doc_files=$(find_all_cn_api_files_modified_by_pr) -echo $need_check_cn_doc_files # 4 Chinese api docs check if [ "${need_check_cn_doc_files}" = "" ] ; then echo "chinese api doc fileslist is empty, skip check." diff --git a/docs/api/gen_doc.py b/docs/api/gen_doc.py index 566cde3bcd9..85c53b39173 100755 --- a/docs/api/gen_doc.py +++ b/docs/api/gen_doc.py @@ -362,36 +362,63 @@ def parse_module_file(mod): def gen_functions_args_str(node): + def _process_positional_args(args, params): + positional_args = args.posonlyargs + args.args + num_defaults = len(args.defaults) + + total_positional = len(positional_args) + first_default_pos = total_positional - num_defaults + if args.posonlyargs: + for idx, arg in enumerate(args.posonlyargs): + if arg.arg == "self": + continue + param = _format_arg_with_default( + arg, idx, first_default_pos, args.defaults + ) + params.append(param) + params.append("/") + + for idx, arg in enumerate(args.args): + if arg.arg == "self": + continue + global_idx = idx + len(args.posonlyargs) + param = _format_arg_with_default( + arg, global_idx, first_default_pos, args.defaults + ) + params.append(param) + + def _format_arg_with_default(arg, index, first_default_pos, defaults): + if first_default_pos is not None and index >= first_default_pos: + default_index = index - first_default_pos + defarg_value = ast.unparse(defaults[default_index]).strip() + return f"{arg.arg}={defarg_value}" + return arg.arg + + def _process_var_args(args, params): + if args.vararg: + params.append(f"*{args.vararg.arg}") + elif args.kwonlyargs: + params.append("*") + + def _process_kwonly_args(args, params): + for idx, arg in enumerate(args.kwonlyargs): + if default := args.kw_defaults[idx]: + default_str = ast.unparse(default).strip() + params.append(f"{arg.arg}={default_str}") + else: + params.append(arg.arg) + + def _process_kwargs(args, params): + if args.kwarg: + params.append(f"**{args.kwarg.arg}") + str_args_list = [] if isinstance(node, ast.FunctionDef): - # 'args', 'defaults', 'kw_defaults', 'kwarg', 'kwonlyargs', 'posonlyargs', 'vararg' - for arg in node.args.args: - if not arg.arg == "self": - str_args_list.append(arg.arg) - - defarg_ind_start = len(str_args_list) - len(node.args.defaults) - for defarg_ind in range(len(node.args.defaults)): - if isinstance(node.args.defaults[defarg_ind], ast.Name): - str_args_list[defarg_ind_start + defarg_ind] += "=" + str( - node.args.defaults[defarg_ind].id - ) - elif isinstance(node.args.defaults[defarg_ind], ast.Constant): - defarg_val = str(node.args.defaults[defarg_ind].value) - if isinstance(node.args.defaults[defarg_ind].value, str): - defarg_val = f"'{defarg_val}'" - str_args_list[defarg_ind_start + defarg_ind] += "=" + defarg_val - if node.args.vararg is not None: - str_args_list.append("*" + node.args.vararg.arg) - if len(node.args.kwonlyargs) > 0: - if node.args.vararg is None: - str_args_list.append("*") - for kwoarg, d in zip(node.args.kwonlyargs, node.args.kw_defaults): - if isinstance(d, ast.Constant): - str_args_list.append(f"{kwoarg.arg}={d.value}") - elif isinstance(d, ast.Name): - str_args_list.append(f"{kwoarg.arg}={d.id}") - if node.args.kwarg is not None: - str_args_list.append("**" + node.args.kwarg.arg) + func_args = node.args + _process_positional_args(func_args, str_args_list) + _process_var_args(func_args, str_args_list) + _process_kwonly_args(func_args, str_args_list) + _process_kwargs(func_args, str_args_list) return ", ".join(str_args_list) From 757b5409e4234aed88e00c45a284ac3838218886 Mon Sep 17 00:00:00 2001 From: ooo oo <106524776+ooooo-create@users.noreply.github.com> Date: Sat, 12 Apr 2025 23:00:18 +0800 Subject: [PATCH 02/14] Update ci_scripts/check_api_parameters.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- ci_scripts/check_api_parameters.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ci_scripts/check_api_parameters.py b/ci_scripts/check_api_parameters.py index 1839e7252cd..9e63c2e2502 100644 --- a/ci_scripts/check_api_parameters.py +++ b/ci_scripts/check_api_parameters.py @@ -66,8 +66,10 @@ def _check_params_in_description(rstfilename, paramstr): func_node = fake_func.body[0] func_args_str = gen_functions_args_str(func_node) params_in_title = func_args_str.split(", ") - params_in_title.remove("/") - params_in_title.remove("*") + if "/" in params_in_title: + params_in_title.remove("/") + if "*" in params_in_title: + params_in_title.remove("*") funcdescnode = extract_params_desc_from_rst_file(rstfilename) if funcdescnode: From 1e942b0431a90324d08f9f7ba8df8ccd65ba9f9b Mon Sep 17 00:00:00 2001 From: ooooo <3164076421@qq.com> Date: Sat, 12 Apr 2025 23:01:47 +0800 Subject: [PATCH 03/14] fix --- ci_scripts/check_api_parameters.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ci_scripts/check_api_parameters.py b/ci_scripts/check_api_parameters.py index 9e63c2e2502..102f398f2ea 100644 --- a/ci_scripts/check_api_parameters.py +++ b/ci_scripts/check_api_parameters.py @@ -123,6 +123,8 @@ def _check_params_in_description_with_fullargspec(rstfilename, funcname): try: func = eval(funcname) except NameError: + import paddle # noqa: F401 + func = eval(funcname) source = inspect.getsource(func) From 21b0be2c0f75a442dfc52895833ca6a07eddc342 Mon Sep 17 00:00:00 2001 From: ooo oo <106524776+ooooo-create@users.noreply.github.com> Date: Sat, 12 Apr 2025 23:03:35 +0800 Subject: [PATCH 04/14] Update ci_scripts/check_api_parameters.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- ci_scripts/check_api_parameters.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ci_scripts/check_api_parameters.py b/ci_scripts/check_api_parameters.py index 102f398f2ea..3321a851588 100644 --- a/ci_scripts/check_api_parameters.py +++ b/ci_scripts/check_api_parameters.py @@ -146,8 +146,10 @@ def visit_FunctionDef(self, node): func_node = modified_tree.body[0] params_inspec = gen_functions_args_str(func_node).split(", ") - params_inspec.remove("/") - params_inspec.remove("*") + if "/" in params_inspec: + params_inspec.remove("/") + if "*" in params_inspec: + params_inspec.remove("*") funcdescnode = extract_params_desc_from_rst_file(rstfilename) if funcdescnode: items = funcdescnode.children[1].children[0].children From e7dfab616e64aa7a8d6c136d0ccdb7b87b58796e Mon Sep 17 00:00:00 2001 From: ooooo <3164076421@qq.com> Date: Mon, 14 Apr 2025 11:06:40 +0800 Subject: [PATCH 05/14] refine --- ci_scripts/check_api_parameters.py | 52 ++++++++++++++++++------------ docs/api/gen_doc.py | 10 +++--- 2 files changed, 38 insertions(+), 24 deletions(-) diff --git a/ci_scripts/check_api_parameters.py b/ci_scripts/check_api_parameters.py index 3321a851588..0cef64559a7 100644 --- a/ci_scripts/check_api_parameters.py +++ b/ci_scripts/check_api_parameters.py @@ -20,6 +20,8 @@ import re import sys +import paddle # noqa: F401 + def add_path(path): if path not in sys.path: @@ -70,6 +72,7 @@ def _check_params_in_description(rstfilename, paramstr): params_in_title.remove("/") if "*" in params_in_title: params_in_title.remove("*") + params_in_title = ", ".join(params_in_title) funcdescnode = extract_params_desc_from_rst_file(rstfilename) if funcdescnode: @@ -122,34 +125,20 @@ def _check_params_in_description_with_fullargspec(rstfilename, funcname): info = "" try: func = eval(funcname) - except NameError: - import paddle # noqa: F401 - - func = eval(funcname) + except AttributeError: + flag = False + info = f"function {funcname} in rst file {rstfilename} not found in paddle module, please check it." + return flag, info source = inspect.getsource(func) - class FunctionDefExtractor(ast.NodeTransformer): - target_name = func.__name__ - - def visit_FunctionDef(self, node): - if node.name == self.target_name: - node.decorator_list = [] - node.body = [ast.Pass()] - return node - return None - tree = ast.parse(source) - modified_tree = FunctionDefExtractor().visit(tree) - modified_tree.body = [ - node for node in modified_tree.body if node is not None - ] - - func_node = modified_tree.body[0] + func_node = tree.body[0] params_inspec = gen_functions_args_str(func_node).split(", ") if "/" in params_inspec: params_inspec.remove("/") if "*" in params_inspec: params_inspec.remove("*") + params_inspec = ", ".join(params_inspec) funcdescnode = extract_params_desc_from_rst_file(rstfilename) if funcdescnode: items = funcdescnode.children[1].children[0].children @@ -207,16 +196,39 @@ def check_api_parameters(rstfiles, apiinfo): print(f"checking : {rstfile}") with open(rstfilename, "r") as rst_fobj: func_found = False + is_first_line = True + api_label = None for line in rst_fobj: + if is_first_line: + api_label = ( + line.strip() + .removeprefix(".. _cn_api_") + .replace("_", ".") + .removesuffix("__upper") + ) + is_first_line = False mo = pat.match(line) if mo: func_found = True functype = mo.group(1) if functype not in ("function", "method"): + # TODO: check class method check_passed.append(rstfile) continue funcname = mo.group(2) paramstr = mo.group(3) + + # check same as the api_label + if funcname != api_label: + # if funcname is a function, try to back to class + obj = eval(funcname) + if inspect.isfunction(obj): + class_name = ".".join(funcname.split(".")[:-1]) + if class_name != api_label: + flag = False + info = f"funcname in title is not same as the label name: {funcname} != {api_label}." + return flag, info + flag = False func_found_in_json = False for apiobj in apiinfo.values(): diff --git a/docs/api/gen_doc.py b/docs/api/gen_doc.py index 85c53b39173..b1aee45f9ef 100755 --- a/docs/api/gen_doc.py +++ b/docs/api/gen_doc.py @@ -328,7 +328,9 @@ def parse_module_file(mod): and n.name == "__init__" ): api_info_dict[obj_id]["args"] = ( - gen_functions_args_str(n) + gen_functions_args_str( + n, skip_self=True + ) ) break else: @@ -361,7 +363,7 @@ def parse_module_file(mod): logger.debug("%s omitted", obj_full_name) -def gen_functions_args_str(node): +def gen_functions_args_str(node, skip_self=False): def _process_positional_args(args, params): positional_args = args.posonlyargs + args.args num_defaults = len(args.defaults) @@ -370,7 +372,7 @@ def _process_positional_args(args, params): first_default_pos = total_positional - num_defaults if args.posonlyargs: for idx, arg in enumerate(args.posonlyargs): - if arg.arg == "self": + if skip_self and arg.arg == "self": continue param = _format_arg_with_default( arg, idx, first_default_pos, args.defaults @@ -379,7 +381,7 @@ def _process_positional_args(args, params): params.append("/") for idx, arg in enumerate(args.args): - if arg.arg == "self": + if skip_self and arg.arg == "self": continue global_idx = idx + len(args.posonlyargs) param = _format_arg_with_default( From 73401121ffe52e37b71c77c5d31462052de5ea4a Mon Sep 17 00:00:00 2001 From: ooooo <3164076421@qq.com> Date: Mon, 14 Apr 2025 19:31:57 +0800 Subject: [PATCH 06/14] test --- ci_scripts/check_api_parameters.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ci_scripts/check_api_parameters.py b/ci_scripts/check_api_parameters.py index 0cef64559a7..542ab62f201 100644 --- a/ci_scripts/check_api_parameters.py +++ b/ci_scripts/check_api_parameters.py @@ -227,7 +227,8 @@ def check_api_parameters(rstfiles, apiinfo): if class_name != api_label: flag = False info = f"funcname in title is not same as the label name: {funcname} != {api_label}." - return flag, info + check_failed[rstfile] = info + continue flag = False func_found_in_json = False From 109765fb63bc4affa6fd6f49e37432192ad6ee6a Mon Sep 17 00:00:00 2001 From: ooooo <3164076421@qq.com> Date: Mon, 14 Apr 2025 19:47:20 +0800 Subject: [PATCH 07/14] fix --- ci_scripts/check_api_parameters.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/ci_scripts/check_api_parameters.py b/ci_scripts/check_api_parameters.py index 542ab62f201..2352545039c 100644 --- a/ci_scripts/check_api_parameters.py +++ b/ci_scripts/check_api_parameters.py @@ -203,7 +203,7 @@ def check_api_parameters(rstfiles, apiinfo): api_label = ( line.strip() .removeprefix(".. _cn_api_") - .replace("_", ".") + .removesuffix(":") .removesuffix("__upper") ) is_first_line = False @@ -217,14 +217,19 @@ def check_api_parameters(rstfiles, apiinfo): continue funcname = mo.group(2) paramstr = mo.group(3) + func_to_label = funcname.replace(".", "_") # check same as the api_label - if funcname != api_label: + if func_to_label != api_label: # if funcname is a function, try to back to class - obj = eval(funcname) - if inspect.isfunction(obj): + try: + obj = eval(funcname) + except AttributeError: + obj = None + if obj is not None and inspect.isfunction(obj): class_name = ".".join(funcname.split(".")[:-1]) - if class_name != api_label: + class_to_label = class_name.replace(".", "_") + if class_to_label != api_label: flag = False info = f"funcname in title is not same as the label name: {funcname} != {api_label}." check_failed[rstfile] = info From 8a5b0ad83abd86e0c43af18ff0f2b8e1be916e54 Mon Sep 17 00:00:00 2001 From: ooooo <3164076421@qq.com> Date: Mon, 14 Apr 2025 20:18:20 +0800 Subject: [PATCH 08/14] refine --- ci_scripts/check_api_parameters.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/ci_scripts/check_api_parameters.py b/ci_scripts/check_api_parameters.py index 2352545039c..83d0e5e3261 100644 --- a/ci_scripts/check_api_parameters.py +++ b/ci_scripts/check_api_parameters.py @@ -31,10 +31,8 @@ def add_path(path): this_dir = osp.dirname(__file__) # Add docs/api to PYTHONPATH add_path(osp.abspath(osp.join(this_dir, "..", "docs", "api"))) -from extract_api_from_docs import ( - extract_params_desc_from_rst_file, - gen_functions_args_str, -) +from extract_api_from_docs import extract_params_desc_from_rst_file +from gen_doc import gen_functions_args_str arguments = [ # flags, dest, type, default, help From 440c92897abe5966d7ecb44d9919311f95573317 Mon Sep 17 00:00:00 2001 From: ooooo <3164076421@qq.com> Date: Mon, 14 Apr 2025 22:19:32 +0800 Subject: [PATCH 09/14] refine --- ci_scripts/check_api_parameters.py | 7 ++++--- docs/api/paddle/nn/functional/max_pool3d_cn.rst | 3 ++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/ci_scripts/check_api_parameters.py b/ci_scripts/check_api_parameters.py index 83d0e5e3261..fb36b54fdb8 100644 --- a/ci_scripts/check_api_parameters.py +++ b/ci_scripts/check_api_parameters.py @@ -70,7 +70,6 @@ def _check_params_in_description(rstfilename, paramstr): params_in_title.remove("/") if "*" in params_in_title: params_in_title.remove("*") - params_in_title = ", ".join(params_in_title) funcdescnode = extract_params_desc_from_rst_file(rstfilename) if funcdescnode: @@ -87,7 +86,9 @@ def _check_params_in_description(rstfilename, paramstr): ) else: info = f"The number of params in title does not match the params in description: {len(params_in_title)} != {len(items)}." - print(f"check failed (parammeters description): {rstfilename}") + print( + f"check failed with different nums (parammeters description): {rstfilename}" + ) else: for i in range(len(items)): pname_in_title = params_in_title[i].split("=")[0].strip() @@ -136,7 +137,7 @@ def _check_params_in_description_with_fullargspec(rstfilename, funcname): params_inspec.remove("/") if "*" in params_inspec: params_inspec.remove("*") - params_inspec = ", ".join(params_inspec) + funcdescnode = extract_params_desc_from_rst_file(rstfilename) if funcdescnode: items = funcdescnode.children[1].children[0].children diff --git a/docs/api/paddle/nn/functional/max_pool3d_cn.rst b/docs/api/paddle/nn/functional/max_pool3d_cn.rst index 4634fad4fae..fa0b0e11e4c 100644 --- a/docs/api/paddle/nn/functional/max_pool3d_cn.rst +++ b/docs/api/paddle/nn/functional/max_pool3d_cn.rst @@ -3,7 +3,8 @@ max_pool3d ------------------------------- -.. py:function:: paddle.nn.functional.max_pool3d(x, kernel_size, stride=None, padding=0, ceil_mode=False, return_mask=False, data_format="NCDHW", name=None)) +.. py:function:: paddle.nn.functional.max_pool3d(x, kernel_size, stride=None, padding=0, ceil_mode=False, return_mask=False, data_format="NCDHW", name=None) + 该函数是一个三维最大池化函数,根据输入参数 `kernel_size`, `stride`, `padding` 等参数对输入 `x` 做最大池化操作。 From 56161d01324420c16555c1d97b869628e067f524 Mon Sep 17 00:00:00 2001 From: ooooo <3164076421@qq.com> Date: Tue, 15 Apr 2025 10:55:44 +0800 Subject: [PATCH 10/14] add patch --- ci_scripts/check_api_parameters.sh | 2 +- ci_scripts/check_api_parameters_patch.py | 324 +++++++++++++++++++++++ 2 files changed, 325 insertions(+), 1 deletion(-) create mode 100644 ci_scripts/check_api_parameters_patch.py diff --git a/ci_scripts/check_api_parameters.sh b/ci_scripts/check_api_parameters.sh index 8cb6c75d1bc..f1567526c08 100644 --- a/ci_scripts/check_api_parameters.sh +++ b/ci_scripts/check_api_parameters.sh @@ -35,7 +35,7 @@ if [ "$need_check_files" = "" ] then echo "need check files is empty, skip api parameters check" else - python check_api_parameters.py --rst-files "${need_check_files}" --api-info $2 + python check_api_parameters_patch.py --rst-files "${need_check_files}" --api-info $2 if [ $? -ne 0 ];then set +x echo "************************************************************************************" diff --git a/ci_scripts/check_api_parameters_patch.py b/ci_scripts/check_api_parameters_patch.py new file mode 100644 index 00000000000..4f65c3a4fbe --- /dev/null +++ b/ci_scripts/check_api_parameters_patch.py @@ -0,0 +1,324 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import argparse +import ast +import inspect +import json +import logging +import re +import sys +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +import paddle # noqa: F401 + +# Set up logging +logging.basicConfig( + level=logging.INFO, + format="%(levelname)s: %(message)s", +) +logger = logging.getLogger(__name__) + + +# Add Project Paths +THIS_DIR = Path(__file__).resolve().parent +API_DOC_TOOLS_PATH = THIS_DIR.parent / "docs" / "api" + + +def add_path(path: str): + if path not in sys.path: + sys.path.insert(0, path) + + +add_path(str(API_DOC_TOOLS_PATH)) + +from extract_api_from_docs import extract_params_desc_from_rst_file +from gen_doc import gen_functions_args_str + +# Custom Exception + + +class APINotFoundError(Exception): + pass + + +class APICheckError(Exception): + pass + + +@dataclass +class CheckResults: + passed: list[str] = field(default_factory=list) + failed: dict[str] = field(default_factory=dict) + not_found: dict[str] = field(default_factory=dict) + + +class ParamChecker: + def __init__(self, api_info: dict[str, Any]): + self.api_info = api_info + self.api_info_by_name = {} + for apiobj in api_info.values(): + if "all_names" in apiobj: + for name in apiobj["all_names"]: + self.api_info_by_name[name] = apiobj + + def check_files(self, rst_files: list[Path]) -> CheckResults: + results = CheckResults() + for rst_file in rst_files: + logger.info(f"Checking: {rst_file}") + try: + self.check_file(rst_file) + results.passed.append(str(rst_file)) + except APINotFoundError as e: + results.not_found[str(rst_file)] = str(e) + logger.warning(f"API not found in {rst_file} - {e}") + except APICheckError as e: + results.failed[str(rst_file)] = str(e) + logger.error(f"API check failed in {rst_file} - {e}") + except Exception as e: + results.failed[str(rst_file)] = str(e) + logger.error(f"Unexpected error in {rst_file} - {e}") + return results + + def check_file(self, rst_file: Path): + pat = re.compile( + r"^\.\.\s+py:(method|function|class)::\s+([^\s(]+)\s*(?:\(\s*(.*)\s*\))?\s*$" + ) + with open(rst_file, "r", encoding="utf-8") as f: + lines = f.readlines() + func_found = False + api_label = None + + for idx, line in enumerate(lines): + line_strip = line.strip() + if idx == 0: + api_label = ( + line_strip.removeprefix(".. _cn_api_") + .removesuffix(":") + .removesuffix("__upper") + ) + mo = pat.match(line_strip) + if mo: + func_found = True + functype, funcname, paramstr = mo.groups() + if functype not in ("function", "method"): + return # + func_to_label = funcname.replace(".", "_") + if api_label and func_to_label != api_label: + class_name = ".".join(funcname.split(".")[:-1]) + class_to_label = class_name.replace(".", "_") + if class_to_label != api_label: + raise APICheckError( + f"Function name in title does not match the api label name: {funcname} != {api_label}" + ) + apiobj = self.api_info_by_name.get(funcname) + if apiobj and "args" in apiobj: + if paramstr == apiobj["args"]: + self._check_params_in_description(rst_file, paramstr) + else: + logger.warning( + f"Parameter string mismatch for {funcname}: RST='{paramstr}', JSON='{apiobj['args']}'" + ) + self._check_params_in_description(rst_file, paramstr) + else: + self._check_params_in_description_with_fullargspec( + rst_file, funcname + ) + return # + if not func_found: + raise APINotFoundError( + "Function name in title not found, please check the format of '.. py:function::func()'" + ) + + def _check_params_in_description( + self, rst_file: Path, paramstr: str | None + ): + params_in_title = [] + if paramstr: + try: + fake_func = ast.parse(f"def fake_func({paramstr}): pass") + func_node = fake_func.body[0] + func_args_str = gen_functions_args_str(func_node) + params_in_title = [ + p for p in func_args_str.split(", ") if p not in ("*", "/") + ] + except Exception as e: + raise APICheckError(f"Failed to parse parameters: {e}") + funcdescnode = extract_params_desc_from_rst_file(rst_file) + if funcdescnode: + try: + items = funcdescnode.children[1].children[0].children + except Exception: + raise APICheckError( + "Params section format error in description." + ) + if not items: + if params_in_title: + raise APICheckError( + "Params section in description is empty, check it please." + ) + elif len(items) != len(params_in_title): + raise APICheckError( + f"The number of params in title does not match the params in description: {len(params_in_title)} != {len(items)}." + ) + else: + for i, item in enumerate(items): + pname_in_title = params_in_title[i] + mo = re.match( + r"\*{0,2}(\w+)\b.*", item.children[0].astext() + ) + if mo: + pname_indesc = mo.group(1) + if pname_indesc != pname_in_title: + raise APICheckError( + f"Param mismatch: {pname_in_title} != {pname_indesc}." + ) + else: + raise APICheckError( + f"Param name '{pname_in_title}' not matched in description line {i + 1}, check it please." + ) + elif params_in_title: + raise APICheckError( + "Params section not found in description, check it please." + ) + + def _check_params_in_description_with_fullargspec( + self, rst_file: Path, funcname: str + ): + try: + obj = self._import_object(funcname) + except Exception: + raise APICheckError( + f"Function {funcname} in rst file {rst_file} not found in paddle module, please check it." + ) + try: + source = inspect.getsource(obj) + tree = ast.parse(source) + func_node = tree.body[0] + params_inspec = [ + p + for p in gen_functions_args_str(func_node).split(", ") + if p not in ("/", "*") + ] + except Exception as e: + raise APICheckError(f"Failed to inspect function {funcname}: {e}") + funcdescnode = extract_params_desc_from_rst_file(str(rst_file)) + if funcdescnode: + try: + items = funcdescnode.children[1].children[0].children + except Exception: + raise APICheckError( + "Params section format error in description." + ) + if len(items) != len(params_inspec): + raise APICheckError( + f"Param count mismatch: {len(params_inspec)} != {len(items)}." + ) + else: + for i, item in enumerate(items): + pname_in_title = params_inspec[i] + mo = re.match( + r"\*{0,2}(\w+)\b.*", item.children[0].astext() + ) + if mo: + pname_indesc = mo.group(1) + if pname_indesc != pname_in_title: + raise APICheckError( + f"Param mismatch: {pname_in_title} != {pname_indesc}." + ) + else: + raise APICheckError( + f"Param name '{pname_in_title}' not matched in description line {i + 1}." + ) + else: + if params_inspec: + raise APICheckError( + "Params section not found in description, check it please." + ) + + def _import_object(self, dotted_path: str): + import importlib + + parts = dotted_path.split(".") + module = importlib.import_module(parts[0]) + obj = module + for part in parts[1:]: + obj = getattr(obj, part) + return obj + + +def parse_args(): + parser = argparse.ArgumentParser(description="check api parameters") + parser.add_argument( + "--rst-files", + dest="rst_files", + required=True, + help="api rst files, separated by space", + type=str, + ) + parser.add_argument( + "--api-info", + dest="api_info_file", + required=True, + help="api_info_all.json filename", + type=str, + ) + parser.add_argument("--debug", dest="debug", action="store_true") + return parser.parse_args() + + +def main(): + args = parse_args() + + if args.debug: + logger.setLevel(logging.DEBUG) + try: + with open(args.api_info_file, "r", encoding="utf-8") as f: + api_info = json.load(f) + except Exception as e: + logger.error(f"Failed to load API info file: {e}") + sys.exit(1) + + rst_files = [fn for fn in args.rst_files.split(" ") if fn] + if not rst_files: + logger.error("No RST files provided.") + sys.exit(1) + + checker = ParamChecker(api_info) + results = checker.check_files(rst_files) + + logging.info( + f"API parameter checking completed. Pass: {len(results.passed)}, Fail: {len(results.failed)}, Not Found: {len(results.not_found)}" + ) + + if results.failed: + logging.error("Following files failed the check:") + for file_path, error in results.failed.items(): + logger.error(f" - {file_path}: {error}") + if results.not_found: + logging.error("Following files had API not found:") + for file_path, error in results.not_found.items(): + logger.error(f" - {file_path}: {error}") + if results.failed or results.not_found: + sys.exit(1) + else: + logging.info("All API parameter checks passed.") + sys.exit(0) + + +if __name__ == "__main__": + main() From 17a0ec785e7838a6ca2dd628cfab6fc93d938f10 Mon Sep 17 00:00:00 2001 From: ooooo <3164076421@qq.com> Date: Tue, 15 Apr 2025 11:14:25 +0800 Subject: [PATCH 11/14] fix style --- ci_scripts/check_api_parameters_patch.py | 17 +++++++++-------- docs/api/extract_api_from_docs.py | 2 +- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/ci_scripts/check_api_parameters_patch.py b/ci_scripts/check_api_parameters_patch.py index 4f65c3a4fbe..8562dbef72c 100644 --- a/ci_scripts/check_api_parameters_patch.py +++ b/ci_scripts/check_api_parameters_patch.py @@ -154,11 +154,13 @@ def _check_params_in_description( func_node = fake_func.body[0] func_args_str = gen_functions_args_str(func_node) params_in_title = [ - p for p in func_args_str.split(", ") if p not in ("*", "/") + p.split("=")[0].strip() + for p in func_args_str.split(", ") + if p not in ("/", "*") ] except Exception as e: raise APICheckError(f"Failed to parse parameters: {e}") - funcdescnode = extract_params_desc_from_rst_file(rst_file) + funcdescnode = extract_params_desc_from_rst_file(str(rst_file)) if funcdescnode: try: items = funcdescnode.children[1].children[0].children @@ -210,7 +212,7 @@ def _check_params_in_description_with_fullargspec( tree = ast.parse(source) func_node = tree.body[0] params_inspec = [ - p + p.split("=")[0].strip() for p in gen_functions_args_str(func_node).split(", ") if p not in ("/", "*") ] @@ -292,7 +294,6 @@ def main(): except Exception as e: logger.error(f"Failed to load API info file: {e}") sys.exit(1) - rst_files = [fn for fn in args.rst_files.split(" ") if fn] if not rst_files: logger.error("No RST files provided.") @@ -301,22 +302,22 @@ def main(): checker = ParamChecker(api_info) results = checker.check_files(rst_files) - logging.info( + logger.info( f"API parameter checking completed. Pass: {len(results.passed)}, Fail: {len(results.failed)}, Not Found: {len(results.not_found)}" ) if results.failed: - logging.error("Following files failed the check:") + logger.info("Following files failed the check:") for file_path, error in results.failed.items(): logger.error(f" - {file_path}: {error}") if results.not_found: - logging.error("Following files had API not found:") + logger.info("Following files had API not found:") for file_path, error in results.not_found.items(): logger.error(f" - {file_path}: {error}") if results.failed or results.not_found: sys.exit(1) else: - logging.info("All API parameter checks passed.") + logger.info("All API parameter checks passed.") sys.exit(0) diff --git a/docs/api/extract_api_from_docs.py b/docs/api/extract_api_from_docs.py index 390ea979fd3..6ca7007583d 100644 --- a/docs/api/extract_api_from_docs.py +++ b/docs/api/extract_api_from_docs.py @@ -242,7 +242,7 @@ def extract_params_desc_from_rst_file(filename, section_title="参数"): "docinfo_xform": 0, "initial_header_level": 2, } - with open(filename, "r") as fileobj: + with open(filename, "r", encoding="utf-8") as fileobj: doctree = docutils.core.publish_doctree( fileobj.read(), settings_overrides=overrides ) From 505cc1e3b282b8aaefd1d0862b82733aa6bd3d43 Mon Sep 17 00:00:00 2001 From: ooooo <3164076421@qq.com> Date: Tue, 15 Apr 2025 13:15:52 +0800 Subject: [PATCH 12/14] clean collect --- ci_scripts/check_api_parameters_patch.py | 21 +++++++++++++++------ docs/api/extract_api_from_docs.py | 18 +++++++++++++----- 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/ci_scripts/check_api_parameters_patch.py b/ci_scripts/check_api_parameters_patch.py index 8562dbef72c..54fcc1d2ff9 100644 --- a/ci_scripts/check_api_parameters_patch.py +++ b/ci_scripts/check_api_parameters_patch.py @@ -158,9 +158,13 @@ def _check_params_in_description( for p in func_args_str.split(", ") if p not in ("/", "*") ] + params_in_title = [ + p.removeprefix("*").removeprefix("*") + for p in params_in_title + ] except Exception as e: raise APICheckError(f"Failed to parse parameters: {e}") - funcdescnode = extract_params_desc_from_rst_file(str(rst_file)) + funcdescnode = extract_params_desc_from_rst_file(str(rst_file), True) if funcdescnode: try: items = funcdescnode.children[1].children[0].children @@ -205,7 +209,7 @@ def _check_params_in_description_with_fullargspec( obj = self._import_object(funcname) except Exception: raise APICheckError( - f"Function {funcname} in rst file {rst_file} not found in paddle module, please check it." + f"Function {funcname} not found in paddle module, please check it." ) try: source = inspect.getsource(obj) @@ -216,9 +220,14 @@ def _check_params_in_description_with_fullargspec( for p in gen_functions_args_str(func_node).split(", ") if p not in ("/", "*") ] + # for *args and **kwargs, remove * and ** + params_inspec = [ + p.removeprefix("*").removeprefix("*") for p in params_inspec + ] + except Exception as e: raise APICheckError(f"Failed to inspect function {funcname}: {e}") - funcdescnode = extract_params_desc_from_rst_file(str(rst_file)) + funcdescnode = extract_params_desc_from_rst_file(str(rst_file), True) if funcdescnode: try: items = funcdescnode.children[1].children[0].children @@ -302,16 +311,16 @@ def main(): checker = ParamChecker(api_info) results = checker.check_files(rst_files) - logger.info( + logger.warning( f"API parameter checking completed. Pass: {len(results.passed)}, Fail: {len(results.failed)}, Not Found: {len(results.not_found)}" ) if results.failed: - logger.info("Following files failed the check:") + logger.warning("Following files failed the check:") for file_path, error in results.failed.items(): logger.error(f" - {file_path}: {error}") if results.not_found: - logger.info("Following files had API not found:") + logger.warning("Following files had API not found:") for file_path, error in results.not_found.items(): logger.error(f" - {file_path}: {error}") if results.failed or results.not_found: diff --git a/docs/api/extract_api_from_docs.py b/docs/api/extract_api_from_docs.py index 6ca7007583d..94b5c053efe 100644 --- a/docs/api/extract_api_from_docs.py +++ b/docs/api/extract_api_from_docs.py @@ -20,7 +20,7 @@ import logging import os import re -from contextlib import contextmanager +from contextlib import contextmanager, redirect_stderr import docutils import docutils.core @@ -235,7 +235,9 @@ def extract_rst_title(filename): return None -def extract_params_desc_from_rst_file(filename, section_title="参数"): +def extract_params_desc_from_rst_file( + filename, need_redirect_stderr=False, section_title="参数" +): overrides = { # Disable the promotion of a lone top-level section title to document # title (and subsequent section title to document subtitle promotion). @@ -243,9 +245,15 @@ def extract_params_desc_from_rst_file(filename, section_title="参数"): "initial_header_level": 2, } with open(filename, "r", encoding="utf-8") as fileobj: - doctree = docutils.core.publish_doctree( - fileobj.read(), settings_overrides=overrides - ) + if need_redirect_stderr: + with redirect_stderr(): + doctree = docutils.core.publish_doctree( + fileobj.read(), settings_overrides=overrides + ) + else: + doctree = docutils.core.publish_doctree( + fileobj.read(), settings_overrides=overrides + ) found = False for child in doctree.children: if isinstance(child, docutils.nodes.section) and isinstance( From b9f6be7d278744ba82017b950888bfb137ee8cc0 Mon Sep 17 00:00:00 2001 From: ooooo <3164076421@qq.com> Date: Tue, 15 Apr 2025 14:01:23 +0800 Subject: [PATCH 13/14] fix bugs --- docs/api/extract_api_from_docs.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/docs/api/extract_api_from_docs.py b/docs/api/extract_api_from_docs.py index 94b5c053efe..f6d43aa0092 100644 --- a/docs/api/extract_api_from_docs.py +++ b/docs/api/extract_api_from_docs.py @@ -246,10 +246,11 @@ def extract_params_desc_from_rst_file( } with open(filename, "r", encoding="utf-8") as fileobj: if need_redirect_stderr: - with redirect_stderr(): - doctree = docutils.core.publish_doctree( - fileobj.read(), settings_overrides=overrides - ) + with open(os.devnull, "w") as fnull: + with redirect_stderr(fnull): + doctree = docutils.core.publish_doctree( + fileobj.read(), settings_overrides=overrides + ) else: doctree = docutils.core.publish_doctree( fileobj.read(), settings_overrides=overrides From 1186b58be135798e6a274cf1f077f18791368a5c Mon Sep 17 00:00:00 2001 From: ooooo <3164076421@qq.com> Date: Tue, 15 Apr 2025 15:19:04 +0800 Subject: [PATCH 14/14] fix bugs --- ci_scripts/check_api_parameters_patch.py | 8 +++++--- docs/api/gen_doc.py | 8 +++++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/ci_scripts/check_api_parameters_patch.py b/ci_scripts/check_api_parameters_patch.py index 54fcc1d2ff9..4425ec76d6d 100644 --- a/ci_scripts/check_api_parameters_patch.py +++ b/ci_scripts/check_api_parameters_patch.py @@ -152,10 +152,12 @@ def _check_params_in_description( try: fake_func = ast.parse(f"def fake_func({paramstr}): pass") func_node = fake_func.body[0] - func_args_str = gen_functions_args_str(func_node) + func_args_str: list[str] = gen_functions_args_str( + func_node, False, False + ) params_in_title = [ p.split("=")[0].strip() - for p in func_args_str.split(", ") + for p in func_args_str if p not in ("/", "*") ] params_in_title = [ @@ -217,7 +219,7 @@ def _check_params_in_description_with_fullargspec( func_node = tree.body[0] params_inspec = [ p.split("=")[0].strip() - for p in gen_functions_args_str(func_node).split(", ") + for p in gen_functions_args_str(func_node, True, False) if p not in ("/", "*") ] # for *args and **kwargs, remove * and ** diff --git a/docs/api/gen_doc.py b/docs/api/gen_doc.py index b1aee45f9ef..33b491b7783 100755 --- a/docs/api/gen_doc.py +++ b/docs/api/gen_doc.py @@ -363,7 +363,7 @@ def parse_module_file(mod): logger.debug("%s omitted", obj_full_name) -def gen_functions_args_str(node, skip_self=False): +def gen_functions_args_str(node, skip_self=False, return_str=True): def _process_positional_args(args, params): positional_args = args.posonlyargs + args.args num_defaults = len(args.defaults) @@ -421,8 +421,10 @@ def _process_kwargs(args, params): _process_var_args(func_args, str_args_list) _process_kwonly_args(func_args, str_args_list) _process_kwargs(func_args, str_args_list) - - return ", ".join(str_args_list) + if return_str: + return ", ".join(str_args_list) + else: + return str_args_list # step 2 fill field : `display`