firebot/contrib/seedyb.py

147 lines
3.5 KiB
Python

#! /usr/bin/env python
import cdb
import random
import codecs
_encode = codecs.getencoder('utf-8')
(_encode, _decode, _, _) = codecs.lookup('utf-8')
def encode(str):
return _encode(str)[0]
def decode(str):
return _decode(str)[0]
class Locked(Exception):
pass
class SeedyB:
"""firebot-specific database using cdb.
Why CDB? Because everything else keeps going corrupt.
Notes:
* This doesn't preserve unsynced additions. If you crash before
running sync(), you lose.
* If you set a value to [], is is effectively deleted.
* You can lock something that's not in the database. You might
want to do this for words like 'that', so the bot doesn't pick
up on them.
"""
def __init__(self, filename):
self.filename = filename
self.tempfile = "%s.tmp" % filename
self.db = {}
try:
self.cdb = cdb.init(self.filename)
except cdb.error:
d = cdb.cdbmake(self.filename, self.tempfile)
d.finish()
del d
self.cdb = cdb.init(self.filename)
def __del__(self):
self.sync(force=True)
def __len__(self):
return len(self.cdb) + len(self.db)
def __delitem__(self, key):
self.delete(key)
def __getitem__(self, key):
val = self.get(key)
if val is None:
raise KeyError(key)
def __contains__(self, key):
return (key in self.db) or (key in self.cdb)
def length(self):
return (len(self.cdb), len(self.db))
def sync(self, force=False):
if not self.db:
return
tmp = cdb.cdbmake(self.filename, self.tempfile)
# Copy original
r = self.cdb.each()
while r:
k,v = r
dk = decode(k)
if k not in self.db:
tmp.add(*r)
r = self.cdb.each()
# Add new stuff
for k,l in self.db.iteritems():
for v in l:
try:
tmp.add(k,v)
except:
print (k,v)
raise
tmp.finish()
self.cdb = cdb.init(self.filename)
self.db = {}
def getall(self, key, special=None):
"""Return all values for a key"""
if special:
key = '\016%s:%s\017' % (special, key)
ekey = encode(key)
vals = self.db.get(ekey, None)
if vals is None:
vals = self.cdb.getall(ekey)
return [decode(v) for v in vals]
def get(self, key, default=None, special=None):
"""Get a value at random"""
vals = self.getall(key, special)
if vals:
return random.choice(vals)
else:
return default
def set(self, key, val, special=None):
if special:
key = '\016%s:%s\017' % (special, key)
ekey = encode(key)
if type(val) not in (type([]), type(())):
val = [val]
if not special and self.is_locked(key):
raise Locked()
self.db[ekey] = [encode(v) for v in val]
def delete(self, key, special=None):
val = self.get(key, special=special)
if not val:
raise KeyError(key)
self.set(key, [], special)
##
## Locking
##
def lock(self, key):
self.set(key, ['locked'], special='lock')
def unlock(self, key):
self.set(key, [], special='lock')
def is_locked(self, key):
l = self.get(key, special='lock')
if l is None:
return False
return True
open = SeedyB