Skip to content

Commit cf9dd80

Browse files
committed
Fixes
Signed-off-by: Krzysztof Lecki <[email protected]>
1 parent d324b97 commit cf9dd80

File tree

2 files changed

+18
-14
lines changed

2 files changed

+18
-14
lines changed

dali/test/python/test_operator_arithmetic_ops.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -544,8 +544,11 @@ def test_arithmetic_ops():
544544
def test_ternary_ops_big():
545545
for kinds in selected_ternary_input_kinds:
546546
for (op, op_desc) in ternary_operations:
547-
for types_in in [(np.int32, np.int32, np.int32), (np.int32, np.int8, np.int16),
548-
(np.int32, np.uint8, np.float32)]:
547+
for types_in in [
548+
(np.int32, np.int32, np.int32),
549+
(np.int32, np.int8, np.int16),
550+
(np.int32, np.uint8, np.float32),
551+
]:
549552
yield check_ternary_op, kinds, types_in, op, shape_big, op_desc
550553

551554

@@ -565,9 +568,12 @@ def test_ternary_ops_selected():
565568
def test_ternary_ops_kinds():
566569
for kinds in ternary_input_kinds:
567570
for (op, op_desc) in ternary_operations:
568-
for types_in in [(np.int32, np.int32, np.int32), (np.float32, np.int32, np.int32),
569-
(np.uint8, np.float32, np.float32),
570-
(np.int32, np.float32, np.float32)]:
571+
for types_in in [
572+
(np.int32, np.int32, np.int32),
573+
(np.float32, np.int32, np.int32),
574+
(np.uint8, np.float32, np.float32),
575+
(np.int32, np.float32, np.float32),
576+
]:
571577
yield check_ternary_op, kinds, types_in, op, shape_small, op_desc
572578

573579

@@ -600,8 +606,6 @@ def test_bitwise_ops():
600606

601607
def check_comparsion_op(kinds, types, op, shape, _):
602608
# Comparisons - should always return bool
603-
left_type, right_type = types
604-
left_kind, right_kind = kinds
605609
iterator = iter(ExternalInputIterator(batch_size, shape, types, kinds))
606610
pipe = ExprOpPipeline(kinds, types, iterator, op, batch_size=batch_size, num_threads=2,
607611
device_id=0)

dali/test/python/test_operator_readers_webdataset_requirements.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def test_return_empty():
3434
extract_dir = generate_temp_extract(tar_file_path)
3535
equivalent_files = glob(extract_dir.name + "/*")
3636
equivalent_files = sorted(equivalent_files,
37-
key=(lambda s: int(s[s.rfind("/") + 1:s.rfind(".")])))
37+
key=(lambda s: int(s[s.rfind("/") + 1 : s.rfind(".")]))) # noqa: 203
3838

3939
compare_pipelines(
4040
webdataset_raw_pipeline(
@@ -62,9 +62,9 @@ def test_skip_sample():
6262
extract_dir = generate_temp_extract(tar_file_path)
6363
equivalent_files = list(
6464
filter(
65-
lambda s: int(s[s.rfind("/") + 1:s.rfind(".")]) < 2500,
65+
lambda s: int(s[s.rfind("/") + 1 : s.rfind(".")]) < 2500, # noqa: 203
6666
sorted(glob(extract_dir.name + "/*"),
67-
key=lambda s: int(s[s.rfind("/") + 1:s.rfind(".")])),
67+
key=lambda s: int(s[s.rfind("/") + 1 : s.rfind(".")])), # noqa: 203
6868
))
6969

7070
compare_pipelines(
@@ -120,7 +120,7 @@ def test_different_components():
120120
extract_dir = generate_temp_extract(tar_file_path)
121121
equivalent_files = glob(extract_dir.name + "/*")
122122
equivalent_files = sorted(equivalent_files,
123-
key=(lambda s: int(s[s.rfind("/") + 1:s.rfind(".")])))
123+
key=(lambda s: int(s[s.rfind("/") + 1 : s.rfind(".")]))) # noqa: 203
124124

125125
compare_pipelines(
126126
webdataset_raw_pipeline(
@@ -176,7 +176,7 @@ def test_wds_sharding():
176176
equivalent_files = sum(
177177
list(
178178
sorted(glob(extract_dir.name +
179-
"/*"), key=lambda s: int(s[s.rfind("/") + 1:s.rfind(".")]))
179+
"/*"), key=lambda s: int(s[s.rfind("/") + 1 : s.rfind(".")])) # noqa: 203
180180
for extract_dir in extract_dirs),
181181
[],
182182
)
@@ -210,7 +210,7 @@ def test_sharding():
210210

211211
extract_dir = generate_temp_extract(tar_file_path)
212212
equivalent_files = sorted(glob(extract_dir.name + "/*"),
213-
key=lambda s: int(s[s.rfind("/") + 1:s.rfind(".")]))
213+
key=lambda s: int(s[s.rfind("/") + 1 : s.rfind(".")])) # noqa: 203
214214

215215
num_shards = 100
216216
for shard_id in range(num_shards):
@@ -287,7 +287,7 @@ def test_index_generation():
287287
equivalent_files = sum(
288288
list(
289289
sorted(glob(extract_dir.name +
290-
"/*"), key=lambda s: int(s[s.rfind("/") + 1:s.rfind(".")]))
290+
"/*"), key=lambda s: int(s[s.rfind("/") + 1 : s.rfind(".")])) # noqa: 203
291291
for extract_dir in extract_dirs),
292292
[],
293293
)

0 commit comments

Comments
 (0)