From e03f5d0a43146a57d20cb4448262495b8d6044ed Mon Sep 17 00:00:00 2001 From: jkhsjdhjs Date: Mon, 16 Mar 2020 21:51:32 +0000 Subject: [PATCH] add auto reconnect for postgres This only works on the second database interaction, since psycopg2 only notices that the connection is gone, when a query is executed. So in the common case reconnect works as follows: - some bot method calls a cursor function like .execute(), .fetchone(), etc. - this raises an error if the connection is broken - if following code then requests a new cursor, this will also fail since psycopg2 now knows that the connection is gone - the error is caught in storage.DBConn.cursor(), a new connection will be set up of which a new cursor is yielded If the error happens in connection.commit() or .rollback() instead we can instantly reconnect since these methods are wrapped. So why not wrap the cursor methods as well? Consider the following example: A query is the last thing that was executed on a cursor. The database connection is lost. Now .fetchone() is called on the cursor. We could wrap .fetchone() and reconnect, but we'd have to use a new cursor since cursors are linked to connections. And on this new cursor .fetchone() wouldn't make any sense, since we haven't executed a query on this cursor. --- bot/__init__.py | 1 - bot/mcmaniac.py | 50 +++++++++++----------- bot/quotes.py | 107 +++++++++++++++++++++++++----------------------- bot/rape.py | 56 +++++++++++++------------ bot/regex.py | 40 +++++++++--------- bot/seen.py | 42 ++++++++++--------- bot/storage.py | 35 +++++++++++++++- bot/tell.py | 49 +++++++++++----------- bot/timer.py | 55 +++++++++++++------------ bot/useless.py | 92 +++++++++++++++++++++-------------------- 10 files changed, 289 insertions(+), 238 deletions(-) diff --git a/bot/__init__.py b/bot/__init__.py index acd17ad..b48eabb 100644 --- a/bot/__init__.py +++ b/bot/__init__.py @@ -73,4 +73,3 @@ class DatabasePlugin(Plugin): super().__init__(bot) self.db = bot.get_plugin(Storage) self.con = self.db.con - self.cur = self.db.cur diff --git a/bot/mcmaniac.py b/bot/mcmaniac.py index b5cbef5..9b203be 100644 --- a/bot/mcmaniac.py +++ b/bot/mcmaniac.py @@ -29,34 +29,36 @@ class McManiac(DatabasePlugin): offset = '' # Fetch result from database - self.cur.execute(''' - SELECT - item, - rank() OVER (ORDER BY id), - count(*) OVER (ROWS BETWEEN UNBOUNDED PRECEDING - AND UNBOUNDED FOLLOWING) AS total - FROM - mcmaniacs - ORDER BY - {order} - LIMIT - 1 - {offset} - '''.format(order=order, offset=offset), [index]) - result = self.cur.fetchone() + with self.con.cursor() as cur: + cur.execute(''' + SELECT + item, + rank() OVER (ORDER BY id), + count(*) OVER (ROWS BETWEEN UNBOUNDED PRECEDING + AND UNBOUNDED FOLLOWING) AS total + FROM + mcmaniacs + ORDER BY + {order} + LIMIT + 1 + {offset} + '''.format(order=order, offset=offset), [index]) + result = cur.fetchone() - if result: - return '[{rank}/{total}] {item}'.format(**result) + if result: + return '[{rank}/{total}] {item}'.format(**result) # TODO: fix regex ("McFooiaC McBariaC" adds "Mc\S+iaC") @irc3.event(r'^:(?P\S+) PRIVMSG \S+ :.*(?PMc\S+iaC).*') def save(self, mask: str, item: str): if IrcString(mask).nick != self.bot.nick: - self.cur.execute(''' - INSERT INTO - mcmaniacs (item) - VALUES - (%s) - ON CONFLICT DO NOTHING - ''', [item]) + with self.con.cursor() as cur: + cur.execute(''' + INSERT INTO + mcmaniacs (item) + VALUES + (%s) + ON CONFLICT DO NOTHING + ''', [item]) self.con.commit() diff --git a/bot/quotes.py b/bot/quotes.py index 948d95b..e5e6627 100644 --- a/bot/quotes.py +++ b/bot/quotes.py @@ -23,43 +23,45 @@ class Quotes(DatabasePlugin): self.bot.notice(mask.nick, '[Quotes] Error parsing nick') else: # Insert quote into database - self.cur.execute(''' - INSERT INTO - quotes (nick, item, channel, created_by) - VALUES - (%s, %s, %s, %s) - ''', [nick, quote, channel, mask.nick]) + with self.con.cursor() as cur: + cur.execute(''' + INSERT INTO + quotes (nick, item, channel, created_by) + VALUES + (%s, %s, %s, %s) + ''', [nick, quote, channel, mask.nick]) def delete_quote(self, nick: str, quote: str): index, order = parse_int(quote, select=False) if index: # Delete from database - self.cur.execute(''' - -- noinspection SqlResolve - WITH ranked_quotes AS ( - SELECT - id, - rank() OVER (PARTITION BY nick ORDER BY id {order}) - FROM - quotes - WHERE - lower(nick) = lower(%s) - ) + with self.con.cursor() as cur: + cur.execute(''' + -- noinspection SqlResolve + WITH ranked_quotes AS ( + SELECT + id, + rank() OVER (PARTITION BY nick ORDER BY id {order}) + FROM + quotes + WHERE + lower(nick) = lower(%s) + ) - -- noinspection SqlResolve - DELETE FROM - quotes - WHERE - id = ( - SELECT - id - FROM - ranked_quotes + -- noinspection SqlResolve + DELETE FROM + quotes WHERE - rank = %s - ) - '''.format(order=order), [nick, index]) + id = ( + SELECT + id + FROM + ranked_quotes + WHERE + rank = %s + ) + '''.format(order=order), [nick, index]) @command(options_first=True, quiet=True) def q(self, mask: IrcString, target: IrcString, args: Dict): @@ -135,30 +137,31 @@ class Quotes(DatabasePlugin): offset = '' # Fetch quote from database - self.cur.execute(''' - WITH ranked_quotes AS ( - SELECT - nick, - item, - rank() OVER (PARTITION BY nick ORDER BY id), - count(*) OVER (PARTITION BY nick) AS total - FROM - quotes - ) + with self.con.cursor() as cur: + cur.execute(''' + WITH ranked_quotes AS ( + SELECT + nick, + item, + rank() OVER (PARTITION BY nick ORDER BY id), + count(*) OVER (PARTITION BY nick) AS total + FROM + quotes + ) - SELECT - * - FROM - ranked_quotes - WHERE - {where} - ORDER BY - {order} - LIMIT - 1 - {offset} - '''.format(where=' AND '.join(where), order=order, offset=offset), values) - result = self.cur.fetchone() + SELECT + * + FROM + ranked_quotes + WHERE + {where} + ORDER BY + {order} + LIMIT + 1 + {offset} + '''.format(where=' AND '.join(where), order=order, offset=offset), values) + result = cur.fetchone() if result: return '[{rank}/{total}] <{nick}> {item}'.format(**result) diff --git a/bot/rape.py b/bot/rape.py index 6fb21d6..e950283 100644 --- a/bot/rape.py +++ b/bot/rape.py @@ -18,15 +18,16 @@ class Rape(DatabasePlugin): nick = args.get('', mask.nick) # Fetch result from database - self.cur.execute(''' - SELECT - fines - FROM - users - WHERE - lower(nick) = lower(%s) - ''', [nick]) - owes = self.cur.fetchone() + with self.con.cursor() as cur: + cur.execute(''' + SELECT + fines + FROM + users + WHERE + lower(nick) = lower(%s) + ''', [nick]) + owes = cur.fetchone() # Colorize owe amount and return string if owes: @@ -56,22 +57,23 @@ class Rape(DatabasePlugin): reason = ('raping', 'being too lewd and getting raped')[rand] # Insert or add fine to database and return total owe - self.cur.execute(''' - INSERT INTO - users (nick, fines) - VALUES - (lower(%s), %s) - ON CONFLICT (nick) DO UPDATE SET - fines = users.fines + excluded.fines - RETURNING - fines - ''', [fined, fine]) - self.con.commit() + with self.con.cursor() as cur: + cur.execute(''' + INSERT INTO + users (nick, fines) + VALUES + (lower(%s), %s) + ON CONFLICT (nick) DO UPDATE SET + fines = users.fines + excluded.fines + RETURNING + fines + ''', [fined, fine]) + self.con.commit() - # Print fine and total owe - self.bot.action(target, 'fines {nick} \x02${fine}\x02 for {reason}. You owe: \x0304${total}\x03'.format( - nick=fined, - fine=fine, - reason=reason, - total=self.cur.fetchone()['fines'], - )) + # Print fine and total owe + self.bot.action(target, 'fines {nick} \x02${fine}\x02 for {reason}. You owe: \x0304${total}\x03'.format( + nick=fined, + fine=fine, + reason=reason, + total=cur.fetchone()['fines'], + )) diff --git a/bot/regex.py b/bot/regex.py index b9f86ff..740e56c 100644 --- a/bot/regex.py +++ b/bot/regex.py @@ -20,16 +20,17 @@ class Useless(DatabasePlugin): if nick == self.bot.nick: return - self.cur.execute(''' - SELECT - item - FROM - last_messages - WHERE - nick = lower(%s) - AND channel = lower(%s) - ''', [nick, target]) - result = self.cur.fetchone() + with self.con.cursor() as cur: + cur.execute(''' + SELECT + item + FROM + last_messages + WHERE + nick = lower(%s) + AND channel = lower(%s) + ''', [nick, target]) + result = cur.fetchone() if result: old = result['item'] @@ -44,13 +45,14 @@ class Useless(DatabasePlugin): """Saves the last message of a user for each channel (for regex).""" mask = IrcString(mask) - self.cur.execute(''' - INSERT INTO - last_messages (nick, host, channel, item) - VALUES - (lower(%s), %s, lower(%s), %s) - ON CONFLICT (nick, channel) DO UPDATE SET - host = excluded.host, - item = excluded.item - ''', [mask.nick, mask.host, target, msg]) + with self.con.cursor() as cur: + cur.execute(''' + INSERT INTO + last_messages (nick, host, channel, item) + VALUES + (lower(%s), %s, lower(%s), %s) + ON CONFLICT (nick, channel) DO UPDATE SET + host = excluded.host, + item = excluded.item + ''', [mask.nick, mask.host, target, msg]) self.con.commit() diff --git a/bot/seen.py b/bot/seen.py index 52e6feb..b70a876 100644 --- a/bot/seen.py +++ b/bot/seen.py @@ -23,15 +23,16 @@ class Seen(DatabasePlugin): return '{}, look in the mirror faggot!'.format(nick) # Fetch seen from database - self.cur.execute(''' - SELECT - seen_at, message, channel - FROM - seens - WHERE - nick = lower(%s) - ''', [nick]) - seen = self.cur.fetchone() + with self.con.cursor() as cur: + cur.execute(''' + SELECT + seen_at, message, channel + FROM + seens + WHERE + nick = lower(%s) + ''', [nick]) + seen = cur.fetchone() # No result if not seen: @@ -50,15 +51,16 @@ class Seen(DatabasePlugin): def save(self, mask: str, target: str, msg: str): mask = IrcString(mask) - self.cur.execute(''' - INSERT INTO - seens (nick, host, channel, message) - VALUES - (lower(%s), %s, %s, %s) - ON CONFLICT (nick) DO UPDATE SET - host = excluded.host, - channel = excluded.channel, - seen_at = now(), - message = excluded.message - ''', [mask.nick, mask.host, target, msg]) + with self.con.cursor() as cur: + cur.execute(''' + INSERT INTO + seens (nick, host, channel, message) + VALUES + (lower(%s), %s, %s, %s) + ON CONFLICT (nick) DO UPDATE SET + host = excluded.host, + channel = excluded.channel, + seen_at = now(), + message = excluded.message + ''', [mask.nick, mask.host, target, msg]) self.con.commit() diff --git a/bot/storage.py b/bot/storage.py index 9c12eeb..00c22dc 100644 --- a/bot/storage.py +++ b/bot/storage.py @@ -1,15 +1,46 @@ # -*- coding: utf-8 -*- import os +import contextlib import psycopg2 from psycopg2.extras import DictCursor from . import Plugin, Bot +class DBConn: + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + self._reconnect() + + def _reconnect(self): + self._con = psycopg2.connect(*self.args, **self.kwargs) + + def commit(self, *args, **kwargs): + try: + return self._con.commit(*args, **kwargs) + except psycopg2.InterfaceError: + self._reconnect() + return self._con.commit(*args, **kwargs) + + def rollback(self, *args, **kwargs): + try: + return self._con.rollback(*args, **kwargs) + except psycopg2.InterfaceError: + self._reconnect() + return self._con.rollback(*args, **kwargs) + + @contextlib.contextmanager + def cursor(self, *args, **kwargs): + try: + yield self._con.cursor(cursor_factory=DictCursor, *args, **kwargs) + except psycopg2.InterfaceError: + self._reconnect() + yield self._con.cursor(cursor_factory=DictCursor, *args, **kwargs) + class Storage(Plugin): def __init__(self, bot: Bot): super().__init__(bot) self.bot.sql = self - self.con = psycopg2.connect(os.environ['DATABASE_URI']) - self.cur = self.con.cursor(cursor_factory=DictCursor) + self.con = DBConn(os.environ['DATABASE_URI']) diff --git a/bot/tell.py b/bot/tell.py index b250bf6..31d777a 100644 --- a/bot/tell.py +++ b/bot/tell.py @@ -15,19 +15,20 @@ class Tell(DatabasePlugin): super().__init__(bot) self.tell_queue = {} - self.cur.execute(''' - SELECT - to_nick, from_nick, message, created_at - FROM - tells - ''') + with self.con.cursor() as cur: + cur.execute(''' + SELECT + to_nick, from_nick, message, created_at + FROM + tells + ''') - for res in self.cur.fetchall(): - nick = res['to_nick'].lower() + for res in cur.fetchall(): + nick = res['to_nick'].lower() - if nick not in self.tell_queue: - self.tell_queue[nick] = [] - self.tell_queue[nick].append(res[1:]) + if nick not in self.tell_queue: + self.tell_queue[nick] = [] + self.tell_queue[nick].append(res[1:]) @command def tell(self, mask: IrcString, target: IrcString, args: Dict): @@ -43,12 +44,13 @@ class Tell(DatabasePlugin): self.tell_queue[nick].append(tell[1:]) try: - self.cur.execute(''' - INSERT INTO - tells (to_nick, from_nick, message, created_at) - VALUES - (%s, %s, %s, %s) - ''', tell) + with self.con.cursor() as cur: + cur.execute(''' + INSERT INTO + tells (to_nick, from_nick, message, created_at) + VALUES + (%s, %s, %s, %s) + ''', tell) self.con.commit() self.bot.notice(mask.nick, "I will tell that to {} when I see them.".format(nick)) @@ -73,10 +75,11 @@ class Tell(DatabasePlugin): )) del self.tell_queue[nick] - self.cur.execute(''' - DELETE FROM - tells - WHERE - to_nick = %s - ''', [nick]) + with self.con.cursor() as cur: + cur.execute(''' + DELETE FROM + tells + WHERE + to_nick = %s + ''', [nick]) self.con.commit() diff --git a/bot/timer.py b/bot/timer.py index 35d3bbd..fe1a3c9 100644 --- a/bot/timer.py +++ b/bot/timer.py @@ -34,17 +34,18 @@ class Timer(DatabasePlugin): return 'Invalid timer delay: {}'.format(delay) try: - self.cur.execute(''' - INSERT INTO - timers (mask, target, message, delay, ends_at) - VALUES - (%s, %s, %s, %s, now() + INTERVAL %s) - RETURNING - * - ''', [mask, target, message, delay, delay]) + with self.con.cursor() as cur: + cur.execute(''' + INSERT INTO + timers (mask, target, message, delay, ends_at) + VALUES + (%s, %s, %s, %s, now() + INTERVAL %s) + RETURNING + * + ''', [mask, target, message, delay, delay]) self.con.commit() - asyncio.ensure_future(self.exec_timer(self.cur.fetchone())) + asyncio.ensure_future(self.exec_timer(cur.fetchone())) self.bot.notice(mask.nick, 'Timer in {delay} set: {message}'.format(delay=delay, message=message)) except Error as ex: @@ -54,18 +55,19 @@ class Timer(DatabasePlugin): def set_timers(self): """Function which queries all timers in the next hour and schedules them.""" self.log.debug('Fetching timers') - self.cur.execute(''' - SELECT - * - FROM - timers - WHERE - ends_at >= now() - AND ends_at < now() + INTERVAL '1h' - ''') + with self.con.cursor() as cur: + cur.execute(''' + SELECT + * + FROM + timers + WHERE + ends_at >= now() + AND ends_at < now() + INTERVAL '1h' + ''') - for timer in self.cur.fetchall(): - asyncio.ensure_future(self.exec_timer(timer)) + for timer in cur.fetchall(): + asyncio.ensure_future(self.exec_timer(timer)) async def exec_timer(self, timer: DictRow): """Sets the actual timer (sleeps until it fires), sends the reminder and deletes the timer from database.""" @@ -84,10 +86,11 @@ class Timer(DatabasePlugin): )) self.timers.remove(timer['id']) - self.cur.execute(''' - DELETE FROM - timers - WHERE - id = %s - ''', [timer['id']]) + with self.con.cursor() as cur: + cur.execute(''' + DELETE FROM + timers + WHERE + id = %s + ''', [timer['id']]) self.con.commit() diff --git a/bot/useless.py b/bot/useless.py index 3d6aa04..55941ad 100644 --- a/bot/useless.py +++ b/bot/useless.py @@ -143,19 +143,20 @@ class Useless(DatabasePlugin): %%kill [] """ - self.cur.execute(''' - SELECT - item - FROM - kills - ORDER BY - random() - LIMIT - 1 - ''') - self.bot.action(target, self.cur.fetchone()['item'].format( - nick=args.get('', mask.nick), - )) + with self.con.cursor() as cur: + cur.execute(''' + SELECT + item + FROM + kills + ORDER BY + random() + LIMIT + 1 + ''') + self.bot.action(target, cur.fetchone()['item'].format( + nick=args.get('', mask.nick), + )) @command def yiff(self, mask: IrcString, target: IrcString, args: Dict): @@ -163,20 +164,21 @@ class Useless(DatabasePlugin): %%yiff [] """ - self.cur.execute(''' - SELECT - item - FROM - yiffs - ORDER BY - random() - LIMIT - 1 - ''') - self.bot.action(target, self.cur.fetchone()['item'].format( - nick=args.get('', mask.nick), - yiffer=mask.nick, - )) + with self.con.cursor() as cur: + cur.execute(''' + SELECT + item + FROM + yiffs + ORDER BY + random() + LIMIT + 1 + ''') + self.bot.action(target, cur.fetchone()['item'].format( + nick=args.get('', mask.nick), + yiffer=mask.nick, + )) @command def waifu(self, mask: IrcString, target: IrcString, args: Dict): @@ -477,14 +479,15 @@ class Useless(DatabasePlugin): nick = nick[1:] try: - self.cur.execute(''' - INSERT INTO - users (nick, {0}) - VALUES - (lower(%s), %s) - ON CONFLICT (nick) DO UPDATE SET - {0} = excluded.{0} - '''.format(field), [mask.nick, nick]) + with self.con.cursor() as cur: + cur.execute(''' + INSERT INTO + users (nick, {0}) + VALUES + (lower(%s), %s) + ON CONFLICT (nick) DO UPDATE SET + {0} = excluded.{0} + '''.format(field), [mask.nick, nick]) self.con.commit() self.bot.notice(mask.nick, '{} set to: {}'.format(field.title(), nick)) @@ -492,15 +495,16 @@ class Useless(DatabasePlugin): self.log.error(ex) self.con.rollback() else: - self.cur.execute(''' - SELECT - {} - FROM - users - WHERE - lower(nick) = lower(%s) - '''.format(field), [nick]) - result = self.cur.fetchone() + with self.con.cursor() as cur: + cur.execute(''' + SELECT + {} + FROM + users + WHERE + lower(nick) = lower(%s) + '''.format(field), [nick]) + result = cur.fetchone() if result and result[field]: return '\x02[{}]\x02 {}: {}'.format(field.title(), nick, result[field])