add auto reconnect for postgres

This only works on the second database interaction, since psycopg2 only notices that
the connection is gone, when a query is executed.

So in the common case reconnect works as follows:
- some bot method calls a cursor function like .execute(), .fetchone(), etc.
  - this raises an error if the connection is broken
  - if following code then requests a new cursor, this will also fail since psycopg2
    now knows that the connection is gone
  - the error is caught in storage.DBConn.cursor(), a new connection will be set up
    of which a new cursor is yielded
If the error happens in connection.commit() or .rollback() instead we can instantly
reconnect since these methods are wrapped.

So why not wrap the cursor methods as well?
Consider the following example:
A query is the last thing that was executed on a cursor.
The database connection is lost.
Now .fetchone() is called on the cursor.
We could wrap .fetchone() and reconnect, but we'd have to use a new cursor since
cursors are linked to connections. And on this new cursor .fetchone() wouldn't
make any sense, since we haven't executed a query on this cursor.
This commit is contained in:
jkhsjdhjs 2020-03-16 21:51:32 +00:00
parent 84396cad99
commit e03f5d0a43
10 changed files with 289 additions and 238 deletions

View File

@ -73,4 +73,3 @@ class DatabasePlugin(Plugin):
super().__init__(bot) super().__init__(bot)
self.db = bot.get_plugin(Storage) self.db = bot.get_plugin(Storage)
self.con = self.db.con self.con = self.db.con
self.cur = self.db.cur

View File

@ -29,7 +29,8 @@ class McManiac(DatabasePlugin):
offset = '' offset = ''
# Fetch result from database # Fetch result from database
self.cur.execute(''' with self.con.cursor() as cur:
cur.execute('''
SELECT SELECT
item, item,
rank() OVER (ORDER BY id), rank() OVER (ORDER BY id),
@ -43,7 +44,7 @@ class McManiac(DatabasePlugin):
1 1
{offset} {offset}
'''.format(order=order, offset=offset), [index]) '''.format(order=order, offset=offset), [index])
result = self.cur.fetchone() result = cur.fetchone()
if result: if result:
return '[{rank}/{total}] {item}'.format(**result) return '[{rank}/{total}] {item}'.format(**result)
@ -52,7 +53,8 @@ class McManiac(DatabasePlugin):
@irc3.event(r'^:(?P<mask>\S+) PRIVMSG \S+ :.*(?P<item>Mc\S+iaC).*') @irc3.event(r'^:(?P<mask>\S+) PRIVMSG \S+ :.*(?P<item>Mc\S+iaC).*')
def save(self, mask: str, item: str): def save(self, mask: str, item: str):
if IrcString(mask).nick != self.bot.nick: if IrcString(mask).nick != self.bot.nick:
self.cur.execute(''' with self.con.cursor() as cur:
cur.execute('''
INSERT INTO INSERT INTO
mcmaniacs (item) mcmaniacs (item)
VALUES VALUES

View File

@ -23,7 +23,8 @@ class Quotes(DatabasePlugin):
self.bot.notice(mask.nick, '[Quotes] Error parsing nick') self.bot.notice(mask.nick, '[Quotes] Error parsing nick')
else: else:
# Insert quote into database # Insert quote into database
self.cur.execute(''' with self.con.cursor() as cur:
cur.execute('''
INSERT INTO INSERT INTO
quotes (nick, item, channel, created_by) quotes (nick, item, channel, created_by)
VALUES VALUES
@ -35,7 +36,8 @@ class Quotes(DatabasePlugin):
if index: if index:
# Delete from database # Delete from database
self.cur.execute(''' with self.con.cursor() as cur:
cur.execute('''
-- noinspection SqlResolve -- noinspection SqlResolve
WITH ranked_quotes AS ( WITH ranked_quotes AS (
SELECT SELECT
@ -135,7 +137,8 @@ class Quotes(DatabasePlugin):
offset = '' offset = ''
# Fetch quote from database # Fetch quote from database
self.cur.execute(''' with self.con.cursor() as cur:
cur.execute('''
WITH ranked_quotes AS ( WITH ranked_quotes AS (
SELECT SELECT
nick, nick,
@ -158,7 +161,7 @@ class Quotes(DatabasePlugin):
1 1
{offset} {offset}
'''.format(where=' AND '.join(where), order=order, offset=offset), values) '''.format(where=' AND '.join(where), order=order, offset=offset), values)
result = self.cur.fetchone() result = cur.fetchone()
if result: if result:
return '[{rank}/{total}] <{nick}> {item}'.format(**result) return '[{rank}/{total}] <{nick}> {item}'.format(**result)

View File

@ -18,7 +18,8 @@ class Rape(DatabasePlugin):
nick = args.get('<nick>', mask.nick) nick = args.get('<nick>', mask.nick)
# Fetch result from database # Fetch result from database
self.cur.execute(''' with self.con.cursor() as cur:
cur.execute('''
SELECT SELECT
fines fines
FROM FROM
@ -26,7 +27,7 @@ class Rape(DatabasePlugin):
WHERE WHERE
lower(nick) = lower(%s) lower(nick) = lower(%s)
''', [nick]) ''', [nick])
owes = self.cur.fetchone() owes = cur.fetchone()
# Colorize owe amount and return string # Colorize owe amount and return string
if owes: if owes:
@ -56,7 +57,8 @@ class Rape(DatabasePlugin):
reason = ('raping', 'being too lewd and getting raped')[rand] reason = ('raping', 'being too lewd and getting raped')[rand]
# Insert or add fine to database and return total owe # Insert or add fine to database and return total owe
self.cur.execute(''' with self.con.cursor() as cur:
cur.execute('''
INSERT INTO INSERT INTO
users (nick, fines) users (nick, fines)
VALUES VALUES
@ -73,5 +75,5 @@ class Rape(DatabasePlugin):
nick=fined, nick=fined,
fine=fine, fine=fine,
reason=reason, reason=reason,
total=self.cur.fetchone()['fines'], total=cur.fetchone()['fines'],
)) ))

View File

@ -20,7 +20,8 @@ class Useless(DatabasePlugin):
if nick == self.bot.nick: if nick == self.bot.nick:
return return
self.cur.execute(''' with self.con.cursor() as cur:
cur.execute('''
SELECT SELECT
item item
FROM FROM
@ -29,7 +30,7 @@ class Useless(DatabasePlugin):
nick = lower(%s) nick = lower(%s)
AND channel = lower(%s) AND channel = lower(%s)
''', [nick, target]) ''', [nick, target])
result = self.cur.fetchone() result = cur.fetchone()
if result: if result:
old = result['item'] old = result['item']
@ -44,7 +45,8 @@ class Useless(DatabasePlugin):
"""Saves the last message of a user for each channel (for regex).""" """Saves the last message of a user for each channel (for regex)."""
mask = IrcString(mask) mask = IrcString(mask)
self.cur.execute(''' with self.con.cursor() as cur:
cur.execute('''
INSERT INTO INSERT INTO
last_messages (nick, host, channel, item) last_messages (nick, host, channel, item)
VALUES VALUES

View File

@ -23,7 +23,8 @@ class Seen(DatabasePlugin):
return '{}, look in the mirror faggot!'.format(nick) return '{}, look in the mirror faggot!'.format(nick)
# Fetch seen from database # Fetch seen from database
self.cur.execute(''' with self.con.cursor() as cur:
cur.execute('''
SELECT SELECT
seen_at, message, channel seen_at, message, channel
FROM FROM
@ -31,7 +32,7 @@ class Seen(DatabasePlugin):
WHERE WHERE
nick = lower(%s) nick = lower(%s)
''', [nick]) ''', [nick])
seen = self.cur.fetchone() seen = cur.fetchone()
# No result # No result
if not seen: if not seen:
@ -50,7 +51,8 @@ class Seen(DatabasePlugin):
def save(self, mask: str, target: str, msg: str): def save(self, mask: str, target: str, msg: str):
mask = IrcString(mask) mask = IrcString(mask)
self.cur.execute(''' with self.con.cursor() as cur:
cur.execute('''
INSERT INTO INSERT INTO
seens (nick, host, channel, message) seens (nick, host, channel, message)
VALUES VALUES

View File

@ -1,15 +1,46 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import os import os
import contextlib
import psycopg2 import psycopg2
from psycopg2.extras import DictCursor from psycopg2.extras import DictCursor
from . import Plugin, Bot from . import Plugin, Bot
class DBConn:
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
self._reconnect()
def _reconnect(self):
self._con = psycopg2.connect(*self.args, **self.kwargs)
def commit(self, *args, **kwargs):
try:
return self._con.commit(*args, **kwargs)
except psycopg2.InterfaceError:
self._reconnect()
return self._con.commit(*args, **kwargs)
def rollback(self, *args, **kwargs):
try:
return self._con.rollback(*args, **kwargs)
except psycopg2.InterfaceError:
self._reconnect()
return self._con.rollback(*args, **kwargs)
@contextlib.contextmanager
def cursor(self, *args, **kwargs):
try:
yield self._con.cursor(cursor_factory=DictCursor, *args, **kwargs)
except psycopg2.InterfaceError:
self._reconnect()
yield self._con.cursor(cursor_factory=DictCursor, *args, **kwargs)
class Storage(Plugin): class Storage(Plugin):
def __init__(self, bot: Bot): def __init__(self, bot: Bot):
super().__init__(bot) super().__init__(bot)
self.bot.sql = self self.bot.sql = self
self.con = psycopg2.connect(os.environ['DATABASE_URI']) self.con = DBConn(os.environ['DATABASE_URI'])
self.cur = self.con.cursor(cursor_factory=DictCursor)

View File

@ -15,14 +15,15 @@ class Tell(DatabasePlugin):
super().__init__(bot) super().__init__(bot)
self.tell_queue = {} self.tell_queue = {}
self.cur.execute(''' with self.con.cursor() as cur:
cur.execute('''
SELECT SELECT
to_nick, from_nick, message, created_at to_nick, from_nick, message, created_at
FROM FROM
tells tells
''') ''')
for res in self.cur.fetchall(): for res in cur.fetchall():
nick = res['to_nick'].lower() nick = res['to_nick'].lower()
if nick not in self.tell_queue: if nick not in self.tell_queue:
@ -43,7 +44,8 @@ class Tell(DatabasePlugin):
self.tell_queue[nick].append(tell[1:]) self.tell_queue[nick].append(tell[1:])
try: try:
self.cur.execute(''' with self.con.cursor() as cur:
cur.execute('''
INSERT INTO INSERT INTO
tells (to_nick, from_nick, message, created_at) tells (to_nick, from_nick, message, created_at)
VALUES VALUES
@ -73,7 +75,8 @@ class Tell(DatabasePlugin):
)) ))
del self.tell_queue[nick] del self.tell_queue[nick]
self.cur.execute(''' with self.con.cursor() as cur:
cur.execute('''
DELETE FROM DELETE FROM
tells tells
WHERE WHERE

View File

@ -34,7 +34,8 @@ class Timer(DatabasePlugin):
return 'Invalid timer delay: {}'.format(delay) return 'Invalid timer delay: {}'.format(delay)
try: try:
self.cur.execute(''' with self.con.cursor() as cur:
cur.execute('''
INSERT INTO INSERT INTO
timers (mask, target, message, delay, ends_at) timers (mask, target, message, delay, ends_at)
VALUES VALUES
@ -44,7 +45,7 @@ class Timer(DatabasePlugin):
''', [mask, target, message, delay, delay]) ''', [mask, target, message, delay, delay])
self.con.commit() self.con.commit()
asyncio.ensure_future(self.exec_timer(self.cur.fetchone())) asyncio.ensure_future(self.exec_timer(cur.fetchone()))
self.bot.notice(mask.nick, 'Timer in {delay} set: {message}'.format(delay=delay, message=message)) self.bot.notice(mask.nick, 'Timer in {delay} set: {message}'.format(delay=delay, message=message))
except Error as ex: except Error as ex:
@ -54,7 +55,8 @@ class Timer(DatabasePlugin):
def set_timers(self): def set_timers(self):
"""Function which queries all timers in the next hour and schedules them.""" """Function which queries all timers in the next hour and schedules them."""
self.log.debug('Fetching timers') self.log.debug('Fetching timers')
self.cur.execute(''' with self.con.cursor() as cur:
cur.execute('''
SELECT SELECT
* *
FROM FROM
@ -64,7 +66,7 @@ class Timer(DatabasePlugin):
AND ends_at < now() + INTERVAL '1h' AND ends_at < now() + INTERVAL '1h'
''') ''')
for timer in self.cur.fetchall(): for timer in cur.fetchall():
asyncio.ensure_future(self.exec_timer(timer)) asyncio.ensure_future(self.exec_timer(timer))
async def exec_timer(self, timer: DictRow): async def exec_timer(self, timer: DictRow):
@ -84,7 +86,8 @@ class Timer(DatabasePlugin):
)) ))
self.timers.remove(timer['id']) self.timers.remove(timer['id'])
self.cur.execute(''' with self.con.cursor() as cur:
cur.execute('''
DELETE FROM DELETE FROM
timers timers
WHERE WHERE

View File

@ -143,7 +143,8 @@ class Useless(DatabasePlugin):
%%kill [<nick>] %%kill [<nick>]
""" """
self.cur.execute(''' with self.con.cursor() as cur:
cur.execute('''
SELECT SELECT
item item
FROM FROM
@ -153,7 +154,7 @@ class Useless(DatabasePlugin):
LIMIT LIMIT
1 1
''') ''')
self.bot.action(target, self.cur.fetchone()['item'].format( self.bot.action(target, cur.fetchone()['item'].format(
nick=args.get('<nick>', mask.nick), nick=args.get('<nick>', mask.nick),
)) ))
@ -163,7 +164,8 @@ class Useless(DatabasePlugin):
%%yiff [<nick>] %%yiff [<nick>]
""" """
self.cur.execute(''' with self.con.cursor() as cur:
cur.execute('''
SELECT SELECT
item item
FROM FROM
@ -173,7 +175,7 @@ class Useless(DatabasePlugin):
LIMIT LIMIT
1 1
''') ''')
self.bot.action(target, self.cur.fetchone()['item'].format( self.bot.action(target, cur.fetchone()['item'].format(
nick=args.get('<nick>', mask.nick), nick=args.get('<nick>', mask.nick),
yiffer=mask.nick, yiffer=mask.nick,
)) ))
@ -477,7 +479,8 @@ class Useless(DatabasePlugin):
nick = nick[1:] nick = nick[1:]
try: try:
self.cur.execute(''' with self.con.cursor() as cur:
cur.execute('''
INSERT INTO INSERT INTO
users (nick, {0}) users (nick, {0})
VALUES VALUES
@ -492,7 +495,8 @@ class Useless(DatabasePlugin):
self.log.error(ex) self.log.error(ex)
self.con.rollback() self.con.rollback()
else: else:
self.cur.execute(''' with self.con.cursor() as cur:
cur.execute('''
SELECT SELECT
{} {}
FROM FROM
@ -500,7 +504,7 @@ class Useless(DatabasePlugin):
WHERE WHERE
lower(nick) = lower(%s) lower(nick) = lower(%s)
'''.format(field), [nick]) '''.format(field), [nick])
result = self.cur.fetchone() result = cur.fetchone()
if result and result[field]: if result and result[field]:
return '\x02[{}]\x02 {}: {}'.format(field.title(), nick, result[field]) return '\x02[{}]\x02 {}: {}'.format(field.title(), nick, result[field])