Skip to content

Commit f213efc

Browse files
add suggested updates
1 parent 9bac18f commit f213efc

File tree

5 files changed

+112
-457
lines changed

5 files changed

+112
-457
lines changed

conftest.py

Lines changed: 66 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
11
import os
2-
from pathlib import Path
32

43
import keras
54
import pytest
65

7-
from keras_hub.src.utils.openvino_utils import get_openvino_skip_reason
8-
from keras_hub.src.utils.openvino_utils import setup_openvino_test_config
9-
106

117
def pytest_addoption(parser):
128
parser.addoption(
@@ -33,16 +29,27 @@ def pytest_addoption(parser):
3329
default=False,
3430
help="fail if a gpu is not present",
3531
)
36-
parser.addoption(
37-
"--auto_skip_training",
38-
action="store_true",
39-
default=True,
40-
help="automatically skip tests with "
41-
"training methods on non-trainable backends",
42-
)
4332

4433

4534
def pytest_configure(config):
35+
# Monkey-patch training methods for OpenVINO backend
36+
if keras.config.backend() == "openvino":
37+
# Store original methods in case we need to restore them
38+
if not hasattr(keras.Model, "_original_compile"):
39+
keras.Model._original_compile = keras.Model.compile
40+
keras.Model._original_fit = keras.Model.fit
41+
keras.Model._original_train_on_batch = keras.Model.train_on_batch
42+
43+
keras.Model.compile = lambda *args, **kwargs: pytest.skip(
44+
"Model.compile() not supported on OpenVINO backend"
45+
)
46+
keras.Model.fit = lambda *args, **kwargs: pytest.skip(
47+
"Model.fit() not supported on OpenVINO backend"
48+
)
49+
keras.Model.train_on_batch = lambda *args, **kwargs: pytest.skip(
50+
"Model.train_on_batch() not supported on OpenVINO backend"
51+
)
52+
4653
# Verify that device has GPU and detected by backend
4754
if config.getoption("--check_gpu"):
4855
found_gpu = False
@@ -84,12 +91,9 @@ def pytest_configure(config):
8491

8592

8693
def pytest_collection_modifyitems(config, items):
87-
openvino_supported_paths = None
88-
8994
run_extra_large_tests = config.getoption("--run_extra_large")
9095
# Run large tests for --run_extra_large or --run_large.
9196
run_large_tests = config.getoption("--run_large") or run_extra_large_tests
92-
auto_skip_training = config.getoption("--auto_skip_training")
9397

9498
# Messages to annotate skipped tests with.
9599
skip_large = pytest.mark.skipif(
@@ -124,21 +128,58 @@ def pytest_collection_modifyitems(config, items):
124128
if "kaggle_key_required" in item.keywords:
125129
item.add_marker(kaggle_key_required)
126130

127-
# OpenVINO-specific skipping logic - whitelist-based approach
131+
# OpenVINO-specific test skipping
128132
if keras.config.backend() == "openvino":
129-
# OpenVINO backend configuration
130-
if openvino_supported_paths is None:
131-
openvino_supported_paths = setup_openvino_test_config(
132-
str(Path(__file__).parent)
133+
test_name = item.name.split("[")[0]
134+
test_path = str(item.fspath)
135+
136+
# OpenVINO supported test paths
137+
openvino_supported_paths = [
138+
"keras-hub/integration_tests",
139+
"keras_hub/src/models/gemma",
140+
"keras_hub/src/models/gpt2",
141+
"keras_hub/src/models/mistral",
142+
"keras_hub/src/samplers/serialization_test.py",
143+
"keras_hub/src/tests/doc_tests/docstring_test.py",
144+
"keras_hub/src/tokenizers",
145+
"keras_hub/src/utils",
146+
]
147+
148+
# Skip specific problematic test methods
149+
specific_skipping_tests = {
150+
"test_backbone_basics": "Requires trainable backend",
151+
"test_score_loss": "Non-implemented roll operation",
152+
"test_layer_behaviors": "Requires trainable backend",
153+
}
154+
155+
if test_name in specific_skipping_tests:
156+
item.add_marker(
157+
pytest.mark.skipif(
158+
True,
159+
reason="OpenVINO: "
160+
f"{specific_skipping_tests[test_name]}",
161+
)
133162
)
134-
skip_reason = get_openvino_skip_reason(
135-
item,
136-
openvino_supported_paths,
137-
auto_skip_training,
163+
continue
164+
165+
parts = test_path.replace("\\", "/").split("/")
166+
try:
167+
keras_hub_idx = parts.index("keras_hub")
168+
relative_test_path = "/".join(parts[keras_hub_idx:])
169+
except ValueError:
170+
relative_test_path = test_path
171+
172+
is_whitelisted = any(
173+
relative_test_path == supported_path
174+
or relative_test_path.startswith(supported_path + "/")
175+
for supported_path in openvino_supported_paths
138176
)
139-
if skip_reason:
177+
178+
if not is_whitelisted:
140179
item.add_marker(
141-
pytest.mark.skipif(True, reason=f"OpenVINO: {skip_reason}")
180+
pytest.mark.skipif(
181+
True, reason="OpenVINO: File/directory not in whitelist"
182+
)
142183
)
143184

144185

keras_hub/src/samplers/beam_sampler.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,15 +95,15 @@ def unflatten_beams(x):
9595
)
9696
log_probs = flatten_beams(ops.repeat(log_probs, batch_size, axis=0))
9797

98-
def cond(prompt, cache, index, log_probs):
98+
def cond(prompt, cache, index, mask, log_probs):
9999
if stop_token_ids is None:
100100
return True
101101
# Stop if all sequences have produced a *new* stop token.
102102
end_tokens = any_equal(prompt, stop_token_ids, ~mask)
103103
prompt_done = ops.any(end_tokens, axis=-1)
104104
return ops.logical_not(ops.all(prompt_done))
105105

106-
def body(prompt, cache, index, log_probs):
106+
def body(prompt, cache, index, mask, log_probs):
107107
# Compute the softmax distribution for the next token.
108108
logits, _, cache = next(prompt, cache, index)
109109
vocab_size = ops.shape(logits)[-1]
@@ -150,12 +150,12 @@ def gather_beams(x):
150150
next_token = next_token[:, None]
151151
prompt = ops.slice_update(prompt, [0, index], next_token)
152152
# Return the iteration of the loop state.
153-
return (prompt, cache, index + 1, log_probs)
153+
return (prompt, cache, index + 1, mask, log_probs)
154154

155-
prompt, _, _, log_probs = self.run_loop(
155+
prompt, _, _, _, log_probs = self.run_loop(
156156
cond=cond,
157157
body=body,
158-
loop_vars=(prompt, cache, index, log_probs),
158+
loop_vars=(prompt, cache, index, mask, log_probs),
159159
maximum_iterations=(max_length - index),
160160
model=model,
161161
)

0 commit comments

Comments
 (0)