@@ -131,6 +131,23 @@ def test_draw_boxes_with_coloured_labels():
131
131
assert_equal (result , expected )
132
132
133
133
134
+ @pytest .mark .skipif (PILLOW_VERSION < (10 , 1 ), reason = "The reference image is only valid for PIL >= 10.1" )
135
+ def test_draw_boxes_with_coloured_label_backgrounds ():
136
+ img = torch .full ((3 , 100 , 100 ), 255 , dtype = torch .uint8 )
137
+ labels = ["a" , "b" , "c" , "d" ]
138
+ colors = ["green" , "#FF00FF" , (0 , 255 , 0 ), "red" ]
139
+ label_colors = ["green" , "red" , (0 , 255 , 0 ), "#FF00FF" ]
140
+ result = utils .draw_bounding_boxes (
141
+ img , boxes , labels = labels , colors = colors , fill = True , label_colors = label_colors , fill_labels = True
142
+ )
143
+
144
+ path = os .path .join (
145
+ os .path .dirname (os .path .abspath (__file__ )), "assets" , "fakedata" , "draw_boxes_different_label_fill_colors.png"
146
+ )
147
+ expected = torch .as_tensor (np .array (Image .open (path ))).permute (2 , 0 , 1 )
148
+ assert_equal (result , expected )
149
+
150
+
134
151
@pytest .mark .parametrize ("fill" , [True , False ])
135
152
def test_draw_boxes_dtypes (fill ):
136
153
img_uint8 = torch .full ((3 , 100 , 100 ), 255 , dtype = torch .uint8 )
0 commit comments