all tests pass with postgres now
This commit is contained in:
@@ -4,26 +4,24 @@ from sqlalchemy.databases import postgres as sa_base
|
|||||||
|
|
||||||
PGSchemaGenerator = sa_base.PGSchemaGenerator
|
PGSchemaGenerator = sa_base.PGSchemaGenerator
|
||||||
|
|
||||||
class PGColumnGenerator(PGSchemaGenerator,ansisql.ANSIColumnGenerator):
|
class PGSchemaGeneratorMixin(object):
|
||||||
def _do_quote_table_identifier(self, identifier):
|
|
||||||
return identifier
|
|
||||||
|
|
||||||
class PGColumnDropper(ansisql.ANSIColumnDropper):
|
|
||||||
def _do_quote_table_identifier(self, identifier):
|
|
||||||
return identifier
|
|
||||||
|
|
||||||
class PGSchemaChanger(ansisql.ANSISchemaChanger):
|
|
||||||
def _do_quote_table_identifier(self, identifier):
|
def _do_quote_table_identifier(self, identifier):
|
||||||
return identifier
|
return identifier
|
||||||
def _do_quote_column_identifier(self, identifier):
|
def _do_quote_column_identifier(self, identifier):
|
||||||
return '"%s"'%identifier
|
return '"%s"'%identifier
|
||||||
|
|
||||||
|
class PGColumnGenerator(PGSchemaGenerator,ansisql.ANSIColumnGenerator, PGSchemaGeneratorMixin):
|
||||||
|
pass
|
||||||
|
|
||||||
class PGConstraintGenerator(ansisql.ANSIConstraintGenerator):
|
class PGColumnDropper(ansisql.ANSIColumnDropper, PGSchemaGeneratorMixin):
|
||||||
def _do_quote_table_identifier(self, identifier):
|
pass
|
||||||
return identifier
|
|
||||||
|
|
||||||
class PGConstraintDropper(ansisql.ANSIConstraintDropper):
|
class PGSchemaChanger(ansisql.ANSISchemaChanger, PGSchemaGeneratorMixin):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class PGConstraintGenerator(ansisql.ANSIConstraintGenerator, PGSchemaGeneratorMixin):
|
||||||
|
pass
|
||||||
|
class PGConstraintDropper(ansisql.ANSIConstraintDropper, PGSchemaGeneratorMixin):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class PGDialect(ansisql.ANSIDialect):
|
class PGDialect(ansisql.ANSIDialect):
|
||||||
|
@@ -40,7 +40,11 @@ def is_supported(url,supported,not_supported):
|
|||||||
ret = True
|
ret = True
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
#we make the engines global, which should make the tests run a bit faster
|
||||||
|
urls = readurls()
|
||||||
|
engines=dict([(url,create_engine(url)) for url in urls])
|
||||||
|
|
||||||
|
|
||||||
def usedb(supported=None,not_supported=None):
|
def usedb(supported=None,not_supported=None):
|
||||||
"""Decorates tests to be run with a database connection
|
"""Decorates tests to be run with a database connection
|
||||||
These tests are run once for each available database
|
These tests are run once for each available database
|
||||||
@@ -55,10 +59,9 @@ def usedb(supported=None,not_supported=None):
|
|||||||
msg = "Can't specify both supported and not_supported in fixture.db()"
|
msg = "Can't specify both supported and not_supported in fixture.db()"
|
||||||
assert False, msg
|
assert False, msg
|
||||||
|
|
||||||
urls = DB.urls
|
my_urls = [url for url in urls if is_supported(url,supported,not_supported)]
|
||||||
urls = [url for url in urls if is_supported(url,supported,not_supported)]
|
|
||||||
def dec(func):
|
def dec(func):
|
||||||
for url in urls:
|
for url in my_urls:
|
||||||
def entangle(self):
|
def entangle(self):
|
||||||
self._setup(url)
|
self._setup(url)
|
||||||
yield func, self
|
yield func, self
|
||||||
@@ -68,6 +71,7 @@ def usedb(supported=None,not_supported=None):
|
|||||||
return entangle
|
return entangle
|
||||||
return dec
|
return dec
|
||||||
|
|
||||||
|
|
||||||
class DB(Base):
|
class DB(Base):
|
||||||
# Constants: connection level
|
# Constants: connection level
|
||||||
NONE=0 # No connection; just set self.url
|
NONE=0 # No connection; just set self.url
|
||||||
@@ -75,9 +79,6 @@ class DB(Base):
|
|||||||
TXN=2 # Everything in a transaction
|
TXN=2 # Everything in a transaction
|
||||||
|
|
||||||
level=TXN
|
level=TXN
|
||||||
urls=readurls()
|
|
||||||
# url: engine
|
|
||||||
engines=dict([(url,create_engine(url)) for url in urls])
|
|
||||||
|
|
||||||
def shortDescription(self,*p,**k):
|
def shortDescription(self,*p,**k):
|
||||||
"""List database connection info with description of the test"""
|
"""List database connection info with description of the test"""
|
||||||
@@ -100,7 +101,7 @@ class DB(Base):
|
|||||||
|
|
||||||
def _connect(self,url):
|
def _connect(self,url):
|
||||||
self.url = url
|
self.url = url
|
||||||
self.engine = self.engines[url]
|
self.engine = engines[url]
|
||||||
self.meta = MetaData(bind=self.engine)
|
self.meta = MetaData(bind=self.engine)
|
||||||
if self.level < self.CONNECT:
|
if self.level < self.CONNECT:
|
||||||
return
|
return
|
||||||
@@ -120,11 +121,6 @@ class DB(Base):
|
|||||||
#if hasattr(self,'conn'):
|
#if hasattr(self,'conn'):
|
||||||
# self.conn.close()
|
# self.conn.close()
|
||||||
|
|
||||||
def ___run(self,*p,**k):
|
|
||||||
"""Run one test for each connection string"""
|
|
||||||
for url in self.urls:
|
|
||||||
self._run_one(url,*p,**k)
|
|
||||||
|
|
||||||
def _supported(self,url):
|
def _supported(self,url):
|
||||||
db = url.split(':',1)[0]
|
db = url.split(':',1)[0]
|
||||||
func = getattr(self,self._TestCase__testMethodName)
|
func = getattr(self,self._TestCase__testMethodName)
|
||||||
@@ -136,15 +132,6 @@ class DB(Base):
|
|||||||
return True
|
return True
|
||||||
def _not_supported(self,url):
|
def _not_supported(self,url):
|
||||||
return not self._supported(url)
|
return not self._supported(url)
|
||||||
|
|
||||||
def _run_one(self,url,*p,**k):
|
|
||||||
if self._not_supported(url):
|
|
||||||
return
|
|
||||||
self._connect(url)
|
|
||||||
try:
|
|
||||||
super(DB,self).run(*p,**k)
|
|
||||||
finally:
|
|
||||||
self._disconnect()
|
|
||||||
|
|
||||||
def refresh_table(self,name=None):
|
def refresh_table(self,name=None):
|
||||||
"""Reload the table from the database
|
"""Reload the table from the database
|
||||||
|
Reference in New Issue
Block a user