Skip to content

Commit 2c49b67

Browse files
authored
Merge pull request #54 from vijos/yaml
implement yaml-based format (#38)
2 parents 7b90fce + 414dc4d commit 2c49b67

File tree

5 files changed

+87
-21
lines changed

5 files changed

+87
-21
lines changed

jd4/case.py

+69-15
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
import csv
2+
import re
3+
import tarfile
24
from asyncio import gather, get_event_loop
35
from functools import partial
46
from io import BytesIO, TextIOWrapper
57
from itertools import islice
68
from os import mkfifo, path
9+
from ruamel import yaml
710
from socket import socket, AF_UNIX, SOCK_STREAM, SOCK_NONBLOCK
8-
from zipfile import ZipFile
11+
from zipfile import ZipFile, is_zipfile
912

1013
from jd4._compare import compare_stream
1114
from jd4.cgroup import wait_cgroup
@@ -81,9 +84,9 @@ def dos2unix(src, dst):
8184
buf = buf.replace(b'\r', b'')
8285
dst.write(buf)
8386

84-
class LegacyCase(CaseBase):
85-
def __init__(self, open_input, open_output, time_sec, mem_kb, score):
86-
super().__init__(int(time_sec * 1e9), int(mem_kb * 1024), PROCESS_LIMIT, score)
87+
class DefaultCase(CaseBase):
88+
def __init__(self, open_input, open_output, time_ns, memory_bytes, score):
89+
super().__init__(time_ns, memory_bytes, PROCESS_LIMIT, score)
8790
self.open_input = open_input
8891
self.open_output = open_output
8992

@@ -115,18 +118,69 @@ def do_stdout(self, stdout_file):
115118
with open(stdout_file, 'rb') as file:
116119
return compare_stream(BytesIO(str(self.a + self.b).encode()), file)
117120

118-
def read_legacy_cases(file):
119-
zip_file = ZipFile(file)
120-
canonical_dict = dict((name.lower(), name) for name in zip_file.namelist())
121-
config = TextIOWrapper(zip_file.open(canonical_dict['config.ini']),
122-
encoding='utf-8', errors='replace')
121+
class FormatError(Exception):
122+
pass
123+
124+
def read_legacy_cases(config, open):
123125
num_cases = int(config.readline())
124126
for line in islice(csv.reader(config, delimiter='|'), num_cases):
125-
input, output, time_sec_str, score_str = line[:4]
127+
input, output, time_str, score_str = line[:4]
126128
try:
127-
mem_kb = float(line[4])
129+
memory_kb = float(line[4])
128130
except (IndexError, ValueError):
129-
mem_kb = DEFAULT_MEM_KB
130-
open_input = partial(zip_file.open, canonical_dict[path.join('input', input.lower())])
131-
open_output = partial(zip_file.open, canonical_dict[path.join('output', output.lower())])
132-
yield LegacyCase(open_input, open_output, float(time_sec_str), mem_kb, int(score_str))
131+
memory_kb = DEFAULT_MEM_KB
132+
yield DefaultCase(partial(open, path.join('input', input)),
133+
partial(open, path.join('output', output)),
134+
int(float(time_str) * 1000000000),
135+
int(memory_kb * 1024),
136+
int(score_str))
137+
138+
TIME_RE = re.compile(r'([0-9]+(?:\.[0-9]*)?)([mun]?)s?')
139+
TIME_UNITS = {'': 1000000000, 'm': 1000000, 'u': 1000, 'n': 1}
140+
MEMORY_RE = re.compile(r'([0-9]+(?:\.[0-9]*)?)([kmg]?)b?')
141+
MEMORY_UNITS = {'': 1, 'k': 1024, 'm': 1048576, 'g': 1073741824}
142+
143+
def read_yaml_cases(config, open):
144+
for case in yaml.safe_load(config)['cases']:
145+
time = TIME_RE.fullmatch(case['time'])
146+
if not time:
147+
raise FormatError(case['time'], 'error parsing time')
148+
memory = MEMORY_RE.fullmatch(case['memory'])
149+
if not memory:
150+
raise FormatError(case['memory'], 'error parsing memory')
151+
yield DefaultCase(
152+
partial(open, case['input']),
153+
partial(open, case['output']),
154+
int(float(time.group(1)) * TIME_UNITS[time.group(2)]),
155+
int(float(memory.group(1)) * MEMORY_UNITS[memory.group(2)]),
156+
int(case['score']))
157+
158+
def read_cases(file):
159+
if is_zipfile(file):
160+
with ZipFile(file) as zip_file:
161+
canonical_dict = dict((name.lower(), name)
162+
for name in zip_file.namelist())
163+
def open(name):
164+
try:
165+
return ZipFile(file).open(canonical_dict[name.lower()])
166+
except KeyError:
167+
raise FileNotFoundError(name) from None
168+
elif tarfile.is_tarfile(file):
169+
def open(name):
170+
try:
171+
return tarfile.open(file).extractfile(name)
172+
except KeyError:
173+
raise FileNotFoundError(name) from None
174+
else:
175+
raise FormatError(file, 'not a zip file or tar file')
176+
try:
177+
config = TextIOWrapper(open('config.ini'), encoding='utf-8')
178+
return read_legacy_cases(config, open)
179+
except FileNotFoundError:
180+
pass
181+
try:
182+
config = open('config.yaml')
183+
return read_yaml_cases(config, open)
184+
except FileNotFoundError:
185+
pass
186+
raise FormatError('config file not found')

jd4/case_test.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
from os import path
22
from unittest import main, TestCase
33

4-
from jd4.case import read_legacy_cases
4+
from jd4.case import read_cases
55

66
class CaseTest(TestCase):
77
def test_legacy_case(self):
88
count = 0
9-
for case in read_legacy_cases(path.join(path.dirname(__file__),
10-
'testdata/1000.zip')):
9+
for case in read_cases(path.join(path.dirname(__file__),
10+
'testdata/aplusb-legacy.zip')):
1111
self.assertEqual(case.time_limit_ns, 1000000000)
1212
self.assertEqual(case.memory_limit_bytes, 16777216)
1313
self.assertEqual(case.score, 10)
@@ -16,5 +16,17 @@ def test_legacy_case(self):
1616
count += 1
1717
self.assertEqual(count, 10)
1818

19+
def test_yaml_case(self):
20+
count = 0
21+
for case in read_cases(path.join(path.dirname(__file__),
22+
'testdata/aplusb.tar.gz')):
23+
self.assertEqual(case.time_limit_ns, 1000000000)
24+
self.assertEqual(case.memory_limit_bytes, 33554432)
25+
self.assertEqual(case.score, 10)
26+
self.assertEqual(sum(map(int, case.open_input().read().split())),
27+
int(case.open_output().read()))
28+
count += 1
29+
self.assertEqual(count, 10)
30+
1931
if __name__ == '__main__':
2032
main()

jd4/integration_test.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from os import path
44
from unittest import TestCase, main
55

6-
from jd4.case import read_legacy_cases, APlusBCase
6+
from jd4.case import read_cases, APlusBCase
77
from jd4.cgroup import try_init_cgroup
88
from jd4.log import logger
99
from jd4.pool import pool_build, pool_judge
@@ -18,8 +18,8 @@ class LanguageTest(TestCase):
1818
@classmethod
1919
def setUpClass(cls):
2020
try_init_cgroup()
21-
cls.cases = list(read_legacy_cases(path.join(path.dirname(__file__),
22-
'testdata/1000.zip')))
21+
cls.cases = list(read_cases(path.join(path.dirname(__file__),
22+
'testdata/aplusb.tar.gz')))
2323

2424
def do_lang(self, lang, code):
2525
package, message, time_usage_ns, memory_usage_bytes = \
File renamed without changes.

jd4/testdata/aplusb.tar.gz

682 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)