Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 28 additions & 12 deletions run_validation/main_task/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
## How to Run

### Prerequisites and setup

Download necessary files
- `bash download.sh`

Expand All @@ -8,28 +10,42 @@ Create and activate a virtual environment
- `source env/bin/activate`

Install requirements

- `pip install --upgrade pip && pip install -r requirements.txt`

Compile protocol buffers

- `bash compile_protos.sh`

Run Passage Validator service in background
- `python3 passage_validator_servicer.py files/all_hashes.sqlite3`
### Running the validation script

Run main validation script (in another terminal but within same virtual env)
- `python3 main.py CAST [path to run file] [--skip_passage_validation]`
Run the Passage Validator service in the background: `python3 passage_validator_servicer.py files/all_hashes.sqlite3`

NOTE: `--skip_passage_validation` is an optional argument that skips passage validation if added. If used, passage_validator does not need to be run in the background.
Run the main validation script (in another terminal but within the same virtual env). The script has several parameters you can view by running `python3 main.py -h`.

To generate a trec run file, ideally after main script runs successfully
- `python3 generate_run.py [path to run file]`
Some examples:

### Tests
```shell
# Run with default parameters
python3 main.py CAST <run file path>
# Run without having the validator service available
python3 main.py CAST <run file path> --skip_passage_validation
# Abort the run if more than 50 validation warnings are generated
python3 main.py CAST <run file path> -m 50
# Abort the run if any gRPC errors occur contacting the validation service
python3 main.py CAST <run file path> -s
# Set a 10s timeout for gRPC calls to the validation service
python3 main.py CAST <run file path> -t 10
```

To run the normal set of tests:
The script logs to stdout and to a file in the current working directory named `<run_file>.errlog` (e.g. a run file named `sample_run.json` will have logs saved to `sample_run.json.errlog`).

- `pytest`
### Generating a TREC run file

To generate a trec run file, ideally after main script runs successfully: `python3 generate_run.py <run file path>`

### Tests

To run the full set (including some slower tests):
To run the normal set of tests: `pytest`

- `pytest --runslow`
To run the full set (including some slower tests): `pytest --runslow`
9 changes: 8 additions & 1 deletion run_validation/main_task/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,13 @@ def load_run_file(run_file_path: str) -> CastRun:
with open(run_file_path, 'r', encoding='utf-8') as run_file:
try:
run = json.load(run_file)
# check for expected attributes
if 'run_name' not in run or 'run_type' not in run:
raise Exception('Missing run_name/run_type entry')

if 'turns' not in run:
raise Exception('Missing turns entry')

run = ParseDict(run, CastRun())
except Exception as e:
logger.error(f'Run file not in the right format ({e})')
Expand All @@ -82,7 +89,7 @@ def validate_turn(turn: Turn, turn_lookup_set: dict, service_stub: PassageValida
try:
warning_count = validate_passages(service_stub, logger, warning_count, turn)
except grpc.RpcError as rpce:
logger.warning(f'A gRPC error occurred when validating passages ({rpce.code().name})')
logger.warning(f'A gRPC error occurred when validating passages (name={rpce.code().name}, message={rpce.details()})')
service_errors += 1

# check response and provenance
Expand Down
7 changes: 5 additions & 2 deletions run_validation/main_task/passage_id_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __enter__(self):
return self

def __exit__(self, type, value, traceback):
self.cur.close()
self.close()

def populate(self, hash_file: str, batch_size: int, print_interval: int) -> bool:
Expand Down Expand Up @@ -113,13 +114,15 @@ def populate(self, hash_file: str, batch_size: int, print_interval: int) -> bool

def validate(self, ids: [str]) -> [bool]:
results = []
cur = self.db.cursor()
for id in ids:
self.cur.execute(f'SELECT {PassageIDDatabase.COL_NAME} FROM {PassageIDDatabase.TABLE_NAME} \
cur.execute(f'SELECT {PassageIDDatabase.COL_NAME} FROM {PassageIDDatabase.TABLE_NAME} \
WHERE {PassageIDDatabase.COL_NAME} = ?', (id, ))
result = self.cur.fetchone()
result = cur.fetchone()
results.append(False if result is None else True)
logger.debug(f'Validate {id} = {result is not None}')

cur.close()
return results

def close(self) -> bool:
Expand Down
4 changes: 3 additions & 1 deletion run_validation/main_task/passage_validator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import sys

import grpc

from compiled_protobufs.passage_validator_pb2 import PassageValidationRequest, PassageValidationResult, PassageValidation
from compiled_protobufs.passage_validator_pb2_grpc import PassageValidatorServicer

Expand All @@ -20,7 +22,7 @@ def __init__(self, db_path: str, expected_rows: int) -> None:
print('>> Service ready')

def validate_passages(self, passage_validation_request: PassageValidationRequest,
context) -> PassageValidationResult:
context: grpc.ServicerContext) -> PassageValidationResult:
"""
Takes in a list of passage ids and checks if they appear in the database
"""
Expand Down
26 changes: 26 additions & 0 deletions run_validation/main_task/tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,32 @@ def test_validate_invalid_run_file(tmp_path):
assert pytest_exc.type == SystemExit
assert pytest_exc.value.code == 255

def test_validate_run_file_missing_run_name(tmp_path, run_file_path):
tmp_file = tmp_path / 'missing_run_name.json'
json_str = '{ "run_type": "manual", "turns": [] }'

with open(tmp_file, 'w') as tf:
tf.write(json_str)

with pytest.raises(SystemExit) as pytest_exc:
_ = load_run_file(tmp_file)

assert pytest_exc.type == SystemExit
assert pytest_exc.value.code == 255

def test_validate_run_file_missing_turns(tmp_path, run_file_path):
tmp_file = tmp_path / 'missing_turns.json'
json_str = '{ "run_type": "manual", "run_name": "missing_turns"}'

with open(tmp_file, 'w') as tf:
tf.write(json_str)

with pytest.raises(SystemExit) as pytest_exc:
_ = load_run_file(tmp_file)

assert pytest_exc.type == SystemExit
assert pytest_exc.value.code == 255

def test_validate_missing_run_file():
with pytest.raises(OSError):
_ = load_run_file('foobar')
Expand Down
28 changes: 28 additions & 0 deletions run_validation/main_task/tests/test_service.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import time
import multiprocessing
import random

import pytest

from passage_validator import PassageValidator
from main import validate

def test_service_startup(servicer_params_test):
pv = PassageValidator(*servicer_params_test)
Expand All @@ -12,3 +17,26 @@ def test_service_startup_invalid_rows(servicer_params_test):

assert(pytest_exc.type == SystemExit)
assert(pytest_exc.value.code == 255)

def validate_wrapper(run_file, file_root, max_warnings, skip_validation, strict, start_delay):
time.sleep(start_delay)
turns_validated, warning_count, service_errors = validate(run_file, file_root, max_warnings, skip_validation, strict)
return (turns_validated, warning_count, service_errors)

@pytest.mark.slow
def test_service_multiple_clients(default_validate_args, grpc_server_full):
num_clients = 25
args = default_validate_args

validation_args = [(args.path_to_run_file,
args.fileroot,
args.max_warnings,
args.skip_passage_validation,
args.strict,
random.random()) for x in range(num_clients)]

with multiprocessing.Pool(processes=num_clients) as pool:
results = pool.starmap(validate_wrapper, validation_args)

for i in range(num_clients):
assert(results[i] == (205, 0, 1))