From c5916450b6767d99ad0db508169620fc152ed871 Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Tue, 1 Oct 2013 01:55:52 +0900 Subject: [PATCH] Add con.select_db() method. (fixes #80) --- pymysql/connections.py | 5 +++++ pymysql/tests/test_connection.py | 13 +++++++++++++ 2 files changed, 18 insertions(+) diff --git a/pymysql/connections.py b/pymysql/connections.py index 9dc38d7..fec4128 100644 --- a/pymysql/connections.py +++ b/pymysql/connections.py @@ -706,6 +706,11 @@ class Connection(object): self._execute_command(COM_QUERY, "ROLLBACK") self._read_ok_packet() + def select_db(self, db): + '''Set current db''' + self._execute_command(COM_INIT_DB, db) + self._read_ok_packet() + def escape(self, obj): ''' Escape whatever value you pass to it ''' if isinstance(obj, str_type): diff --git a/pymysql/tests/test_connection.py b/pymysql/tests/test_connection.py index 67b685f..2e78aa8 100644 --- a/pymysql/tests/test_connection.py +++ b/pymysql/tests/test_connection.py @@ -41,6 +41,19 @@ class TestConnection(base.PyMySQLTestCase): cur.execute("SELECT @@AUTOCOMMIT") self.assertEqual(cur.fetchone()[0], 0) + def test_select_db(self): + con = self.connections[0] + current_db = self.databases[0]['db'] + other_db = self.databases[1]['db'] + + cur = con.cursor() + cur.execute('SELECT database()') + self.assertEqual(cur.fetchone()[0], current_db) + + con.select_db(other_db) + cur.execute('SELECT database()') + self.assertEqual(cur.fetchone()[0], other_db) + if __name__ == "__main__": try: