Skip to content

Commit 89d40f4

Browse files
authored
Merge pull request #33 from jpienaar/main
[RFC] OpenXLA PJRT plugin
2 parents 77a3099 + e0b163b commit 89d40f4

File tree

3 files changed

+243
-0
lines changed

3 files changed

+243
-0
lines changed

rfcs/20230123-pjrt-plugin.md

+243
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
## RFC: OpenXLA PJRT Plugin
2+
3+
| Status | Proposed |
4+
| :------------ | :------------------------------------------------------ |
5+
| **RFC #** | [33](https://github.com/openxla/community/pull/33) |
6+
| **Author(s)** | Skye Wanderman-Milne ([email protected]), Jieying Luo ([email protected]), Jacques Pienaar ([email protected]) |
7+
| **Sponsor** | Stella Laurenzo ([email protected]), James Rubin ([email protected]) |
8+
| **Updated** | 2023-01-23 |
9+
10+
## Objective
11+
12+
* Framework integration of a packaged compiler and runtime solution;
13+
14+
## Proposal
15+
16+
* Adopt PJRT as the supported device plugin mechanism for OpenXLA;
17+
* Create new repo openxla/openxla-pjrt-plugin for the OpenXLA PJRT plugin;
18+
19+
## Background: PJRT
20+
21+
[PJRT](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h)
22+
is a uniform Device API that we want to add to the OpenXLA ecosystem. The long
23+
term vision for PJRT is that: (1) frameworks (TensorFlow, PyTorch, JAX, etc.)
24+
will call PJRT, which has device-specific implementations that are opaque to the
25+
frameworks; (2) each device focuses on implementing PJRT APIs, and can be opaque
26+
to the frameworks.
27+
28+
![PJRT plugin integrating OpenXLA into ML frameworks](20230123-pjrt-plugin/frameworks.png)
29+
<p align = "center"> PJRT provides a platform independent interfacing for
30+
compilers and corresponding runtimes. </p>
31+
32+
PJRT API will provide an easy interface with which frameworks can integrate a
33+
packaged compiler and runtime solution. It will be the supported interface that
34+
will be used by TensorFlow and JAX for all compiler and runtime integration. And
35+
as such it will be easy for other compilers and runtimes that implement the PJRT
36+
interface to integrate with these systems.
37+
38+
## PJRT plugin mechanism goal
39+
40+
The PJRT plugin mechanism should support the following features:
41+
42+
* Different devices (e.g. TPU, GPU) can have different implementations
43+
(through
44+
[PJRT C API interface](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h)).
45+
* Registration of multiple PJRT plugins.
46+
* Loading multiple PJRT plugins (e.g. both CPU and TPU) in the same process.
47+
* Passing configuration values from client (e.g. JAX lib) or from json files
48+
provided by the plugin (default configs).
49+
* Plugin discovery and choosing which plugins to load.
50+
51+
## High level plugin structure
52+
53+
![High-level plugin structure](20230123-pjrt-plugin/plugin-structure.png)
54+
55+
## Current Status
56+
57+
As of Dec 14, 2022
58+
59+
* [LoadPjrtPlugin(plugin_name, library_path)](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/stream_executor/tpu/pjrt_api.cc#L73)
60+
can be used to load a PJRT plugin. We also provide a Python method
61+
[load_pjrt_plugin](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/xla.cc#L329)
62+
binded to it.
63+
* [GetCApiClient(plugin_name)](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.cc#L1150)
64+
can be used to create a PJRT client. We also provide a Python method
65+
[get_c_api_client](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/xla.cc#L391)
66+
binded to it.
67+
68+
## Design Considerations
69+
70+
### <a name="heading=h.782ksg6rl5bj"></a> Loading plugins
71+
72+
We will provide a low level C++ method to load a PJRT plugin, which takes
73+
`library_path` and `config_values` as inputs. A python method `load_pjrt_plugin`
74+
binded to it will be provided as well.
75+
76+
```c++
77+
Status LoadPjrtPlugin(string library_path, map<string, string> config_values) {
78+
library_handle = dlopen(library_path);
79+
function_prt = dlsym(library_handle, "GetPjrtApi");
80+
if (function_prt != nullptr) {
81+
PJRT_Api* api = function_prt();
82+
plugin_name = parse(library_path);
83+
if (auto it = global_pjrt_plugin_map.find(plugin_name);
84+
it != global_pjrt_plugin_map.end()) {
85+
// Allows to the same plugin multiple times.
86+
return;
87+
}
88+
PluginInfo plugin_info(api, config_values);
89+
global_pjrt_plugin_map[plugin_name] = plugin_info;
90+
}
91+
}
92+
```
93+
94+
* `GetPjrtApi` returns a `PJRT_Api*` pointer that contains the implementation
95+
of the C APIs.
96+
* `global_pjrt_plugin_map` is a global `map<plugin_name, PluginInfo>` with the
97+
same lifetime as the program
98+
* `plugin_name` has the same meaning as `platform_name` in JAX.
99+
* To be able to store the config values, and be future-proofed for other
100+
information such as version, we propose to use a class PluginInfo. This
101+
class is immutable after constructed, and contains getter method for
102+
`PJRT_Api*` and config values.
103+
104+
### Discovering and automatic loading plugins
105+
106+
To allow framework users to pip-install plugins without requiring further code
107+
changes, we'll implement a Python discovery mechanism to automatically find and
108+
load plugins. The plugin discovery will be based on the
109+
[naming convention of the Python module](https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/#using-naming-convention),
110+
and full paths set in the environment variable `PJRT_PLUGIN_LIBRARY_PATH` which
111+
is added to allow users to manually specify directories and/or .so files). For
112+
modules found, it will be imported.
113+
114+
* Tentative naming convention for .so/json files:
115+
"pjrt-plugin-<plugin_name>.so" or "pjrt-plugin-<plugin_name>.json".
116+
117+
There are two options to automatically load plugins (decision has not been made
118+
yet):
119+
120+
Option 1: Plugin module updates PJRT_PLUGIN_LIBRARY_PATH on import. A python
121+
method `load_pjrt_plugins` will be added to discover .so/json files related to
122+
PJRT plugins and load them.
123+
124+
```python
125+
def load_pjrt_plugins():
126+
for directory in env[PJRT_PLUGIN_LIBRARY_PATH]:
127+
if json file:
128+
Library_path, config_values = parse json file
129+
load_pjrt_plugin(library_path, config_values)
130+
elif .so file:
131+
load_pjrt_plugin(.so file path)
132+
```
133+
134+
Option 2: Plugin module is responsible for calling
135+
[load_pjrt_plugin](#heading=h.782ksg6rl5bj) with its default options on import.
136+
137+
Open questions:
138+
139+
* Are there requirements about what file should not be loaded?
140+
141+
### <a name="heading=h.396bmv8gkskz"></a> Create PJRT client(s)
142+
143+
Frameworks decide which PJRT clients to create and use. For example, a framework
144+
can create PJRT clients for all the plugins loaded. It can also choose to only
145+
create a subset of PJRT clients based on some priorities rules. It can also come
146+
from the user configuration.
147+
148+
Two python methods binding to C++ methods will be provided to facilitate
149+
creating PJRT clients:
150+
151+
1. `get_loaded_plugin_names()` which gets all loaded plugin names from
152+
`global_pjrt_plugin_map`.
153+
2. `create_pjrt_client(plugin_name)` which creates a PJRT C API client (similar
154+
to
155+
[GetCApiClient](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.cc#L1150)).
156+
It will retrieve the PluginInfo stored in `global_pjrt_plugin_map` and run
157+
`PJRT_Client_Create` (see
158+
[Config values for creating PJRT client(s)](#heading=h.bjuf0soco0sj)
159+
section).
160+
161+
Open questions:
162+
163+
* What about plugin initialization that are not related to creating a PJRT
164+
client? For example, these two functions in
165+
[tpu_initializer_helper.cc](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/stream_executor/tpu/tpu_initializer_helper.cc#L269-L270)
166+
should be run when initializing TPU. Currently they are called every time a
167+
PJRT TPU client is created. One solution is to place this initizliation in
168+
method `GetPjrtApi`. It will run the first time that a plugin is loaded.
169+
* It is not desired to create PJRT clients for every plugin that is found as
170+
that will increase resource utilization. Framework decides what behavior
171+
they prefer.
172+
* We may need to store loaded PluginInfo in a nested map `map<device_type,
173+
<plugin_name, PluginInfo>>` if the framework wants to know what plugins
174+
are available for a device (e.g. current PJRT GPU and IREE GPU) and only
175+
create one PJRT client per device.
176+
177+
### <a name="heading=h.bjuf0soco0sj"></a> Config values for creating PJRT client(s)
178+
179+
GetCApiClient will be changed to take a map of config values. This map can be
180+
passed to `PJRT_Client_Create` through
181+
[PJRT_NamedValue](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h#L210).
182+
The default config values can (1) come from the json file which will be stored
183+
in PluginInfo, or (2) plugin implementation during import.
184+
185+
### Framework integration
186+
187+
All the methods mentioned above can be reused across frameworks. For example,
188+
JAX lib can integrate as follows:
189+
190+
* Call [load_pjrt_plugins](#heading=h.782ksg6rl5bj) when
191+
[initializing backends](https://github.com/google/jax/blob/a66b3dcdd378b723e275a19258de826260b4c83e/jax/_src/lib/xla_bridge.py#L381).
192+
* Call [get_loaded_plugin_names](#heading=h.396bmv8gkskz) to get loaded PJRT
193+
`plugin_name`, have some framework specific logics to decide whether to call
194+
[create_pjrt_client](#heading=h.396bmv8gkskz) to create the PJRT client.
195+
196+
```python
197+
def create_pjrt_clients():
198+
loaded_plugin_names = get_loaded_plugin_names()
199+
for plugin_name in loaded_plugin_names:
200+
# Framework specific logics to decide whether to create
201+
if should_create_pjrt_client(plugin_name):
202+
pjrt_client = create_pjrt_client(plugin_name)
203+
```
204+
205+
For TensorFlow, discovery will be added to
206+
[here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/api_template.__init__.py#L142),
207+
and loading PJRT plugins will be added to
208+
[here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/common_runtime/pluggable_device/pluggable_device_plugin_init.cc#L124).
209+
Depending on the plugin, the client can be created implicitly in the first time
210+
it is used, or created in a plugin custom op kernel. Created PJRT clients will
211+
be saved in the
212+
[global TensorFlow ResourceManager](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/tfrt/common/pjrt_util.h#L26).
213+
214+
Whether multiple frameworks can create a client for the same device is up to the
215+
specific hardware and plugin implementation. If the hardware only allows
216+
exclusive access, then the software will abide by that constraint. Otherwise,
217+
some plugins may use the hardware in an exclusive way (ie. allocate all memory).
218+
The openxla plugin that we envision now will default to allowing multiple clients
219+
and supporting on demand memory allocation (with a possible option for more greedy
220+
access as an optimization for cases that need it).
221+
222+
For more information about PJRT, please consult the PJRT
223+
[C header](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h),
224+
[C++ header](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/pjrt/pjrt_client.h),
225+
[plugin mechanism design doc](https://docs.google.com/document/d/1FieBwiflD6DT9Z38wFIrnNxrUto1AOP3UuibQ4vpvEQ/edit)
226+
and
227+
[integration guide](https://docs.google.com/document/d/1EL62DkASMFZbk89K6MIPNGUFBTbkC21v5-uL5fB8VsM/edit).
228+
229+
## Open questions and future improvments from the comments
230+
* Windows support.
231+
* Test suites for PJRT implementations.
232+
* How to handle versioning or variants (e.g. x86, x86-64 and ARM CPUs, different
233+
CUDA compute capability GPUs, DPUs, etc.).
234+
* How does a device broadcast its capability or supported API calls?
235+
* Support unload and reload a PJRT plugin (hot-swap/plug).
236+
* Single vs Multi-process access/ ownership (maybe covered by IRFT).
237+
* Command queues/buffer management APIs access, simple commands vs. transactional.
238+
239+
## Initial contribution
240+
241+
We will use the plugin in
242+
[IREE samples](https://github.com/iree-org/iree-samples/tree/main/pjrt-plugin)
243+
to seed the repo and build full PJRT support as part of OpenXLA project.
127 KB
Loading
116 KB
Loading

0 commit comments

Comments
 (0)