|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | + |
| 3 | + |
| 4 | +import unittest |
| 5 | +import sqlite3 |
| 6 | +import re |
| 7 | + |
| 8 | +import sqlitefts as fts |
| 9 | +from sqlitefts import ranking |
| 10 | + |
| 11 | + |
| 12 | +class Tokenizer(fts.Tokenizer): |
| 13 | + |
| 14 | + _spliter = re.compile(r'\s+|\S+', re.UNICODE) |
| 15 | + _nonws = re.compile(r'\S+', re.UNICODE) |
| 16 | + |
| 17 | + def _normalize(self, token): |
| 18 | + return token.lower() |
| 19 | + |
| 20 | + def _tokenize(self, text): |
| 21 | + pos = 0 |
| 22 | + for t in self._spliter.findall(text): |
| 23 | + byteLen = len(t.encode('utf-8')) |
| 24 | + if self._nonws.match(t): |
| 25 | + yield self._normalize(t), pos, pos + byteLen |
| 26 | + pos += byteLen |
| 27 | + |
| 28 | + def tokenize(self, text): |
| 29 | + return self._tokenize(text) |
| 30 | + |
| 31 | + |
| 32 | +class TestCase(unittest.TestCase): |
| 33 | + |
| 34 | + def setUp(self): |
| 35 | + name = 'test' |
| 36 | + conn = sqlite3.connect(':memory:') |
| 37 | + conn.row_factory = sqlite3.Row |
| 38 | + |
| 39 | + fts.register_tokenizer(conn, name, fts.make_tokenizer_module(Tokenizer())) |
| 40 | + |
| 41 | + conn.execute('CREATE VIRTUAL TABLE fts3 USING FTS3(tokenize={})'.format(name)) |
| 42 | + conn.execute('CREATE VIRTUAL TABLE fts4 USING FTS4(tokenize={})'.format(name)) |
| 43 | + |
| 44 | + values = [ |
| 45 | + (u'Make thing I',), |
| 46 | + (u'Some thing φχικλψ thing',), |
| 47 | + (u'Fusce volutpat hendrerit sem. Fusce sit amet vulputate dui. ' |
| 48 | + u'Sed posuere mi a nisl aliquet tempor. Praesent tincidunt vel nunc ac pharetra.',), |
| 49 | + (u'Nam molestie euismod leo id aliquam. In hac habitasse platea dictumst.',), |
| 50 | + (u'Vivamus tincidunt feugiat tellus ac bibendum. In rhoncus dignissim suscipit.',), |
| 51 | + (u'Pellentesque hendrerit nulla rutrum luctus rutrum. Fusce hendrerit fermentum nunc at posuere.',), |
| 52 | + ] |
| 53 | + for n in ('fts3', 'fts4'): |
| 54 | + result = conn.executemany('INSERT INTO {0} VALUES(?)'.format(n), values) |
| 55 | + assert result.rowcount == len(values) |
| 56 | + |
| 57 | + conn.create_function('bm25', 2, ranking.bm25) |
| 58 | + conn.create_function('rank', 1, ranking.simple) |
| 59 | + |
| 60 | + self.testee = conn |
| 61 | + |
| 62 | + def testSimple(self): |
| 63 | + sql = ''' |
| 64 | + SELECT content, rank(matchinfo(fts3)) AS rank |
| 65 | + FROM fts3 |
| 66 | + WHERE fts3 MATCH :query |
| 67 | + ORDER BY rank DESC |
| 68 | + ''' |
| 69 | + actual = [dict(x) for x in self.testee.execute(sql, {'query': u'thing'})] |
| 70 | + |
| 71 | + self.assertEqual(2, len(actual)) |
| 72 | + self.assertEqual({ |
| 73 | + 'content': u'Some thing φχικλψ thing', |
| 74 | + 'rank': 0.6666666666666666 |
| 75 | + }, actual[0]) |
| 76 | + self.assertEqual({ |
| 77 | + 'content': u'Make thing I', |
| 78 | + 'rank': 0.3333333333333333 |
| 79 | + }, actual[1]) |
| 80 | + |
| 81 | + def testBm25(self): |
| 82 | + sql = ''' |
| 83 | + SELECT content, bm25(matchinfo(fts4, 'pcxnal'), 0) AS rank |
| 84 | + FROM fts4 |
| 85 | + WHERE fts4 MATCH :query |
| 86 | + ORDER BY rank DESC |
| 87 | + ''' |
| 88 | + actual = [dict(x) for x in self.testee.execute(sql, {'query': u'thing'})] |
| 89 | + |
| 90 | + self.assertEqual(2, len(actual)) |
| 91 | + self.assertEqual({ |
| 92 | + 'content': u'Some thing φχικλψ thing', |
| 93 | + 'rank': 0.9722786938230542 |
| 94 | + }, actual[0]) |
| 95 | + self.assertEqual({ |
| 96 | + 'content': u'Make thing I', |
| 97 | + 'rank': 0.8236501036844982 |
| 98 | + }, actual[1]) |
0 commit comments