Skip to content

Commit c4d78d9

Browse files
committed
added ranking related things from https://gist.github.com/saaj/fdc8e6351d07fbb1a511
a part of issue #2
1 parent f33f1e5 commit c4d78d9

File tree

2 files changed

+199
-0
lines changed

2 files changed

+199
-0
lines changed

sqlitefts/ranking.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# -*- coding: utf-8 -*-
2+
'''
3+
Ranking code based on:
4+
https://github.com/coleifer/peewee/blob/master/playhouse/sqlite_ext.py
5+
'''
6+
7+
8+
import struct
9+
import math
10+
11+
12+
def parseMatchInfo(buf):
13+
'''see http://sqlite.org/fts3.html#matchinfo'''
14+
bufsize = len(buf) # length in bytes
15+
return [struct.unpack('@I', buf[i:i+4])[0] for i in range(0, bufsize, 4)]
16+
17+
18+
def simple(raw_match_info):
19+
'''
20+
handle match_info called w/default args 'pcx' - based on the example rank
21+
function http://sqlite.org/fts3.html#appendix_a
22+
'''
23+
match_info = parseMatchInfo(raw_match_info)
24+
score = 0.0
25+
p, c = match_info[:2]
26+
for phrase_num in range(p):
27+
phrase_info_idx = 2 + (phrase_num * c * 3)
28+
for col_num in range(c):
29+
col_idx = phrase_info_idx + (col_num * 3)
30+
x1, x2 = match_info[col_idx:col_idx + 2]
31+
if x1 > 0:
32+
score += float(x1) / x2
33+
return score
34+
35+
36+
def bm25(raw_match_info, column_index, k1=1.2, b=0.75):
37+
"""
38+
FTS4-only ranking function.
39+
40+
Usage:
41+
42+
# Format string *must* be pcxnal
43+
# Second parameter to bm25 specifies the index of the column, on
44+
# the table being queries.
45+
46+
bm25(matchinfo(document_tbl, 'pcxnal'), 1) AS rank
47+
"""
48+
match_info = parseMatchInfo(raw_match_info)
49+
score = 0.0
50+
# p, 1 --> num terms
51+
# c, 1 --> num cols
52+
# x, (3 * p * c) --> for each phrase/column,
53+
# term_freq for this column
54+
# term_freq for all columns
55+
# total documents containing this term
56+
# n, 1 --> total rows in table
57+
# a, c --> for each column, avg number of tokens in this column
58+
# l, c --> for each column, length of value for this column (in this row)
59+
# s, c --> ignore
60+
p, c = match_info[:2]
61+
n_idx = 2 + (3 * p * c)
62+
a_idx = n_idx + 1
63+
l_idx = a_idx + c
64+
n = match_info[n_idx]
65+
a = match_info[a_idx: a_idx + c]
66+
l = match_info[l_idx: l_idx + c]
67+
68+
total_docs = n
69+
avg_length = float(a[column_index])
70+
doc_length = float(l[column_index])
71+
if avg_length == 0:
72+
D = 0
73+
else:
74+
D = 1 - b + (b * (doc_length / avg_length))
75+
76+
for phrase in range(p):
77+
# p, c, p0c01, p0c02, p0c03, p0c11, p0c12, p0c13, p1c01, p1c02, p1c03..
78+
# So if we're interested in column <i>, the counts will be at indexes
79+
x_idx = 2 + (3 * column_index * (phrase + 1))
80+
term_freq = float(match_info[x_idx])
81+
term_matches = float(match_info[x_idx + 2])
82+
83+
# The `max` check here is based on a suggestion in the Wikipedia
84+
# article. For terms that are common to a majority of documents, the
85+
# idf function can return negative values. Applying the max() here
86+
# weeds out those values.
87+
idf = max(
88+
math.log(
89+
(total_docs - term_matches + 0.5) /
90+
(term_matches + 0.5)),
91+
0)
92+
93+
denom = term_freq + (k1 * D)
94+
if denom == 0:
95+
rhs = 0
96+
else:
97+
rhs = (term_freq * (k1 + 1)) / denom
98+
99+
score += (idf * rhs)
100+
101+
return score

tests/test_ranking.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# -*- coding: utf-8 -*-
2+
3+
4+
import unittest
5+
import sqlite3
6+
import re
7+
8+
import sqlitefts as fts
9+
from sqlitefts import ranking
10+
11+
12+
class Tokenizer(fts.Tokenizer):
13+
14+
_spliter = re.compile(r'\s+|\S+', re.UNICODE)
15+
_nonws = re.compile(r'\S+', re.UNICODE)
16+
17+
def _normalize(self, token):
18+
return token.lower()
19+
20+
def _tokenize(self, text):
21+
pos = 0
22+
for t in self._spliter.findall(text):
23+
byteLen = len(t.encode('utf-8'))
24+
if self._nonws.match(t):
25+
yield self._normalize(t), pos, pos + byteLen
26+
pos += byteLen
27+
28+
def tokenize(self, text):
29+
return self._tokenize(text)
30+
31+
32+
class TestCase(unittest.TestCase):
33+
34+
def setUp(self):
35+
name = 'test'
36+
conn = sqlite3.connect(':memory:')
37+
conn.row_factory = sqlite3.Row
38+
39+
fts.register_tokenizer(conn, name, fts.make_tokenizer_module(Tokenizer()))
40+
41+
conn.execute('CREATE VIRTUAL TABLE fts3 USING FTS3(tokenize={})'.format(name))
42+
conn.execute('CREATE VIRTUAL TABLE fts4 USING FTS4(tokenize={})'.format(name))
43+
44+
values = [
45+
(u'Make thing I',),
46+
(u'Some thing φχικλψ thing',),
47+
(u'Fusce volutpat hendrerit sem. Fusce sit amet vulputate dui. '
48+
u'Sed posuere mi a nisl aliquet tempor. Praesent tincidunt vel nunc ac pharetra.',),
49+
(u'Nam molestie euismod leo id aliquam. In hac habitasse platea dictumst.',),
50+
(u'Vivamus tincidunt feugiat tellus ac bibendum. In rhoncus dignissim suscipit.',),
51+
(u'Pellentesque hendrerit nulla rutrum luctus rutrum. Fusce hendrerit fermentum nunc at posuere.',),
52+
]
53+
for n in ('fts3', 'fts4'):
54+
result = conn.executemany('INSERT INTO {0} VALUES(?)'.format(n), values)
55+
assert result.rowcount == len(values)
56+
57+
conn.create_function('bm25', 2, ranking.bm25)
58+
conn.create_function('rank', 1, ranking.simple)
59+
60+
self.testee = conn
61+
62+
def testSimple(self):
63+
sql = '''
64+
SELECT content, rank(matchinfo(fts3)) AS rank
65+
FROM fts3
66+
WHERE fts3 MATCH :query
67+
ORDER BY rank DESC
68+
'''
69+
actual = [dict(x) for x in self.testee.execute(sql, {'query': u'thing'})]
70+
71+
self.assertEqual(2, len(actual))
72+
self.assertEqual({
73+
'content': u'Some thing φχικλψ thing',
74+
'rank': 0.6666666666666666
75+
}, actual[0])
76+
self.assertEqual({
77+
'content': u'Make thing I',
78+
'rank': 0.3333333333333333
79+
}, actual[1])
80+
81+
def testBm25(self):
82+
sql = '''
83+
SELECT content, bm25(matchinfo(fts4, 'pcxnal'), 0) AS rank
84+
FROM fts4
85+
WHERE fts4 MATCH :query
86+
ORDER BY rank DESC
87+
'''
88+
actual = [dict(x) for x in self.testee.execute(sql, {'query': u'thing'})]
89+
90+
self.assertEqual(2, len(actual))
91+
self.assertEqual({
92+
'content': u'Some thing φχικλψ thing',
93+
'rank': 0.9722786938230542
94+
}, actual[0])
95+
self.assertEqual({
96+
'content': u'Make thing I',
97+
'rank': 0.8236501036844982
98+
}, actual[1])

0 commit comments

Comments
 (0)