diff --git a/mypyc/irbuild/for_helpers.py b/mypyc/irbuild/for_helpers.py index 715f5432cd13..77387a7afef9 100644 --- a/mypyc/irbuild/for_helpers.py +++ b/mypyc/irbuild/for_helpers.py @@ -7,7 +7,7 @@ from __future__ import annotations -from typing import Callable, ClassVar +from typing import Callable, ClassVar, cast from mypy.nodes import ( ARG_POS, @@ -20,6 +20,7 @@ Lvalue, MemberExpr, NameExpr, + OpExpr, RefExpr, SetExpr, StarExpr, @@ -56,6 +57,7 @@ is_dict_rprimitive, is_fixed_width_rtype, is_immutable_rprimitive, + is_int_rprimitive, is_list_rprimitive, is_sequence_rprimitive, is_short_int_rprimitive, @@ -454,39 +456,45 @@ def make_for_loop_generator( return for_dict if isinstance(expr, CallExpr) and isinstance(expr.callee, RefExpr): - if ( - is_range_ref(expr.callee) - and ( - len(expr.args) <= 2 - or (len(expr.args) == 3 and builder.extract_int(expr.args[2]) is not None) - ) - and set(expr.arg_kinds) == {ARG_POS} - ): - # Special case "for x in range(...)". - # We support the 3 arg form but only for int literals, since it doesn't - # seem worth the hassle of supporting dynamically determining which - # direction of comparison to do. - if len(expr.args) == 1: - start_reg: Value = Integer(0) - end_reg = builder.accept(expr.args[0]) - else: - start_reg = builder.accept(expr.args[0]) - end_reg = builder.accept(expr.args[1]) - if len(expr.args) == 3: - step = builder.extract_int(expr.args[2]) - assert step is not None - if step == 0: - builder.error("range() step can't be zero", expr.args[2].line) - else: - step = 1 + num_args = len(expr.args) - for_range = ForRange(builder, index, body_block, loop_exit, line, nested) - for_range.init(start_reg, end_reg, step) - return for_range + if is_range_ref(expr.callee) and set(expr.arg_kinds) == {ARG_POS}: + # Special case "for x in range(...)". + # NOTE We support the 3 arg form but only when `step` is constant- + # foldable, since it doesn't seem worth the hassle of supporting + # dynamically determining which direction of comparison to do. + # If we cannot constant fold `step`, we just fallback to stdlib range. + if num_args <= 2 or ( + num_args == 3 + and ( + builder.extract_int(expr.args[2]) + or expr_value_is_inspectable(builder, expr.args[2]) + ) + ): + if num_args == 1: + start_reg: Value = Integer(0) + end_reg = builder.accept(expr.args[0]) + step = 1 + else: + start_reg = builder.accept(expr.args[0]) + end_reg = builder.accept(expr.args[1]) + step = ( + 1 + if num_args == 2 + else ( + expr.args[2] + if expr_value_is_inspectable(builder, expr.args[2]) + else cast(int, builder.extract_int(expr.args[2])) + ) + ) + + for_range = ForRange(builder, index, body_block, loop_exit, line, nested) + for_range.init(start_reg, end_reg, step) + return for_range elif ( expr.callee.fullname == "builtins.enumerate" - and len(expr.args) == 1 + and num_args == 1 and expr.arg_kinds == [ARG_POS] and isinstance(index, TupleExpr) and len(index.items) == 2 @@ -500,10 +508,10 @@ def make_for_loop_generator( elif ( expr.callee.fullname == "builtins.zip" - and len(expr.args) >= 2 + and num_args >= 2 and set(expr.arg_kinds) == {ARG_POS} and isinstance(index, TupleExpr) - and len(index.items) == len(expr.args) + and len(index.items) == num_args ): # Special case "for x, y in zip(a, b)". for_zip = ForZip(builder, index, body_block, loop_exit, line, nested) @@ -512,7 +520,7 @@ def make_for_loop_generator( if ( expr.callee.fullname == "builtins.reversed" - and len(expr.args) == 1 + and num_args == 1 and expr.arg_kinds == [ARG_POS] and is_sequence_rprimitive(builder.node_type(expr.args[0])) ): @@ -1045,14 +1053,68 @@ def begin_body(self) -> None: builder.assign(target, rvalue, line) +def expr_value_is_inspectable(builder: IRBuilder, expr: Expression) -> bool: + """ + Return True if we can call buidler.accept(expr) with NO side effects. + + This is useful so we can call `builder.accept` twice on the same Expression + in order to perform some sort of runtime check on its value BEFORE the place + in the code where we actually intend to accept it. + """ + rtype = builder.node_type(expr) + if not ( + # we know we can do operations on these without side effects + is_immutable_rprimitive(rtype) + or is_int_rprimitive(rtype) + or is_short_int_rprimitive(rtype) + ): + return False + if isinstance(expr, RefExpr): + return True + if isinstance(expr, OpExpr): + return expr_value_is_inspectable(builder, expr.left) and expr_value_is_inspectable( + builder, expr.right + ) + # TODO: extend me with some more Expression subtypes + return False + + class ForRange(ForGenerator): """Generate optimized IR for a for loop over an integer range.""" - def init(self, start_reg: Value, end_reg: Value, step: int) -> None: + def init(self, start_reg: Value, end_reg: Value, step: int | Expression) -> None: builder = self.builder self.start_reg = start_reg self.end_reg = end_reg + self.step = step + if isinstance(step, int): + self.step_value = Integer(step) + else: + # we've already used `expr_value_is_inspectable` to make sure + # we can accept this Expression twice without side-effects + self.step_value = builder.accept(step) + + # Emit a runtime check for step == 0 and raise ValueError if so + zero_block = BasicBlock() + continue_block = BasicBlock() + is_zero = builder.binary_op(step, Integer(0), "==", self.line) + builder.add_bool_branch(is_zero, zero_block, continue_block, rare=True) + + # If step == 0, raise ValueError + builder.activate_block(zero_block) + builder.add( + RaiseStandardError( + RaiseStandardError.VALUE_ERROR, "range() arg 3 must not be zero", self.line + ) + ) + + # Continue with initialization if step != 0 + builder.activate_block(continue_block) + + # Check this one for use in gen_condition + self.step_is_positive = builder.binary_op(step, Integer(0), ">", self.line) + self.end_target = builder.maybe_spill(end_reg) if is_short_int_rprimitive(start_reg.type) and is_short_int_rprimitive(end_reg.type): index_type: RType = short_int_rprimitive @@ -1071,11 +1133,38 @@ def gen_condition(self) -> None: builder = self.builder line = self.line # Add loop condition check. - cmp = "<" if self.step > 0 else ">" - comparison = builder.binary_op( - builder.read(self.index_reg, line), builder.read(self.end_target, line), cmp, line - ) - builder.add_bool_branch(comparison, self.body_block, self.loop_exit) + if isinstance(self.step, int): + cmp = "<" if self.step > 0 else ">" + comparison = builder.binary_op( + builder.read(self.index_reg, line), builder.read(self.end_target, line), cmp, line + ) + builder.add_bool_branch(comparison, self.body_block, self.loop_exit) + elif isinstance(self.step, Value): + index_val = builder.read(self.index_reg, line) + end_target = builder.read(self.end_target, line) + + # Dynamic step: determine sign at runtime and branch accordingly + # NOTE step can't be zero here: we have already checked that step != 0 before constructing ForRange + # so at runtime, step is either positive or negative + positive_step_block = BasicBlock() + negative_step_block = BasicBlock() + + # Check if step > 0 + builder.add_bool_branch( + self.step_is_positive, positive_step_block, negative_step_block + ) + + # Positive step: index < end + builder.activate_block(positive_step_block) + cmp_pos = builder.binary_op(index_val, end_target, "<", line) + builder.add_bool_branch(cmp_pos, self.body_block, self.loop_exit) + + # Negative step: index > end + builder.activate_block(negative_step_block) + cmp_neg = builder.binary_op(index_val, end_target, ">", line) + builder.add_bool_branch(cmp_neg, self.body_block, self.loop_exit) + else: + raise NotImplementedError(type(self.step)) def gen_step(self) -> None: builder = self.builder @@ -1083,21 +1172,17 @@ def gen_step(self) -> None: # Increment index register. If the range is known to fit in short ints, use # short ints. + index_val = builder.read(self.index_reg, line) if is_short_int_rprimitive(self.start_reg.type) and is_short_int_rprimitive( self.end_reg.type ): new_val = builder.int_op( - short_int_rprimitive, - builder.read(self.index_reg, line), - Integer(self.step), - IntOp.ADD, - line, + short_int_rprimitive, index_val, self.step_value, IntOp.ADD, line ) else: - new_val = builder.binary_op( - builder.read(self.index_reg, line), Integer(self.step), "+", line - ) + new_val = builder.binary_op(index_val, self.step_value, "+", line) + builder.assign(self.index_reg, new_val, line) builder.assign(self.index_target, new_val, line) diff --git a/mypyc/test-data/commandline.test b/mypyc/test-data/commandline.test index 392ad3620790..bbcec9bd6a0a 100644 --- a/mypyc/test-data/commandline.test +++ b/mypyc/test-data/commandline.test @@ -168,7 +168,7 @@ from typing import Final, List, Any, AsyncIterable from mypy_extensions import trait, mypyc_attr def busted(b: bool) -> None: - for i in range(1, 10, 0): # E: range() step can't be zero + for i in range(1, 10): try: if i == 5: break # E: break inside try/finally block is unimplemented