Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
FROM python:3.10 as requirements-stage

WORKDIR /tmp
RUN pip install poetry
RUN pip install poetry==1.8.5
COPY ./pyproject.toml ./poetry.lock* /tmp/
RUN poetry export -f requirements.txt --output requirements.txt --without-hashes

Expand Down
2 changes: 1 addition & 1 deletion app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"model": {
"path": os.environ.get(
"SEQ2SEQ_MODEL_PATH",
"msalnikov/kgqa_sqwd-tunned_t5-large-ssm-nq",
"s-nlp/t5_large_ssm_nq_mintaka",
),
"route_postfix": "sqwd_tunned/t5_large_ssm_nq",
},
Expand Down
27 changes: 27 additions & 0 deletions app/kgqa/entity_linking.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from nltk.stem.porter import PorterStemmer

from .ner import NerToSentenceInsertion
from .utils.utils import get_wd_search_results
from .mgenre import build_mgenre_pipeline, MGENREPipeline


class EntitiesSelection:
Expand Down Expand Up @@ -45,3 +47,28 @@ def _check_label_fn(self, label, entities_list):
if label == entity:
return True
return False


class EntityLinker:
def __init__(self, ner: NerToSentenceInsertion, mgenre: MGENREPipeline, entity_selection: EntitiesSelection):
self.ner = ner
self.mgenre = mgenre
self.entity_selection = entity_selection

def extract_entities_from_question(self, question_text: str):
question_wit_ner, all_question_entities = self.ner.entity_labeling(question_text, True)
mgenre_predicted_entities = self.mgenre(question_wit_ner)
question_entities = self.entity_selection(
all_question_entities, mgenre_predicted_entities
)

question_entities = [get_wd_search_results(label, 1)[0] for label in question_entities]
question_entities = [idx for idx in question_entities if idx is not None]

return question_entities


ner = NerToSentenceInsertion("/data/ner/")
mgenre = build_mgenre_pipeline()
entity_selection = EntitiesSelection(ner.model)
entity_linker = EntityLinker(ner, mgenre, entity_selection)
4 changes: 2 additions & 2 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from app.models.base import WikidataSSPRequest
from app.pipelines import seq2seq
from app.pipelines import act_selection
# from app.pipelines import m3m
from app.pipelines import m3m
from fastapi.middleware.cors import CORSMiddleware

app = FastAPI()
Expand All @@ -38,7 +38,7 @@

app.include_router(act_selection.router)
app.include_router(seq2seq.router)
# app.include_router(m3m.router)
app.include_router(m3m.router)


@app.get("/", response_class=HTMLResponse)
Expand Down
4 changes: 4 additions & 0 deletions app/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ class PipelineResponce(BaseModel):
answers: list[str]


class Seq2SeqPipelineResponce(PipelineResponce):
question_entities: list[str]


class EntityNeighboursResponce(BaseModel):
entity: str
property: str
Expand Down
22 changes: 3 additions & 19 deletions app/pipelines/act_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,14 @@

from fastapi import APIRouter
from joblib import Parallel, delayed
from pywikidata import Entity

from app.kgqa.act_selection import QuestionToRankInstanceOf, QuestionToRankInstanceOfSimple, QuestionToRankInstanceOfSimpleWithDescriptionMatching
from app.kgqa.entity_linking import EntitiesSelection
from app.kgqa.mgenre import build_mgenre_pipeline
from app.kgqa.ner import NerToSentenceInsertion
from app.kgqa.entity_linking import ner, mgenre, entity_selection, entity_linker
from app.kgqa.utils.utils import get_wd_search_results
from app.models.base import Entity as EntityResponce
from app.models.base import Question as QuestionRequest
from app.models.base import ACTPipelineResponce, ACTPipelineResponceWithDescriptionScore, QuestionEntitiesResponce, EntityNeighboursResponce
from app.pipelines.seq2seq import seq2seq

ner = NerToSentenceInsertion("/data/ner/")
mgenre = build_mgenre_pipeline()
entity_selection = EntitiesSelection(ner.model)


router = APIRouter(
prefix="/pipeline/act_selection",
Expand Down Expand Up @@ -53,15 +45,7 @@ def raw_seq2seq(question: str) -> list[str]:


def _prepare_question_entities_and_answer_candidate_helper(question: QuestionRequest):
question_wit_ner, all_question_entities = ner.entity_labeling(question.text, True)
mgenre_predicted_entities = mgenre(question_wit_ner)
question_entities = entity_selection(
all_question_entities, mgenre_predicted_entities
)

question_entities = [get_wd_search_results(label, 1)[0] for label in question_entities]
question_entities = [idx for idx in question_entities if idx is not None]

question_entities = entity_linker.extract_entities_from_question(question.text)
seq2seq_results = seq2seq(question.text)
answers_candidates = Parallel(n_jobs=-2)(
delayed(get_wd_search_results)(label, 1) for label in seq2seq_results
Expand All @@ -72,7 +56,7 @@ def _prepare_question_entities_and_answer_candidate_helper(question: QuestionReq


@lru_cache(maxsize=1024)
@router.post("/main")
@router.post("/main/")
def pipeline(question: QuestionRequest) -> ACTPipelineResponce:
question_entities, answers_candidates = _prepare_question_entities_and_answer_candidate_helper(question)

Expand Down
11 changes: 6 additions & 5 deletions app/pipelines/seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
import torch
from fastapi import APIRouter
from joblib import Parallel, delayed
from pywikidata import Entity

from app.config import seq2seq as seq2seq_config
from app.kgqa.entity_linking import entity_linker
from app.kgqa.seq2seq import build_seq2seq_pipeline
from app.kgqa.utils.utils import label_to_entity_idx
from app.models.base import Entity as EntityResponce
from app.models.base import PipelineResponce
from app.models.base import Seq2SeqPipelineResponce
from app.models.base import Question as QuestionRequest

router = APIRouter(
Expand All @@ -23,13 +22,15 @@
torch.device(seq2seq_config["device"]),
)


@lru_cache(maxsize=1024)
@router.post("/")
def seq2seq_pipeline(question: QuestionRequest) -> PipelineResponce:
def seq2seq_pipeline(question: QuestionRequest) -> Seq2SeqPipelineResponce:
question_entities = entity_linker.extract_entities_from_question(question.text)
seq2seq_results = seq2seq(question.text)

corr_entities = Parallel(n_jobs=-2)(
delayed(label_to_entity_idx)(label) for label in seq2seq_results
)
corr_entities = [e for e in corr_entities if e is not None]
return PipelineResponce(answers=corr_entities[:60])
return Seq2SeqPipelineResponce(answers=corr_entities[:60], question_entities=question_entities)
157 changes: 88 additions & 69 deletions app/templates/index_search.html
Original file line number Diff line number Diff line change
Expand Up @@ -90,25 +90,20 @@
}
}

function process_pipeline_data(pipeline_data) {
const queryString = window.location.search;
const urlParams = new URLSearchParams(queryString);
const pipeline = urlParams.get('pipeline');

function process_pipeline_data(pipeline_data, pipeline) {
graph_svg = $(document.createElement('div')).addClass("my-3");
$("#details").append(graph_svg);

if (pipeline == 'm3m' || pipeline == 'm3m_subj_question_matching') {
if (pipeline === 'm3m' || pipeline === 'm3m_subj_question_matching') {
$("#details").append(
"<b>"+m3m_uncertanity_to_decision(pipeline_data['uncertenity'])+"</b>"
);
}

if (pipeline == 'act_selection' || pipeline == 'act_selection_simple' || pipeline == 'act_selection_simple_description') {
if (pipeline === 'seq2seq' || pipeline === 'act_selection' || pipeline === 'act_selection_simple' || pipeline === 'act_selection_simple_description') {
question_entities = $(document.createElement('div'));
question_entities.append('Question entities: ');
pipeline_data['question_entities'].forEach(obj => {
let entity_idx = obj['entity'];
let entity_idx = pipeline === 'seq2seq' ? obj : obj['entity'];
$.ajax({
url: 'wikidata/entities/' + entity_idx + '/label',
method: 'GET',
Expand All @@ -121,43 +116,45 @@
});
$("#details").append(question_entities);

answer_instance_of = $('<div class="m-1">Final Instance Of: </div>');
answer_instance_of_count = $('<table class="table-auto"></table>');
th = $(document.createElement('thead'))
th.append('<tr class="border-b"><th>Instance Of Entity</th><th>Count</th></tr>');
answer_instance_of_count.append(th);
answer_instance_of_count_body = $("<tbody></tbody>");
answer_instance_of_count.append(answer_instance_of_count_body);
if (pipeline !== 'seq2seq') {
answer_instance_of = $('<div>Final Instance Of: </div>');
answer_instance_of_count = $('<table class="table-auto"></table>');
th = $(document.createElement('thead'))
th.append('<tr class="border-b"><th>Instance Of Entity</th><th>Count</th></tr>');
answer_instance_of_count.append(th);
answer_instance_of_count_body = $("<tbody></tbody>");
answer_instance_of_count.append(answer_instance_of_count_body);

$("#details").append(answer_instance_of);
$("#details").append(answer_instance_of_count);
$("#details").append(answer_instance_of);
$("#details").append(answer_instance_of_count);

Object.entries(pipeline_data['answer_instance_of_count']).forEach(function(entry, idx, _) {
const [entity_idx, count] = entry;
$.ajax({
url: 'wikidata/entities/' + entity_idx + '/label',
method: 'GET',
dataType: 'json',
cache: true,
success: function(edata) {
tr = $(document.createElement('tr'));
tr.addClass('border-b');
tr.addClass('answer_instance_of_item');
tr.append("<td>" + one_instace_of_to_html(entity_idx, edata['label']) + "</td>");
tr.append("<td>" + String(count) + "</td>");
tr.attr('order', 10000000 - count);
answer_instance_of_count_body.append(tr);

if (pipeline_data['answer_instance_of'].includes(entity_idx)) {
answer_instance_of.append(one_instace_of_to_html(entity_idx, edata['label']));
Object.entries(pipeline_data['answer_instance_of_count']).forEach(function (entry, idx, _) {
const [entity_idx, count] = entry;
$.ajax({
url: 'wikidata/entities/' + entity_idx + '/label',
method: 'GET',
dataType: 'json',
cache: true,
success: function (edata) {
tr = $(document.createElement('tr'));
tr.addClass('border-b');
tr.addClass('answer_instance_of_item');
tr.append("<td>" + one_instace_of_to_html(entity_idx, edata['label']) + "</td>");
tr.append("<td>" + String(count) + "</td>");
tr.attr('order', 10000000 - count);
answer_instance_of_count_body.append(tr);

if (pipeline_data['answer_instance_of'].includes(entity_idx)) {
answer_instance_of.append(one_instace_of_to_html(entity_idx, edata['label']));
}
},
complete: function (edata) {
let instance_of_count_objcts = getSorted('.answer_instance_of_item', 'order');
instance_of_count_objcts.detach().appendTo(answer_instance_of_count_body);
}
},
complete: function(edata) {
let instance_of_count_objcts = getSorted('.answer_instance_of_item', 'order');
instance_of_count_objcts.detach().appendTo(answer_instance_of_count_body);
}
})
});
})
});
}
}

answers = [];
Expand Down Expand Up @@ -213,39 +210,55 @@
$("#top_answer_card").append(body);

if (
pipeline == 'act_selection' ||
pipeline == 'act_selection_simple' ||
pipeline == 'act_selection_simple_description' ||
pipeline == 'm3m' ||
pipeline == 'm3m_subj_question_matching'
pipeline === 'seq2seq' ||
pipeline === 'act_selection' ||
pipeline === 'act_selection_simple' ||
pipeline === 'act_selection_simple_description' ||
pipeline === 'm3m' ||
pipeline === 'm3m_subj_question_matching'
) {
let question_entities_idx = [];
if (pipeline == 'act_selection' || pipeline == 'act_selection_simple' || pipeline == 'act_selection_simple_description') {
if (pipeline === 'act_selection' || pipeline === 'act_selection_simple' || pipeline === 'act_selection_simple_description') {
pipeline_data['question_entities'].forEach(obj => {
question_entities_idx.push(String(obj["entity"]));
})
} else if (pipeline === 'seq2seq') {
pipeline_data['question_entities'].forEach(idx => {
question_entities_idx.push(String(idx));
})
} else {
question_entities_idx.push(pipeline_data['triples'][0][0]);
}
$.ajax({
url: '/wikidata/entities/ssp/graph/svg',
method: 'POST',
data: JSON.stringify({question_entities_idx: question_entities_idx, answer_idx: String(data['idx'])}),
contentType: "application/json",
dataType: 'json',
cache: true,
beforeSend: function() {
if ( $( "#loading_spinner2" ).length < 1) {
graph_svg.append(loading_spinner(2));
}
},
success: function (svg_data) {
graph_svg.append(String(svg_data));
},
complete: function() {
$("#loading_spinner2").remove();
},
});
const problem_text = question_entities_idx.length === 0 ? "Can't visualize graph\nNo question entities detected" : "";
if (problem_text) {
pre = $(document.createElement('pre'));
pre.text(problem_text);
graph_svg.append(pre);
} else {
graph_svg.empty();
$.ajax({
url: '/wikidata/entities/ssp/graph/svg',
method: 'POST',
data: JSON.stringify({
question_entities_idx: question_entities_idx,
answer_idx: String(data['idx'])
}),
contentType: "application/json",
dataType: 'json',
cache: true,
beforeSend: function () {
if ($("#loading_spinner2").length < 1) {
graph_svg.append(loading_spinner(2));
}
},
success: function (svg_data) {
graph_svg.append(String(svg_data));
},
complete: function () {
$("#loading_spinner2").remove();
},
});
}
}
}

Expand Down Expand Up @@ -301,6 +314,12 @@
}

$(document).ready(function() {
const queryString = window.location.search;
const urlParams = new URLSearchParams(queryString);
const pipeline = urlParams.get('pipeline');
if (!pipeline) {
window.location.replace(`${window.location.href}&pipeline=seq2seq`);
}
$.ajax({
url: '/pipeline/' + pipeline_switcher('{{pipeline}}') + '/',
method: 'POST',
Expand All @@ -309,7 +328,7 @@
dataType: 'json',
cache: true,
success: function (data) {
process_pipeline_data(data);
process_pipeline_data(data, pipeline);
},
})
})
Expand Down
2 changes: 2 additions & 0 deletions docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,5 @@ services:
build: ./gap
environment:
- TOKENIZERS_PARALLELISM=false
ports:
- 8084:8082
Loading