Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add bcast and bcast_until_actx_array functions #307

Closed
wants to merge 1 commit into from

Conversation

inducer
Copy link
Owner

@inducer inducer commented Mar 27, 2025

I'm mostly seeking early feedback on this. It still needs, at least:

  • Documentation
  • Tests
  • Fixed types

This could help fill the hole that's left by the plan to make implicit broadcasting of data class containers across actx array types illegal. See inducer/grudge#377 for an example use, specifically shortcuts.py.

Closes #280. (Would supersede the cursed Bcast objects that @alexfikl hated---and I have to agree, it is just nicer.)

Copy link
Owner Author

@inducer inducer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some specific questions for the review.

(k, op(left_v, right)) for k, left_v in serialized])


def bcast_left_until_actx_array(
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Naming? (May not matter as much, as it's convenient to make local aliases, as in the grudge PR.)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe rec_bcast_left? Mostly to match some of the other traversal functions.

I agree that local aliases are probably needed, so this should be nice and verbose.. i.e. I'm quite fine with the name.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with adding rec_, and I would suggest maybe adding something else to disambiguate what's happening to the "left". (Picturing myself looking at this again in 6 months and not being sure if it means "broadcast the left across" or "broadcast across the left".) rec_bcast_left_operand_across_actx_arrays or something (probably too verbose, maybe you can think of something better).

(k, op(left_v, right)) for k, left_v in serialized])


def bcast_left_until_actx_array(
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Argument order? (Here and above.) I initially had the operator in the middle but went with "first" to allow convenient use of partial.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the functional feel of it too 👍

@@ -987,4 +989,79 @@ def treat_as_scalar(x: Any) -> bool:

# }}}


# {{{
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# {{{
# {{{ bcast

Copy link
Collaborator

@alexfikl alexfikl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Took a quick look and it seems fine to me.

The only issue would be if it's too verbose for more complex usage? The one in grudge looks nice enough to me :\

@inducer
Copy link
Owner Author

inducer commented Mar 28, 2025

Thanks for taking a look!

The one in grudge looks nice enough to me :\

Same.

The only issue would be if it's too verbose for more complex usage?

Good question. Without question, the prior state in grudge is nicer to look at.

Before:

residual = a*residual + h*rhs_val
y = y + b * residual

After:

residual = bcmul(a, residual) + bcmul(h, rhs_val)
y = y + bcmul(b, residual)

On the one hand, it's clearly more verbose. On the other, the verbosity scales linearly with expression size, so maybe not too bad? We could also just defer this issue to when it comes up. One attractive feature of these functions is that they're easy to replace: deprecate, grep, done. So I anticipate this to be low regret, even if it turns out to not be the right answer long term.

If we wanted to preserve the operatory-y feel, we need a way to let the the operator know broadcast rules (given that they depend on the array context), and we are (I am) hoping to no longer have array context handles in array containers. The only approach I can think of there is the Bcast solution (#280) for reference:

residual = bcast(a)*residual + bcast(h)*rhs_val
y = y + bcast(b) * residual

(That's assuming a local alias bcast has been created.) I'm not advocating for it. It's just about the same amount of visual clutter, and it requires special machinery inside container arithmetic, unlike the function-based approach here.

@alexfikl
Copy link
Collaborator

The only approach I can think of there is the Bcast solution (#280) for reference:

residual = bcast(a)*residual + bcast(h)*rhs_val
y = y + bcast(b) * residual

(That's assuming a local alias bcast has been created.) I'm not advocating for it. It's just about the same amount of visual clutter, and it requires special machinery inside container arithmetic, unlike the function-based approach here.

That's a good point.. I forgot that version was quite verbose-y as well. Then yeah, from my side this looks very nice and very low on the black-box magic compared to the previous version 😁

(k, op(left_v, right)) for k, left_v in serialized])


def bcast_left_until_actx_array(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with adding rec_, and I would suggest maybe adding something else to disambiguate what's happening to the "left". (Picturing myself looking at this again in 6 months and not being sure if it means "broadcast the left across" or "broadcast across the left".) rec_bcast_left_operand_across_actx_arrays or something (probably too verbose, maybe you can think of something better).

Comment on lines +1026 to +1031
actx: ArrayContext,
op: Callable[[ArrayOrArithContainer, ArrayOrArithContainer],
ArrayOrArithContainer],
left: ArrayOrArithContainer,
right: ArithArrayContainer,
) -> ArrayOrArithContainer:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to have the interface for these be more like the container traversal functions, e.g.:

def rec_bcast_left(op, left, right, leaf_cls: type | None = None):
    ...

and then

mul = partial(rec_bcast_left, operator.mul, leaf_cls=actx.array_types)

?

@majosm
Copy link
Collaborator

majosm commented Mar 28, 2025

Thanks for taking a look!

The one in grudge looks nice enough to me :\

Same.

The only issue would be if it's too verbose for more complex usage?

Good question. Without question, the prior state in grudge is nicer to look at.

Before:

residual = a*residual + h*rhs_val
y = y + b * residual

After:

residual = bcmul(a, residual) + bcmul(h, rhs_val)
y = y + bcmul(b, residual)

On the one hand, it's clearly more verbose. On the other, the verbosity scales linearly with expression size, so maybe not too bad? We could also just defer this issue to when it comes up. One attractive feature of these functions is that they're easy to replace: deprecate, grep, done. So I anticipate this to be low regret, even if it turns out to not be the right answer long term.

If we wanted to preserve the operatory-y feel, we need a way to let the the operator know broadcast rules (given that they depend on the array context), and we are (I am) hoping to no longer have array context handles in array containers. The only approach I can think of there is the Bcast solution (#280) for reference:

residual = bcast(a)*residual + bcast(h)*rhs_val
y = y + bcast(b) * residual

(That's assuming a local alias bcast has been created.) I'm not advocating for it. It's just about the same amount of visual clutter, and it requires special machinery inside container arithmetic, unlike the function-based approach here.

TBH from a syntax point of view I think I like the bcast(a) version better, as it avoids the "left"/"right" naming ambiguity that I mentioned above, and it seems a little easier to follow with the original operators still being there. (IIRC my main objection to #280 was the 1/2/3/NLevels stuff.) If there aren't any major complications with doing it that way vs. the mul version, I think I would prefer that.

@majosm
Copy link
Collaborator

majosm commented Mar 28, 2025

Does the bcast(a) version necessitate doing the _rewrap stuff from #280? Or is there a way to do it that's more like the current implementation here?

Edit: I'm imagining something roughly along the lines of _binary_op in Array?

@inducer
Copy link
Owner Author

inducer commented Mar 28, 2025

Does the bcast(a) version necessitate doing the _rewrap stuff from #280?

Good question.

If we wrap the thing being broadcast in Bcast (as in #280), then I don't see how one can get by without the rewrap nonsense.

However, if we wrap the container being broadcast over in Bcast, conceivably we can get by with just overloaded operators on the Bcast object itself. I think that would considerably simplify the machinery (to just some overloaded operators). @alexfikl How much would you loathe that?

@majosm
Copy link
Collaborator

majosm commented Mar 28, 2025

Does the bcast(a) version necessitate doing the _rewrap stuff from #280?

Good question.

If we wrap the thing being broadcast in Bcast (as in #280), then I don't see how one can get by without the rewrap nonsense.

However, if we wrap the container being broadcast over in Bcast, conceivably we can get by with just overloaded operators on the Bcast object itself. I think that would considerably simplify the machinery (to just some overloaded operators). @alexfikl How much would you loathe that?

I can't speak for Alex, but I would object to that. 🙂 I would be confused by that syntax.

I'm probably missing something obvious, but why wouldn't we be able to do something along the lines of:

class Bcast:
    def __init__(self, array):
      self.array = array

    def _binary_op(self, op, right):
        # Need a reverse case too, but I'm lazy

        try:
            serialized = serialize_container(right)
        except NotAnArrayContainerError:
            return op(self.array, right)

        return deserialize_container(right, [
            (k, op(self.array, right_v)
                if isinstance(right_v, actx.array_types) else
                self._binary_op(op, right_v)
            )
            for k, right_v in serialized])

    __mul__ = partialmethod(_binary_op, operator.mul)

?

@inducer
Copy link
Owner Author

inducer commented Mar 28, 2025

I'm probably missing something obvious

I'm not sure you are. That looks possible from my point of view. In #280, I think I was too mentally focused on modifying code in the container arithmetic, leading to loads of unnecessary complexity.

@inducer
Copy link
Owner Author

inducer commented Mar 28, 2025

To add: I'm not hating that as a possible direction. It shouldn't need any weird "magic". @alexfikl?

@alexfikl
Copy link
Collaborator

To add: I'm not hating that as a possible direction. It shouldn't need any weird "magic". @alexfikl?

Agreed! That looks like it should be pretty easy to understand with some of the same greppability. I'll mull it over, but can't think of any downsides at the moment :-?

@inducer
Copy link
Owner Author

inducer commented Mar 31, 2025

See #310 for a stab at this.

@inducer inducer closed this Mar 31, 2025
@inducer inducer mentioned this pull request Mar 31, 2025
3 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants