Add transactions and some tests for them.
This commit is contained in:
@@ -29,3 +29,7 @@ Public API
|
||||
|
||||
A :class:`~kazoo.protocol.states.KazooState` attribute indicating
|
||||
the current higher-level connection state.
|
||||
|
||||
.. autoclass:: TransactionRequest
|
||||
:members:
|
||||
:member-order: bysource
|
||||
|
||||
148
kazoo/client.py
148
kazoo/client.py
@@ -22,6 +22,7 @@ from kazoo.protocol.paths import normpath
|
||||
from kazoo.protocol.paths import _prefix_root
|
||||
from kazoo.protocol.serialization import (
|
||||
Auth,
|
||||
CheckVersion,
|
||||
Close,
|
||||
Create,
|
||||
Delete,
|
||||
@@ -32,7 +33,8 @@ from kazoo.protocol.serialization import (
|
||||
SetACL,
|
||||
GetData,
|
||||
SetData,
|
||||
Sync
|
||||
Sync,
|
||||
Transaction
|
||||
)
|
||||
from kazoo.protocol.states import KazooState
|
||||
from kazoo.protocol.states import KeeperState
|
||||
@@ -936,6 +938,20 @@ class KazooClient(object):
|
||||
async_result)
|
||||
return async_result
|
||||
|
||||
def transaction(self):
|
||||
"""Create and return a :class:`TransactionRequest` object
|
||||
|
||||
Creates a :class:`TransactionRequest` object. A Transaction can
|
||||
consist of multiple operations which can be committed as a
|
||||
single atomic unit. Either all of the operations will succeed
|
||||
or none of them.
|
||||
|
||||
:returns: A TransactionRequest.
|
||||
:rtype: :class:`TransactionRequest`
|
||||
|
||||
"""
|
||||
return TransactionRequest(self)
|
||||
|
||||
def delete(self, path, version=-1, recursive=False):
|
||||
"""Delete a node.
|
||||
|
||||
@@ -1009,3 +1025,133 @@ class KazooClient(object):
|
||||
self.delete(path)
|
||||
except NoNodeError: # pragma: nocover
|
||||
pass
|
||||
|
||||
|
||||
class TransactionRequest(object):
|
||||
"""A Zookeeper Transaction Request
|
||||
|
||||
A Transaction provides a builder object that can be used to
|
||||
construct and commit an atomic set of operations. The transaction
|
||||
must be committed before its sent.
|
||||
|
||||
Transactions are not thread-safe and should not be accessed from
|
||||
multiple threads at once.
|
||||
|
||||
"""
|
||||
def __init__(self, client):
|
||||
self.client = client
|
||||
self.operations = []
|
||||
self.committed = False
|
||||
|
||||
def create(self, path, value="", acl=None, ephemeral=False,
|
||||
sequence=False):
|
||||
"""Add a create ZNode to the transaction. Takes the same
|
||||
arguments as :meth:`KazooClient.create`, with the exception
|
||||
of `makepath`.
|
||||
|
||||
:returns: None
|
||||
|
||||
"""
|
||||
if acl is None and self.client.default_acl:
|
||||
acl = self.client.default_acl
|
||||
|
||||
if not isinstance(path, basestring):
|
||||
raise TypeError("path must be a string")
|
||||
if acl and not isinstance(acl, (tuple, list)):
|
||||
raise TypeError("acl must be a tuple/list of ACL's")
|
||||
if not isinstance(value, str):
|
||||
raise TypeError("value must be a byte string")
|
||||
if not isinstance(ephemeral, bool):
|
||||
raise TypeError("ephemeral must be a bool")
|
||||
if not isinstance(sequence, bool):
|
||||
raise TypeError("sequence must be a bool")
|
||||
|
||||
flags = 0
|
||||
if ephemeral:
|
||||
flags |= 1
|
||||
if sequence:
|
||||
flags |= 2
|
||||
if acl is None:
|
||||
acl = OPEN_ACL_UNSAFE
|
||||
|
||||
self._add(Create(_prefix_root(self.client.chroot, path), value, acl,
|
||||
flags), None)
|
||||
|
||||
def delete(self, path, version=-1):
|
||||
"""Add a delete ZNode to the transaction. Takes the same
|
||||
arguments as :meth:`KazooClient.delete`, with the exception of
|
||||
`recursive`.
|
||||
|
||||
"""
|
||||
if not isinstance(path, basestring):
|
||||
raise TypeError("path must be a string")
|
||||
if not isinstance(version, int):
|
||||
raise TypeError("version must be an int")
|
||||
self._add(Delete(_prefix_root(self.client.chroot, path), version))
|
||||
|
||||
def set_data(self, path, data, version=-1):
|
||||
"""Add a set ZNode value to the transaction. Takes the same
|
||||
arguments as :meth:`KazooClient.set`.
|
||||
|
||||
"""
|
||||
if not isinstance(path, basestring):
|
||||
raise TypeError("path must be a string")
|
||||
if not isinstance(data, basestring):
|
||||
raise TypeError("data must be a string")
|
||||
if not isinstance(version, int):
|
||||
raise TypeError("version must be an int")
|
||||
self._add(SetData(_prefix_root(self.client.chroot, path), data,
|
||||
version))
|
||||
|
||||
def check(self, path, version):
|
||||
"""Add a Check Version to the transaction.
|
||||
|
||||
This command will fail and abort a transaction if the path
|
||||
does not match the specified version.
|
||||
|
||||
"""
|
||||
if not isinstance(path, basestring):
|
||||
raise TypeError("path must be a string")
|
||||
if not isinstance(version, int):
|
||||
raise TypeError("version must be an int")
|
||||
self._add(CheckVersion(_prefix_root(self.client.chroot, path),
|
||||
version))
|
||||
|
||||
def commit_async(self):
|
||||
"""Commit the transaction asynchronously
|
||||
|
||||
:rtype: :class:`~kazoo.interfaces.IAsyncResult`
|
||||
|
||||
"""
|
||||
self._check_tx_state()
|
||||
self.committed = True
|
||||
async_object = self.client.handler.async_result()
|
||||
self.client._call(Transaction(self.operations), async_object)
|
||||
return async_object
|
||||
|
||||
def commit(self):
|
||||
"""Commit the transaction
|
||||
|
||||
:returns: A list of the results for each operation in the
|
||||
transaction.
|
||||
|
||||
"""
|
||||
return self.commit_async().get()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, exc_tb):
|
||||
"""Commit and cleanup accumulated transaction data"""
|
||||
if not exc_type:
|
||||
self.commit()
|
||||
|
||||
def _check_tx_state(self):
|
||||
if self.committed:
|
||||
raise ValueError('Transaction already committed')
|
||||
|
||||
def _add(self, request, post_processor=None):
|
||||
self._check_tx_state()
|
||||
if self.client.log_debug:
|
||||
log.debug('Added %r to %r', request, self)
|
||||
self.operations.append(request)
|
||||
|
||||
@@ -13,7 +13,7 @@ class ZookeeperError(KazooException):
|
||||
|
||||
|
||||
class CancelledError(KazooException):
|
||||
"""Raised when a process is cancelled by another thread"""
|
||||
"""Raised when a process is canceled by another thread"""
|
||||
|
||||
|
||||
class ConfigurationError(KazooException):
|
||||
|
||||
@@ -22,6 +22,7 @@ from kazoo.protocol.serialization import (
|
||||
GetChildren,
|
||||
Ping,
|
||||
ReplyHeader,
|
||||
Transaction,
|
||||
Watch,
|
||||
int_struct
|
||||
)
|
||||
@@ -297,6 +298,11 @@ class ConnectionHandler(object):
|
||||
async_object.set_exception(exc)
|
||||
return
|
||||
log.debug('Received response: %r', response)
|
||||
|
||||
# We special case a Transaction as we have to unchroot things
|
||||
if request.type == Transaction.type:
|
||||
response = Transaction.unchroot(client, response)
|
||||
|
||||
async_object.set(response)
|
||||
|
||||
# Determine if watchers should be registered
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
from collections import namedtuple
|
||||
import struct
|
||||
|
||||
from kazoo.exceptions import EXCEPTIONS
|
||||
from kazoo.protocol.states import ZnodeStat
|
||||
from kazoo.security import ACL
|
||||
from kazoo.security import Id
|
||||
@@ -13,6 +14,7 @@ int_int_struct = struct.Struct('!ii')
|
||||
int_int_long_struct = struct.Struct('!iiq')
|
||||
|
||||
int_long_int_long_struct = struct.Struct('!iqiq')
|
||||
multiheader_struct = struct.Struct('!iBi')
|
||||
reply_header_struct = struct.Struct('!iqi')
|
||||
stat_struct = struct.Struct('!qqqqiiiqiiq')
|
||||
|
||||
@@ -287,6 +289,62 @@ class GetChildren2(namedtuple('GetChildren2', 'path watcher')):
|
||||
return children, stat
|
||||
|
||||
|
||||
class CheckVersion(namedtuple('CheckVersion', 'path version')):
|
||||
type = 13
|
||||
|
||||
def serialize(self):
|
||||
b = bytearray()
|
||||
b.extend(write_string(self.path))
|
||||
b.extend(int_struct.pack(self.version))
|
||||
return b
|
||||
|
||||
|
||||
class Transaction(namedtuple('Transaction', 'operations')):
|
||||
type = 14
|
||||
|
||||
def serialize(self):
|
||||
b = bytearray()
|
||||
for op in self.operations:
|
||||
b.extend(MultiHeader(op.type, False, -1).serialize() +
|
||||
op.serialize())
|
||||
return b + multiheader_struct.pack(-1, True, -1)
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, bytes, offset):
|
||||
header = MultiHeader(None, False, None)
|
||||
results = []
|
||||
response = None
|
||||
while not header.done:
|
||||
if header.type == Create.type:
|
||||
response, offset = read_string(bytes, offset)
|
||||
elif header.type == Delete.type:
|
||||
response = True
|
||||
elif header.type == SetData.type:
|
||||
response = ZnodeStat._make(
|
||||
stat_struct.unpack_from(bytes, offset))
|
||||
offset += stat_struct.size
|
||||
elif header.type == CheckVersion.type:
|
||||
response = True
|
||||
elif header.type == -1:
|
||||
err = int_struct.unpack_from(bytes, offset)[0]
|
||||
offset += int_struct.size
|
||||
response = EXCEPTIONS[err]()
|
||||
if response:
|
||||
results.append(response)
|
||||
header, offset = MultiHeader.deserialize(bytes, offset)
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def unchroot(client, response):
|
||||
resp = []
|
||||
for result in response:
|
||||
if isinstance(result, unicode):
|
||||
resp.append(client.unchroot(result))
|
||||
else:
|
||||
resp.append(result)
|
||||
return resp
|
||||
|
||||
|
||||
class Auth(namedtuple('Auth', 'auth_type scheme auth')):
|
||||
type = 100
|
||||
|
||||
@@ -297,20 +355,35 @@ class Auth(namedtuple('Auth', 'auth_type scheme auth')):
|
||||
|
||||
class Watch(namedtuple('Watch', 'type state path')):
|
||||
@classmethod
|
||||
def deserialize(cls, buffer, offset):
|
||||
"""Given a buffer and the current buffer offset, return the
|
||||
def deserialize(cls, bytes, offset):
|
||||
"""Given bytes and the current bytes offset, return the
|
||||
type, state, path, and new offset"""
|
||||
type, state = int_int_struct.unpack_from(buffer, offset)
|
||||
type, state = int_int_struct.unpack_from(bytes, offset)
|
||||
offset += int_int_struct.size
|
||||
path, offset = read_string(buffer, offset)
|
||||
path, offset = read_string(bytes, offset)
|
||||
return cls(type, state, path), offset
|
||||
|
||||
|
||||
class ReplyHeader(namedtuple('ReplyHeader', 'xid, zxid, err')):
|
||||
@classmethod
|
||||
def deserialize(cls, buffer, offset):
|
||||
"""Given a buffer and the current buffer offset, return a
|
||||
def deserialize(cls, bytes, offset):
|
||||
"""Given bytes and the current bytes offset, return a
|
||||
:class:`ReplyHeader` instance and the new offset"""
|
||||
new_offset = offset + reply_header_struct.size
|
||||
return cls._make(
|
||||
reply_header_struct.unpack_from(buffer, offset)), new_offset
|
||||
reply_header_struct.unpack_from(bytes, offset)), new_offset
|
||||
|
||||
|
||||
class MultiHeader(namedtuple('MultiHeader', 'type done err')):
|
||||
def serialize(self):
|
||||
b = bytearray()
|
||||
b.extend(int_struct.pack(self.type))
|
||||
b.extend([1 if self.done else 0])
|
||||
b.extend(int_struct.pack(self.err))
|
||||
return b
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, bytes, offset):
|
||||
t, done, err = multiheader_struct.unpack_from(bytes, offset)
|
||||
offset += multiheader_struct.size
|
||||
return cls(t, done is 1, err), offset
|
||||
|
||||
@@ -617,6 +617,11 @@ class TestClient(KazooTestCase):
|
||||
client._safe_close()
|
||||
testit()
|
||||
|
||||
def test_client_state(self):
|
||||
from kazoo.protocol.states import KeeperState
|
||||
eq_(self.client.client_state, KeeperState.CONNECTED)
|
||||
|
||||
|
||||
dummy_dict = {
|
||||
'aversion': 1, 'ctime': 0, 'cversion': 1,
|
||||
'czxid': 110, 'dataLength': 1, 'ephemeralOwner': 'ben',
|
||||
@@ -624,6 +629,54 @@ dummy_dict = {
|
||||
}
|
||||
|
||||
|
||||
class TestClientTransactions(KazooTestCase):
|
||||
def test_basic_create(self):
|
||||
t = self.client.transaction()
|
||||
t.create('/freddy')
|
||||
t.create('/fred', ephemeral=True)
|
||||
t.create('/smith', sequence=True)
|
||||
results = t.commit()
|
||||
eq_(results[0], '/freddy')
|
||||
eq_(len(results), 3)
|
||||
self.assertTrue(results[2].startswith('/smith0'))
|
||||
|
||||
def test_bad_creates(self):
|
||||
args_list = [(True,), ('/smith', 0), ('/smith', '', 'bleh'),
|
||||
('/smith', '', None, 'fred'),
|
||||
('/smith', '', None, True, 'fred')]
|
||||
|
||||
@raises(TypeError)
|
||||
def testit(args):
|
||||
t = self.client.transaction()
|
||||
t.create(*args)
|
||||
|
||||
for args in args_list:
|
||||
testit(args)
|
||||
|
||||
def test_default_acl(self):
|
||||
from kazoo.security import make_digest_acl
|
||||
username = uuid.uuid4().hex
|
||||
password = uuid.uuid4().hex
|
||||
|
||||
digest_auth = "%s:%s" % (username, password)
|
||||
acl = make_digest_acl(username, password, all=True)
|
||||
|
||||
self.client.add_auth("digest", digest_auth)
|
||||
self.client.default_acl = (acl,)
|
||||
|
||||
t = self.client.transaction()
|
||||
t.create('/freddy')
|
||||
results = t.commit()
|
||||
eq_(results[0], '/freddy')
|
||||
|
||||
def test_basic_delete(self):
|
||||
self.client.create('/fred')
|
||||
t = self.client.transaction()
|
||||
t.delete('/fred')
|
||||
results = t.commit()
|
||||
eq_(results[0], True)
|
||||
|
||||
|
||||
class TestCallbacks(unittest.TestCase):
|
||||
def test_session_callback_states(self):
|
||||
from kazoo.protocol.states import KazooState, KeeperState
|
||||
|
||||
Reference in New Issue
Block a user