add auto reconnect for postgres

This only works on the second database interaction, since psycopg2 only notices that
the connection is gone, when a query is executed.

So in the common case reconnect works as follows:
- some bot method calls a cursor function like .execute(), .fetchone(), etc.
  - this raises an error if the connection is broken
  - if following code then requests a new cursor, this will also fail since psycopg2
    now knows that the connection is gone
  - the error is caught in storage.DBConn.cursor(), a new connection will be set up
    of which a new cursor is yielded
If the error happens in connection.commit() or .rollback() instead we can instantly
reconnect since these methods are wrapped.

So why not wrap the cursor methods as well?
Consider the following example:
A query is the last thing that was executed on a cursor.
The database connection is lost.
Now .fetchone() is called on the cursor.
We could wrap .fetchone() and reconnect, but we'd have to use a new cursor since
cursors are linked to connections. And on this new cursor .fetchone() wouldn't
make any sense, since we haven't executed a query on this cursor.
This commit is contained in:
jkhsjdhjs 2020-03-16 21:51:32 +00:00
parent 84396cad99
commit e03f5d0a43
10 changed files with 289 additions and 238 deletions

View File

@ -73,4 +73,3 @@ class DatabasePlugin(Plugin):
super().__init__(bot)
self.db = bot.get_plugin(Storage)
self.con = self.db.con
self.cur = self.db.cur

View File

@ -29,7 +29,8 @@ class McManiac(DatabasePlugin):
offset = ''
# Fetch result from database
self.cur.execute('''
with self.con.cursor() as cur:
cur.execute('''
SELECT
item,
rank() OVER (ORDER BY id),
@ -43,7 +44,7 @@ class McManiac(DatabasePlugin):
1
{offset}
'''.format(order=order, offset=offset), [index])
result = self.cur.fetchone()
result = cur.fetchone()
if result:
return '[{rank}/{total}] {item}'.format(**result)
@ -52,7 +53,8 @@ class McManiac(DatabasePlugin):
@irc3.event(r'^:(?P<mask>\S+) PRIVMSG \S+ :.*(?P<item>Mc\S+iaC).*')
def save(self, mask: str, item: str):
if IrcString(mask).nick != self.bot.nick:
self.cur.execute('''
with self.con.cursor() as cur:
cur.execute('''
INSERT INTO
mcmaniacs (item)
VALUES

View File

@ -23,7 +23,8 @@ class Quotes(DatabasePlugin):
self.bot.notice(mask.nick, '[Quotes] Error parsing nick')
else:
# Insert quote into database
self.cur.execute('''
with self.con.cursor() as cur:
cur.execute('''
INSERT INTO
quotes (nick, item, channel, created_by)
VALUES
@ -35,7 +36,8 @@ class Quotes(DatabasePlugin):
if index:
# Delete from database
self.cur.execute('''
with self.con.cursor() as cur:
cur.execute('''
-- noinspection SqlResolve
WITH ranked_quotes AS (
SELECT
@ -135,7 +137,8 @@ class Quotes(DatabasePlugin):
offset = ''
# Fetch quote from database
self.cur.execute('''
with self.con.cursor() as cur:
cur.execute('''
WITH ranked_quotes AS (
SELECT
nick,
@ -158,7 +161,7 @@ class Quotes(DatabasePlugin):
1
{offset}
'''.format(where=' AND '.join(where), order=order, offset=offset), values)
result = self.cur.fetchone()
result = cur.fetchone()
if result:
return '[{rank}/{total}] <{nick}> {item}'.format(**result)

View File

@ -18,7 +18,8 @@ class Rape(DatabasePlugin):
nick = args.get('<nick>', mask.nick)
# Fetch result from database
self.cur.execute('''
with self.con.cursor() as cur:
cur.execute('''
SELECT
fines
FROM
@ -26,7 +27,7 @@ class Rape(DatabasePlugin):
WHERE
lower(nick) = lower(%s)
''', [nick])
owes = self.cur.fetchone()
owes = cur.fetchone()
# Colorize owe amount and return string
if owes:
@ -56,7 +57,8 @@ class Rape(DatabasePlugin):
reason = ('raping', 'being too lewd and getting raped')[rand]
# Insert or add fine to database and return total owe
self.cur.execute('''
with self.con.cursor() as cur:
cur.execute('''
INSERT INTO
users (nick, fines)
VALUES
@ -73,5 +75,5 @@ class Rape(DatabasePlugin):
nick=fined,
fine=fine,
reason=reason,
total=self.cur.fetchone()['fines'],
total=cur.fetchone()['fines'],
))

View File

@ -20,7 +20,8 @@ class Useless(DatabasePlugin):
if nick == self.bot.nick:
return
self.cur.execute('''
with self.con.cursor() as cur:
cur.execute('''
SELECT
item
FROM
@ -29,7 +30,7 @@ class Useless(DatabasePlugin):
nick = lower(%s)
AND channel = lower(%s)
''', [nick, target])
result = self.cur.fetchone()
result = cur.fetchone()
if result:
old = result['item']
@ -44,7 +45,8 @@ class Useless(DatabasePlugin):
"""Saves the last message of a user for each channel (for regex)."""
mask = IrcString(mask)
self.cur.execute('''
with self.con.cursor() as cur:
cur.execute('''
INSERT INTO
last_messages (nick, host, channel, item)
VALUES

View File

@ -23,7 +23,8 @@ class Seen(DatabasePlugin):
return '{}, look in the mirror faggot!'.format(nick)
# Fetch seen from database
self.cur.execute('''
with self.con.cursor() as cur:
cur.execute('''
SELECT
seen_at, message, channel
FROM
@ -31,7 +32,7 @@ class Seen(DatabasePlugin):
WHERE
nick = lower(%s)
''', [nick])
seen = self.cur.fetchone()
seen = cur.fetchone()
# No result
if not seen:
@ -50,7 +51,8 @@ class Seen(DatabasePlugin):
def save(self, mask: str, target: str, msg: str):
mask = IrcString(mask)
self.cur.execute('''
with self.con.cursor() as cur:
cur.execute('''
INSERT INTO
seens (nick, host, channel, message)
VALUES

View File

@ -1,15 +1,46 @@
# -*- coding: utf-8 -*-
import os
import contextlib
import psycopg2
from psycopg2.extras import DictCursor
from . import Plugin, Bot
class DBConn:
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
self._reconnect()
def _reconnect(self):
self._con = psycopg2.connect(*self.args, **self.kwargs)
def commit(self, *args, **kwargs):
try:
return self._con.commit(*args, **kwargs)
except psycopg2.InterfaceError:
self._reconnect()
return self._con.commit(*args, **kwargs)
def rollback(self, *args, **kwargs):
try:
return self._con.rollback(*args, **kwargs)
except psycopg2.InterfaceError:
self._reconnect()
return self._con.rollback(*args, **kwargs)
@contextlib.contextmanager
def cursor(self, *args, **kwargs):
try:
yield self._con.cursor(cursor_factory=DictCursor, *args, **kwargs)
except psycopg2.InterfaceError:
self._reconnect()
yield self._con.cursor(cursor_factory=DictCursor, *args, **kwargs)
class Storage(Plugin):
def __init__(self, bot: Bot):
super().__init__(bot)
self.bot.sql = self
self.con = psycopg2.connect(os.environ['DATABASE_URI'])
self.cur = self.con.cursor(cursor_factory=DictCursor)
self.con = DBConn(os.environ['DATABASE_URI'])

View File

@ -15,14 +15,15 @@ class Tell(DatabasePlugin):
super().__init__(bot)
self.tell_queue = {}
self.cur.execute('''
with self.con.cursor() as cur:
cur.execute('''
SELECT
to_nick, from_nick, message, created_at
FROM
tells
''')
for res in self.cur.fetchall():
for res in cur.fetchall():
nick = res['to_nick'].lower()
if nick not in self.tell_queue:
@ -43,7 +44,8 @@ class Tell(DatabasePlugin):
self.tell_queue[nick].append(tell[1:])
try:
self.cur.execute('''
with self.con.cursor() as cur:
cur.execute('''
INSERT INTO
tells (to_nick, from_nick, message, created_at)
VALUES
@ -73,7 +75,8 @@ class Tell(DatabasePlugin):
))
del self.tell_queue[nick]
self.cur.execute('''
with self.con.cursor() as cur:
cur.execute('''
DELETE FROM
tells
WHERE

View File

@ -34,7 +34,8 @@ class Timer(DatabasePlugin):
return 'Invalid timer delay: {}'.format(delay)
try:
self.cur.execute('''
with self.con.cursor() as cur:
cur.execute('''
INSERT INTO
timers (mask, target, message, delay, ends_at)
VALUES
@ -44,7 +45,7 @@ class Timer(DatabasePlugin):
''', [mask, target, message, delay, delay])
self.con.commit()
asyncio.ensure_future(self.exec_timer(self.cur.fetchone()))
asyncio.ensure_future(self.exec_timer(cur.fetchone()))
self.bot.notice(mask.nick, 'Timer in {delay} set: {message}'.format(delay=delay, message=message))
except Error as ex:
@ -54,7 +55,8 @@ class Timer(DatabasePlugin):
def set_timers(self):
"""Function which queries all timers in the next hour and schedules them."""
self.log.debug('Fetching timers')
self.cur.execute('''
with self.con.cursor() as cur:
cur.execute('''
SELECT
*
FROM
@ -64,7 +66,7 @@ class Timer(DatabasePlugin):
AND ends_at < now() + INTERVAL '1h'
''')
for timer in self.cur.fetchall():
for timer in cur.fetchall():
asyncio.ensure_future(self.exec_timer(timer))
async def exec_timer(self, timer: DictRow):
@ -84,7 +86,8 @@ class Timer(DatabasePlugin):
))
self.timers.remove(timer['id'])
self.cur.execute('''
with self.con.cursor() as cur:
cur.execute('''
DELETE FROM
timers
WHERE

View File

@ -143,7 +143,8 @@ class Useless(DatabasePlugin):
%%kill [<nick>]
"""
self.cur.execute('''
with self.con.cursor() as cur:
cur.execute('''
SELECT
item
FROM
@ -153,7 +154,7 @@ class Useless(DatabasePlugin):
LIMIT
1
''')
self.bot.action(target, self.cur.fetchone()['item'].format(
self.bot.action(target, cur.fetchone()['item'].format(
nick=args.get('<nick>', mask.nick),
))
@ -163,7 +164,8 @@ class Useless(DatabasePlugin):
%%yiff [<nick>]
"""
self.cur.execute('''
with self.con.cursor() as cur:
cur.execute('''
SELECT
item
FROM
@ -173,7 +175,7 @@ class Useless(DatabasePlugin):
LIMIT
1
''')
self.bot.action(target, self.cur.fetchone()['item'].format(
self.bot.action(target, cur.fetchone()['item'].format(
nick=args.get('<nick>', mask.nick),
yiffer=mask.nick,
))
@ -477,7 +479,8 @@ class Useless(DatabasePlugin):
nick = nick[1:]
try:
self.cur.execute('''
with self.con.cursor() as cur:
cur.execute('''
INSERT INTO
users (nick, {0})
VALUES
@ -492,7 +495,8 @@ class Useless(DatabasePlugin):
self.log.error(ex)
self.con.rollback()
else:
self.cur.execute('''
with self.con.cursor() as cur:
cur.execute('''
SELECT
{}
FROM
@ -500,7 +504,7 @@ class Useless(DatabasePlugin):
WHERE
lower(nick) = lower(%s)
'''.format(field), [nick])
result = self.cur.fetchone()
result = cur.fetchone()
if result and result[field]:
return '\x02[{}]\x02 {}: {}'.format(field.title(), nick, result[field])