Skip to content

Commit 2bff13d

Browse files
committed
Core ML doc updates
1 parent c4bd450 commit 2bff13d

File tree

6 files changed

+420
-389
lines changed

6 files changed

+420
-389
lines changed

docs/source/backends-coreml.md

Lines changed: 0 additions & 389 deletions
This file was deleted.
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Op support
2+
3+
The Core ML backend supports almost all PyTorch operators.
4+
5+
If an operator in your model is not supported by Core ML, you will see a warning about this during lowering. If you want to guarantee that your model fully delegates to Core ML, you can set [`lower_full_graph=True`](coreml-partitioner.md) in the `CoreMLPartitioner`. When set, lowering will fail if an unsupported operator is encountered.
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Core ML Backend
2+
3+
Core ML delegate is the ExecuTorch solution to take advantage of Apple's [Core ML framework](https://developer.apple.com/documentation/coreml) for on-device ML. With Core ML, a model can run on CPU, GPU, and the Apple Neural Engine (ANE).
4+
5+
## Features
6+
7+
- Dynamic dispatch to the CPU, GPU, and ANE.
8+
- Supports fp32 and fp16 computation.
9+
10+
## Target Requirements
11+
12+
Below are the minimum OS requirements on various hardware for running a Core ML-delegated ExecuTorch model:
13+
- [macOS](https://developer.apple.com/macos) >= 13.0
14+
- [iOS](https://developer.apple.com/ios/) >= 16.0
15+
- [iPadOS](https://developer.apple.com/ipados/) >= 16.0
16+
- [tvOS](https://developer.apple.com/tvos/) >= 16.0
17+
18+
## Development Requirements
19+
To develop you need:
20+
21+
- [macOS](https://developer.apple.com/macos) >= 13.0
22+
- [Xcode](https://developer.apple.com/documentation/xcode) >= 14.1
23+
24+
25+
Before starting, make sure you install the Xcode Command Line Tools:
26+
27+
```bash
28+
xcode-select --install
29+
```
30+
31+
----
32+
33+
## Using the Core ML Backend
34+
35+
To target the Core ML backend during the export and lowering process, pass an instance of the `CoreMLPartitioner` to `to_edge_transform_and_lower`. The example below demonstrates this process using the MobileNet V2 model from torchvision.
36+
37+
```python
38+
import torch
39+
import torchvision.models as models
40+
from torchvision.models.mobilenetv2 import MobileNet_V2_Weights
41+
from executorch.backends.apple.coreml.partition import CoreMLPartitioner
42+
from executorch.exir import to_edge_transform_and_lower
43+
44+
mobilenet_v2 = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).eval()
45+
sample_inputs = (torch.randn(1, 3, 224, 224), )
46+
47+
et_program = to_edge_transform_and_lower(
48+
torch.export.export(mobilenet_v2, sample_inputs),
49+
partitioner=[CoreMLPartitioner()],
50+
).to_executorch()
51+
52+
with open("mv2_coreml.pte", "wb") as file:
53+
et_program.write_to_file(file)
54+
```
55+
56+
See [Partitioner API](coreml-partitioner.md) for a reference on available partitioner options.
57+
58+
----
59+
60+
## Quantization
61+
62+
The Core ML delegate can also be used as a backend to execute quantized models. See [Core ML Quantization](coreml-quantization.md) for more information on available quantization schemes and APIs.
63+
64+
65+
## Backward compatibility
66+
67+
Core ML supports backward compatibility via the [`minimum_deployment_target`](coreml-partitioner.md#coreml-compilespec) option. A model exported with a specific deployment target is guaranteed to work on all deployment targets >= the specified deployment target. For example, a model exported with `coremltools.target.iOS17` will work on iOS 17 or higher.
68+
69+
----
70+
71+
## Runtime integration
72+
73+
To run the model on device, use the standard ExecuTorch runtime APIs. See [Running on Device](getting-started.md#running-on-device) for more information, including building the iOS frameworks.
74+
75+
When building from source, pass `-DEXECUTORCH_BUILD_COREML=ON` when configuring the CMake build to compile the Core ML backend.
76+
77+
Due to the use of static initializers for registration, it may be necessary to use whole-archive to link against the `coremldelegate` target. This can typically be done by passing `"$<LINK_LIBRARY:WHOLE_ARCHIVE,coremldelegate>"` to `target_link_libraries`.
78+
79+
```
80+
# CMakeLists.txt
81+
add_subdirectory("executorch")
82+
...
83+
target_link_libraries(
84+
my_target
85+
PRIVATE executorch
86+
extension_module_static
87+
extension_tensor
88+
optimized_native_cpu_ops_lib
89+
$<LINK_LIBRARY:WHOLE_ARHIVE,coremldelegate>)
90+
```
91+
92+
No additional steps are necessary to use the backend beyond linking the target. A Core ML-delegated .pte file will automatically run on the registered backend.
93+
94+
95+
## Reference
96+
97+
**→{doc}`coreml-troubleshooting` — Debug common issues.**
98+
99+
**→{doc}`coreml-partitioner` — Partitioner options.**
100+
101+
**→{doc}`coreml-quantization` — Supported quantization schemes.**
102+
103+
**→{doc}`coreml-op-support` — Supported operators.**
104+
105+
```{toctree}
106+
:maxdepth: 2
107+
:hidden:
108+
:caption: Core ML Backend
109+
coreml-troubleshooting
110+
coreml-partitioner
111+
coreml-quantization
112+
coreml-op-support
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# Partitioner API
2+
3+
The Core ML partitioner API allows for configuration of the model delegation to Core ML. Passing a `CoreMLPartitioner` instance with no additional parameters will run as much of the model as possible on the Core ML backend with default settings. This is the most common use case. For advanced use cases, the partitioner exposes the following options via the [constructor](https://github.com/pytorch/executorch/blob/14ff52ff89a89c074fc6c14d3f01683677783dcd/backends/apple/coreml/partition/coreml_partitioner.py#L60):
4+
5+
6+
- `skip_ops_for_coreml_delegation`: Allows you to skip ops for delegation by Core ML. By default, all ops that Core ML supports will be delegated. See [here](https://github.com/pytorch/executorch/blob/14ff52ff89a89c074fc6c14d3f01683677783dcd/backends/apple/coreml/test/test_coreml_partitioner.py#L42) for an example of skipping an op for delegation.
7+
- `compile_specs`: A list of `CompileSpec`s for the Core ML backend. These control low-level details of Core ML delegation, such as the compute unit (CPU, GPU, ANE), the iOS deployment target, and the compute precision (FP16, FP32). These are discussed more below.
8+
- `take_over_mutable_buffer`: A boolean that indicates whether PyTorch mutable buffers in stateful models should be converted to [Core ML `MLState`](https://developer.apple.com/documentation/coreml/mlstate). If set to `False`, mutable buffers in the PyTorch graph are converted to graph inputs and outputs to the Core ML lowered module under the hood. Generally, setting `take_over_mutable_buffer` to true will result in better performance, but using `MLState` requires iOS >= 18.0, macOS >= 15.0, and Xcode >= 16.0.
9+
- `take_over_constant_data`: A boolean that indicates whether PyTorch constant data like model weights should be consumed by the Core ML delegate. If set to False, constant data is passed to the Core ML delegate as inputs. By default, take_over_constant_data=True.
10+
- `lower_full_graph`: A boolean that indicates whether the entire graph must be lowered to Core ML. If set to True and Core ML does not support an op, an error is raised during lowering. If set to False and Core ML does not support an op, the op is executed on the CPU by ExecuTorch. Although setting `lower_full_graph`=False can allow a model to lower where it would otherwise fail, it can introduce performance overhead in the model when there are unsupported ops. You will see warnings about unsupported ops during lowering if there are any. By default, `lower_full_graph`=False.
11+
12+
13+
#### Core ML CompileSpec
14+
15+
A list of `CompileSpec`s is constructed with [`CoreMLBackend.generate_compile_specs`](https://github.com/pytorch/executorch/blob/14ff52ff89a89c074fc6c14d3f01683677783dcd/backends/apple/coreml/compiler/coreml_preprocess.py#L210). Below are the available options:
16+
- `compute_unit`: this controls the compute units (CPU, GPU, ANE) that are used by Core ML. The default value is `coremltools.ComputeUnit.ALL`. The available options from coremltools are:
17+
- `coremltools.ComputeUnit.ALL` (uses the CPU, GPU, and ANE)
18+
- `coremltools.ComputeUnit.CPU_ONLY` (uses the CPU only)
19+
- `coremltools.ComputeUnit.CPU_AND_GPU` (uses both the CPU and GPU, but not the ANE)
20+
- `coremltools.ComputeUnit.CPU_AND_NE` (uses both the CPU and ANE, but not the GPU)
21+
- `minimum_deployment_target`: The minimum iOS deployment target (e.g., `coremltools.target.iOS18`). By default, the smallest deployment target needed to deploy the model is selected. During export, you will see a warning about the "Core ML specification version" that was used for the model, which maps onto a deployment target as discussed [here](https://apple.github.io/coremltools/mlmodel/Format/Model.html#model). If you need to control the deployment target, please specify it explicitly.
22+
- `compute_precision`: The compute precision used by Core ML (`coremltools.precision.FLOAT16` or `coremltools.precision.FLOAT32`). The default value is `coremltools.precision.FLOAT16`. Note that the compute precision is applied no matter what dtype is specified in the exported PyTorch model. For example, an FP32 PyTorch model will be converted to FP16 when delegating to the Core ML backend by default. Also note that the ANE only supports FP16 precision.
23+
- `model_type`: Whether the model should be compiled to the Core ML [mlmodelc format](https://developer.apple.com/documentation/coreml/downloading-and-compiling-a-model-on-the-user-s-device) during .pte creation ([`CoreMLBackend.MODEL_TYPE.COMPILED_MODEL`](https://github.com/pytorch/executorch/blob/14ff52ff89a89c074fc6c14d3f01683677783dcd/backends/apple/coreml/compiler/coreml_preprocess.py#L71)), or whether it should be compiled to mlmodelc on device ([`CoreMLBackend.MODEL_TYPE.MODEL`](https://github.com/pytorch/executorch/blob/14ff52ff89a89c074fc6c14d3f01683677783dcd/backends/apple/coreml/compiler/coreml_preprocess.py#L70)). Using `CoreMLBackend.MODEL_TYPE.COMPILED_MODEL` and doing compilation ahead of time should improve the first time on-device model load time.
24+
25+
### Dynamic and Enumerated Shapes in Core ML Export
26+
27+
When exporting an `ExportedProgram` to Core ML, **dynamic shapes** are mapped to [`RangeDim`](https://apple.github.io/coremltools/docs-guides/source/flexible-inputs.html#set-the-range-for-each-dimension).
28+
This enables Core ML `.pte` files to accept inputs with varying dimensions at runtime.
29+
30+
⚠️ **Note:** The Apple Neural Engine (ANE) does not support true dynamic shapes. If a model relies on `RangeDim`, Core ML will fall back to scheduling the model on the CPU or GPU instead of the ANE.
31+
32+
---
33+
34+
#### Enumerated Shapes
35+
36+
To enable limited flexibility on the ANE—and often achieve better performance overall—you can export models using **[enumerated shapes](https://apple.github.io/coremltools/docs-guides/source/flexible-inputs.html#select-from-predetermined-shapes)**.
37+
38+
- Enumerated shapes are *not fully dynamic*.
39+
- Instead, they define a **finite set of valid input shapes** that Core ML can select from at runtime.
40+
- This approach allows some adaptability while still preserving ANE compatibility.
41+
42+
---
43+
44+
#### Specifying Enumerated Shapes
45+
46+
Unlike `RangeDim`, **enumerated shapes are not part of the `ExportedProgram` itself.**
47+
They must be provided through a compile spec.
48+
49+
For reference on how to do this, see:
50+
- The annotated code snippet below, and
51+
- The [end-to-end test in ExecuTorch](https://github.com/pytorch/executorch/blob/main/backends/apple/coreml/test/test_enumerated_shapes.py), which demonstrates how to specify enumerated shapes during export.
52+
53+
54+
```python
55+
class Model(torch.nn.Module):
56+
def __init__(self):
57+
super().__init__()
58+
self.linear1 = torch.nn.Linear(10, 5)
59+
self.linear2 = torch.nn.Linear(11, 5)
60+
61+
def forward(self, x, y):
62+
return self.linear1(x).sum() + self.linear2(y)
63+
64+
model = Model()
65+
example_inputs = (
66+
torch.randn((4, 6, 10)),
67+
torch.randn((5, 11)),
68+
)
69+
70+
# Specify the enumerated shapes. Below we specify that:
71+
#
72+
# * x can take shape [1, 5, 10] and y can take shape [3, 11], or
73+
# * x can take shape [4, 6, 10] and y can take shape [5, 11]
74+
#
75+
# Any other input shapes will result in a runtime error.
76+
#
77+
# Note that we must export x and y with dynamic shapes in the ExportedProgram
78+
# because some of their dimensions are dynamic
79+
enumerated_shapes = {"x": [[1, 5, 10], [4, 6, 10]], "y": [[3, 11], [5, 11]]}
80+
dynamic_shapes = [
81+
{
82+
0: torch.export.Dim.AUTO(min=1, max=4),
83+
1: torch.export.Dim.AUTO(min=5, max=6),
84+
},
85+
{0: torch.export.Dim.AUTO(min=3, max=5)},
86+
]
87+
ep = torch.export.export(
88+
model.eval(), example_inputs, dynamic_shapes=dynamic_shapes
89+
)
90+
91+
# If enumerated shapes are specified for multiple inputs, we must export
92+
# for iOS18+
93+
compile_specs = CoreMLBackend.generate_compile_specs(
94+
minimum_deployment_target=ct.target.iOS18
95+
)
96+
compile_specs.append(
97+
CoreMLBackend.generate_enumerated_shapes_compile_spec(
98+
ep,
99+
enumerated_shapes,
100+
)
101+
)
102+
103+
# When using an enumerated shape compile spec, you must specify lower_full_graph=True
104+
# in the CoreMLPartitioner. We do not support using enumerated shapes
105+
# for partially exported models
106+
partitioner = CoreMLPartitioner(
107+
compile_specs=compile_specs, lower_full_graph=True
108+
)
109+
delegated_program = executorch.exir.to_edge_transform_and_lower(
110+
ep,
111+
partitioner=[partitioner],
112+
)
113+
et_prog = delegated_program.to_executorch()
114+
```

0 commit comments

Comments
 (0)