From f7e88199b4a19d445831d7be2978a46f013f19fc Mon Sep 17 00:00:00 2001 From: percious17 Date: Tue, 2 Dec 2008 15:25:12 +0000 Subject: [PATCH] all tests pass with postgres now --- migrate/changeset/databases/postgres.py | 24 +++++++++---------- test/fixture/database.py | 31 +++++++------------------ 2 files changed, 20 insertions(+), 35 deletions(-) diff --git a/migrate/changeset/databases/postgres.py b/migrate/changeset/databases/postgres.py index 533a5f4..9ea564f 100644 --- a/migrate/changeset/databases/postgres.py +++ b/migrate/changeset/databases/postgres.py @@ -4,26 +4,24 @@ from sqlalchemy.databases import postgres as sa_base PGSchemaGenerator = sa_base.PGSchemaGenerator -class PGColumnGenerator(PGSchemaGenerator,ansisql.ANSIColumnGenerator): - def _do_quote_table_identifier(self, identifier): - return identifier - -class PGColumnDropper(ansisql.ANSIColumnDropper): - def _do_quote_table_identifier(self, identifier): - return identifier - -class PGSchemaChanger(ansisql.ANSISchemaChanger): +class PGSchemaGeneratorMixin(object): def _do_quote_table_identifier(self, identifier): return identifier def _do_quote_column_identifier(self, identifier): return '"%s"'%identifier +class PGColumnGenerator(PGSchemaGenerator,ansisql.ANSIColumnGenerator, PGSchemaGeneratorMixin): + pass -class PGConstraintGenerator(ansisql.ANSIConstraintGenerator): - def _do_quote_table_identifier(self, identifier): - return identifier +class PGColumnDropper(ansisql.ANSIColumnDropper, PGSchemaGeneratorMixin): + pass -class PGConstraintDropper(ansisql.ANSIConstraintDropper): +class PGSchemaChanger(ansisql.ANSISchemaChanger, PGSchemaGeneratorMixin): + pass + +class PGConstraintGenerator(ansisql.ANSIConstraintGenerator, PGSchemaGeneratorMixin): + pass +class PGConstraintDropper(ansisql.ANSIConstraintDropper, PGSchemaGeneratorMixin): pass class PGDialect(ansisql.ANSIDialect): diff --git a/test/fixture/database.py b/test/fixture/database.py index a5bbac5..add3732 100644 --- a/test/fixture/database.py +++ b/test/fixture/database.py @@ -40,7 +40,11 @@ def is_supported(url,supported,not_supported): ret = True return ret - +#we make the engines global, which should make the tests run a bit faster +urls = readurls() +engines=dict([(url,create_engine(url)) for url in urls]) + + def usedb(supported=None,not_supported=None): """Decorates tests to be run with a database connection These tests are run once for each available database @@ -55,10 +59,9 @@ def usedb(supported=None,not_supported=None): msg = "Can't specify both supported and not_supported in fixture.db()" assert False, msg - urls = DB.urls - urls = [url for url in urls if is_supported(url,supported,not_supported)] + my_urls = [url for url in urls if is_supported(url,supported,not_supported)] def dec(func): - for url in urls: + for url in my_urls: def entangle(self): self._setup(url) yield func, self @@ -68,6 +71,7 @@ def usedb(supported=None,not_supported=None): return entangle return dec + class DB(Base): # Constants: connection level NONE=0 # No connection; just set self.url @@ -75,9 +79,6 @@ class DB(Base): TXN=2 # Everything in a transaction level=TXN - urls=readurls() - # url: engine - engines=dict([(url,create_engine(url)) for url in urls]) def shortDescription(self,*p,**k): """List database connection info with description of the test""" @@ -100,7 +101,7 @@ class DB(Base): def _connect(self,url): self.url = url - self.engine = self.engines[url] + self.engine = engines[url] self.meta = MetaData(bind=self.engine) if self.level < self.CONNECT: return @@ -120,11 +121,6 @@ class DB(Base): #if hasattr(self,'conn'): # self.conn.close() - def ___run(self,*p,**k): - """Run one test for each connection string""" - for url in self.urls: - self._run_one(url,*p,**k) - def _supported(self,url): db = url.split(':',1)[0] func = getattr(self,self._TestCase__testMethodName) @@ -136,15 +132,6 @@ class DB(Base): return True def _not_supported(self,url): return not self._supported(url) - - def _run_one(self,url,*p,**k): - if self._not_supported(url): - return - self._connect(url) - try: - super(DB,self).run(*p,**k) - finally: - self._disconnect() def refresh_table(self,name=None): """Reload the table from the database