diff --git a/x12306/__init__.py b/x12306/__init__.py index 82b68f5..af7a0bb 100644 --- a/x12306/__init__.py +++ b/x12306/__init__.py @@ -5,7 +5,6 @@ @file: __init__.py @time: 2019-02-08 """ - import click from .settings import settings @@ -28,6 +27,7 @@ @click.option("--proxies-file", help="代理列表文件") @click.option("--stations-file", help="站点信息文件") @click.option("--cdn-file", help="CDN文件") +@click.option("--csv-file", help="保存为csv文件") def main( from_station, to_station, @@ -43,6 +43,7 @@ def main( proxies_file, stations_file, cdn_file, + csv_file ): """ 12306查票助手 https://github.com/0xHJK/x12306 @@ -64,6 +65,7 @@ def main( proxies_file=proxies_file, stations_file=stations_file, cdn_file=cdn_file, + csv_file=csv_file ) print("\n-----------------------") diff --git a/x12306/settings.py b/x12306/settings.py index 14ee203..1b54068 100644 --- a/x12306/settings.py +++ b/x12306/settings.py @@ -102,6 +102,8 @@ def __init__(self): self.stations_file = DEFAULT_STATIONS_FILE # CDN文件 --cdn-file self.cdn_file = DEFAULT_CDN_FILE + # CDN文件 --csv-file + self.csv_file = None self.headers = DEFAULT_HEADERS self.init_url = DEFAULT_BASE_URL + "/init" diff --git a/x12306/train.py b/x12306/train.py index 92f0fd7..859f7db 100644 --- a/x12306/train.py +++ b/x12306/train.py @@ -13,6 +13,7 @@ import re import requests import prettytable as pt +import csv from .settings import settings from .utils import colorize @@ -148,6 +149,13 @@ def echo(self): for train in self.trains_list: tb.add_row(train.row) print(tb) + if settings.csv_file: + with open(settings.csv_file, mode='w+', encoding='utf-8', newline='') as file: + csv_writer = csv.writer(file) + for train in self.trains_list: + csv_writer.writerow(train.row) + print("已写入",settings.csv_file) + def cleanup(self): """处理trains_list,排序和删除无效数据""" @@ -172,6 +180,13 @@ def update(self): settings.date, settings.trains_no_list, ) + elif settings.zzmode: + self.trains_list = self._query_trains_zzmode( + settings.fs_code, + settings.ts_code, + settings.date, + settings.trains_no_list, + ) else: self.trains_list = self._query_trains( settings.fs_code, @@ -304,3 +319,28 @@ def _query_trains_zmode(self, fs_code, ts_code, date, trains_no_list) -> list: ) return list(set(trains_list)) + def _query_trains_zzmode(self, fs_code, ts_code, date, trains_no_list) -> list: + """ + 高级查询模式,会查询从出发站到沿途所有站的车次情况 + 仅被内部调用,调用前处理好参数 + :param fs_code: 出发地编码 + :param ts_code: 目的地编码 + :param date: 日期 + :param trains_no_list: 限制车次 + :return: Train对象列表 + """ + trains_list = self._query_trains(fs_code, ts_code, date, trains_no_list) + trains_no_list = [train.no for train in trains_list] + stations_list = [] + + for train in trains_list: + stations_list += self._query_stations(train) + stations_list = list(set(stations_list)) + for station in stations_list: + nfs_code = settings.stations_dict.get(station, "") + if ts_code: + trains_list += self._query_trains_zmode( + nfs_code, ts_code, date, trains_no_list + ) + + return list(set(trains_list))