Skip to content

Commit f266d96

Browse files
Add Cockatiel explainability method
Signed-off-by: Frederic Boisnard <[email protected]>
1 parent cc63413 commit f266d96

File tree

9 files changed

+1012
-0
lines changed

9 files changed

+1012
-0
lines changed
+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
"""
2+
Test object detection BoundingBoxesExplainer
3+
"""
4+
import numpy as np
5+
6+
from xplique.attributions import NlpOcclusion
7+
8+
def test_masks():
9+
"""Test the masks creation"""
10+
sentence = "aaa bbb ccc"
11+
words = sentence.split(" ")
12+
masks = NlpOcclusion._get_masks(words)
13+
assert masks.shape == (len(words), len(words))
14+
expected_mask = np.array([[False, True, True],
15+
[True, False, True],
16+
[True, True, False]])
17+
18+
assert np.array_equal(masks, expected_mask)
19+
20+
def test_apply_masks():
21+
"""Test if the application of a mask generate valid results"""
22+
sentence = "aaa bbb ccc"
23+
words = sentence.split(" ")
24+
masks = NlpOcclusion._get_masks(words)
25+
26+
occluded_inputs = NlpOcclusion._apply_masks(words, masks)
27+
expected_occludec_inputs = [['bbb', 'ccc'], ['aaa', 'ccc'], ['aaa', 'bbb']]
28+
assert np.array_equal(occluded_inputs, expected_occludec_inputs)
29+
30+
def test_output_shape():
31+
"""Test the output shape for several input sentences"""
32+
33+
nb_concepts = 10
34+
35+
def transform(inputs):
36+
# simulate the transorm method used in Craft/Cockatiel
37+
return np.ones((len(inputs), nb_concepts))
38+
39+
input_sentence = ["aaa bbb ccc ddd eee fff", "ggg hhh iii jjj"]
40+
for sentence in input_sentence:
41+
words = sentence.split(" ")
42+
separator = " "
43+
44+
method = NlpOcclusion(model=transform)
45+
sensitivity = method.explain(sentence, words, separator)
46+
47+
assert sensitivity.shape == (nb_concepts, len(words))

tests/nlp/__init__.py

Whitespace-only changes.

tests/nlp/test_token_extractor.py

+143
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
from xplique.commons.nlp import WordExtractor, SentenceExtractor
2+
from xplique.commons.nlp import ClauseExtractor, ExcerptExtractor, ExtractorFactory
3+
4+
import pytest
5+
6+
@pytest.fixture
7+
def example_sentence():
8+
return "One two three. Second sentence.Third Sentence, test1, test2; test3-test4 .GO!"\
9+
" Trust me,. sentence not starting with capital letter."\
10+
" Sentence with dots..Word, Word, word,...word ....so a sentence"
11+
12+
def test_word_extractor(example_sentence):
13+
extractor = WordExtractor()
14+
tokens, separator = extractor.extract_tokens(example_sentence)
15+
assert isinstance(tokens, list)
16+
assert isinstance(separator, str)
17+
assert separator == ' '
18+
expected_tokens = [ 'One', 'two', 'three', '.',
19+
'Second', 'sentence.Third', 'Sentence', ',', 'test1', ',', 'test2', ';',
20+
'test3-test4', '.GO', '!', 'Trust', 'me', ',', '.', 'sentence', 'not',
21+
'starting','with', 'capital', 'letter', '.', 'Sentence', 'with', 'dots',
22+
'..', 'Word', ',', 'Word', ',', 'word', ',', '...', 'word', '....', 'so',
23+
'a', 'sentence']
24+
assert tokens == expected_tokens, print('tokens:', tokens)
25+
26+
def test_word_extractor_ignore_words(example_sentence):
27+
extractor = WordExtractor(ignore_words = ['me', 'not', 'so', 'a',])
28+
tokens, separator = extractor.extract_tokens(example_sentence)
29+
assert isinstance(tokens, list)
30+
assert isinstance(separator, str)
31+
assert separator == ' '
32+
expected_tokens = [ 'One', 'two', 'three', '.',
33+
'Second', 'sentence.Third', 'Sentence', ',', 'test1', ',', 'test2', ';',
34+
'test3-test4', '.GO', '!', 'Trust', ',', '.', 'sentence', 'starting',
35+
'with','capital', 'letter', '.', 'Sentence', 'with', 'dots', '..',
36+
'Word', ',', 'Word', ',', 'word', ',', '...', 'word', '....', 'sentence']
37+
assert tokens == expected_tokens, print('tokens:', tokens)
38+
39+
def test_word_extractor_from_list(example_sentence):
40+
extractor = WordExtractor()
41+
tokens, separator = extractor.extract_tokens([example_sentence, example_sentence])
42+
assert isinstance(tokens, list)
43+
assert isinstance(separator, str)
44+
assert separator == ' '
45+
46+
def test_sentence_extractor(example_sentence):
47+
extractor = SentenceExtractor()
48+
tokens, separator = extractor.extract_tokens(example_sentence)
49+
assert isinstance(tokens, list)
50+
assert isinstance(separator, str)
51+
assert separator == '. '
52+
expected_tokens = [ 'One two three.',
53+
'Second sentence.Third Sentence, test1, test2; test3-test4 .GO!',
54+
'Trust me,.',
55+
'sentence not starting with capital letter.',
56+
'Sentence with dots..Word, Word, word,...word ....so a sentence']
57+
assert tokens == expected_tokens, print('tokens:', tokens)
58+
59+
def test_excerpt_extractor(example_sentence):
60+
extractor = ExcerptExtractor()
61+
tokens, separator = extractor.extract_tokens(example_sentence)
62+
assert isinstance(tokens, list)
63+
assert isinstance(separator, str)
64+
assert separator == ' '
65+
expected_tokens = [ 'One two three.',
66+
'Second sentence.',
67+
'Third Sentence, test1, test2; test3-test4 .',
68+
'GO!',
69+
'Trust me,.',
70+
'Sentence with dots.',
71+
'Word, Word, word,.']
72+
assert tokens == expected_tokens, print('tokens:', tokens)
73+
74+
def test_clause_extractor_close_type_none(example_sentence):
75+
clause_extractor = ClauseExtractor(clause_type = None)
76+
tokens, separator = clause_extractor.extract_tokens(example_sentence)
77+
assert isinstance(tokens, list)
78+
assert isinstance(separator, str)
79+
expected_tokens = [ 'One two three',
80+
'Second sentence.Third Sentence',
81+
'test1',
82+
'test2',
83+
'test3-test4',
84+
'GO',
85+
'Trust',
86+
'me',
87+
'sentence',
88+
'not starting',
89+
'with',
90+
'capital letter',
91+
'Sentence',
92+
'with',
93+
'dots.',
94+
'Word',
95+
'Word, word,...word',
96+
'a sentence']
97+
assert tokens == expected_tokens
98+
99+
100+
def test_clause_extractor_close_type_NP(example_sentence):
101+
clause_extractor = ClauseExtractor(clause_type = ['NP'])
102+
tokens, separator = clause_extractor.extract_tokens(example_sentence)
103+
assert isinstance(tokens, list)
104+
assert isinstance(separator, str)
105+
expected_tokens = [ 'One two three',
106+
'Second sentence.Third Sentence',
107+
'test1',
108+
'test2',
109+
'test3-test4',
110+
'me',
111+
'sentence',
112+
'capital letter',
113+
'Sentence',
114+
'dots.',
115+
'Word',
116+
'Word, word,...word',
117+
'a sentence']
118+
assert tokens == expected_tokens
119+
120+
def test_clause_extractor_close_type_ADJP(example_sentence):
121+
clause_extractor = ClauseExtractor(clause_type = ['ADJP'])
122+
tokens, separator = clause_extractor.extract_tokens(example_sentence)
123+
assert isinstance(tokens, list)
124+
assert isinstance(separator, str)
125+
print(tokens)
126+
expected_tokens = []
127+
assert tokens == expected_tokens
128+
129+
def test_extractor_factory():
130+
word_extractor = ExtractorFactory.get_extractor(extract_fct="word")
131+
assert isinstance(word_extractor, WordExtractor)
132+
133+
sentence_extractor = ExtractorFactory.get_extractor(extract_fct="sentence")
134+
assert isinstance(sentence_extractor, SentenceExtractor)
135+
136+
excerpt_extractor = ExtractorFactory.get_extractor(extract_fct="excerpt")
137+
assert isinstance(excerpt_extractor, ExcerptExtractor)
138+
139+
clause_extractor = ExtractorFactory.get_extractor(extract_fct="clause", clause_type=['NP'])
140+
assert isinstance(clause_extractor, ClauseExtractor)
141+
142+
with pytest.raises(ValueError):
143+
ExtractorFactory.get_extractor(extract_fct="invalid")

xplique/attributions/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,5 @@
1616
from .object_detector import BoundingBoxesExplainer
1717
from .global_sensitivity_analysis import SobolAttributionMethod, HsicAttributionMethod
1818
from .gradient_statistics import SmoothGrad, VarGrad, SquareGrad
19+
from .nlp_occlusion import NlpOcclusion
1920
from . import global_sensitivity_analysis

xplique/attributions/nlp_occlusion.py

+106
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
"""
2+
Module related to Occlusion sensitivity method for NLP.
3+
"""
4+
5+
import numpy as np
6+
7+
from .base import BlackBoxExplainer
8+
from ..commons import Tasks
9+
from ..types import Callable, Union, Optional, OperatorSignature, List
10+
11+
class NlpOcclusion(BlackBoxExplainer):
12+
"""
13+
Occlusion class for NLP.
14+
"""
15+
def __init__(self,
16+
model: Callable,
17+
batch_size: Optional[int] = 32,
18+
operator: Optional[Union[Tasks, str, OperatorSignature]] = None):
19+
super().__init__(model, batch_size, operator)
20+
21+
@staticmethod
22+
def _get_masks(input_len: int) -> np.ndarray:
23+
"""
24+
Generate occlusion masks for a given input length.
25+
26+
Parameters
27+
----------
28+
input_len : int
29+
The length of the input for which occlusion masks are generated.
30+
Typically it will be the number of words of a sentence.
31+
32+
Returns
33+
-------
34+
occlusion_masks : np.ndarray
35+
The boolean occlusion masks, an identity matrix with False for the main diagonal.
36+
This kind of mask can be used to generate n sentences,
37+
each with a single distinct word removed.
38+
"""
39+
return np.eye(input_len) == 0
40+
41+
@staticmethod
42+
def _apply_masks(words: List[str], masks: np.ndarray) -> np.ndarray:
43+
"""
44+
Apply occlusion masks to a list of words.
45+
46+
Parameters
47+
----------
48+
words : List[str]
49+
The list of words to which occlusion masks are applied.
50+
masks : np.ndarray
51+
The boolean occlusion masks to be applied.
52+
53+
Returns
54+
-------
55+
occluded_words : np.ndarray
56+
The list of words with occlusion masks applied.
57+
"""
58+
perturbated_words = [np.array(words)[mask].tolist() for mask in masks]
59+
return perturbated_words
60+
61+
def explain(self,
62+
sentence: str,
63+
words: List[str],
64+
separator: str) -> np.ndarray:
65+
"""
66+
Generate an explanation for the input sentence, by providing the importance of each word.
67+
The importance will be computed by successively occluding each word of the sentence and
68+
studying the impact of this occlusion on the model results.
69+
70+
Parameters
71+
----------
72+
sentence : str
73+
The input sentence for which an explanation is generated.
74+
words : List[str]
75+
List of words used to generate the explanation. These words must be part of
76+
the input sentence, the importance will be computed on this list of words
77+
(i.e some words of the original sentence can be omited this way).
78+
separator : str
79+
The separator used to join the words after the occlusion step, so a full
80+
sentence can be fed to the model.
81+
82+
Returns
83+
-------
84+
explanation : np.ndarray
85+
The generated explanation of format (nb_concepts, nb_words).
86+
"""
87+
88+
# generate n sentences with a different word masked (removed) each time
89+
masks = NlpOcclusion._get_masks(len(words))
90+
perturbated_words = NlpOcclusion._apply_masks(words, masks)
91+
92+
perturbated_sentences = [sentence]
93+
perturbated_sentences.extend(
94+
[separator.join(perturbated_word) for perturbated_word in perturbated_words])
95+
96+
# transform the perturbated reviews into their concept representation
97+
# u_values has shape: ((W+1) x C)
98+
u_values = self.model(perturbated_sentences)
99+
100+
# Compute sensitivities: importances = u_value of the whole sentence - u_value of each word
101+
whole_sentence_uvalues = u_values[0,:]
102+
words_uvalues = u_values[1:,:]
103+
l_importances = (whole_sentence_uvalues - words_uvalues).transpose()
104+
l_importances /= (np.max(np.abs(l_importances)) + 1e-5)
105+
106+
return l_importances

xplique/commons/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@
1111
get_inference_function, get_gradient_functions)
1212
from .exceptions import no_gradients_available, raise_invalid_operator
1313
from .forgrad import forgrad
14+
from .nlp import TokenExtractor, WordExtractor, SentenceExtractor, ClauseExtractor, ExcerptExtractor, ExtractorFactory

0 commit comments

Comments
 (0)