diff --git a/irc3/compat.py b/irc3/compat.py index 2c6385a..d0e7ad1 100644 --- a/irc3/compat.py +++ b/irc3/compat.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- import sys -import types PY34 = bool(sys.version_info[0:2] >= (3, 4)) PY35 = bool(sys.version_info[0:2] >= (3, 5)) diff --git a/irc3/plugins/feeds.py b/irc3/plugins/feeds.py index ca43493..343044d 100644 --- a/irc3/plugins/feeds.py +++ b/irc3/plugins/feeds.py @@ -4,7 +4,6 @@ import irc3 import datetime from irc3.compat import asyncio -from concurrent.futures import ThreadPoolExecutor __doc__ = ''' ========================================== @@ -22,9 +21,9 @@ irc3.plugins.feeds [irc3.plugins.feeds] - channels = #irc3 # global channel to notify - delay = 5 # delay to check feeds - directory = ~/.irc3/feeds # directory to store feeds + channels = #irc3 # global channel to notify + delay = 5 # delay to check feeds in minutes + directory = ~/.irc3/feeds # directory to store feeds hook = irc3.plugins.feeds.default_hook # dotted name to a callable fmt = [{name}] {entry.title} - {entry.link} # formatter @@ -34,7 +33,7 @@ github/irc3.fmt = [{feed.name}] New commit: {entry.title} - {entry.link} # custom channels github/irc3.channels = #irc3dev #irc3 - # custom delay + # custom delay in minutes github/irc3.delay = 10 Hook is a dotted name refering to a callable (function or class) wich take a @@ -64,7 +63,7 @@ ''' HEADERS = { - 'User-Agent': 'python-requests/irc3/feeds', + 'User-Agent': 'python-aiohttp/irc3/feeds', 'Cache-Control': 'max-age=0', 'Pragma': 'no-cache', } @@ -82,21 +81,6 @@ def dispatcher(messages): return dispatcher -def fetch(args): - """fetch a feed""" - requests = args['requests'] - for feed, filename in zip(args['feeds'], args['filenames']): - try: - resp = requests.get(feed, timeout=5, headers=HEADERS) - content = resp.content - except Exception: # pragma: no cover - pass - else: - with open(filename, 'wb') as fd: - fd.write(content) - return args['name'] - - ISO_FORMAT = "%Y-%m-%dT%H:%M:%S" @@ -108,7 +92,7 @@ def parse(feedparser, args): for filename in args['filenames']: try: - with open(filename + '.updated') as fd: + with open(filename + '.updated', encoding="UTF-8") as fd: updated = datetime.datetime.strptime( fd.read()[:len("YYYY-MM-DDTHH:MM:SS")], ISO_FORMAT ) @@ -146,8 +130,6 @@ def parse(feedparser, args): class Feeds: """Feeds plugin""" - PoolExecutor = ThreadPoolExecutor - def __init__(self, bot): bot.feeds = self self.bot = bot @@ -207,7 +189,16 @@ def __init__(self, bot): def connection_made(self): """Initialize checkings""" - self.bot.loop.call_later(10, self.update) + self.bot.create_task(self.periodically_update()) + + async def periodically_update(self): + """After a connection has been made, call update feeds periodically.""" + if not self.aiohttp or not self.feedparser: + return + await asyncio.sleep(10) + while True: + await self.update() + await asyncio.sleep(self.delay) def imports(self): """show some warnings if needed""" @@ -218,14 +209,14 @@ def imports(self): self.bot.log.critical('feedparser is not installed') self.feedparser = None try: - import requests + import aiohttp except ImportError: # pragma: no cover - self.bot.log.critical('requests is not installed') - self.requests = None + self.bot.log.critical('aiohttp is not installed') + self.aiohttp = None else: - self.requests = requests + self.aiohttp = aiohttp - def parse(self, *args): + def parse(self): """parse pre-fetched feeds and notify new entries""" entries = [] for feed in self.feeds.values(): @@ -237,33 +228,37 @@ def messages(): if entry: feed = entry.feed message = feed['fmt'].format(feed=feed, entry=entry) - for c in feed['channels']: - yield c, message + for channel in feed['channels']: + yield channel, message self.dispatcher(messages()) - def update_time(self, future): - name = future.result() - self.bot.log.debug('Feed %s fetched', name) - feed = self.feeds[name] - feed['time'] = time.time() - - def update(self): + async def update(self): """update feeds""" - loop = self.bot.loop - loop.call_later(self.delay, self.update) - now = time.time() - feeds = [dict(f, requests=self.requests) for f in self.feeds.values() - if f['time'] < now - f['delay']] - if feeds: - self.bot.log.info('Fetching feeds %s', - ', '.join([f['name'] for f in feeds])) - tasks = [] - for feed in feeds: - task = loop.run_in_executor(None, fetch, feed) - task.add_done_callback(self.update_time) - tasks.append(task) - task = self.bot.create_task( - asyncio.wait(tasks, timeout=len(feeds) * 2, loop=loop)) - task.add_done_callback(self.parse) + feeds = [feed for feed in self.feeds.values() + if feed['time'] < now - feed['delay']] + if not feeds: + return + self.bot.log.info('Fetching feeds %s', + ', '.join([f['name'] for f in feeds])) + timeout = self.aiohttp.ClientTimeout(total=5) + async with self.aiohttp.ClientSession(timeout=timeout) as session: + await asyncio.gather( + *[self.fetch(feed, session) for feed in feeds] + ) + self.parse() + + async def fetch(self, feed, session): + """fetch a feed""" + for url, filename in zip(feed['feeds'], feed['filenames']): + try: + async with session.get(url, headers=HEADERS) as resp: + with open(filename, 'wb') as file: + file.write(await resp.read()) + except Exception: # pragma: no cover + self.bot.log.exception( + "Exception while fetching feed %s", feed['name'] + ) + self.bot.log.debug('Feed %s fetched', feed['name']) + feed['time'] = time.time() diff --git a/setup.py b/setup.py index 9cb40eb..f39bc30 100644 --- a/setup.py +++ b/setup.py @@ -10,6 +10,7 @@ test_requires = [ 'pytest-asyncio', 'pytest-aiohttp', + 'aiohttp', 'feedparser', 'requests', 'pysocks',