Skip to content

Commit e95a27c

Browse files
committed
[RFC] OpenXLA PJRT plugin
Present proposal for OpenXLA PJRT plugin along with creation of new repo for its development.
1 parent 3fe1df4 commit e95a27c

File tree

3 files changed

+221
-0
lines changed

3 files changed

+221
-0
lines changed

rfcs/20230123-pjrt-plugin.md

+221
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
## RFC: OpenXLA PJRT Plugin
2+
3+
| Status | Proposed |
4+
| :------------ | :------------------------------------------------------ |
5+
| **RFC #** | [33](https://github.com/openxla/community/pull/33) |
6+
| **Author(s)** | Skye Wanderman-Milne ([email protected]), Jieying Luo ([email protected]), Jacques Pienaar ([email protected]) |
7+
| **Sponsor** | Stella Laurenzo ([email protected]), James Rubin ([email protected]) |
8+
| **Updated** | 2023-01-23 |
9+
10+
## Objective
11+
12+
* Framework integration of a packaged compiler and runtime solution;
13+
14+
## Proposal
15+
16+
* Adopt PJRT as the supported device plugin mechanism for OpenXLA;
17+
* Create new repo openxla/openxla-pjrt-plugin for the OpenXLA PJRT plugin;
18+
19+
## Background: PJRT
20+
21+
[PJRT](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h)
22+
is a uniform Device API that we want to add to the OpenXLA ecosystem. The long
23+
term vision for PJRT is that: (1) frameworks (TensorFlow, PyTorch, JAX, etc.)
24+
will call PJRT, which has device-specific implementations that are opaque to the
25+
frameworks; (2) each device focuses on implementing PJRT APIs, and can be opaque
26+
to the frameworks.
27+
28+
![PJRT plugin integrating OpenXLA into ML frameworks](20230123-pjrt-plugin/frameworks.png)
29+
<p align = "center"> PJRT provides a platform independent interfacing for
30+
compilers and corresponding runtimes. </p>
31+
32+
PJRT API will provide an easy interface with which frameworks can integrate a
33+
packaged compiler and runtime solution. It will be the supported interface that
34+
will be used by TensorFlow and JAX for all compiler and runtime integration. And
35+
as such it will be easy for other compilers and runtimes that implement the PJRT
36+
interface to integrate with these systems.
37+
38+
## PJRT plugin mechanism goal
39+
40+
The PJRT plugin mechanism should support the following features:
41+
42+
* Different devices (e.g. TPU, GPU) can have different implementations
43+
(through
44+
[PJRT C API interface](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h)).
45+
* Registration of multiple PJRT plugins.
46+
* Loading multiple PJRT plugins (e.g. both CPU and TPU) in the same process.
47+
* Passing configuration values from client (e.g. JAX lib) or from json files
48+
provided by the plugin (default configs).
49+
* Plugin discovery and choosing which plugins to load.
50+
51+
## High level plugin structure
52+
53+
![High-level plugin structure](20230123-pjrt-plugin/plugin-structure.png)
54+
55+
## Current Status
56+
57+
As of Dec 14, 2022
58+
59+
* [LoadPjrtPlugin(plugin_name, library_path)](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/stream_executor/tpu/pjrt_api.cc#L73)
60+
can be used to load a PJRT plugin. We also provide a Python method
61+
[load_pjrt_plugin](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/xla.cc#L329)
62+
binded to it.
63+
* [GetCApiClient(plugin_name)](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.cc#L1150)
64+
can be used to create a PJRT client. We also provide a Python method
65+
[get_c_api_client](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/xla.cc#L391)
66+
binded to it.
67+
68+
## Design Considerations
69+
70+
### <a name="heading=h.782ksg6rl5bj"></a> Loading plugins
71+
72+
We will provide a low level C++ method to load a PJRT plugin, which takes
73+
`library_path` and `config_values` as inputs. A python method `load_pjrt_plugin`
74+
binded to it will be provided as well.
75+
76+
```c++
77+
Status LoadPjrtPlugin(string library_path, map<string, string> config_values) {
78+
library_handle = dlopen(library_path);
79+
function_prt = dlsym(library_handle, "GetPjrtApi");
80+
if (function_prt != nullptr) {
81+
PJRT_Api* api = function_prt();
82+
plugin_name = parse(library_path);
83+
PluginInfo plugin_info(api, config_values);
84+
global_pjrt_plugin_map[plugin_name] = plugin_info;
85+
}
86+
}
87+
```
88+
89+
* `GetPjrtApi` returns a `PJRT_Api*` pointer that contains the implementation
90+
of the C APIs.
91+
* `global_pjrt_plugin_map` is a global `map<plugin_name, PluginInfo>` with the
92+
same lifetime as the program
93+
* `plugin_name` has the same meaning as `platform_name` in JAX.
94+
* To be able to store the config values, and be future-proofed for other
95+
information such as version, we propose to use a class PluginInfo. This
96+
class is immutable after constructed, and contains getter method for
97+
`PJRT_Api*` and config values.
98+
99+
### Discovering and automatic loading plugins
100+
101+
To allow framework users to pip-install plugins without requiring further code
102+
changes, we'll implement a Python discovery mechanism to automatically find and
103+
load plugins. The plugin discovery will be based on the
104+
[naming convention of the Python module](https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/#using-naming-convention),
105+
and full paths set in the environment variable `PJRT_PLUGIN_LIBRARY_PATH` which
106+
is added to allow users to manually specify directories and/or .so files). For
107+
modules found, it will be imported.
108+
109+
* Tentative naming convention for .so/json files:
110+
"pjrt-plugin-<plugin_name>.so" or "pjrt-plugin-<plugin_name>.json".
111+
112+
There are two options to automatically load plugins (decision has not been made
113+
yet):
114+
115+
Option 1: Plugin module updates PJRT_PLUGIN_LIBRARY_PATH on import. A python
116+
method `load_pjrt_plugins` will be added to discover .so/json files related to
117+
PJRT plugins and load them.
118+
119+
```python
120+
def load_pjrt_plugins():
121+
for directory in env[PJRT_PLUGIN_LIBRARY_PATH]:
122+
if json file:
123+
Library_path, config_values = parse json file
124+
load_pjrt_plugin(library_path, config_values)
125+
elif .so file:
126+
load_pjrt_plugin(.so file path)
127+
```
128+
129+
Option 2: Plugin module is responsible for calling
130+
[load_pjrt_plugin](#heading=h.782ksg6rl5bj) with its default options on import.
131+
132+
Open questions:
133+
134+
* Are there requirements about what file should not be loaded?
135+
136+
### <a name="heading=h.396bmv8gkskz"></a> Create PJRT client(s)
137+
138+
Frameworks decide which PJRT clients to create and use. For example, a framework
139+
can create PJRT clients for all the plugins loaded. It can also choose to only
140+
create a subset of PJRT clients based on some priorities rules. It can also come
141+
from the user configuration.
142+
143+
Two python methods binding to C++ methods will be provided to facilitate
144+
creating PJRT clients:
145+
146+
1. `get_loaded_plugin_names()` which gets all loaded plugin names from
147+
`global_pjrt_plugin_map`.
148+
2. `create_pjrt_client(plugin_name)` which creates a PJRT C API client (similar
149+
to
150+
[GetCApiClient](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.cc#L1150)).
151+
It will retrieve the PluginInfo stored in `global_pjrt_plugin_map` and run
152+
`PJRT_Client_Create` (see
153+
[Config values for creating PJRT client(s)](#heading=h.bjuf0soco0sj)
154+
section).
155+
156+
Open questions:
157+
158+
* What about plugin initialization that are not related to creating a PJRT
159+
client? For example, these two functions in
160+
[tpu_initializer_helper.cc](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/stream_executor/tpu/tpu_initializer_helper.cc#L269-L270)
161+
should be run when initializing TPU. Currently they are called every time a
162+
PJRT TPU client is created. Shall we add another method InitializePlugin and
163+
only run it once? Alternatively, the plugin can implement it in
164+
`PJRT_Client_Create` and run it once in the first time a client was created.
165+
* Do we want to create PJRT clients for every plugin that is found? Will that
166+
involve some initialization for the device which should not run multiple
167+
times?
168+
* We may need to store loaded PluginInfo in a nested map `map<device_type,
169+
<plugin_name, PluginInfo>>` if we want to know what plugins are
170+
available for a device (e.g. current PJRT GPU and IREE GPU) and only
171+
create one PJRT client per device.
172+
173+
### <a name="heading=h.bjuf0soco0sj"></a> Config values for creating PJRT client(s)
174+
175+
GetCApiClient will be changed to take a map of config values. This map can be
176+
passed to `PJRT_Client_Create` through
177+
[PJRT_NamedValue](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h#L210).
178+
The default config values can (1) come from the json file which will be stored
179+
in PluginInfo, or (2) plugin implementation during import.
180+
181+
### Framework integration
182+
183+
All the methods mentioned above can be reused across frameworks. For example,
184+
JAX lib can integrate as follows:
185+
186+
* Call [load_pjrt_plugins](#heading=h.782ksg6rl5bj) when
187+
[initializing backends](https://github.com/google/jax/blob/a66b3dcdd378b723e275a19258de826260b4c83e/jax/_src/lib/xla_bridge.py#L381).
188+
* Call [get_loaded_plugin_names](#heading=h.396bmv8gkskz) to get loaded PJRT
189+
`plugin_name`, have some framework specific logics to decide whether to call
190+
[create_pjrt_client](#heading=h.396bmv8gkskz) to create the PJRT client.
191+
192+
```python
193+
def create_pjrt_clients():
194+
loaded_plugin_names = get_loaded_plugin_names()
195+
for plugin_name in loaded_plugin_names:
196+
# Framework specific logics to decide whether to create
197+
if should_create_pjrt_client(plugin_name):
198+
pjrt_client = create_pjrt_client(plugin_name)
199+
```
200+
201+
For TensorFlow, discovery will be added to
202+
[here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/api_template.__init__.py#L142),
203+
and loading PJRT plugins will be added to
204+
[here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/common_runtime/pluggable_device/pluggable_device_plugin_init.cc#L124).
205+
Depending on the plugin, the client can be created implicitly in the first time
206+
it is used, or created in a plugin custom op kernel. Created PJRT clients will
207+
be saved in the
208+
[global TensorFlow ResourceManager](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/tfrt/common/pjrt_util.h#L26).
209+
210+
For more information about PJRT, please consult the PJRT
211+
[C header](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h),
212+
[C++ header](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/pjrt/pjrt_client.h),
213+
[plugin mechanism design doc](https://docs.google.com/document/d/1FieBwiflD6DT9Z38wFIrnNxrUto1AOP3UuibQ4vpvEQ/edit)
214+
and
215+
[integration guide](https://docs.google.com/document/d/1EL62DkASMFZbk89K6MIPNGUFBTbkC21v5-uL5fB8VsM/edit).
216+
217+
## Initial contribution
218+
219+
We will use the plugin in
220+
[IREE samples](https://github.com/iree-org/iree-samples/tree/main/pjrt-plugin)
221+
to seed the repo and build full PJRT support as part of OpenXLA project.
127 KB
Loading
116 KB
Loading

0 commit comments

Comments
 (0)