Skip to content

Is there a way to attach metadata to a layer in a way that is included in the StableHLO export? #8993

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
j2kun opened this issue Apr 17, 2025 · 3 comments
Labels
question stablehlo StableHLO related work

Comments

@j2kun
Copy link

j2kun commented Apr 17, 2025

❓ Questions and Help

I am looking at a use case where metadata about a trained model's layers needs to be attached to the StableHLO export. I am using exported_program_to_stablehlo

One option I had considered is exporting the data completely separately from exported_program_to_stablehlo (say, by writing some random json to disk), but then I don't know how to connect the written metadata back to the stableHLO export, because the layer names do not appear to be attached to the generated StableHLO ops such.

Another option I tried was to attach the metadata directly to the torch nodes before calling exported_program_to_stablehlo, but I can't figure out how to do so in a way that results in the metadata being exported as MLIR attributes. It would suffice to export the attributes as, e.g., an op attribute with a given string name and value.

Could someone advise on whether this is possible, or suggest an alternative? (Or add a feature that would support this?)

@ysiraichi ysiraichi added question stablehlo StableHLO related work labels Apr 17, 2025
@ysiraichi
Copy link
Collaborator

Could you be more specific on why you would want to do that? I don't think StableHLO preserves that kind of metadata for each nn.Module.

@qihqi @lsy323 @tengyifei Any thoughts?

@j2kun
Copy link
Author

j2kun commented Apr 17, 2025

My specific use case is to include, in the exported IR, an empirical distribution of individual layer outputs (think range bounds), and then have a downstream compiler use them for optimizations.

An trivial example model might look like

nn.Sequential(
    nn.Linear(28 * 28, 512),
    nn.ReLU(),
    nn.Linear(512, 10),
)

And the exported stableHLO is something like

module @IrToHlo.20 attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
  func.func @main(%arg0: tensor<10xf32>, %arg1: tensor<10x512xf32>, %arg2: tensor<512xf32>, %arg3: tensor<512x784xf32>, %arg4: tensor<1x784xf32>) -> tensor<1x10xf32> {
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<1x512xf32>
    %0 = stablehlo.transpose %arg3, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[784,512]{0,1}"} : (tensor<512x784xf32>) -> tensor<784x512xf32>
    %1 = stablehlo.dot_general %arg4, %0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x784xf32>, tensor<784x512xf32>) -> tensor<1x512xf32>
    %2 = stablehlo.reshape %arg2 : (tensor<512xf32>) -> tensor<1x512xf32>
    %3 = stablehlo.add %1, %2 : tensor<1x512xf32>
    %4 = stablehlo.maximum %3, %cst : tensor<1x512xf32>
    %5 = stablehlo.transpose %arg1, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[512,10]{0,1}"} : (tensor<10x512xf32>) -> tensor<512x10xf32>
    %6 = stablehlo.dot_general %4, %5, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x512xf32>, tensor<512x10xf32>) -> tensor<1x10xf32>
    %7 = stablehlo.reshape %arg0 : (tensor<10xf32>) -> tensor<1x10xf32>
    %8 = stablehlo.add %6, %7 : tensor<1x10xf32>
    return %8 : tensor<1x10xf32>
  }
}

And I'd want the custom metadata to be attached to op's that correspond to layer outputs. E.g. (the last two lines of)

func.func @main(%arg0: tensor<10xf32>, %arg1: tensor<10x512xf32>, %arg2: tensor<512xf32>, %arg3: tensor<512x784xf32>, %arg4: tensor<1x784xf32>) -> tensor<1x10xf32> {
  %cst = stablehlo.constant dense<0.000000e+00> : tensor<1x512xf32>
  %0 = stablehlo.transpose %arg3, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[784,512]{0,1}"} : (tensor<512x784xf32>) -> tensor<784x512xf32>
  %1 = stablehlo.dot_general %arg4, %0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x784xf32>, tensor<784x512xf32>) -> tensor<1x512xf32>
  %2 = stablehlo.reshape %arg2 : (tensor<512xf32>) -> tensor<1x512xf32>
  %3 = stablehlo.add %1, %2 {custom.min = "-23.5", custom.max = "18.9"} : tensor<1x512xf32>    // <<--- attribute
  %4 = stablehlo.maximum %3, %cst {custom.min = "0", custom.max = "18.9"} : tensor<1x512xf32>  // <<--- attribute
  ... etc
}

I know that MLIR can support adding arbitrary attributes to ops without having to modify the op definition, and I can probably ensure those attributes are preserved by downstream passes in stablehlo-opt (presumably I can contribute a patch to fix any places where attributes are dropped, this is standard MLIR stuff). So the main question is, when building the StableHLO bundle during export, is it straightforward to propagate that information to the generated MLIR?

Otherwise, the alternative is side-loading the attributes in MLIR, which is difficult for the reason mentioned above: the exported stablehlo does not preserve the node names as far as I can tell. So if we could also just export node names as attributes, that would probably suffice.

@ysiraichi
Copy link
Collaborator

Right. I don't think there's a way to straight-forwardly add this kind of metadata to the StableHLO we generate. But, I'm not super familiar with it. Maybe @tengyifei @qihqi @lsy323 could clarify a bit more.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question stablehlo StableHLO related work
Projects
None yet
Development

No branches or pull requests

2 participants