|
1 | 1 | import csv
|
| 2 | +import re |
| 3 | +import tarfile |
2 | 4 | from asyncio import gather, get_event_loop
|
3 | 5 | from functools import partial
|
4 | 6 | from io import BytesIO, TextIOWrapper
|
5 | 7 | from itertools import islice
|
6 | 8 | from os import mkfifo, path
|
| 9 | +from ruamel import yaml |
7 | 10 | from socket import socket, AF_UNIX, SOCK_STREAM, SOCK_NONBLOCK
|
8 |
| -from zipfile import ZipFile |
| 11 | +from zipfile import ZipFile, is_zipfile |
9 | 12 |
|
10 | 13 | from jd4._compare import compare_stream
|
11 | 14 | from jd4.cgroup import wait_cgroup
|
@@ -81,9 +84,9 @@ def dos2unix(src, dst):
|
81 | 84 | buf = buf.replace(b'\r', b'')
|
82 | 85 | dst.write(buf)
|
83 | 86 |
|
84 |
| -class LegacyCase(CaseBase): |
85 |
| - def __init__(self, open_input, open_output, time_sec, mem_kb, score): |
86 |
| - super().__init__(int(time_sec * 1e9), int(mem_kb * 1024), PROCESS_LIMIT, score) |
| 87 | +class DefaultCase(CaseBase): |
| 88 | + def __init__(self, open_input, open_output, time_ns, memory_bytes, score): |
| 89 | + super().__init__(time_ns, memory_bytes, PROCESS_LIMIT, score) |
87 | 90 | self.open_input = open_input
|
88 | 91 | self.open_output = open_output
|
89 | 92 |
|
@@ -115,18 +118,69 @@ def do_stdout(self, stdout_file):
|
115 | 118 | with open(stdout_file, 'rb') as file:
|
116 | 119 | return compare_stream(BytesIO(str(self.a + self.b).encode()), file)
|
117 | 120 |
|
118 |
| -def read_legacy_cases(file): |
119 |
| - zip_file = ZipFile(file) |
120 |
| - canonical_dict = dict((name.lower(), name) for name in zip_file.namelist()) |
121 |
| - config = TextIOWrapper(zip_file.open(canonical_dict['config.ini']), |
122 |
| - encoding='utf-8', errors='replace') |
| 121 | +class FormatError(Exception): |
| 122 | + pass |
| 123 | + |
| 124 | +def read_legacy_cases(config, open): |
123 | 125 | num_cases = int(config.readline())
|
124 | 126 | for line in islice(csv.reader(config, delimiter='|'), num_cases):
|
125 |
| - input, output, time_sec_str, score_str = line[:4] |
| 127 | + input, output, time_str, score_str = line[:4] |
126 | 128 | try:
|
127 |
| - mem_kb = float(line[4]) |
| 129 | + memory_kb = float(line[4]) |
128 | 130 | except (IndexError, ValueError):
|
129 |
| - mem_kb = DEFAULT_MEM_KB |
130 |
| - open_input = partial(zip_file.open, canonical_dict[path.join('input', input.lower())]) |
131 |
| - open_output = partial(zip_file.open, canonical_dict[path.join('output', output.lower())]) |
132 |
| - yield LegacyCase(open_input, open_output, float(time_sec_str), mem_kb, int(score_str)) |
| 131 | + memory_kb = DEFAULT_MEM_KB |
| 132 | + yield DefaultCase(partial(open, path.join('input', input)), |
| 133 | + partial(open, path.join('output', output)), |
| 134 | + int(float(time_str) * 1000000000), |
| 135 | + int(memory_kb * 1024), |
| 136 | + int(score_str)) |
| 137 | + |
| 138 | +TIME_RE = re.compile(r'([0-9]+(?:\.[0-9]*)?)([mun]?)s?') |
| 139 | +TIME_UNITS = {'': 1000000000, 'm': 1000000, 'u': 1000, 'n': 1} |
| 140 | +MEMORY_RE = re.compile(r'([0-9]+(?:\.[0-9]*)?)([kmg]?)b?') |
| 141 | +MEMORY_UNITS = {'': 1, 'k': 1024, 'm': 1048576, 'g': 1073741824} |
| 142 | + |
| 143 | +def read_yaml_cases(config, open): |
| 144 | + for case in yaml.safe_load(config)['cases']: |
| 145 | + time = TIME_RE.fullmatch(case['time']) |
| 146 | + if not time: |
| 147 | + raise FormatError(case['time'], 'error parsing time') |
| 148 | + memory = MEMORY_RE.fullmatch(case['memory']) |
| 149 | + if not memory: |
| 150 | + raise FormatError(case['memory'], 'error parsing memory') |
| 151 | + yield DefaultCase( |
| 152 | + partial(open, case['input']), |
| 153 | + partial(open, case['output']), |
| 154 | + int(float(time.group(1)) * TIME_UNITS[time.group(2)]), |
| 155 | + int(float(memory.group(1)) * MEMORY_UNITS[memory.group(2)]), |
| 156 | + int(case['score'])) |
| 157 | + |
| 158 | +def read_cases(file): |
| 159 | + if is_zipfile(file): |
| 160 | + with ZipFile(file) as zip_file: |
| 161 | + canonical_dict = dict((name.lower(), name) |
| 162 | + for name in zip_file.namelist()) |
| 163 | + def open(name): |
| 164 | + try: |
| 165 | + return ZipFile(file).open(canonical_dict[name.lower()]) |
| 166 | + except KeyError: |
| 167 | + raise FileNotFoundError(name) from None |
| 168 | + elif tarfile.is_tarfile(file): |
| 169 | + def open(name): |
| 170 | + try: |
| 171 | + return tarfile.open(file).extractfile(name) |
| 172 | + except KeyError: |
| 173 | + raise FileNotFoundError(name) from None |
| 174 | + else: |
| 175 | + raise FormatError(file, 'not a zip file or tar file') |
| 176 | + try: |
| 177 | + config = TextIOWrapper(open('config.ini'), encoding='utf-8') |
| 178 | + return read_legacy_cases(config, open) |
| 179 | + except FileNotFoundError: |
| 180 | + pass |
| 181 | + try: |
| 182 | + config = open('config.yaml') |
| 183 | + return read_yaml_cases(config, open) |
| 184 | + except FileNotFoundError: |
| 185 | + pass |
| 186 | + raise FormatError('config file not found') |
0 commit comments