Skip to content

Commit 8101ae3

Browse files
committed
Merge branch 'fix-older-python-compatibility'
2 parents e1eb199 + 31b743c commit 8101ae3

File tree

1 file changed

+12
-11
lines changed

1 file changed

+12
-11
lines changed

drpai/ei2gst_drpai.py

+12-11
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import tflite_runtime.interpreter as tflite
55
import numpy as np
66
from loguru import logger as logging
7+
from typing import Tuple, List
78

89

910
def csv_2_bytearray(s: str) -> bytearray:
@@ -32,7 +33,7 @@ def csv_2_bytearray(s: str) -> bytearray:
3233
return r
3334

3435

35-
def get_grid_anchors(interpreter: tflite.Interpreter, grids: list[int]) -> tuple[int, list]:
36+
def get_grid_anchors(interpreter: tflite.Interpreter, grids: List[int]) -> Tuple[int, list]:
3637
"""
3738
Retrieves grid anchors from the TensorFlow Lite interpreter.
3839
@@ -116,8 +117,8 @@ def __arrayname_2_filename(self, array_name: str) -> str:
116117
>>> self.__arrayname_2_filename('unsigned char ei_ei_addrmap_intm_txt[] = {')
117118
'model/model_addrmap_intm.txt'
118119
"""
119-
array_name = array_name.removeprefix("unsigned char ")
120-
array_name = array_name.removeprefix("ei_")
120+
array_name = array_name.replace("unsigned char ", "", 1)
121+
array_name = array_name.replace("ei_", "", 1)
121122
array_name = array_name.replace("ei_", f"{self.model_name}_")
122123
array_name = array_name.replace("[]", "")
123124
array_name = array_name.replace("=", "")
@@ -170,7 +171,7 @@ def gen_drpai_model_files(self):
170171
# Store it in the `var_list` dictionary.
171172
line_sections = line.split(" ")
172173
key = line_sections[-3]
173-
value = line_sections[-1].removesuffix("\n").removesuffix(";")
174+
value = line_sections[-1].replace("\n", "").replace(";", "")
174175
self.var_list[key] = value
175176
if key.endswith("_len"):
176177
# Ensure the length of the last array is correct.
@@ -187,7 +188,7 @@ def gen_drpai_model_files(self):
187188
gc.collect()
188189
else:
189190
# The C array has not ended yet, so convert hex values to binary and append.
190-
self.var_list[output_file_path] += csv_2_bytearray(line.removesuffix("\n"))
191+
self.var_list[output_file_path] += csv_2_bytearray(line.replace("\n", ""))
191192

192193
def read_variables(self):
193194
"""
@@ -224,7 +225,7 @@ def read_variables(self):
224225
if len(line_sections) >= 2:
225226
# The line probably looks like `type name;`, `type * name;`, `type *name;` or `} name;`
226227
# We only care about the last word which is a name.
227-
name = line_sections[-1].removeprefix("*").removesuffix(";")
228+
name = line_sections[-1].replace("*", "").replace(";", "")
228229
if line_sections[0] == "}":
229230
# The struct definition is finished
230231
# Let's add all the variable names to the dictionary with a prefix of the struct name.
@@ -239,7 +240,7 @@ def read_variables(self):
239240
file_path = f"{self.working_directory}/model-parameters/model_variables.h"
240241
logging.info("Reading file: " + file_path)
241242
struct_variables = list() # A list of variable names when they are grouped in a structure
242-
with (open(file_path, "rt") as f):
243+
with open(file_path, "rt") as f:
243244
# Read the C header file line by line.
244245
for line in f:
245246
# By splitting the line into words, we can look for C language keywords
@@ -255,7 +256,7 @@ def read_variables(self):
255256
if "[]" in line:
256257
# It is defining an array
257258
# Let's extract its contents into a list and save it in the `var_list`
258-
name = line_sections[line_sections.index("=")-1].removesuffix("[]")
259+
name = line_sections[line_sections.index("=")-1].replace("[]", "")
259260
value = line[line.find("{")+1: line.find("}")]
260261
value = value.replace("\"", "").replace(" ", "")
261262
self.var_list[name] = value.split(",")
@@ -280,13 +281,13 @@ def read_variables(self):
280281
if line_sections[0] == "const":
281282
# Delete the const keyword
282283
del line_sections[0]
283-
self.var_list[line_sections[1]] = line_sections[-1].removesuffix(";")
284+
self.var_list[line_sections[1]] = line_sections[-1].replace(";", "")
284285
else:
285286
# We are in the middle of a struct initialization.
286287
# The order of values match the order of variable names at the declaration time.
287288
# Variable names are already selected in `struct_variables`.
288289
# Let's assign values to the first item and remove. It will eventually become empty.
289-
value = line_sections[0].removesuffix(",").removeprefix("\"").removesuffix("\"")
290+
value = line_sections[0].replace(",", "").replace("\"", "").replace("\"", "")
290291
self.var_list[struct_variables[0]] = value
291292
del struct_variables[0]
292293

@@ -408,7 +409,7 @@ def gen_postprocess_params_txt(self):
408409
model_version = "5"
409410

410411
# Retrieve and validate the IoU threshold from var_list dictionary
411-
iou_threshold = self.var_list["ei_object_detection_nms_config_t_iou_threshold"].removesuffix("f")
412+
iou_threshold = self.var_list["ei_object_detection_nms_config_t_iou_threshold"][:-1]
412413
assert float(iou_threshold) >= 0, \
413414
"The loaded model_variables.h doesn't have the required output iou_threshold."
414415

0 commit comments

Comments
 (0)