   1 #-*- coding: utf-8 -*-
   3 """Threadsafe database wrapper library over psycopg
   5 Very simple---- Oops, not so simple. File descriptors are copied on forks, so
   6 the thread-safe variant must take pid into account.
   8 >>> db = pgdb.get_cursor('host=localhost dbname=eyedb '
   9 ...                      'user=eyeinvoice password=Kthulhu4presidente',
  10 ...                      {'encoding': 'UNICODE'})
  11 >>> 
  12 >>> db.execute("SELECT COUNT(*) FROM accounting.invoices")
  13 >>> db.dictfetchall()
  14 [{'count': 138145L}]
  16 @author:       Anders Eurenius <>
  17 @author:       Ulf Renman <>
  18 @organization: Favoptic Glasögondirekt AB
  19 @requires:     config : dbstring, dbparams
  20 @var con:     The current database connection
  21 @type con:    psycopg connection
  22 """
  24 import sys, os, time, thread
  26 v = None
  27 try:
  28     import psycopg2        as pg
  29     import psycopg2.extras as ex
  30     v = 2
  31 except:
  32     import psycopg         as pg
  33     v = 1
  35 try:
  36     import config                # import global config if there is one
  37     config.dsn = getattr(config, 'dsn', '')
  38     config.bug = getattr(config, 'bug', {})
  39 except:
  40     class old: pass              # make our own, reachable from outside.
  41     config = old()               # we don't insert it because that might
  42     del old                      # interfere with loading that module later.
  43     config.dsn = ''              # We delete it, because it's junk.
  44     config.bug = {}
  46 if '--dsn' in sys.argv and len(sys.argv) < sys.index('--dsn')+1:
  47     config.dsn = sys.pop(index('--dsn')+1)
  48     sys.remove('--dsn')
  50 def init():
  51     global con, thr, bug, thr_lvl, iso_lvl
  53     thr_lvl = ['Threads may not share the module.',
  54                'Threads may share the module, but not connections.',
  55                'Threads may share the module and connections.',
  56                'Threads may share the module, connections and cursors.']
  57     iso_lvl = ['Autocommit','Read committed','Read uncommitted',
  58                'Repeatable read', 'Serializable']
  59     con, thr, bug = None, {}, getattr(config, 'bug', {})
  60     if bug:
  61         print 'psycopg:      ', pg.__version__
  62         print 'API level:    ', pg.apilevel
  63         print 'Param style:  ', pg.paramstyle
  64         print 'Thread safety:', thr_lvl[pg.threadsafety]
  66 init()
  68 def dsn_clean(dsn=''):
  69     l = dsn.split('=')                              # split kvskv to k,vsk,v
  70     d, k = {}, l[0]                                 #                k
  71     for vk in l[1:-1]: d[k], k = vk.rsplit(' ', 1)  #                 =v,k
  72     d[k] = l[-1]                                    #                     =v
  73     dsn = ''
  74     for k,v in d.items():
  75         v = v.rstrip()
  76         if v:
  77             v = v[0] == v[-1] and v[0] in '\'"' and v[1:-1] or v
  78             d[k] = v
  79     for k in 'dbname','host','port','user','password','sslmode':
  80         if k in d:
  81             dsn = dsn + (k.lower() +'='+ d[k])
  82     return dsn, d
  84 def connect(dsn):
  85     dsn, d = dsn_clean(dsn)
  86     return pg.connect(dsn), d
  88 def get_cursor(dsn=''):
  89     """Creates a new connection (if necessary) and a new cursor"""
  90     global con
  91     cxn = con
  92     if       dsn: cxn, d = connect(dsn)
  93     elif not con or con.closed:
  94         con, d = connect(config.dsn)
  95         cxn = con
  96     else: d = {}
  98     if   v == 1: cur = cxn.cursor()
  99     elif v == 2: cur = cxn.cursor(cursor_factory = ex.DictCursor)
 101     if d.get('autocommit'):
 102         if   v == 1: cxn.autocommit(1)
 103         elif v == 2: cxn.set_isolation_level(0)
 104     if d.get('encoding'):
 105         if   v == 1: cur.execute("SET CLIENT_ENCODING=%s", (d['encoding'],))
 106         elif v == 2: cxn.set_client_encoding(d['encoding'])
 107     return cur
 109 def get_thread_cursor(dsn='', pid=[-1]):
 110     """Utility to create separate cursors for different (pid x thread)s."""
 111     global thr, con
 113     if pid[0] != os.getpid(): con, pid[0], thr = None, os.getpid(), {}
 114     me = (dsn, thread.get_ident())
 115     return thr.setdefault(me, get_cursor(dsn))
 117 class Query( object ):
 118     """Utility for reusing a query in a safe and convenient way.
 120     The instance is created with
 121       1. the query string,
 122       2. an ordered list of pairs of parameters and their casting (or
 123          conversion, or..) functions. (optional)
 124       3. a dictionary containing default values.
 126     When called the argument defaults to an empty dictionary. If a dict is
 127     given on the other hand, the defaults are copied, the copy is then updated
 128     with the argument dict. A list of query paramenters is then constructed by
 129     picking the dict items according to the keys list and mapping them with
 130     their corresponding functions.
 132     @note: The same dict key can be used more than once in the query.
 134     @note: If the query fails because of a db restart, it tries to reconnect.
 136     Although the class is crafted for that use,
 137       1. The query need not be a select,
 138       2. The keys need not be strings and
 139       3. The functions need not be constructors; Notably, they can return None
 140     """
 141     def __init__(self, sql, keys=(), defaults={}):
 142         self.sql, self.keys, self.defaults = sql, keys, defaults
 143     def __repr__(self):
 144         return ('Query("""%s""" x (%s) x {%s})' %
 145                 (self.sql,
 146                  ', '.join([ k for k, f in self.keys]),
 147                  ', '.join(  self.defaults.keys()    ) ))
 148     def __call__(self, d={}, retry=1):
 149         e = dict(self.defaults)                   # (shallow) copy the defaults
 150         if isinstance( d, dict ):
 151             e.update(d)                           # args override defaults
 152             l = [ f(e[k]) for k, f in self.keys ] # map out the args
 153         else:                                     # same, but assume object
 154             l = [ f( getattr(d, k, e.get(k)) ) for k, f in self.keys ]
 155         try:
 156             tc = get_thread_cursor()
 157             if bug.get('pgdb') > 2: print self.sql % l
 158             tc.execute(self.sql, l)
 159         except pg.OperationalError, x:
 160             if bug.get('pgdb') > 1: print x
 161             if retry:
 162                 global thr
 163                 if bug.get('pgdb') > 0: print '--- RECONNECTING ---'
 164                 make_connection()
 165                 thr = {}
 166                 return self(d, retry-1)
 167         except Exception, x:
 168             if bug.get('pgdb') > 0: print x
 169         if tc.statusmessage == 'SELECT':
 170             return [ dict(x) for x in tc.fetchall() ] or []
 171         else:
 172             return tc.rowcount
 175 class CachedQuery( list ):
 176     """Result caching list class that refreshes itself.
 178     The instance is given
 179       1. a function or functor, (L{Query}, hint, hint.)
 180       2. an optional cache time in seconds
 181       3. a function the query results are mapped through
 183     The point of the exercise is to get stuff from the db while balancing
 184       1. not doing a query every time
 185       2. allowing change without restarting the application
 187     It can be refreshed manually with C{refresh}, so if you like, you can set
 188     the cache time to 2**64 and refresh it explicitly.
 190     >>> cq = CachedQuery(Query("SELECT * FROM pg_catalog.pg_class"),
 191     ...                  f=lambda x: (x['start'].strftime('%F'), x['uid']))
 193     @warning: Failure semantics are not so well thought out. (If you have a
 194     better idea, tell me.)
 195     """
 196     def __init__(self, q, to=300, f=lambda x: x):
 197         self.q,, self.f, self.t = q, to, f, 0
 198         self.refresh()
 199     def refresh(self):
 200         t = time.time()
 201         if t - self.t >
 202             if bug.get('pgdb') > 1: print 'REFRESH: ', self.q
 203             try:    self.t, self[:] = t, map(self.f, self.q())
 204             except: self[:] = []
 205     def __getitem__(self, *x): self.refresh();return list.__getitem__(self,*x)
 206     def __getslice__(self,*x): self.refresh();return list.__getslice__(self,*x)
 207     def __repr__(self,*x):     self.refresh();return list.__repr__(self, *x)
 208     def __str__(self,*x):      self.refresh();return list.__str__(self, *x)
 209     def __len__(self,*x):      self.refresh();return list.__len__(self, *x)

