1
+ from transformers import ViTModel , AutoImageProcessor
2
+ from PIL import Image
3
+ import time
4
+ import torch
5
+ import os
6
+ import numpy as np
7
+ import logging
8
+
9
+ import torch_xla
10
+
11
+ from neuronx_distributed_inference .utils .hf_adapter import load_pretrained_config
12
+ from neuronx_distributed_inference .models .config import NeuronConfig
13
+ from neuronx_distributed_inference .utils .accuracy import check_accuracy_embeddings
14
+ from neuronx_distributed_inference .utils .benchmark import LatencyCollector
15
+ from neuronx_distributed_inference .models .vit .modeling_vit import NeuronViTForImageEncoding , ViTInferenceConfig
16
+
17
+
18
+ NUM_BENCHMARK_ITER = 10
19
+ MODEL_PATH = "/home/ubuntu/model_hf/google--vit-huge-patch14-224-in21k/"
20
+ TRACED_MODEL_PATH = "/home/ubuntu/model_hf/google--vit-huge-patch14-224-in21k/traced_model/"
21
+
22
+ logger = logging .getLogger (__name__ )
23
+ logger .setLevel (logging .INFO )
24
+
25
+ def setup_debug_env ():
26
+ os .environ ["XLA_FALLBACK_CPU" ] = "0"
27
+ os .environ ["XLA_IR_DEBUG" ] = "1"
28
+ os .environ ["XLA_HLO_DEBUG" ] = "1"
29
+ os .environ ["NEURON_FUSE_SOFTMAX" ] = "1"
30
+ torch_xla ._XLAC ._set_ir_debug (True )
31
+ torch .manual_seed (0 )
32
+
33
+
34
+ def run_vit_encoding (validate_accuracy = True ):
35
+ # Define configs
36
+ neuron_config = NeuronConfig (
37
+ tp_degree = 32 ,
38
+ torch_dtype = torch .float32 ,
39
+ )
40
+ inference_config = ViTInferenceConfig (
41
+ neuron_config = neuron_config ,
42
+ load_config = load_pretrained_config (MODEL_PATH ),
43
+ use_mask_token = False ,
44
+ add_pooling_layer = False ,
45
+ interpolate_pos_encoding = False
46
+ )
47
+
48
+ # input image
49
+ image_file = "dog.jpg" # [512, 512]
50
+ with open (image_file , "rb" ) as f :
51
+ image = Image .open (f ).convert ("RGB" )
52
+ print (f"Input image size { image .size } " )
53
+ # preprocess input image
54
+ image_processor = AutoImageProcessor .from_pretrained (MODEL_PATH )
55
+ pixel_values = image_processor (image , return_tensors = "pt" )["pixel_values" ]
56
+
57
+ # Get neuron model
58
+ neuron_model = NeuronViTForImageEncoding (model_path = MODEL_PATH , config = inference_config )
59
+
60
+ # Compile model on Neuron
61
+ compile_start_time = time .time ()
62
+ neuron_model .compile (TRACED_MODEL_PATH )
63
+ compile_elapsed_time = time .time () - compile_start_time
64
+ print (f"Compilation time taken { compile_elapsed_time } s" )
65
+
66
+ # Load model on Neuron
67
+ neuron_model .load (TRACED_MODEL_PATH )
68
+ print ("Done loading neuron model" )
69
+
70
+ # Run NxDI implementation on Neuron
71
+ neuron_latency_collector = LatencyCollector ()
72
+ for i in range (NUM_BENCHMARK_ITER ):
73
+ neuron_latency_collector .pre_hook ()
74
+ neuron_output = neuron_model (pixel_values )[0 ] # NeuronViTModel output (sequence_output,) or (sequence_output, pooled_output)
75
+ neuron_latency_collector .hook ()
76
+ print (f"Got neuron output { neuron_output .shape } { neuron_output } " )
77
+ # Benchmark report
78
+ for p in [25 , 50 , 90 , 99 ]:
79
+ latency = np .percentile (neuron_latency_collector .latency_list , p ) * 1000
80
+ print (f"Neuron inference latency_ms_p{ p } : { latency } " )
81
+
82
+ # The below section is optional, use if you want to validate e2e accuracy against golden
83
+ if validate_accuracy :
84
+ # Get CPU model
85
+ cpu_model = ViTModel .from_pretrained (MODEL_PATH )
86
+ print (f"cpu model { cpu_model } " )
87
+
88
+ # Get golden output by running original implementation on CPU
89
+ cpu_latency_collector = LatencyCollector ()
90
+ for i in range (NUM_BENCHMARK_ITER ):
91
+ cpu_latency_collector .pre_hook ()
92
+ golden_output = cpu_model (pixel_values ).last_hidden_state
93
+ cpu_latency_collector .hook ()
94
+ print (f"expected_output { golden_output .shape } { golden_output } " )
95
+ # Benchmark report
96
+ for p in [25 , 50 , 90 , 99 ]:
97
+ latency = np .percentile (cpu_latency_collector .latency_list , p ) * 1000
98
+ print (f"CPU inference latency_ms_p{ p } : { latency } " )
99
+
100
+ # Compare output logits
101
+ passed , max_err = check_accuracy_embeddings (neuron_output , golden_output , plot_outputs = True , atol = 1e-5 , rtol = 1e-5 )
102
+ print (f"Golden and Neuron outputs match: { passed } , max relative error: { max_err } " )
103
+
104
+
105
+
106
+ if __name__ == "__main__" :
107
+ # Set flags for debugging
108
+ setup_debug_env ()
109
+
110
+ run_vit_encoding (validate_accuracy = True )
0 commit comments