add auto reconnect for postgres

This only works on the second database interaction, since psycopg2 only notices that
the connection is gone, when a query is executed.

So in the common case reconnect works as follows:
- some bot method calls a cursor function like .execute(), .fetchone(), etc.
  - this raises an error if the connection is broken
  - if following code then requests a new cursor, this will also fail since psycopg2
    now knows that the connection is gone
  - the error is caught in storage.DBConn.cursor(), a new connection will be set up
    of which a new cursor is yielded
If the error happens in connection.commit() or .rollback() instead we can instantly
reconnect since these methods are wrapped.

So why not wrap the cursor methods as well?
Consider the following example:
A query is the last thing that was executed on a cursor.
The database connection is lost.
Now .fetchone() is called on the cursor.
We could wrap .fetchone() and reconnect, but we'd have to use a new cursor since
cursors are linked to connections. And on this new cursor .fetchone() wouldn't
make any sense, since we haven't executed a query on this cursor.
This commit is contained in:
jkhsjdhjs 2020-03-16 21:51:32 +00:00
parent 84396cad99
commit e03f5d0a43
10 changed files with 289 additions and 238 deletions

View File

@ -73,4 +73,3 @@ class DatabasePlugin(Plugin):
super().__init__(bot)
self.db = bot.get_plugin(Storage)
self.con = self.db.con
self.cur = self.db.cur

View File

@ -29,34 +29,36 @@ class McManiac(DatabasePlugin):
offset = ''
# Fetch result from database
self.cur.execute('''
SELECT
item,
rank() OVER (ORDER BY id),
count(*) OVER (ROWS BETWEEN UNBOUNDED PRECEDING
AND UNBOUNDED FOLLOWING) AS total
FROM
mcmaniacs
ORDER BY
{order}
LIMIT
1
{offset}
'''.format(order=order, offset=offset), [index])
result = self.cur.fetchone()
with self.con.cursor() as cur:
cur.execute('''
SELECT
item,
rank() OVER (ORDER BY id),
count(*) OVER (ROWS BETWEEN UNBOUNDED PRECEDING
AND UNBOUNDED FOLLOWING) AS total
FROM
mcmaniacs
ORDER BY
{order}
LIMIT
1
{offset}
'''.format(order=order, offset=offset), [index])
result = cur.fetchone()
if result:
return '[{rank}/{total}] {item}'.format(**result)
if result:
return '[{rank}/{total}] {item}'.format(**result)
# TODO: fix regex ("McFooiaC McBariaC" adds "Mc\S+iaC")
@irc3.event(r'^:(?P<mask>\S+) PRIVMSG \S+ :.*(?P<item>Mc\S+iaC).*')
def save(self, mask: str, item: str):
if IrcString(mask).nick != self.bot.nick:
self.cur.execute('''
INSERT INTO
mcmaniacs (item)
VALUES
(%s)
ON CONFLICT DO NOTHING
''', [item])
with self.con.cursor() as cur:
cur.execute('''
INSERT INTO
mcmaniacs (item)
VALUES
(%s)
ON CONFLICT DO NOTHING
''', [item])
self.con.commit()

View File

@ -23,43 +23,45 @@ class Quotes(DatabasePlugin):
self.bot.notice(mask.nick, '[Quotes] Error parsing nick')
else:
# Insert quote into database
self.cur.execute('''
INSERT INTO
quotes (nick, item, channel, created_by)
VALUES
(%s, %s, %s, %s)
''', [nick, quote, channel, mask.nick])
with self.con.cursor() as cur:
cur.execute('''
INSERT INTO
quotes (nick, item, channel, created_by)
VALUES
(%s, %s, %s, %s)
''', [nick, quote, channel, mask.nick])
def delete_quote(self, nick: str, quote: str):
index, order = parse_int(quote, select=False)
if index:
# Delete from database
self.cur.execute('''
-- noinspection SqlResolve
WITH ranked_quotes AS (
SELECT
id,
rank() OVER (PARTITION BY nick ORDER BY id {order})
FROM
quotes
WHERE
lower(nick) = lower(%s)
)
with self.con.cursor() as cur:
cur.execute('''
-- noinspection SqlResolve
WITH ranked_quotes AS (
SELECT
id,
rank() OVER (PARTITION BY nick ORDER BY id {order})
FROM
quotes
WHERE
lower(nick) = lower(%s)
)
-- noinspection SqlResolve
DELETE FROM
quotes
WHERE
id = (
SELECT
id
FROM
ranked_quotes
-- noinspection SqlResolve
DELETE FROM
quotes
WHERE
rank = %s
)
'''.format(order=order), [nick, index])
id = (
SELECT
id
FROM
ranked_quotes
WHERE
rank = %s
)
'''.format(order=order), [nick, index])
@command(options_first=True, quiet=True)
def q(self, mask: IrcString, target: IrcString, args: Dict):
@ -135,30 +137,31 @@ class Quotes(DatabasePlugin):
offset = ''
# Fetch quote from database
self.cur.execute('''
WITH ranked_quotes AS (
SELECT
nick,
item,
rank() OVER (PARTITION BY nick ORDER BY id),
count(*) OVER (PARTITION BY nick) AS total
FROM
quotes
)
with self.con.cursor() as cur:
cur.execute('''
WITH ranked_quotes AS (
SELECT
nick,
item,
rank() OVER (PARTITION BY nick ORDER BY id),
count(*) OVER (PARTITION BY nick) AS total
FROM
quotes
)
SELECT
*
FROM
ranked_quotes
WHERE
{where}
ORDER BY
{order}
LIMIT
1
{offset}
'''.format(where=' AND '.join(where), order=order, offset=offset), values)
result = self.cur.fetchone()
SELECT
*
FROM
ranked_quotes
WHERE
{where}
ORDER BY
{order}
LIMIT
1
{offset}
'''.format(where=' AND '.join(where), order=order, offset=offset), values)
result = cur.fetchone()
if result:
return '[{rank}/{total}] <{nick}> {item}'.format(**result)

View File

@ -18,15 +18,16 @@ class Rape(DatabasePlugin):
nick = args.get('<nick>', mask.nick)
# Fetch result from database
self.cur.execute('''
SELECT
fines
FROM
users
WHERE
lower(nick) = lower(%s)
''', [nick])
owes = self.cur.fetchone()
with self.con.cursor() as cur:
cur.execute('''
SELECT
fines
FROM
users
WHERE
lower(nick) = lower(%s)
''', [nick])
owes = cur.fetchone()
# Colorize owe amount and return string
if owes:
@ -56,22 +57,23 @@ class Rape(DatabasePlugin):
reason = ('raping', 'being too lewd and getting raped')[rand]
# Insert or add fine to database and return total owe
self.cur.execute('''
INSERT INTO
users (nick, fines)
VALUES
(lower(%s), %s)
ON CONFLICT (nick) DO UPDATE SET
fines = users.fines + excluded.fines
RETURNING
fines
''', [fined, fine])
self.con.commit()
with self.con.cursor() as cur:
cur.execute('''
INSERT INTO
users (nick, fines)
VALUES
(lower(%s), %s)
ON CONFLICT (nick) DO UPDATE SET
fines = users.fines + excluded.fines
RETURNING
fines
''', [fined, fine])
self.con.commit()
# Print fine and total owe
self.bot.action(target, 'fines {nick} \x02${fine}\x02 for {reason}. You owe: \x0304${total}\x03'.format(
nick=fined,
fine=fine,
reason=reason,
total=self.cur.fetchone()['fines'],
))
# Print fine and total owe
self.bot.action(target, 'fines {nick} \x02${fine}\x02 for {reason}. You owe: \x0304${total}\x03'.format(
nick=fined,
fine=fine,
reason=reason,
total=cur.fetchone()['fines'],
))

View File

@ -20,16 +20,17 @@ class Useless(DatabasePlugin):
if nick == self.bot.nick:
return
self.cur.execute('''
SELECT
item
FROM
last_messages
WHERE
nick = lower(%s)
AND channel = lower(%s)
''', [nick, target])
result = self.cur.fetchone()
with self.con.cursor() as cur:
cur.execute('''
SELECT
item
FROM
last_messages
WHERE
nick = lower(%s)
AND channel = lower(%s)
''', [nick, target])
result = cur.fetchone()
if result:
old = result['item']
@ -44,13 +45,14 @@ class Useless(DatabasePlugin):
"""Saves the last message of a user for each channel (for regex)."""
mask = IrcString(mask)
self.cur.execute('''
INSERT INTO
last_messages (nick, host, channel, item)
VALUES
(lower(%s), %s, lower(%s), %s)
ON CONFLICT (nick, channel) DO UPDATE SET
host = excluded.host,
item = excluded.item
''', [mask.nick, mask.host, target, msg])
with self.con.cursor() as cur:
cur.execute('''
INSERT INTO
last_messages (nick, host, channel, item)
VALUES
(lower(%s), %s, lower(%s), %s)
ON CONFLICT (nick, channel) DO UPDATE SET
host = excluded.host,
item = excluded.item
''', [mask.nick, mask.host, target, msg])
self.con.commit()

View File

@ -23,15 +23,16 @@ class Seen(DatabasePlugin):
return '{}, look in the mirror faggot!'.format(nick)
# Fetch seen from database
self.cur.execute('''
SELECT
seen_at, message, channel
FROM
seens
WHERE
nick = lower(%s)
''', [nick])
seen = self.cur.fetchone()
with self.con.cursor() as cur:
cur.execute('''
SELECT
seen_at, message, channel
FROM
seens
WHERE
nick = lower(%s)
''', [nick])
seen = cur.fetchone()
# No result
if not seen:
@ -50,15 +51,16 @@ class Seen(DatabasePlugin):
def save(self, mask: str, target: str, msg: str):
mask = IrcString(mask)
self.cur.execute('''
INSERT INTO
seens (nick, host, channel, message)
VALUES
(lower(%s), %s, %s, %s)
ON CONFLICT (nick) DO UPDATE SET
host = excluded.host,
channel = excluded.channel,
seen_at = now(),
message = excluded.message
''', [mask.nick, mask.host, target, msg])
with self.con.cursor() as cur:
cur.execute('''
INSERT INTO
seens (nick, host, channel, message)
VALUES
(lower(%s), %s, %s, %s)
ON CONFLICT (nick) DO UPDATE SET
host = excluded.host,
channel = excluded.channel,
seen_at = now(),
message = excluded.message
''', [mask.nick, mask.host, target, msg])
self.con.commit()

View File

@ -1,15 +1,46 @@
# -*- coding: utf-8 -*-
import os
import contextlib
import psycopg2
from psycopg2.extras import DictCursor
from . import Plugin, Bot
class DBConn:
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
self._reconnect()
def _reconnect(self):
self._con = psycopg2.connect(*self.args, **self.kwargs)
def commit(self, *args, **kwargs):
try:
return self._con.commit(*args, **kwargs)
except psycopg2.InterfaceError:
self._reconnect()
return self._con.commit(*args, **kwargs)
def rollback(self, *args, **kwargs):
try:
return self._con.rollback(*args, **kwargs)
except psycopg2.InterfaceError:
self._reconnect()
return self._con.rollback(*args, **kwargs)
@contextlib.contextmanager
def cursor(self, *args, **kwargs):
try:
yield self._con.cursor(cursor_factory=DictCursor, *args, **kwargs)
except psycopg2.InterfaceError:
self._reconnect()
yield self._con.cursor(cursor_factory=DictCursor, *args, **kwargs)
class Storage(Plugin):
def __init__(self, bot: Bot):
super().__init__(bot)
self.bot.sql = self
self.con = psycopg2.connect(os.environ['DATABASE_URI'])
self.cur = self.con.cursor(cursor_factory=DictCursor)
self.con = DBConn(os.environ['DATABASE_URI'])

View File

@ -15,19 +15,20 @@ class Tell(DatabasePlugin):
super().__init__(bot)
self.tell_queue = {}
self.cur.execute('''
SELECT
to_nick, from_nick, message, created_at
FROM
tells
''')
with self.con.cursor() as cur:
cur.execute('''
SELECT
to_nick, from_nick, message, created_at
FROM
tells
''')
for res in self.cur.fetchall():
nick = res['to_nick'].lower()
for res in cur.fetchall():
nick = res['to_nick'].lower()
if nick not in self.tell_queue:
self.tell_queue[nick] = []
self.tell_queue[nick].append(res[1:])
if nick not in self.tell_queue:
self.tell_queue[nick] = []
self.tell_queue[nick].append(res[1:])
@command
def tell(self, mask: IrcString, target: IrcString, args: Dict):
@ -43,12 +44,13 @@ class Tell(DatabasePlugin):
self.tell_queue[nick].append(tell[1:])
try:
self.cur.execute('''
INSERT INTO
tells (to_nick, from_nick, message, created_at)
VALUES
(%s, %s, %s, %s)
''', tell)
with self.con.cursor() as cur:
cur.execute('''
INSERT INTO
tells (to_nick, from_nick, message, created_at)
VALUES
(%s, %s, %s, %s)
''', tell)
self.con.commit()
self.bot.notice(mask.nick, "I will tell that to {} when I see them.".format(nick))
@ -73,10 +75,11 @@ class Tell(DatabasePlugin):
))
del self.tell_queue[nick]
self.cur.execute('''
DELETE FROM
tells
WHERE
to_nick = %s
''', [nick])
with self.con.cursor() as cur:
cur.execute('''
DELETE FROM
tells
WHERE
to_nick = %s
''', [nick])
self.con.commit()

View File

@ -34,17 +34,18 @@ class Timer(DatabasePlugin):
return 'Invalid timer delay: {}'.format(delay)
try:
self.cur.execute('''
INSERT INTO
timers (mask, target, message, delay, ends_at)
VALUES
(%s, %s, %s, %s, now() + INTERVAL %s)
RETURNING
*
''', [mask, target, message, delay, delay])
with self.con.cursor() as cur:
cur.execute('''
INSERT INTO
timers (mask, target, message, delay, ends_at)
VALUES
(%s, %s, %s, %s, now() + INTERVAL %s)
RETURNING
*
''', [mask, target, message, delay, delay])
self.con.commit()
asyncio.ensure_future(self.exec_timer(self.cur.fetchone()))
asyncio.ensure_future(self.exec_timer(cur.fetchone()))
self.bot.notice(mask.nick, 'Timer in {delay} set: {message}'.format(delay=delay, message=message))
except Error as ex:
@ -54,18 +55,19 @@ class Timer(DatabasePlugin):
def set_timers(self):
"""Function which queries all timers in the next hour and schedules them."""
self.log.debug('Fetching timers')
self.cur.execute('''
SELECT
*
FROM
timers
WHERE
ends_at >= now()
AND ends_at < now() + INTERVAL '1h'
''')
with self.con.cursor() as cur:
cur.execute('''
SELECT
*
FROM
timers
WHERE
ends_at >= now()
AND ends_at < now() + INTERVAL '1h'
''')
for timer in self.cur.fetchall():
asyncio.ensure_future(self.exec_timer(timer))
for timer in cur.fetchall():
asyncio.ensure_future(self.exec_timer(timer))
async def exec_timer(self, timer: DictRow):
"""Sets the actual timer (sleeps until it fires), sends the reminder and deletes the timer from database."""
@ -84,10 +86,11 @@ class Timer(DatabasePlugin):
))
self.timers.remove(timer['id'])
self.cur.execute('''
DELETE FROM
timers
WHERE
id = %s
''', [timer['id']])
with self.con.cursor() as cur:
cur.execute('''
DELETE FROM
timers
WHERE
id = %s
''', [timer['id']])
self.con.commit()

View File

@ -143,19 +143,20 @@ class Useless(DatabasePlugin):
%%kill [<nick>]
"""
self.cur.execute('''
SELECT
item
FROM
kills
ORDER BY
random()
LIMIT
1
''')
self.bot.action(target, self.cur.fetchone()['item'].format(
nick=args.get('<nick>', mask.nick),
))
with self.con.cursor() as cur:
cur.execute('''
SELECT
item
FROM
kills
ORDER BY
random()
LIMIT
1
''')
self.bot.action(target, cur.fetchone()['item'].format(
nick=args.get('<nick>', mask.nick),
))
@command
def yiff(self, mask: IrcString, target: IrcString, args: Dict):
@ -163,20 +164,21 @@ class Useless(DatabasePlugin):
%%yiff [<nick>]
"""
self.cur.execute('''
SELECT
item
FROM
yiffs
ORDER BY
random()
LIMIT
1
''')
self.bot.action(target, self.cur.fetchone()['item'].format(
nick=args.get('<nick>', mask.nick),
yiffer=mask.nick,
))
with self.con.cursor() as cur:
cur.execute('''
SELECT
item
FROM
yiffs
ORDER BY
random()
LIMIT
1
''')
self.bot.action(target, cur.fetchone()['item'].format(
nick=args.get('<nick>', mask.nick),
yiffer=mask.nick,
))
@command
def waifu(self, mask: IrcString, target: IrcString, args: Dict):
@ -477,14 +479,15 @@ class Useless(DatabasePlugin):
nick = nick[1:]
try:
self.cur.execute('''
INSERT INTO
users (nick, {0})
VALUES
(lower(%s), %s)
ON CONFLICT (nick) DO UPDATE SET
{0} = excluded.{0}
'''.format(field), [mask.nick, nick])
with self.con.cursor() as cur:
cur.execute('''
INSERT INTO
users (nick, {0})
VALUES
(lower(%s), %s)
ON CONFLICT (nick) DO UPDATE SET
{0} = excluded.{0}
'''.format(field), [mask.nick, nick])
self.con.commit()
self.bot.notice(mask.nick, '{} set to: {}'.format(field.title(), nick))
@ -492,15 +495,16 @@ class Useless(DatabasePlugin):
self.log.error(ex)
self.con.rollback()
else:
self.cur.execute('''
SELECT
{}
FROM
users
WHERE
lower(nick) = lower(%s)
'''.format(field), [nick])
result = self.cur.fetchone()
with self.con.cursor() as cur:
cur.execute('''
SELECT
{}
FROM
users
WHERE
lower(nick) = lower(%s)
'''.format(field), [nick])
result = cur.fetchone()
if result and result[field]:
return '\x02[{}]\x02 {}: {}'.format(field.title(), nick, result[field])