add auto reconnect for postgres

This only works on the second database interaction, since psycopg2 only notices that
the connection is gone, when a query is executed.

So in the common case reconnect works as follows:
- some bot method calls a cursor function like .execute(), .fetchone(), etc.
  - this raises an error if the connection is broken
  - if following code then requests a new cursor, this will also fail since psycopg2
    now knows that the connection is gone
  - the error is caught in storage.DBConn.cursor(), a new connection will be set up
    of which a new cursor is yielded
If the error happens in connection.commit() or .rollback() instead we can instantly
reconnect since these methods are wrapped.

So why not wrap the cursor methods as well?
Consider the following example:
A query is the last thing that was executed on a cursor.
The database connection is lost.
Now .fetchone() is called on the cursor.
We could wrap .fetchone() and reconnect, but we'd have to use a new cursor since
cursors are linked to connections. And on this new cursor .fetchone() wouldn't
make any sense, since we haven't executed a query on this cursor.
This commit is contained in:
jkhsjdhjs 2020-03-16 21:51:32 +00:00
parent 84396cad99
commit e03f5d0a43
10 changed files with 289 additions and 238 deletions

View File

@ -73,4 +73,3 @@ class DatabasePlugin(Plugin):
super().__init__(bot) super().__init__(bot)
self.db = bot.get_plugin(Storage) self.db = bot.get_plugin(Storage)
self.con = self.db.con self.con = self.db.con
self.cur = self.db.cur

View File

@ -29,34 +29,36 @@ class McManiac(DatabasePlugin):
offset = '' offset = ''
# Fetch result from database # Fetch result from database
self.cur.execute(''' with self.con.cursor() as cur:
SELECT cur.execute('''
item, SELECT
rank() OVER (ORDER BY id), item,
count(*) OVER (ROWS BETWEEN UNBOUNDED PRECEDING rank() OVER (ORDER BY id),
AND UNBOUNDED FOLLOWING) AS total count(*) OVER (ROWS BETWEEN UNBOUNDED PRECEDING
FROM AND UNBOUNDED FOLLOWING) AS total
mcmaniacs FROM
ORDER BY mcmaniacs
{order} ORDER BY
LIMIT {order}
1 LIMIT
{offset} 1
'''.format(order=order, offset=offset), [index]) {offset}
result = self.cur.fetchone() '''.format(order=order, offset=offset), [index])
result = cur.fetchone()
if result: if result:
return '[{rank}/{total}] {item}'.format(**result) return '[{rank}/{total}] {item}'.format(**result)
# TODO: fix regex ("McFooiaC McBariaC" adds "Mc\S+iaC") # TODO: fix regex ("McFooiaC McBariaC" adds "Mc\S+iaC")
@irc3.event(r'^:(?P<mask>\S+) PRIVMSG \S+ :.*(?P<item>Mc\S+iaC).*') @irc3.event(r'^:(?P<mask>\S+) PRIVMSG \S+ :.*(?P<item>Mc\S+iaC).*')
def save(self, mask: str, item: str): def save(self, mask: str, item: str):
if IrcString(mask).nick != self.bot.nick: if IrcString(mask).nick != self.bot.nick:
self.cur.execute(''' with self.con.cursor() as cur:
INSERT INTO cur.execute('''
mcmaniacs (item) INSERT INTO
VALUES mcmaniacs (item)
(%s) VALUES
ON CONFLICT DO NOTHING (%s)
''', [item]) ON CONFLICT DO NOTHING
''', [item])
self.con.commit() self.con.commit()

View File

@ -23,43 +23,45 @@ class Quotes(DatabasePlugin):
self.bot.notice(mask.nick, '[Quotes] Error parsing nick') self.bot.notice(mask.nick, '[Quotes] Error parsing nick')
else: else:
# Insert quote into database # Insert quote into database
self.cur.execute(''' with self.con.cursor() as cur:
INSERT INTO cur.execute('''
quotes (nick, item, channel, created_by) INSERT INTO
VALUES quotes (nick, item, channel, created_by)
(%s, %s, %s, %s) VALUES
''', [nick, quote, channel, mask.nick]) (%s, %s, %s, %s)
''', [nick, quote, channel, mask.nick])
def delete_quote(self, nick: str, quote: str): def delete_quote(self, nick: str, quote: str):
index, order = parse_int(quote, select=False) index, order = parse_int(quote, select=False)
if index: if index:
# Delete from database # Delete from database
self.cur.execute(''' with self.con.cursor() as cur:
-- noinspection SqlResolve cur.execute('''
WITH ranked_quotes AS ( -- noinspection SqlResolve
SELECT WITH ranked_quotes AS (
id, SELECT
rank() OVER (PARTITION BY nick ORDER BY id {order}) id,
FROM rank() OVER (PARTITION BY nick ORDER BY id {order})
quotes FROM
WHERE quotes
lower(nick) = lower(%s) WHERE
) lower(nick) = lower(%s)
)
-- noinspection SqlResolve -- noinspection SqlResolve
DELETE FROM DELETE FROM
quotes quotes
WHERE
id = (
SELECT
id
FROM
ranked_quotes
WHERE WHERE
rank = %s id = (
) SELECT
'''.format(order=order), [nick, index]) id
FROM
ranked_quotes
WHERE
rank = %s
)
'''.format(order=order), [nick, index])
@command(options_first=True, quiet=True) @command(options_first=True, quiet=True)
def q(self, mask: IrcString, target: IrcString, args: Dict): def q(self, mask: IrcString, target: IrcString, args: Dict):
@ -135,30 +137,31 @@ class Quotes(DatabasePlugin):
offset = '' offset = ''
# Fetch quote from database # Fetch quote from database
self.cur.execute(''' with self.con.cursor() as cur:
WITH ranked_quotes AS ( cur.execute('''
SELECT WITH ranked_quotes AS (
nick, SELECT
item, nick,
rank() OVER (PARTITION BY nick ORDER BY id), item,
count(*) OVER (PARTITION BY nick) AS total rank() OVER (PARTITION BY nick ORDER BY id),
FROM count(*) OVER (PARTITION BY nick) AS total
quotes FROM
) quotes
)
SELECT SELECT
* *
FROM FROM
ranked_quotes ranked_quotes
WHERE WHERE
{where} {where}
ORDER BY ORDER BY
{order} {order}
LIMIT LIMIT
1 1
{offset} {offset}
'''.format(where=' AND '.join(where), order=order, offset=offset), values) '''.format(where=' AND '.join(where), order=order, offset=offset), values)
result = self.cur.fetchone() result = cur.fetchone()
if result: if result:
return '[{rank}/{total}] <{nick}> {item}'.format(**result) return '[{rank}/{total}] <{nick}> {item}'.format(**result)

View File

@ -18,15 +18,16 @@ class Rape(DatabasePlugin):
nick = args.get('<nick>', mask.nick) nick = args.get('<nick>', mask.nick)
# Fetch result from database # Fetch result from database
self.cur.execute(''' with self.con.cursor() as cur:
SELECT cur.execute('''
fines SELECT
FROM fines
users FROM
WHERE users
lower(nick) = lower(%s) WHERE
''', [nick]) lower(nick) = lower(%s)
owes = self.cur.fetchone() ''', [nick])
owes = cur.fetchone()
# Colorize owe amount and return string # Colorize owe amount and return string
if owes: if owes:
@ -56,22 +57,23 @@ class Rape(DatabasePlugin):
reason = ('raping', 'being too lewd and getting raped')[rand] reason = ('raping', 'being too lewd and getting raped')[rand]
# Insert or add fine to database and return total owe # Insert or add fine to database and return total owe
self.cur.execute(''' with self.con.cursor() as cur:
INSERT INTO cur.execute('''
users (nick, fines) INSERT INTO
VALUES users (nick, fines)
(lower(%s), %s) VALUES
ON CONFLICT (nick) DO UPDATE SET (lower(%s), %s)
fines = users.fines + excluded.fines ON CONFLICT (nick) DO UPDATE SET
RETURNING fines = users.fines + excluded.fines
fines RETURNING
''', [fined, fine]) fines
self.con.commit() ''', [fined, fine])
self.con.commit()
# Print fine and total owe # Print fine and total owe
self.bot.action(target, 'fines {nick} \x02${fine}\x02 for {reason}. You owe: \x0304${total}\x03'.format( self.bot.action(target, 'fines {nick} \x02${fine}\x02 for {reason}. You owe: \x0304${total}\x03'.format(
nick=fined, nick=fined,
fine=fine, fine=fine,
reason=reason, reason=reason,
total=self.cur.fetchone()['fines'], total=cur.fetchone()['fines'],
)) ))

View File

@ -20,16 +20,17 @@ class Useless(DatabasePlugin):
if nick == self.bot.nick: if nick == self.bot.nick:
return return
self.cur.execute(''' with self.con.cursor() as cur:
SELECT cur.execute('''
item SELECT
FROM item
last_messages FROM
WHERE last_messages
nick = lower(%s) WHERE
AND channel = lower(%s) nick = lower(%s)
''', [nick, target]) AND channel = lower(%s)
result = self.cur.fetchone() ''', [nick, target])
result = cur.fetchone()
if result: if result:
old = result['item'] old = result['item']
@ -44,13 +45,14 @@ class Useless(DatabasePlugin):
"""Saves the last message of a user for each channel (for regex).""" """Saves the last message of a user for each channel (for regex)."""
mask = IrcString(mask) mask = IrcString(mask)
self.cur.execute(''' with self.con.cursor() as cur:
INSERT INTO cur.execute('''
last_messages (nick, host, channel, item) INSERT INTO
VALUES last_messages (nick, host, channel, item)
(lower(%s), %s, lower(%s), %s) VALUES
ON CONFLICT (nick, channel) DO UPDATE SET (lower(%s), %s, lower(%s), %s)
host = excluded.host, ON CONFLICT (nick, channel) DO UPDATE SET
item = excluded.item host = excluded.host,
''', [mask.nick, mask.host, target, msg]) item = excluded.item
''', [mask.nick, mask.host, target, msg])
self.con.commit() self.con.commit()

View File

@ -23,15 +23,16 @@ class Seen(DatabasePlugin):
return '{}, look in the mirror faggot!'.format(nick) return '{}, look in the mirror faggot!'.format(nick)
# Fetch seen from database # Fetch seen from database
self.cur.execute(''' with self.con.cursor() as cur:
SELECT cur.execute('''
seen_at, message, channel SELECT
FROM seen_at, message, channel
seens FROM
WHERE seens
nick = lower(%s) WHERE
''', [nick]) nick = lower(%s)
seen = self.cur.fetchone() ''', [nick])
seen = cur.fetchone()
# No result # No result
if not seen: if not seen:
@ -50,15 +51,16 @@ class Seen(DatabasePlugin):
def save(self, mask: str, target: str, msg: str): def save(self, mask: str, target: str, msg: str):
mask = IrcString(mask) mask = IrcString(mask)
self.cur.execute(''' with self.con.cursor() as cur:
INSERT INTO cur.execute('''
seens (nick, host, channel, message) INSERT INTO
VALUES seens (nick, host, channel, message)
(lower(%s), %s, %s, %s) VALUES
ON CONFLICT (nick) DO UPDATE SET (lower(%s), %s, %s, %s)
host = excluded.host, ON CONFLICT (nick) DO UPDATE SET
channel = excluded.channel, host = excluded.host,
seen_at = now(), channel = excluded.channel,
message = excluded.message seen_at = now(),
''', [mask.nick, mask.host, target, msg]) message = excluded.message
''', [mask.nick, mask.host, target, msg])
self.con.commit() self.con.commit()

View File

@ -1,15 +1,46 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import os import os
import contextlib
import psycopg2 import psycopg2
from psycopg2.extras import DictCursor from psycopg2.extras import DictCursor
from . import Plugin, Bot from . import Plugin, Bot
class DBConn:
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
self._reconnect()
def _reconnect(self):
self._con = psycopg2.connect(*self.args, **self.kwargs)
def commit(self, *args, **kwargs):
try:
return self._con.commit(*args, **kwargs)
except psycopg2.InterfaceError:
self._reconnect()
return self._con.commit(*args, **kwargs)
def rollback(self, *args, **kwargs):
try:
return self._con.rollback(*args, **kwargs)
except psycopg2.InterfaceError:
self._reconnect()
return self._con.rollback(*args, **kwargs)
@contextlib.contextmanager
def cursor(self, *args, **kwargs):
try:
yield self._con.cursor(cursor_factory=DictCursor, *args, **kwargs)
except psycopg2.InterfaceError:
self._reconnect()
yield self._con.cursor(cursor_factory=DictCursor, *args, **kwargs)
class Storage(Plugin): class Storage(Plugin):
def __init__(self, bot: Bot): def __init__(self, bot: Bot):
super().__init__(bot) super().__init__(bot)
self.bot.sql = self self.bot.sql = self
self.con = psycopg2.connect(os.environ['DATABASE_URI']) self.con = DBConn(os.environ['DATABASE_URI'])
self.cur = self.con.cursor(cursor_factory=DictCursor)

View File

@ -15,19 +15,20 @@ class Tell(DatabasePlugin):
super().__init__(bot) super().__init__(bot)
self.tell_queue = {} self.tell_queue = {}
self.cur.execute(''' with self.con.cursor() as cur:
SELECT cur.execute('''
to_nick, from_nick, message, created_at SELECT
FROM to_nick, from_nick, message, created_at
tells FROM
''') tells
''')
for res in self.cur.fetchall(): for res in cur.fetchall():
nick = res['to_nick'].lower() nick = res['to_nick'].lower()
if nick not in self.tell_queue: if nick not in self.tell_queue:
self.tell_queue[nick] = [] self.tell_queue[nick] = []
self.tell_queue[nick].append(res[1:]) self.tell_queue[nick].append(res[1:])
@command @command
def tell(self, mask: IrcString, target: IrcString, args: Dict): def tell(self, mask: IrcString, target: IrcString, args: Dict):
@ -43,12 +44,13 @@ class Tell(DatabasePlugin):
self.tell_queue[nick].append(tell[1:]) self.tell_queue[nick].append(tell[1:])
try: try:
self.cur.execute(''' with self.con.cursor() as cur:
INSERT INTO cur.execute('''
tells (to_nick, from_nick, message, created_at) INSERT INTO
VALUES tells (to_nick, from_nick, message, created_at)
(%s, %s, %s, %s) VALUES
''', tell) (%s, %s, %s, %s)
''', tell)
self.con.commit() self.con.commit()
self.bot.notice(mask.nick, "I will tell that to {} when I see them.".format(nick)) self.bot.notice(mask.nick, "I will tell that to {} when I see them.".format(nick))
@ -73,10 +75,11 @@ class Tell(DatabasePlugin):
)) ))
del self.tell_queue[nick] del self.tell_queue[nick]
self.cur.execute(''' with self.con.cursor() as cur:
DELETE FROM cur.execute('''
tells DELETE FROM
WHERE tells
to_nick = %s WHERE
''', [nick]) to_nick = %s
''', [nick])
self.con.commit() self.con.commit()

View File

@ -34,17 +34,18 @@ class Timer(DatabasePlugin):
return 'Invalid timer delay: {}'.format(delay) return 'Invalid timer delay: {}'.format(delay)
try: try:
self.cur.execute(''' with self.con.cursor() as cur:
INSERT INTO cur.execute('''
timers (mask, target, message, delay, ends_at) INSERT INTO
VALUES timers (mask, target, message, delay, ends_at)
(%s, %s, %s, %s, now() + INTERVAL %s) VALUES
RETURNING (%s, %s, %s, %s, now() + INTERVAL %s)
* RETURNING
''', [mask, target, message, delay, delay]) *
''', [mask, target, message, delay, delay])
self.con.commit() self.con.commit()
asyncio.ensure_future(self.exec_timer(self.cur.fetchone())) asyncio.ensure_future(self.exec_timer(cur.fetchone()))
self.bot.notice(mask.nick, 'Timer in {delay} set: {message}'.format(delay=delay, message=message)) self.bot.notice(mask.nick, 'Timer in {delay} set: {message}'.format(delay=delay, message=message))
except Error as ex: except Error as ex:
@ -54,18 +55,19 @@ class Timer(DatabasePlugin):
def set_timers(self): def set_timers(self):
"""Function which queries all timers in the next hour and schedules them.""" """Function which queries all timers in the next hour and schedules them."""
self.log.debug('Fetching timers') self.log.debug('Fetching timers')
self.cur.execute(''' with self.con.cursor() as cur:
SELECT cur.execute('''
* SELECT
FROM *
timers FROM
WHERE timers
ends_at >= now() WHERE
AND ends_at < now() + INTERVAL '1h' ends_at >= now()
''') AND ends_at < now() + INTERVAL '1h'
''')
for timer in self.cur.fetchall(): for timer in cur.fetchall():
asyncio.ensure_future(self.exec_timer(timer)) asyncio.ensure_future(self.exec_timer(timer))
async def exec_timer(self, timer: DictRow): async def exec_timer(self, timer: DictRow):
"""Sets the actual timer (sleeps until it fires), sends the reminder and deletes the timer from database.""" """Sets the actual timer (sleeps until it fires), sends the reminder and deletes the timer from database."""
@ -84,10 +86,11 @@ class Timer(DatabasePlugin):
)) ))
self.timers.remove(timer['id']) self.timers.remove(timer['id'])
self.cur.execute(''' with self.con.cursor() as cur:
DELETE FROM cur.execute('''
timers DELETE FROM
WHERE timers
id = %s WHERE
''', [timer['id']]) id = %s
''', [timer['id']])
self.con.commit() self.con.commit()

View File

@ -143,19 +143,20 @@ class Useless(DatabasePlugin):
%%kill [<nick>] %%kill [<nick>]
""" """
self.cur.execute(''' with self.con.cursor() as cur:
SELECT cur.execute('''
item SELECT
FROM item
kills FROM
ORDER BY kills
random() ORDER BY
LIMIT random()
1 LIMIT
''') 1
self.bot.action(target, self.cur.fetchone()['item'].format( ''')
nick=args.get('<nick>', mask.nick), self.bot.action(target, cur.fetchone()['item'].format(
)) nick=args.get('<nick>', mask.nick),
))
@command @command
def yiff(self, mask: IrcString, target: IrcString, args: Dict): def yiff(self, mask: IrcString, target: IrcString, args: Dict):
@ -163,20 +164,21 @@ class Useless(DatabasePlugin):
%%yiff [<nick>] %%yiff [<nick>]
""" """
self.cur.execute(''' with self.con.cursor() as cur:
SELECT cur.execute('''
item SELECT
FROM item
yiffs FROM
ORDER BY yiffs
random() ORDER BY
LIMIT random()
1 LIMIT
''') 1
self.bot.action(target, self.cur.fetchone()['item'].format( ''')
nick=args.get('<nick>', mask.nick), self.bot.action(target, cur.fetchone()['item'].format(
yiffer=mask.nick, nick=args.get('<nick>', mask.nick),
)) yiffer=mask.nick,
))
@command @command
def waifu(self, mask: IrcString, target: IrcString, args: Dict): def waifu(self, mask: IrcString, target: IrcString, args: Dict):
@ -477,14 +479,15 @@ class Useless(DatabasePlugin):
nick = nick[1:] nick = nick[1:]
try: try:
self.cur.execute(''' with self.con.cursor() as cur:
INSERT INTO cur.execute('''
users (nick, {0}) INSERT INTO
VALUES users (nick, {0})
(lower(%s), %s) VALUES
ON CONFLICT (nick) DO UPDATE SET (lower(%s), %s)
{0} = excluded.{0} ON CONFLICT (nick) DO UPDATE SET
'''.format(field), [mask.nick, nick]) {0} = excluded.{0}
'''.format(field), [mask.nick, nick])
self.con.commit() self.con.commit()
self.bot.notice(mask.nick, '{} set to: {}'.format(field.title(), nick)) self.bot.notice(mask.nick, '{} set to: {}'.format(field.title(), nick))
@ -492,15 +495,16 @@ class Useless(DatabasePlugin):
self.log.error(ex) self.log.error(ex)
self.con.rollback() self.con.rollback()
else: else:
self.cur.execute(''' with self.con.cursor() as cur:
SELECT cur.execute('''
{} SELECT
FROM {}
users FROM
WHERE users
lower(nick) = lower(%s) WHERE
'''.format(field), [nick]) lower(nick) = lower(%s)
result = self.cur.fetchone() '''.format(field), [nick])
result = cur.fetchone()
if result and result[field]: if result and result[field]:
return '\x02[{}]\x02 {}: {}'.format(field.title(), nick, result[field]) return '\x02[{}]\x02 {}: {}'.format(field.title(), nick, result[field])