import threading import time as mod_time import uuid from types import SimpleNamespace, TracebackType from typing import Optional, Type from redis.exceptions import LockError, LockNotOwnedError from redis.typing import Number class Lock: """ A shared, distributed Lock. Using Redis for locking allows the Lock to be shared across processes and/or machines. It's left to the user to resolve deadlock issues and make sure multiple clients play nicely together. """ lua_release = None lua_extend = None lua_reacquire = None # KEYS[1] - lock name # ARGV[1] - token # return 1 if the lock was released, otherwise 0 LUA_RELEASE_SCRIPT = """ local token = redis.call('get', KEYS[1]) if not token or token ~= ARGV[1] then return 0 end redis.call('del', KEYS[1]) return 1 """ # KEYS[1] - lock name # ARGV[1] - token # ARGV[2] - additional milliseconds # ARGV[3] - "0" if the additional time should be added to the lock's # existing ttl or "1" if the existing ttl should be replaced # return 1 if the locks time was extended, otherwise 0 LUA_EXTEND_SCRIPT = """ local token = redis.call('get', KEYS[1]) if not token or token ~= ARGV[1] then return 0 end local expiration = redis.call('pttl', KEYS[1]) if not expiration then expiration = 0 end if expiration < 0 then return 0 end local newttl = ARGV[2] if ARGV[3] == "0" then newttl = ARGV[2] + expiration end redis.call('pexpire', KEYS[1], newttl) return 1 """ # KEYS[1] - lock name # ARGV[1] - token # ARGV[2] - milliseconds # return 1 if the locks time was reacquired, otherwise 0 LUA_REACQUIRE_SCRIPT = """ local token = redis.call('get', KEYS[1]) if not token or token ~= ARGV[1] then return 0 end redis.call('pexpire', KEYS[1], ARGV[2]) return 1 """ def __init__( self, redis, name: str, timeout: Optional[Number] = None, sleep: Number = 0.1, blocking: bool = True, blocking_timeout: Optional[Number] = None, thread_local: bool = True, ): """ Create a new Lock instance named ``name`` using the Redis client supplied by ``redis``. ``timeout`` indicates a maximum life for the lock in seconds. By default, it will remain locked until release() is called. ``timeout`` can be specified as a float or integer, both representing the number of seconds to wait. ``sleep`` indicates the amount of time to sleep in seconds per loop iteration when the lock is in blocking mode and another client is currently holding the lock. ``blocking`` indicates whether calling ``acquire`` should block until the lock has been acquired or to fail immediately, causing ``acquire`` to return False and the lock not being acquired. Defaults to True. Note this value can be overridden by passing a ``blocking`` argument to ``acquire``. ``blocking_timeout`` indicates the maximum amount of time in seconds to spend trying to acquire the lock. A value of ``None`` indicates continue trying forever. ``blocking_timeout`` can be specified as a float or integer, both representing the number of seconds to wait. ``thread_local`` indicates whether the lock token is placed in thread-local storage. By default, the token is placed in thread local storage so that a thread only sees its token, not a token set by another thread. Consider the following timeline: time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds. thread-1 sets the token to "abc" time: 1, thread-2 blocks trying to acquire `my-lock` using the Lock instance. time: 5, thread-1 has not yet completed. redis expires the lock key. time: 5, thread-2 acquired `my-lock` now that it's available. thread-2 sets the token to "xyz" time: 6, thread-1 finishes its work and calls release(). if the token is *not* stored in thread local storage, then thread-1 would see the token value as "xyz" and would be able to successfully release the thread-2's lock. In some use cases it's necessary to disable thread local storage. For example, if you have code where one thread acquires a lock and passes that lock instance to a worker thread to release later. If thread local storage isn't disabled in this case, the worker thread won't see the token set by the thread that acquired the lock. Our assumption is that these cases aren't common and as such default to using thread local storage. """ self.redis = redis self.name = name self.timeout = timeout self.sleep = sleep self.blocking = blocking self.blocking_timeout = blocking_timeout self.thread_local = bool(thread_local) self.local = threading.local() if self.thread_local else SimpleNamespace() self.local.token = None self.register_scripts() def register_scripts(self) -> None: cls = self.__class__ client = self.redis if cls.lua_release is None: cls.lua_release = client.register_script(cls.LUA_RELEASE_SCRIPT) if cls.lua_extend is None: cls.lua_extend = client.register_script(cls.LUA_EXTEND_SCRIPT) if cls.lua_reacquire is None: cls.lua_reacquire = client.register_script(cls.LUA_REACQUIRE_SCRIPT) def __enter__(self) -> "Lock": if self.acquire(): return self raise LockError( "Unable to acquire lock within the time specified", lock_name=self.name, ) def __exit__( self, exc_type: Optional[Type[BaseException]], exc_value: Optional[BaseException], traceback: Optional[TracebackType], ) -> None: self.release() def acquire( self, sleep: Optional[Number] = None, blocking: Optional[bool] = None, blocking_timeout: Optional[Number] = None, token: Optional[str] = None, ): """ Use Redis to hold a shared, distributed lock named ``name``. Returns True once the lock is acquired. If ``blocking`` is False, always return immediately. If the lock was acquired, return True, otherwise return False. ``blocking_timeout`` specifies the maximum number of seconds to wait trying to acquire the lock. ``token`` specifies the token value to be used. If provided, token must be a bytes object or a string that can be encoded to a bytes object with the default encoding. If a token isn't specified, a UUID will be generated. """ if sleep is None: sleep = self.sleep if token is None: token = uuid.uuid1().hex.encode() else: encoder = self.redis.get_encoder() token = encoder.encode(token) if blocking is None: blocking = self.blocking if blocking_timeout is None: blocking_timeout = self.blocking_timeout stop_trying_at = None if blocking_timeout is not None: stop_trying_at = mod_time.monotonic() + blocking_timeout while True: if self.do_acquire(token): self.local.token = token return True if not blocking: return False next_try_at = mod_time.monotonic() + sleep if stop_trying_at is not None and next_try_at > stop_trying_at: return False mod_time.sleep(sleep) def do_acquire(self, token: str) -> bool: if self.timeout: # convert to milliseconds timeout = int(self.timeout * 1000) else: timeout = None if self.redis.set(self.name, token, nx=True, px=timeout): return True return False def locked(self) -> bool: """ Returns True if this key is locked by any process, otherwise False. """ return self.redis.get(self.name) is not None def owned(self) -> bool: """ Returns True if this key is locked by this lock, otherwise False. """ stored_token = self.redis.get(self.name) # need to always compare bytes to bytes # TODO: this can be simplified when the context manager is finished if stored_token and not isinstance(stored_token, bytes): encoder = self.redis.get_encoder() stored_token = encoder.encode(stored_token) return self.local.token is not None and stored_token == self.local.token def release(self) -> None: """ Releases the already acquired lock """ expected_token = self.local.token if expected_token is None: raise LockError("Cannot release an unlocked lock", lock_name=self.name) self.local.token = None self.do_release(expected_token) def do_release(self, expected_token: str) -> None: if not bool( self.lua_release(keys=[self.name], args=[expected_token], client=self.redis) ): raise LockNotOwnedError( "Cannot release a lock that's no longer owned", lock_name=self.name, ) def extend(self, additional_time: int, replace_ttl: bool = False) -> bool: """ Adds more time to an already acquired lock. ``additional_time`` can be specified as an integer or a float, both representing the number of seconds to add. ``replace_ttl`` if False (the default), add `additional_time` to the lock's existing ttl. If True, replace the lock's ttl with `additional_time`. """ if self.local.token is None: raise LockError("Cannot extend an unlocked lock", lock_name=self.name) if self.timeout is None: raise LockError("Cannot extend a lock with no timeout", lock_name=self.name) return self.do_extend(additional_time, replace_ttl) def do_extend(self, additional_time: int, replace_ttl: bool) -> bool: additional_time = int(additional_time * 1000) if not bool( self.lua_extend( keys=[self.name], args=[self.local.token, additional_time, "1" if replace_ttl else "0"], client=self.redis, ) ): raise LockNotOwnedError( "Cannot extend a lock that's no longer owned", lock_name=self.name, ) return True def reacquire(self) -> bool: """ Resets a TTL of an already acquired lock back to a timeout value. """ if self.local.token is None: raise LockError("Cannot reacquire an unlocked lock", lock_name=self.name) if self.timeout is None: raise LockError( "Cannot reacquire a lock with no timeout", lock_name=self.name, ) return self.do_reacquire() def do_reacquire(self) -> bool: timeout = int(self.timeout * 1000) if not bool( self.lua_reacquire( keys=[self.name], args=[self.local.token, timeout], client=self.redis ) ): raise LockNotOwnedError( "Cannot reacquire a lock that's no longer owned", lock_name=self.name, ) return True