|
| 1 | +import unittest |
| 2 | +import csv |
| 3 | +import numpy as np |
| 4 | +from io import StringIO |
| 5 | +from unittest.mock import Mock, patch |
| 6 | +from djl_python.inputs import Input |
| 7 | +from djl_python.outputs import Output |
| 8 | +from djl_python.encode_decode import decode, encode, decode_csv, encode_csv |
| 9 | + |
| 10 | + |
| 11 | +class TestEncodeDecode(unittest.TestCase): |
| 12 | + |
| 13 | + def setUp(self): |
| 14 | + self.mock_input = Mock(spec=Input) |
| 15 | + self.mock_output = Mock(spec=Output) |
| 16 | + self.mock_output.add_as_string = Mock() |
| 17 | + self.mock_output.add_as_json = Mock() |
| 18 | + self.mock_output.add_as_numpy = Mock() |
| 19 | + self.mock_output.add_as_npz = Mock() |
| 20 | + self.mock_output.add_property = Mock() |
| 21 | + |
| 22 | + def test_decode_csv_valid_inputs_header(self): |
| 23 | + csv_content = "inputs,other\ntest input 1,value1\ntest input 2,value2" |
| 24 | + self.mock_input.get_as_string.return_value = csv_content |
| 25 | + |
| 26 | + result = decode_csv(self.mock_input) |
| 27 | + |
| 28 | + expected = {"inputs": ["test input 1", "test input 2"]} |
| 29 | + self.assertEqual(result, expected) |
| 30 | + |
| 31 | + def test_decode_csv_valid_question_header(self): |
| 32 | + csv_content = "question,context\nWhat is AI?,Technology context\nWhat is ML?,Machine learning context" |
| 33 | + self.mock_input.get_as_string.return_value = csv_content |
| 34 | + |
| 35 | + result = decode_csv(self.mock_input) |
| 36 | + |
| 37 | + expected = { |
| 38 | + "inputs": [{ |
| 39 | + "question": "What is AI?", |
| 40 | + "context": "Technology context" |
| 41 | + }, { |
| 42 | + "question": "What is ML?", |
| 43 | + "context": "Machine learning context" |
| 44 | + }] |
| 45 | + } |
| 46 | + self.assertEqual(result, expected) |
| 47 | + |
| 48 | + def test_decode_csv_invalid_header(self): |
| 49 | + csv_content = "invalid,header\nvalue1,value2" |
| 50 | + self.mock_input.get_as_string.return_value = csv_content |
| 51 | + |
| 52 | + with self.assertRaises(ValueError) as context: |
| 53 | + decode_csv(self.mock_input) |
| 54 | + |
| 55 | + self.assertIn("correct CSV with Header columns", |
| 56 | + str(context.exception)) |
| 57 | + |
| 58 | + def test_encode_csv_list_of_dicts(self): |
| 59 | + content = [{ |
| 60 | + "name": "Axl", |
| 61 | + "age": "30" |
| 62 | + }, { |
| 63 | + "name": "Fiona", |
| 64 | + "age": "25" |
| 65 | + }] |
| 66 | + |
| 67 | + result = encode_csv(content) |
| 68 | + |
| 69 | + reader = csv.DictReader(StringIO(result)) |
| 70 | + rows = list(reader) |
| 71 | + self.assertEqual(len(rows), 2) |
| 72 | + self.assertEqual(rows[0]["name"], "Axl") |
| 73 | + self.assertEqual(rows[1]["name"], "Fiona") |
| 74 | + |
| 75 | + def test_encode_csv_single_dict(self): |
| 76 | + content = [{"name": "Axl", "age": "30"}] |
| 77 | + |
| 78 | + result = encode_csv(content) |
| 79 | + |
| 80 | + reader = csv.DictReader(StringIO(result)) |
| 81 | + rows = list(reader) |
| 82 | + self.assertEqual(len(rows), 1) |
| 83 | + self.assertEqual(rows[0]["name"], "Axl") |
| 84 | + |
| 85 | + def test_decode_no_content_type(self): |
| 86 | + self.mock_input.get_as_bytes.return_value = None |
| 87 | + |
| 88 | + result = decode(self.mock_input, None) |
| 89 | + |
| 90 | + expected = {"inputs": ""} |
| 91 | + self.assertEqual(result, expected) |
| 92 | + |
| 93 | + def test_decode_no_content_type_with_json(self): |
| 94 | + test_data = {"test": "data"} |
| 95 | + self.mock_input.get_as_bytes.return_value = b'{"test": "data"}' |
| 96 | + self.mock_input.get_as_json.return_value = test_data |
| 97 | + |
| 98 | + result = decode(self.mock_input, None) |
| 99 | + |
| 100 | + self.assertEqual(result, test_data) |
| 101 | + |
| 102 | + def test_decode_application_json(self): |
| 103 | + test_data = {"message": "hello"} |
| 104 | + self.mock_input.get_as_json.return_value = test_data |
| 105 | + |
| 106 | + result = decode(self.mock_input, "application/json") |
| 107 | + |
| 108 | + self.assertEqual(result, test_data) |
| 109 | + |
| 110 | + def test_decode_text_csv(self): |
| 111 | + csv_content = "inputs\ntest input" |
| 112 | + self.mock_input.get_as_string.return_value = csv_content |
| 113 | + |
| 114 | + with patch('djl_python.encode_decode.decode_csv') as mock_decode_csv: |
| 115 | + mock_decode_csv.return_value = {"inputs": ["test input"]} |
| 116 | + result = decode(self.mock_input, "text/csv") |
| 117 | + |
| 118 | + mock_decode_csv.assert_called_once_with(self.mock_input) |
| 119 | + self.assertEqual(result, {"inputs": ["test input"]}) |
| 120 | + |
| 121 | + def test_decode_text_plain(self): |
| 122 | + text_content = "Hello world" |
| 123 | + self.mock_input.get_as_string.return_value = text_content |
| 124 | + |
| 125 | + result = decode(self.mock_input, "text/plain") |
| 126 | + |
| 127 | + expected = {"inputs": ["Hello world"]} |
| 128 | + self.assertEqual(result, expected) |
| 129 | + |
| 130 | + def test_decode_image_content_type(self): |
| 131 | + image_data = b"fake_image_data" |
| 132 | + self.mock_input.get_as_image.return_value = image_data |
| 133 | + |
| 134 | + result = decode(self.mock_input, "image/jpeg") |
| 135 | + |
| 136 | + expected = {"inputs": image_data} |
| 137 | + self.assertEqual(result, expected) |
| 138 | + |
| 139 | + def test_decode_audio_content_type(self): |
| 140 | + audio_data = b"fake_audio_data" |
| 141 | + self.mock_input.get_as_bytes.return_value = audio_data |
| 142 | + |
| 143 | + result = decode(self.mock_input, "audio/wav") |
| 144 | + |
| 145 | + expected = {"inputs": audio_data} |
| 146 | + self.assertEqual(result, expected) |
| 147 | + |
| 148 | + def test_decode_tensor_npz(self): |
| 149 | + tensor_data = [np.array([1, 2, 3])] |
| 150 | + self.mock_input.get_as_npz.return_value = tensor_data |
| 151 | + |
| 152 | + result = decode(self.mock_input, "tensor/npz") |
| 153 | + |
| 154 | + expected = {"inputs": tensor_data} |
| 155 | + self.assertEqual(result, expected) |
| 156 | + |
| 157 | + def test_decode_tensor_ndlist(self): |
| 158 | + tensor_data = [np.array([1, 2, 3])] |
| 159 | + self.mock_input.get_as_numpy.return_value = tensor_data |
| 160 | + |
| 161 | + result = decode(self.mock_input, "tensor/ndlist") |
| 162 | + |
| 163 | + expected = {"inputs": tensor_data} |
| 164 | + self.assertEqual(result, expected) |
| 165 | + |
| 166 | + def test_decode_application_x_npy(self): |
| 167 | + tensor_data = [np.array([1, 2, 3])] |
| 168 | + self.mock_input.get_as_numpy.return_value = tensor_data |
| 169 | + |
| 170 | + result = decode(self.mock_input, "application/x-npy") |
| 171 | + |
| 172 | + expected = {"inputs": tensor_data} |
| 173 | + self.assertEqual(result, expected) |
| 174 | + |
| 175 | + def test_decode_form_urlencoded(self): |
| 176 | + form_data = "key1=value1&key2=value2" |
| 177 | + self.mock_input.get_as_string.return_value = form_data |
| 178 | + |
| 179 | + result = decode(self.mock_input, "application/x-www-form-urlencoded") |
| 180 | + |
| 181 | + expected = {"inputs": form_data} |
| 182 | + self.assertEqual(result, expected) |
| 183 | + |
| 184 | + def test_decode_octet_stream(self): |
| 185 | + binary_data = b"binary_data" |
| 186 | + self.mock_input.get_as_bytes.return_value = binary_data |
| 187 | + |
| 188 | + result = decode(self.mock_input, "application/octet-stream") |
| 189 | + |
| 190 | + expected = {"inputs": binary_data} |
| 191 | + self.assertEqual(result, expected) |
| 192 | + |
| 193 | + def test_decode_with_key(self): |
| 194 | + test_data = {"test": "data"} |
| 195 | + self.mock_input.get_as_json.return_value = test_data |
| 196 | + |
| 197 | + result = decode(self.mock_input, "application/json", key="test_key") |
| 198 | + |
| 199 | + self.mock_input.get_as_json.assert_called_with(key="test_key") |
| 200 | + self.assertEqual(result, test_data) |
| 201 | + |
| 202 | + def test_encode_default_json(self): |
| 203 | + prediction = {"result": "success"} |
| 204 | + |
| 205 | + encode(self.mock_output, prediction, None) |
| 206 | + |
| 207 | + self.mock_output.add_as_json.assert_called_once_with(prediction, |
| 208 | + key=None) |
| 209 | + self.mock_output.add_property.assert_called_once_with( |
| 210 | + "Content-Type", "application/json") |
| 211 | + |
| 212 | + def test_encode_application_json(self): |
| 213 | + prediction = {"result": "success"} |
| 214 | + |
| 215 | + encode(self.mock_output, prediction, "application/json") |
| 216 | + |
| 217 | + self.mock_output.add_as_json.assert_called_once_with(prediction, |
| 218 | + key=None) |
| 219 | + self.mock_output.add_property.assert_called_once_with( |
| 220 | + "Content-Type", "application/json") |
| 221 | + |
| 222 | + def test_encode_text_csv(self): |
| 223 | + prediction = [{"name": "Axl", "age": "30"}] |
| 224 | + |
| 225 | + with patch('djl_python.encode_decode.encode_csv') as mock_encode_csv: |
| 226 | + mock_encode_csv.return_value = "name,age\nAxl,30\n" |
| 227 | + encode(self.mock_output, prediction, "text/csv") |
| 228 | + |
| 229 | + mock_encode_csv.assert_called_once_with(prediction) |
| 230 | + self.mock_output.add_as_string.assert_called_once_with( |
| 231 | + "name,age\nAxl,30\n", key=None) |
| 232 | + self.mock_output.add_property.assert_called_once_with( |
| 233 | + "Content-Type", "text/csv") |
| 234 | + |
| 235 | + def test_encode_tensor_npz(self): |
| 236 | + prediction = [np.array([1, 2, 3])] |
| 237 | + |
| 238 | + encode(self.mock_output, prediction, "tensor/npz") |
| 239 | + |
| 240 | + self.mock_output.add_as_npz.assert_called_once_with(prediction, |
| 241 | + key=None) |
| 242 | + self.mock_output.add_property.assert_called_once_with( |
| 243 | + "Content-Type", "tensor/npz") |
| 244 | + |
| 245 | + def test_encode_other_content_type(self): |
| 246 | + prediction = [np.array([1, 2, 3])] |
| 247 | + |
| 248 | + encode(self.mock_output, prediction, "custom/type") |
| 249 | + |
| 250 | + self.mock_output.add_as_numpy.assert_called_once_with(prediction, |
| 251 | + key=None) |
| 252 | + self.mock_output.add_property.assert_called_once_with( |
| 253 | + "Content-Type", "tensor/ndlist") |
| 254 | + |
| 255 | + def test_encode_with_key(self): |
| 256 | + prediction = {"result": "success"} |
| 257 | + |
| 258 | + encode(self.mock_output, |
| 259 | + prediction, |
| 260 | + "application/json", |
| 261 | + key="test_key") |
| 262 | + |
| 263 | + self.mock_output.add_as_json.assert_called_once_with(prediction, |
| 264 | + key="test_key") |
| 265 | + |
| 266 | + |
| 267 | +if __name__ == '__main__': |
| 268 | + unittest.main() |
0 commit comments