|
| 1 | +## RFC: OpenXLA PJRT Plugin |
| 2 | + |
| 3 | +| Status | Proposed | |
| 4 | +| :------------ | :------------------------------------------------------ | |
| 5 | +| **RFC #** | [32](https://github.com/openxla/community/pull/32) | |
| 6 | +| **Author(s) ** | Skye Wanderman-Milne ( [email protected]), Jieying Luo ( [email protected]), Jacques Pienaar ( [email protected]) | |
| 7 | +| **Sponsor ** | Stella Laurenzo ( [email protected]), James Rubin ( [email protected]) | |
| 8 | +| **Updated** | 2023-01-23 | |
| 9 | + |
| 10 | +## Objective |
| 11 | + |
| 12 | +* Framework integration of a packaged compiler and runtime solution; |
| 13 | + |
| 14 | +## Proposal |
| 15 | + |
| 16 | +* Adopt PJRT as the supported device plugin mechanism for OpenXLA; |
| 17 | +* Create new repo openxla/openxla-pjrt-plugin for the OpenXLA PJRT plugin; |
| 18 | + |
| 19 | +## Background: PJRT |
| 20 | + |
| 21 | +[PJRT](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h) |
| 22 | +is a uniform Device API that we want to add to the OpenXLA ecosystem. The long |
| 23 | +term vision for PJRT is that: (1) frameworks (TensorFlow, PyTorch, JAX, etc.) |
| 24 | +will call PJRT, which has device-specific implementations that are opaque to the |
| 25 | +frameworks; (2) each device focuses on implementing PJRT APIs, and can be opaque |
| 26 | +to the frameworks. |
| 27 | + |
| 28 | + |
| 29 | +<p align = "center"> PJRT provides a platform independent interfacing for |
| 30 | +compilers and corresponding runtimes. </p> |
| 31 | + |
| 32 | +PJRT API will provide an easy interface with which frameworks can integrate a |
| 33 | +packaged compiler and runtime solution. It will be the supported interface that |
| 34 | +will be used by TensorFlow and JAX for all compiler and runtime integration. And |
| 35 | +as such it will be easy for other compilers and runtimes that implement the PJRT |
| 36 | +interface to integrate with these systems. |
| 37 | + |
| 38 | +## PJRT plugin mechanism goal |
| 39 | + |
| 40 | +The PJRT plugin mechanism should support the following features: |
| 41 | + |
| 42 | +* Different devices (e.g. TPU, GPU) can have different implementations |
| 43 | + (through |
| 44 | + [PJRT C API interface](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h)). |
| 45 | +* Registration of multiple PJRT plugins. |
| 46 | +* Loading multiple PJRT plugins (e.g. both CPU and TPU) in the same process. |
| 47 | +* Passing configuration values from client (e.g. JAX lib) or from json files |
| 48 | + provided by the plugin (default configs). |
| 49 | +* Plugin discovery and choosing which plugins to load. |
| 50 | + |
| 51 | +## High level plugin structure |
| 52 | + |
| 53 | + |
| 54 | + |
| 55 | +## Current Status |
| 56 | + |
| 57 | +As of Dec 14, 2022 |
| 58 | + |
| 59 | +* [LoadPjrtPlugin(plugin_name, library_path)](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/stream_executor/tpu/pjrt_api.cc#L73) |
| 60 | + can be used to load a PJRT plugin. We also provide a Python method |
| 61 | + [load_pjrt_plugin](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/xla.cc#L329) |
| 62 | + binded to it. |
| 63 | +* [GetCApiClient(plugin_name)](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.cc#L1150) |
| 64 | + can be used to create a PJRT client. We also provide a Python method |
| 65 | + [get_c_api_client](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/xla.cc#L391) |
| 66 | + binded to it. |
| 67 | + |
| 68 | +## Design Considerations |
| 69 | + |
| 70 | +### <a name="heading=h.782ksg6rl5bj"></a> Loading plugins |
| 71 | + |
| 72 | +We will provide a low level C++ method to load a PJRT plugin, which takes |
| 73 | +`library_path` and `config_values` as inputs. A python method `load_pjrt_plugin` |
| 74 | +binded to it will be provided as well. |
| 75 | + |
| 76 | +```c++ |
| 77 | +Status LoadPjrtPlugin(string library_path, map<string, string> config_values) { |
| 78 | + library_handle = dlopen(library_path); |
| 79 | + function_prt = dlsym(library_handle, "GetPjrtApi"); |
| 80 | + if (function_prt != nullptr) { |
| 81 | + PJRT_Api* api = function_prt(); |
| 82 | + plugin_name = parse(library_path); |
| 83 | + PluginInfo plugin_info(api, config_values); |
| 84 | + global_pjrt_plugin_map[plugin_name] = plugin_info; |
| 85 | + } |
| 86 | +} |
| 87 | +``` |
| 88 | +
|
| 89 | +* `GetPjrtApi` returns a `PJRT_Api*` pointer that contains the implementation |
| 90 | + of the C APIs. |
| 91 | +* `global_pjrt_plugin_map` is a global `map<plugin_name, PluginInfo>` with the |
| 92 | + same lifetime as the program |
| 93 | +* `plugin_name` has the same meaning as `platform_name` in JAX. |
| 94 | +* To be able to store the config values, and be future-proofed for other |
| 95 | + information such as version, we propose to use a class PluginInfo. This |
| 96 | + class is immutable after constructed, and contains getter method for |
| 97 | + `PJRT_Api*` and config values. |
| 98 | +
|
| 99 | +### Discovering and automatic loading plugins |
| 100 | +
|
| 101 | +To allow framework users to pip-install plugins without requiring further code |
| 102 | +changes, we'll implement a Python discovery mechanism to automatically find and |
| 103 | +load plugins. The plugin discovery will be based on the |
| 104 | +[naming convention of the Python module](https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/#using-naming-convention), |
| 105 | +and full paths set in the environment variable `PJRT_PLUGIN_LIBRARY_PATH` which |
| 106 | +is added to allow users to manually specify directories and/or .so files). For |
| 107 | +modules found, it will be imported. |
| 108 | +
|
| 109 | +* Tentative naming convention for .so/json files: |
| 110 | + "pjrt-plugin-<plugin_name>.so" or "pjrt-plugin-<plugin_name>.json". |
| 111 | +
|
| 112 | +There are two options to automatically load plugins (decision has not been made |
| 113 | +yet): |
| 114 | +
|
| 115 | +Option 1: Plugin module updates PJRT_PLUGIN_LIBRARY_PATH on import. A python |
| 116 | +method `load_pjrt_plugins` will be added to discover .so/json files related to |
| 117 | +PJRT plugins and load them. |
| 118 | +
|
| 119 | +```python |
| 120 | +def load_pjrt_plugins(): |
| 121 | + for directory in env[PJRT_PLUGIN_LIBRARY_PATH]: |
| 122 | + if json file: |
| 123 | + Library_path, config_values = parse json file |
| 124 | + load_pjrt_plugin(library_path, config_values) |
| 125 | + elif .so file: |
| 126 | + load_pjrt_plugin(.so file path) |
| 127 | +``` |
| 128 | + |
| 129 | +Option 2: Plugin module is responsible for calling |
| 130 | +[load_pjrt_plugin](#heading=h.782ksg6rl5bj) with its default options on import. |
| 131 | + |
| 132 | +Open questions: |
| 133 | + |
| 134 | +* Are there requirements about what file should not be loaded? |
| 135 | + |
| 136 | +### <a name="heading=h.396bmv8gkskz"></a> Create PJRT client(s) |
| 137 | + |
| 138 | +Frameworks decide which PJRT clients to create and use. For example, a framework |
| 139 | +can create PJRT clients for all the plugins loaded. It can also choose to only |
| 140 | +create a subset of PJRT clients based on some priorities rules. It can also come |
| 141 | +from the user configuration. |
| 142 | + |
| 143 | +Two python methods binding to C++ methods will be provided to facilitate |
| 144 | +creating PJRT clients: |
| 145 | + |
| 146 | +1. `get_loaded_plugin_names()` which gets all loaded plugin names from |
| 147 | + `global_pjrt_plugin_map`. |
| 148 | +2. `create_pjrt_client(plugin_name)` which creates a PJRT C API client (similar |
| 149 | + to |
| 150 | + [GetCApiClient](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.cc#L1150)). |
| 151 | + It will retrieve the PluginInfo stored in `global_pjrt_plugin_map` and run |
| 152 | + `PJRT_Client_Create` (see |
| 153 | + [Config values for creating PJRT client(s)](#heading=h.bjuf0soco0sj) |
| 154 | + section). |
| 155 | + |
| 156 | +Open questions: |
| 157 | + |
| 158 | +* What about plugin initialization that are not related to creating a PJRT |
| 159 | + client? For example, these two functions in |
| 160 | + [tpu_initializer_helper.cc](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/stream_executor/tpu/tpu_initializer_helper.cc#L269-L270) |
| 161 | + should be run when initializing TPU. Currently they are called every time a |
| 162 | + PJRT TPU client is created. Shall we add another method InitializePlugin and |
| 163 | + only run it once? Alternatively, the plugin can implement it in |
| 164 | + `PJRT_Client_Create` and run it once in the first time a client was created. |
| 165 | +* Do we want to create PJRT clients for every plugin that is found? Will that |
| 166 | + involve some initialization for the device which should not run multiple |
| 167 | + times? |
| 168 | + * We may need to store loaded PluginInfo in a nested map `map<device_type, |
| 169 | + <plugin_name, PluginInfo>>` if we want to know what plugins are |
| 170 | + available for a device (e.g. current PJRT GPU and IREE GPU) and only |
| 171 | + create one PJRT client per device. |
| 172 | + |
| 173 | +### <a name="heading=h.bjuf0soco0sj"></a> Config values for creating PJRT client(s) |
| 174 | + |
| 175 | +GetCApiClient will be changed to take a map of config values. This map can be |
| 176 | +passed to `PJRT_Client_Create` through |
| 177 | +[PJRT_NamedValue](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h#L210). |
| 178 | +The default config values can (1) come from the json file which will be stored |
| 179 | +in PluginInfo, or (2) plugin implementation during import. |
| 180 | + |
| 181 | +### Framework integration |
| 182 | + |
| 183 | +All the methods mentioned above can be reused across frameworks. For example, |
| 184 | +JAX lib can integrate as follows: |
| 185 | + |
| 186 | +* Call [load_pjrt_plugins](#heading=h.782ksg6rl5bj) when |
| 187 | + [initializing backends](https://github.com/google/jax/blob/a66b3dcdd378b723e275a19258de826260b4c83e/jax/_src/lib/xla_bridge.py#L381). |
| 188 | +* Call [get_loaded_plugin_names](#heading=h.396bmv8gkskz) to get loaded PJRT |
| 189 | + `plugin_name`, have some framework specific logics to decide whether to call |
| 190 | + [create_pjrt_client](#heading=h.396bmv8gkskz) to create the PJRT client. |
| 191 | + |
| 192 | +```python |
| 193 | +def create_pjrt_clients(): |
| 194 | + loaded_plugin_names = get_loaded_plugin_names() |
| 195 | + for plugin_name in loaded_plugin_names: |
| 196 | + # Framework specific logics to decide whether to create |
| 197 | + if should_create_pjrt_client(plugin_name): |
| 198 | + pjrt_client = create_pjrt_client(plugin_name) |
| 199 | +``` |
| 200 | + |
| 201 | +For TensorFlow, discovery will be added to |
| 202 | +[here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/api_template.__init__.py#L142), |
| 203 | +and loading PJRT plugins will be added to |
| 204 | +[here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/common_runtime/pluggable_device/pluggable_device_plugin_init.cc#L124). |
| 205 | +Depending on the plugin, the client can be created implicitly in the first time |
| 206 | +it is used, or created in a plugin custom op kernel. Created PJRT clients will |
| 207 | +be saved in the |
| 208 | +[global TensorFlow ResourceManager](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/tfrt/common/pjrt_util.h#L26). |
| 209 | + |
| 210 | +For more information about PJRT, please consult the PJRT |
| 211 | +[C header](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h), |
| 212 | +[C++ header](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/pjrt/pjrt_client.h), |
| 213 | +[plugin mechanism design doc](https://docs.google.com/document/d/1FieBwiflD6DT9Z38wFIrnNxrUto1AOP3UuibQ4vpvEQ/edit) |
| 214 | +and |
| 215 | +[integration guide](https://docs.google.com/document/d/1EL62DkASMFZbk89K6MIPNGUFBTbkC21v5-uL5fB8VsM/edit). |
| 216 | + |
| 217 | +## Initial contribution |
| 218 | + |
| 219 | +We will use the plugin in |
| 220 | +[IREE samples](https://github.com/iree-org/iree-samples/tree/main/pjrt-plugin) |
| 221 | +to seed the repo and build full PJRT support as part of OpenXLA project. |
0 commit comments