diff --git a/bot/__init__.py b/bot/__init__.py index acd17ad..b48eabb 100644 --- a/bot/__init__.py +++ b/bot/__init__.py @@ -73,4 +73,3 @@ class DatabasePlugin(Plugin): super().__init__(bot) self.db = bot.get_plugin(Storage) self.con = self.db.con - self.cur = self.db.cur diff --git a/bot/mcmaniac.py b/bot/mcmaniac.py index b5cbef5..9b203be 100644 --- a/bot/mcmaniac.py +++ b/bot/mcmaniac.py @@ -29,34 +29,36 @@ class McManiac(DatabasePlugin): offset = '' # Fetch result from database - self.cur.execute(''' - SELECT - item, - rank() OVER (ORDER BY id), - count(*) OVER (ROWS BETWEEN UNBOUNDED PRECEDING - AND UNBOUNDED FOLLOWING) AS total - FROM - mcmaniacs - ORDER BY - {order} - LIMIT - 1 - {offset} - '''.format(order=order, offset=offset), [index]) - result = self.cur.fetchone() + with self.con.cursor() as cur: + cur.execute(''' + SELECT + item, + rank() OVER (ORDER BY id), + count(*) OVER (ROWS BETWEEN UNBOUNDED PRECEDING + AND UNBOUNDED FOLLOWING) AS total + FROM + mcmaniacs + ORDER BY + {order} + LIMIT + 1 + {offset} + '''.format(order=order, offset=offset), [index]) + result = cur.fetchone() - if result: - return '[{rank}/{total}] {item}'.format(**result) + if result: + return '[{rank}/{total}] {item}'.format(**result) # TODO: fix regex ("McFooiaC McBariaC" adds "Mc\S+iaC") @irc3.event(r'^:(?P\S+) PRIVMSG \S+ :.*(?PMc\S+iaC).*') def save(self, mask: str, item: str): if IrcString(mask).nick != self.bot.nick: - self.cur.execute(''' - INSERT INTO - mcmaniacs (item) - VALUES - (%s) - ON CONFLICT DO NOTHING - ''', [item]) + with self.con.cursor() as cur: + cur.execute(''' + INSERT INTO + mcmaniacs (item) + VALUES + (%s) + ON CONFLICT DO NOTHING + ''', [item]) self.con.commit() diff --git a/bot/quotes.py b/bot/quotes.py index 948d95b..e5e6627 100644 --- a/bot/quotes.py +++ b/bot/quotes.py @@ -23,43 +23,45 @@ class Quotes(DatabasePlugin): self.bot.notice(mask.nick, '[Quotes] Error parsing nick') else: # Insert quote into database - self.cur.execute(''' - INSERT INTO - quotes (nick, item, channel, created_by) - VALUES - (%s, %s, %s, %s) - ''', [nick, quote, channel, mask.nick]) + with self.con.cursor() as cur: + cur.execute(''' + INSERT INTO + quotes (nick, item, channel, created_by) + VALUES + (%s, %s, %s, %s) + ''', [nick, quote, channel, mask.nick]) def delete_quote(self, nick: str, quote: str): index, order = parse_int(quote, select=False) if index: # Delete from database - self.cur.execute(''' - -- noinspection SqlResolve - WITH ranked_quotes AS ( - SELECT - id, - rank() OVER (PARTITION BY nick ORDER BY id {order}) - FROM - quotes - WHERE - lower(nick) = lower(%s) - ) + with self.con.cursor() as cur: + cur.execute(''' + -- noinspection SqlResolve + WITH ranked_quotes AS ( + SELECT + id, + rank() OVER (PARTITION BY nick ORDER BY id {order}) + FROM + quotes + WHERE + lower(nick) = lower(%s) + ) - -- noinspection SqlResolve - DELETE FROM - quotes - WHERE - id = ( - SELECT - id - FROM - ranked_quotes + -- noinspection SqlResolve + DELETE FROM + quotes WHERE - rank = %s - ) - '''.format(order=order), [nick, index]) + id = ( + SELECT + id + FROM + ranked_quotes + WHERE + rank = %s + ) + '''.format(order=order), [nick, index]) @command(options_first=True, quiet=True) def q(self, mask: IrcString, target: IrcString, args: Dict): @@ -135,30 +137,31 @@ class Quotes(DatabasePlugin): offset = '' # Fetch quote from database - self.cur.execute(''' - WITH ranked_quotes AS ( - SELECT - nick, - item, - rank() OVER (PARTITION BY nick ORDER BY id), - count(*) OVER (PARTITION BY nick) AS total - FROM - quotes - ) + with self.con.cursor() as cur: + cur.execute(''' + WITH ranked_quotes AS ( + SELECT + nick, + item, + rank() OVER (PARTITION BY nick ORDER BY id), + count(*) OVER (PARTITION BY nick) AS total + FROM + quotes + ) - SELECT - * - FROM - ranked_quotes - WHERE - {where} - ORDER BY - {order} - LIMIT - 1 - {offset} - '''.format(where=' AND '.join(where), order=order, offset=offset), values) - result = self.cur.fetchone() + SELECT + * + FROM + ranked_quotes + WHERE + {where} + ORDER BY + {order} + LIMIT + 1 + {offset} + '''.format(where=' AND '.join(where), order=order, offset=offset), values) + result = cur.fetchone() if result: return '[{rank}/{total}] <{nick}> {item}'.format(**result) diff --git a/bot/rape.py b/bot/rape.py index 6fb21d6..e950283 100644 --- a/bot/rape.py +++ b/bot/rape.py @@ -18,15 +18,16 @@ class Rape(DatabasePlugin): nick = args.get('', mask.nick) # Fetch result from database - self.cur.execute(''' - SELECT - fines - FROM - users - WHERE - lower(nick) = lower(%s) - ''', [nick]) - owes = self.cur.fetchone() + with self.con.cursor() as cur: + cur.execute(''' + SELECT + fines + FROM + users + WHERE + lower(nick) = lower(%s) + ''', [nick]) + owes = cur.fetchone() # Colorize owe amount and return string if owes: @@ -56,22 +57,23 @@ class Rape(DatabasePlugin): reason = ('raping', 'being too lewd and getting raped')[rand] # Insert or add fine to database and return total owe - self.cur.execute(''' - INSERT INTO - users (nick, fines) - VALUES - (lower(%s), %s) - ON CONFLICT (nick) DO UPDATE SET - fines = users.fines + excluded.fines - RETURNING - fines - ''', [fined, fine]) - self.con.commit() + with self.con.cursor() as cur: + cur.execute(''' + INSERT INTO + users (nick, fines) + VALUES + (lower(%s), %s) + ON CONFLICT (nick) DO UPDATE SET + fines = users.fines + excluded.fines + RETURNING + fines + ''', [fined, fine]) + self.con.commit() - # Print fine and total owe - self.bot.action(target, 'fines {nick} \x02${fine}\x02 for {reason}. You owe: \x0304${total}\x03'.format( - nick=fined, - fine=fine, - reason=reason, - total=self.cur.fetchone()['fines'], - )) + # Print fine and total owe + self.bot.action(target, 'fines {nick} \x02${fine}\x02 for {reason}. You owe: \x0304${total}\x03'.format( + nick=fined, + fine=fine, + reason=reason, + total=cur.fetchone()['fines'], + )) diff --git a/bot/regex.py b/bot/regex.py index b9f86ff..740e56c 100644 --- a/bot/regex.py +++ b/bot/regex.py @@ -20,16 +20,17 @@ class Useless(DatabasePlugin): if nick == self.bot.nick: return - self.cur.execute(''' - SELECT - item - FROM - last_messages - WHERE - nick = lower(%s) - AND channel = lower(%s) - ''', [nick, target]) - result = self.cur.fetchone() + with self.con.cursor() as cur: + cur.execute(''' + SELECT + item + FROM + last_messages + WHERE + nick = lower(%s) + AND channel = lower(%s) + ''', [nick, target]) + result = cur.fetchone() if result: old = result['item'] @@ -44,13 +45,14 @@ class Useless(DatabasePlugin): """Saves the last message of a user for each channel (for regex).""" mask = IrcString(mask) - self.cur.execute(''' - INSERT INTO - last_messages (nick, host, channel, item) - VALUES - (lower(%s), %s, lower(%s), %s) - ON CONFLICT (nick, channel) DO UPDATE SET - host = excluded.host, - item = excluded.item - ''', [mask.nick, mask.host, target, msg]) + with self.con.cursor() as cur: + cur.execute(''' + INSERT INTO + last_messages (nick, host, channel, item) + VALUES + (lower(%s), %s, lower(%s), %s) + ON CONFLICT (nick, channel) DO UPDATE SET + host = excluded.host, + item = excluded.item + ''', [mask.nick, mask.host, target, msg]) self.con.commit() diff --git a/bot/seen.py b/bot/seen.py index 52e6feb..b70a876 100644 --- a/bot/seen.py +++ b/bot/seen.py @@ -23,15 +23,16 @@ class Seen(DatabasePlugin): return '{}, look in the mirror faggot!'.format(nick) # Fetch seen from database - self.cur.execute(''' - SELECT - seen_at, message, channel - FROM - seens - WHERE - nick = lower(%s) - ''', [nick]) - seen = self.cur.fetchone() + with self.con.cursor() as cur: + cur.execute(''' + SELECT + seen_at, message, channel + FROM + seens + WHERE + nick = lower(%s) + ''', [nick]) + seen = cur.fetchone() # No result if not seen: @@ -50,15 +51,16 @@ class Seen(DatabasePlugin): def save(self, mask: str, target: str, msg: str): mask = IrcString(mask) - self.cur.execute(''' - INSERT INTO - seens (nick, host, channel, message) - VALUES - (lower(%s), %s, %s, %s) - ON CONFLICT (nick) DO UPDATE SET - host = excluded.host, - channel = excluded.channel, - seen_at = now(), - message = excluded.message - ''', [mask.nick, mask.host, target, msg]) + with self.con.cursor() as cur: + cur.execute(''' + INSERT INTO + seens (nick, host, channel, message) + VALUES + (lower(%s), %s, %s, %s) + ON CONFLICT (nick) DO UPDATE SET + host = excluded.host, + channel = excluded.channel, + seen_at = now(), + message = excluded.message + ''', [mask.nick, mask.host, target, msg]) self.con.commit() diff --git a/bot/storage.py b/bot/storage.py index 9c12eeb..00c22dc 100644 --- a/bot/storage.py +++ b/bot/storage.py @@ -1,15 +1,46 @@ # -*- coding: utf-8 -*- import os +import contextlib import psycopg2 from psycopg2.extras import DictCursor from . import Plugin, Bot +class DBConn: + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + self._reconnect() + + def _reconnect(self): + self._con = psycopg2.connect(*self.args, **self.kwargs) + + def commit(self, *args, **kwargs): + try: + return self._con.commit(*args, **kwargs) + except psycopg2.InterfaceError: + self._reconnect() + return self._con.commit(*args, **kwargs) + + def rollback(self, *args, **kwargs): + try: + return self._con.rollback(*args, **kwargs) + except psycopg2.InterfaceError: + self._reconnect() + return self._con.rollback(*args, **kwargs) + + @contextlib.contextmanager + def cursor(self, *args, **kwargs): + try: + yield self._con.cursor(cursor_factory=DictCursor, *args, **kwargs) + except psycopg2.InterfaceError: + self._reconnect() + yield self._con.cursor(cursor_factory=DictCursor, *args, **kwargs) + class Storage(Plugin): def __init__(self, bot: Bot): super().__init__(bot) self.bot.sql = self - self.con = psycopg2.connect(os.environ['DATABASE_URI']) - self.cur = self.con.cursor(cursor_factory=DictCursor) + self.con = DBConn(os.environ['DATABASE_URI']) diff --git a/bot/tell.py b/bot/tell.py index b250bf6..31d777a 100644 --- a/bot/tell.py +++ b/bot/tell.py @@ -15,19 +15,20 @@ class Tell(DatabasePlugin): super().__init__(bot) self.tell_queue = {} - self.cur.execute(''' - SELECT - to_nick, from_nick, message, created_at - FROM - tells - ''') + with self.con.cursor() as cur: + cur.execute(''' + SELECT + to_nick, from_nick, message, created_at + FROM + tells + ''') - for res in self.cur.fetchall(): - nick = res['to_nick'].lower() + for res in cur.fetchall(): + nick = res['to_nick'].lower() - if nick not in self.tell_queue: - self.tell_queue[nick] = [] - self.tell_queue[nick].append(res[1:]) + if nick not in self.tell_queue: + self.tell_queue[nick] = [] + self.tell_queue[nick].append(res[1:]) @command def tell(self, mask: IrcString, target: IrcString, args: Dict): @@ -43,12 +44,13 @@ class Tell(DatabasePlugin): self.tell_queue[nick].append(tell[1:]) try: - self.cur.execute(''' - INSERT INTO - tells (to_nick, from_nick, message, created_at) - VALUES - (%s, %s, %s, %s) - ''', tell) + with self.con.cursor() as cur: + cur.execute(''' + INSERT INTO + tells (to_nick, from_nick, message, created_at) + VALUES + (%s, %s, %s, %s) + ''', tell) self.con.commit() self.bot.notice(mask.nick, "I will tell that to {} when I see them.".format(nick)) @@ -73,10 +75,11 @@ class Tell(DatabasePlugin): )) del self.tell_queue[nick] - self.cur.execute(''' - DELETE FROM - tells - WHERE - to_nick = %s - ''', [nick]) + with self.con.cursor() as cur: + cur.execute(''' + DELETE FROM + tells + WHERE + to_nick = %s + ''', [nick]) self.con.commit() diff --git a/bot/timer.py b/bot/timer.py index 35d3bbd..fe1a3c9 100644 --- a/bot/timer.py +++ b/bot/timer.py @@ -34,17 +34,18 @@ class Timer(DatabasePlugin): return 'Invalid timer delay: {}'.format(delay) try: - self.cur.execute(''' - INSERT INTO - timers (mask, target, message, delay, ends_at) - VALUES - (%s, %s, %s, %s, now() + INTERVAL %s) - RETURNING - * - ''', [mask, target, message, delay, delay]) + with self.con.cursor() as cur: + cur.execute(''' + INSERT INTO + timers (mask, target, message, delay, ends_at) + VALUES + (%s, %s, %s, %s, now() + INTERVAL %s) + RETURNING + * + ''', [mask, target, message, delay, delay]) self.con.commit() - asyncio.ensure_future(self.exec_timer(self.cur.fetchone())) + asyncio.ensure_future(self.exec_timer(cur.fetchone())) self.bot.notice(mask.nick, 'Timer in {delay} set: {message}'.format(delay=delay, message=message)) except Error as ex: @@ -54,18 +55,19 @@ class Timer(DatabasePlugin): def set_timers(self): """Function which queries all timers in the next hour and schedules them.""" self.log.debug('Fetching timers') - self.cur.execute(''' - SELECT - * - FROM - timers - WHERE - ends_at >= now() - AND ends_at < now() + INTERVAL '1h' - ''') + with self.con.cursor() as cur: + cur.execute(''' + SELECT + * + FROM + timers + WHERE + ends_at >= now() + AND ends_at < now() + INTERVAL '1h' + ''') - for timer in self.cur.fetchall(): - asyncio.ensure_future(self.exec_timer(timer)) + for timer in cur.fetchall(): + asyncio.ensure_future(self.exec_timer(timer)) async def exec_timer(self, timer: DictRow): """Sets the actual timer (sleeps until it fires), sends the reminder and deletes the timer from database.""" @@ -84,10 +86,11 @@ class Timer(DatabasePlugin): )) self.timers.remove(timer['id']) - self.cur.execute(''' - DELETE FROM - timers - WHERE - id = %s - ''', [timer['id']]) + with self.con.cursor() as cur: + cur.execute(''' + DELETE FROM + timers + WHERE + id = %s + ''', [timer['id']]) self.con.commit() diff --git a/bot/useless.py b/bot/useless.py index 3d6aa04..55941ad 100644 --- a/bot/useless.py +++ b/bot/useless.py @@ -143,19 +143,20 @@ class Useless(DatabasePlugin): %%kill [] """ - self.cur.execute(''' - SELECT - item - FROM - kills - ORDER BY - random() - LIMIT - 1 - ''') - self.bot.action(target, self.cur.fetchone()['item'].format( - nick=args.get('', mask.nick), - )) + with self.con.cursor() as cur: + cur.execute(''' + SELECT + item + FROM + kills + ORDER BY + random() + LIMIT + 1 + ''') + self.bot.action(target, cur.fetchone()['item'].format( + nick=args.get('', mask.nick), + )) @command def yiff(self, mask: IrcString, target: IrcString, args: Dict): @@ -163,20 +164,21 @@ class Useless(DatabasePlugin): %%yiff [] """ - self.cur.execute(''' - SELECT - item - FROM - yiffs - ORDER BY - random() - LIMIT - 1 - ''') - self.bot.action(target, self.cur.fetchone()['item'].format( - nick=args.get('', mask.nick), - yiffer=mask.nick, - )) + with self.con.cursor() as cur: + cur.execute(''' + SELECT + item + FROM + yiffs + ORDER BY + random() + LIMIT + 1 + ''') + self.bot.action(target, cur.fetchone()['item'].format( + nick=args.get('', mask.nick), + yiffer=mask.nick, + )) @command def waifu(self, mask: IrcString, target: IrcString, args: Dict): @@ -477,14 +479,15 @@ class Useless(DatabasePlugin): nick = nick[1:] try: - self.cur.execute(''' - INSERT INTO - users (nick, {0}) - VALUES - (lower(%s), %s) - ON CONFLICT (nick) DO UPDATE SET - {0} = excluded.{0} - '''.format(field), [mask.nick, nick]) + with self.con.cursor() as cur: + cur.execute(''' + INSERT INTO + users (nick, {0}) + VALUES + (lower(%s), %s) + ON CONFLICT (nick) DO UPDATE SET + {0} = excluded.{0} + '''.format(field), [mask.nick, nick]) self.con.commit() self.bot.notice(mask.nick, '{} set to: {}'.format(field.title(), nick)) @@ -492,15 +495,16 @@ class Useless(DatabasePlugin): self.log.error(ex) self.con.rollback() else: - self.cur.execute(''' - SELECT - {} - FROM - users - WHERE - lower(nick) = lower(%s) - '''.format(field), [nick]) - result = self.cur.fetchone() + with self.con.cursor() as cur: + cur.execute(''' + SELECT + {} + FROM + users + WHERE + lower(nick) = lower(%s) + '''.format(field), [nick]) + result = cur.fetchone() if result and result[field]: return '\x02[{}]\x02 {}: {}'.format(field.title(), nick, result[field])