50
50
###################################################################################
51
51
# Defining the Neural Network
52
52
# ---------------------------
53
- #
53
+ #
54
54
# We will use the same neural network structure as the regional compilation recipe.
55
55
#
56
56
# We will use a network, composed of repeated layers. This mimics a
@@ -93,12 +93,12 @@ def forward(self, x):
93
93
##################################################################################
94
94
# Compiling the model ahead-of-time
95
95
# ---------------------------------
96
- #
96
+ #
97
97
# Since we're compiling the model ahead-of-time, we need to prepare representative
98
98
# input examples, that we expect the model to see during actual deployments.
99
- #
99
+ #
100
100
# Let's create an instance of ``Model`` and pass it some sample input data.
101
- #
101
+ #
102
102
103
103
model = Model ().cuda ()
104
104
input = torch .randn (10 , 10 , device = "cuda" )
@@ -123,7 +123,7 @@ def forward(self, x):
123
123
######################################################################################
124
124
# Compiling _regions_ of the model ahead-of-time
125
125
# ----------------------------------------------
126
- #
126
+ #
127
127
# Compiling model regions ahead-of-time, on the other hand, requires a few key changes.
128
128
#
129
129
# Since the compute pattern is shared by all the blocks that
@@ -141,13 +141,13 @@ def forward(self, x):
141
141
142
142
###################################################
143
143
# An exported program (``torch.export.ExportedProgram``) contains the Tensor computation,
144
- # a ``state_dict`` containing tensor values of all lifted parameters and buffer alongside
144
+ # a ``state_dict`` containing tensor values of all lifted parameters and buffer alongside
145
145
# other metadata. We specify the ``aot_inductor.package_constants_in_so`` to be ``False`` to
146
146
# not serialize the model parameters in the generated artifact.
147
147
#
148
148
# Now, when loading the compiled binary, we can reuse the existing parameters of
149
149
# each block. This lets us take advantage of the compiled binary obtained above.
150
- #
150
+ #
151
151
152
152
for layer in model .layers :
153
153
compiled_layer = torch ._inductor .aoti_load_package (path )
@@ -187,17 +187,17 @@ def measure_compile_time(input, regional=False):
187
187
def aot_compile_load_model (regional = False ) -> torch .nn .Module :
188
188
input = torch .randn (10 , 10 , device = "cuda" )
189
189
model = Model ().cuda ()
190
-
190
+
191
191
inductor_configs = {}
192
192
if regional :
193
193
inductor_configs = {"aot_inductor.package_constants_in_so" : False }
194
-
194
+
195
195
# Reset the compiler caches to ensure no reuse between different runs
196
196
torch .compiler .reset ()
197
197
with torch ._inductor .utils .fresh_inductor_cache ():
198
198
path = torch ._inductor .aoti_compile_and_package (
199
199
torch .export .export (
200
- model .layers [0 ] if regional else model ,
200
+ model .layers [0 ] if regional else model ,
201
201
args = (input ,)
202
202
),
203
203
inductor_configs = inductor_configs ,
@@ -224,16 +224,16 @@ def aot_compile_load_model(regional=False) -> torch.nn.Module:
224
224
assert regional_compilation_latency < full_model_compilation_latency
225
225
226
226
############################################################################
227
- # There may also be layers in a model incompatible with compilation. So,
227
+ # There may also be layers in a model incompatible with compilation. So,
228
228
# full compilation will result in a fragmented computation graph resulting
229
229
# in potential latency degradation. In these case, regional compilation
230
230
# can be beneficial.
231
- #
231
+ #
232
232
233
233
############################################################################
234
234
# Conclusion
235
235
# -----------
236
236
#
237
- # This recipe shows how to control the cold start time when compiling your
237
+ # This recipe shows how to control the cold start time when compiling your
238
238
# model ahead-of-time. This becomes effective when your model has repeated
239
239
# blocks, which is typically seen in large generative models.
0 commit comments