Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 132 additions & 47 deletions mypyc/irbuild/for_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from __future__ import annotations

from typing import Callable, ClassVar
from typing import Callable, ClassVar, cast

from mypy.nodes import (
ARG_POS,
Expand All @@ -20,6 +20,7 @@
Lvalue,
MemberExpr,
NameExpr,
OpExpr,
RefExpr,
SetExpr,
StarExpr,
Expand Down Expand Up @@ -56,6 +57,7 @@
is_dict_rprimitive,
is_fixed_width_rtype,
is_immutable_rprimitive,
is_int_rprimitive,
is_list_rprimitive,
is_sequence_rprimitive,
is_short_int_rprimitive,
Expand Down Expand Up @@ -454,39 +456,45 @@ def make_for_loop_generator(
return for_dict

if isinstance(expr, CallExpr) and isinstance(expr.callee, RefExpr):
if (
is_range_ref(expr.callee)
and (
len(expr.args) <= 2
or (len(expr.args) == 3 and builder.extract_int(expr.args[2]) is not None)
)
and set(expr.arg_kinds) == {ARG_POS}
):
# Special case "for x in range(...)".
# We support the 3 arg form but only for int literals, since it doesn't
# seem worth the hassle of supporting dynamically determining which
# direction of comparison to do.
if len(expr.args) == 1:
start_reg: Value = Integer(0)
end_reg = builder.accept(expr.args[0])
else:
start_reg = builder.accept(expr.args[0])
end_reg = builder.accept(expr.args[1])
if len(expr.args) == 3:
step = builder.extract_int(expr.args[2])
assert step is not None
if step == 0:
builder.error("range() step can't be zero", expr.args[2].line)
else:
step = 1
num_args = len(expr.args)

for_range = ForRange(builder, index, body_block, loop_exit, line, nested)
for_range.init(start_reg, end_reg, step)
return for_range
if is_range_ref(expr.callee) and set(expr.arg_kinds) == {ARG_POS}:
# Special case "for x in range(...)".
# NOTE We support the 3 arg form but only when `step` is constant-
# foldable, since it doesn't seem worth the hassle of supporting
# dynamically determining which direction of comparison to do.
# If we cannot constant fold `step`, we just fallback to stdlib range.
if num_args <= 2 or (
num_args == 3
and (
builder.extract_int(expr.args[2])
or expr_value_is_inspectable(builder, expr.args[2])
)
):
if num_args == 1:
start_reg: Value = Integer(0)
end_reg = builder.accept(expr.args[0])
step = 1
else:
start_reg = builder.accept(expr.args[0])
end_reg = builder.accept(expr.args[1])
step = (
1
if num_args == 2
else (
expr.args[2]
if expr_value_is_inspectable(builder, expr.args[2])
else cast(int, builder.extract_int(expr.args[2]))
)
)

for_range = ForRange(builder, index, body_block, loop_exit, line, nested)
for_range.init(start_reg, end_reg, step)
return for_range

elif (
expr.callee.fullname == "builtins.enumerate"
and len(expr.args) == 1
and num_args == 1
and expr.arg_kinds == [ARG_POS]
and isinstance(index, TupleExpr)
and len(index.items) == 2
Expand All @@ -500,10 +508,10 @@ def make_for_loop_generator(

elif (
expr.callee.fullname == "builtins.zip"
and len(expr.args) >= 2
and num_args >= 2
and set(expr.arg_kinds) == {ARG_POS}
and isinstance(index, TupleExpr)
and len(index.items) == len(expr.args)
and len(index.items) == num_args
):
# Special case "for x, y in zip(a, b)".
for_zip = ForZip(builder, index, body_block, loop_exit, line, nested)
Expand All @@ -512,7 +520,7 @@ def make_for_loop_generator(

if (
expr.callee.fullname == "builtins.reversed"
and len(expr.args) == 1
and num_args == 1
and expr.arg_kinds == [ARG_POS]
and is_sequence_rprimitive(builder.node_type(expr.args[0]))
):
Expand Down Expand Up @@ -1045,14 +1053,68 @@ def begin_body(self) -> None:
builder.assign(target, rvalue, line)


def expr_value_is_inspectable(builder: IRBuilder, expr: Expression) -> bool:
"""
Return True if we can call buidler.accept(expr) with NO side effects.

This is useful so we can call `builder.accept` twice on the same Expression
in order to perform some sort of runtime check on its value BEFORE the place
in the code where we actually intend to accept it.
"""
rtype = builder.node_type(expr)
if not (
# we know we can do operations on these without side effects
is_immutable_rprimitive(rtype)
or is_int_rprimitive(rtype)
or is_short_int_rprimitive(rtype)
):
return False
if isinstance(expr, RefExpr):
return True
if isinstance(expr, OpExpr):
return expr_value_is_inspectable(builder, expr.left) and expr_value_is_inspectable(
builder, expr.right
)
# TODO: extend me with some more Expression subtypes
return False


class ForRange(ForGenerator):
"""Generate optimized IR for a for loop over an integer range."""

def init(self, start_reg: Value, end_reg: Value, step: int) -> None:
def init(self, start_reg: Value, end_reg: Value, step: int | Expression) -> None:
builder = self.builder
self.start_reg = start_reg
self.end_reg = end_reg

self.step = step
if isinstance(step, int):
self.step_value = Integer(step)
else:
# we've already used `expr_value_is_inspectable` to make sure
# we can accept this Expression twice without side-effects
self.step_value = builder.accept(step)

# Emit a runtime check for step == 0 and raise ValueError if so
zero_block = BasicBlock()
continue_block = BasicBlock()
is_zero = builder.binary_op(step, Integer(0), "==", self.line)
builder.add_bool_branch(is_zero, zero_block, continue_block, rare=True)

# If step == 0, raise ValueError
builder.activate_block(zero_block)
builder.add(
RaiseStandardError(
RaiseStandardError.VALUE_ERROR, "range() arg 3 must not be zero", self.line
)
)

# Continue with initialization if step != 0
builder.activate_block(continue_block)

# Check this one for use in gen_condition
self.step_is_positive = builder.binary_op(step, Integer(0), ">", self.line)

self.end_target = builder.maybe_spill(end_reg)
if is_short_int_rprimitive(start_reg.type) and is_short_int_rprimitive(end_reg.type):
index_type: RType = short_int_rprimitive
Expand All @@ -1071,33 +1133,56 @@ def gen_condition(self) -> None:
builder = self.builder
line = self.line
# Add loop condition check.
cmp = "<" if self.step > 0 else ">"
comparison = builder.binary_op(
builder.read(self.index_reg, line), builder.read(self.end_target, line), cmp, line
)
builder.add_bool_branch(comparison, self.body_block, self.loop_exit)
if isinstance(self.step, int):
cmp = "<" if self.step > 0 else ">"
comparison = builder.binary_op(
builder.read(self.index_reg, line), builder.read(self.end_target, line), cmp, line
)
builder.add_bool_branch(comparison, self.body_block, self.loop_exit)
elif isinstance(self.step, Value):
index_val = builder.read(self.index_reg, line)
end_target = builder.read(self.end_target, line)

# Dynamic step: determine sign at runtime and branch accordingly
# NOTE step can't be zero here: we have already checked that step != 0 before constructing ForRange
# so at runtime, step is either positive or negative
positive_step_block = BasicBlock()
negative_step_block = BasicBlock()

# Check if step > 0
builder.add_bool_branch(
self.step_is_positive, positive_step_block, negative_step_block
)

# Positive step: index < end
builder.activate_block(positive_step_block)
cmp_pos = builder.binary_op(index_val, end_target, "<", line)
builder.add_bool_branch(cmp_pos, self.body_block, self.loop_exit)

# Negative step: index > end
builder.activate_block(negative_step_block)
cmp_neg = builder.binary_op(index_val, end_target, ">", line)
builder.add_bool_branch(cmp_neg, self.body_block, self.loop_exit)
else:
raise NotImplementedError(type(self.step))

def gen_step(self) -> None:
builder = self.builder
line = self.line

# Increment index register. If the range is known to fit in short ints, use
# short ints.
index_val = builder.read(self.index_reg, line)
if is_short_int_rprimitive(self.start_reg.type) and is_short_int_rprimitive(
self.end_reg.type
):
new_val = builder.int_op(
short_int_rprimitive,
builder.read(self.index_reg, line),
Integer(self.step),
IntOp.ADD,
line,
short_int_rprimitive, index_val, self.step_value, IntOp.ADD, line
)

else:
new_val = builder.binary_op(
builder.read(self.index_reg, line), Integer(self.step), "+", line
)
new_val = builder.binary_op(index_val, self.step_value, "+", line)

builder.assign(self.index_reg, new_val, line)
builder.assign(self.index_target, new_val, line)

Expand Down
2 changes: 1 addition & 1 deletion mypyc/test-data/commandline.test
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ from typing import Final, List, Any, AsyncIterable
from mypy_extensions import trait, mypyc_attr

def busted(b: bool) -> None:
for i in range(1, 10, 0): # E: range() step can't be zero
for i in range(1, 10):
try:
if i == 5:
break # E: break inside try/finally block is unimplemented
Expand Down
Loading