mirror of
https://github.com/MISP/misp-galaxy.git
synced 2024-11-26 16:57:18 +00:00
324 lines
12 KiB
Python
324 lines
12 KiB
Python
|
import threading
|
||
|
import time as mod_time
|
||
|
import uuid
|
||
|
from types import SimpleNamespace, TracebackType
|
||
|
from typing import Optional, Type
|
||
|
|
||
|
from redis.exceptions import LockError, LockNotOwnedError
|
||
|
from redis.typing import Number
|
||
|
|
||
|
|
||
|
class Lock:
|
||
|
"""
|
||
|
A shared, distributed Lock. Using Redis for locking allows the Lock
|
||
|
to be shared across processes and/or machines.
|
||
|
|
||
|
It's left to the user to resolve deadlock issues and make sure
|
||
|
multiple clients play nicely together.
|
||
|
"""
|
||
|
|
||
|
lua_release = None
|
||
|
lua_extend = None
|
||
|
lua_reacquire = None
|
||
|
|
||
|
# KEYS[1] - lock name
|
||
|
# ARGV[1] - token
|
||
|
# return 1 if the lock was released, otherwise 0
|
||
|
LUA_RELEASE_SCRIPT = """
|
||
|
local token = redis.call('get', KEYS[1])
|
||
|
if not token or token ~= ARGV[1] then
|
||
|
return 0
|
||
|
end
|
||
|
redis.call('del', KEYS[1])
|
||
|
return 1
|
||
|
"""
|
||
|
|
||
|
# KEYS[1] - lock name
|
||
|
# ARGV[1] - token
|
||
|
# ARGV[2] - additional milliseconds
|
||
|
# ARGV[3] - "0" if the additional time should be added to the lock's
|
||
|
# existing ttl or "1" if the existing ttl should be replaced
|
||
|
# return 1 if the locks time was extended, otherwise 0
|
||
|
LUA_EXTEND_SCRIPT = """
|
||
|
local token = redis.call('get', KEYS[1])
|
||
|
if not token or token ~= ARGV[1] then
|
||
|
return 0
|
||
|
end
|
||
|
local expiration = redis.call('pttl', KEYS[1])
|
||
|
if not expiration then
|
||
|
expiration = 0
|
||
|
end
|
||
|
if expiration < 0 then
|
||
|
return 0
|
||
|
end
|
||
|
|
||
|
local newttl = ARGV[2]
|
||
|
if ARGV[3] == "0" then
|
||
|
newttl = ARGV[2] + expiration
|
||
|
end
|
||
|
redis.call('pexpire', KEYS[1], newttl)
|
||
|
return 1
|
||
|
"""
|
||
|
|
||
|
# KEYS[1] - lock name
|
||
|
# ARGV[1] - token
|
||
|
# ARGV[2] - milliseconds
|
||
|
# return 1 if the locks time was reacquired, otherwise 0
|
||
|
LUA_REACQUIRE_SCRIPT = """
|
||
|
local token = redis.call('get', KEYS[1])
|
||
|
if not token or token ~= ARGV[1] then
|
||
|
return 0
|
||
|
end
|
||
|
redis.call('pexpire', KEYS[1], ARGV[2])
|
||
|
return 1
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
redis,
|
||
|
name: str,
|
||
|
timeout: Optional[Number] = None,
|
||
|
sleep: Number = 0.1,
|
||
|
blocking: bool = True,
|
||
|
blocking_timeout: Optional[Number] = None,
|
||
|
thread_local: bool = True,
|
||
|
):
|
||
|
"""
|
||
|
Create a new Lock instance named ``name`` using the Redis client
|
||
|
supplied by ``redis``.
|
||
|
|
||
|
``timeout`` indicates a maximum life for the lock in seconds.
|
||
|
By default, it will remain locked until release() is called.
|
||
|
``timeout`` can be specified as a float or integer, both representing
|
||
|
the number of seconds to wait.
|
||
|
|
||
|
``sleep`` indicates the amount of time to sleep in seconds per loop
|
||
|
iteration when the lock is in blocking mode and another client is
|
||
|
currently holding the lock.
|
||
|
|
||
|
``blocking`` indicates whether calling ``acquire`` should block until
|
||
|
the lock has been acquired or to fail immediately, causing ``acquire``
|
||
|
to return False and the lock not being acquired. Defaults to True.
|
||
|
Note this value can be overridden by passing a ``blocking``
|
||
|
argument to ``acquire``.
|
||
|
|
||
|
``blocking_timeout`` indicates the maximum amount of time in seconds to
|
||
|
spend trying to acquire the lock. A value of ``None`` indicates
|
||
|
continue trying forever. ``blocking_timeout`` can be specified as a
|
||
|
float or integer, both representing the number of seconds to wait.
|
||
|
|
||
|
``thread_local`` indicates whether the lock token is placed in
|
||
|
thread-local storage. By default, the token is placed in thread local
|
||
|
storage so that a thread only sees its token, not a token set by
|
||
|
another thread. Consider the following timeline:
|
||
|
|
||
|
time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds.
|
||
|
thread-1 sets the token to "abc"
|
||
|
time: 1, thread-2 blocks trying to acquire `my-lock` using the
|
||
|
Lock instance.
|
||
|
time: 5, thread-1 has not yet completed. redis expires the lock
|
||
|
key.
|
||
|
time: 5, thread-2 acquired `my-lock` now that it's available.
|
||
|
thread-2 sets the token to "xyz"
|
||
|
time: 6, thread-1 finishes its work and calls release(). if the
|
||
|
token is *not* stored in thread local storage, then
|
||
|
thread-1 would see the token value as "xyz" and would be
|
||
|
able to successfully release the thread-2's lock.
|
||
|
|
||
|
In some use cases it's necessary to disable thread local storage. For
|
||
|
example, if you have code where one thread acquires a lock and passes
|
||
|
that lock instance to a worker thread to release later. If thread
|
||
|
local storage isn't disabled in this case, the worker thread won't see
|
||
|
the token set by the thread that acquired the lock. Our assumption
|
||
|
is that these cases aren't common and as such default to using
|
||
|
thread local storage.
|
||
|
"""
|
||
|
self.redis = redis
|
||
|
self.name = name
|
||
|
self.timeout = timeout
|
||
|
self.sleep = sleep
|
||
|
self.blocking = blocking
|
||
|
self.blocking_timeout = blocking_timeout
|
||
|
self.thread_local = bool(thread_local)
|
||
|
self.local = threading.local() if self.thread_local else SimpleNamespace()
|
||
|
self.local.token = None
|
||
|
self.register_scripts()
|
||
|
|
||
|
def register_scripts(self) -> None:
|
||
|
cls = self.__class__
|
||
|
client = self.redis
|
||
|
if cls.lua_release is None:
|
||
|
cls.lua_release = client.register_script(cls.LUA_RELEASE_SCRIPT)
|
||
|
if cls.lua_extend is None:
|
||
|
cls.lua_extend = client.register_script(cls.LUA_EXTEND_SCRIPT)
|
||
|
if cls.lua_reacquire is None:
|
||
|
cls.lua_reacquire = client.register_script(cls.LUA_REACQUIRE_SCRIPT)
|
||
|
|
||
|
def __enter__(self) -> "Lock":
|
||
|
if self.acquire():
|
||
|
return self
|
||
|
raise LockError(
|
||
|
"Unable to acquire lock within the time specified",
|
||
|
lock_name=self.name,
|
||
|
)
|
||
|
|
||
|
def __exit__(
|
||
|
self,
|
||
|
exc_type: Optional[Type[BaseException]],
|
||
|
exc_value: Optional[BaseException],
|
||
|
traceback: Optional[TracebackType],
|
||
|
) -> None:
|
||
|
self.release()
|
||
|
|
||
|
def acquire(
|
||
|
self,
|
||
|
sleep: Optional[Number] = None,
|
||
|
blocking: Optional[bool] = None,
|
||
|
blocking_timeout: Optional[Number] = None,
|
||
|
token: Optional[str] = None,
|
||
|
):
|
||
|
"""
|
||
|
Use Redis to hold a shared, distributed lock named ``name``.
|
||
|
Returns True once the lock is acquired.
|
||
|
|
||
|
If ``blocking`` is False, always return immediately. If the lock
|
||
|
was acquired, return True, otherwise return False.
|
||
|
|
||
|
``blocking_timeout`` specifies the maximum number of seconds to
|
||
|
wait trying to acquire the lock.
|
||
|
|
||
|
``token`` specifies the token value to be used. If provided, token
|
||
|
must be a bytes object or a string that can be encoded to a bytes
|
||
|
object with the default encoding. If a token isn't specified, a UUID
|
||
|
will be generated.
|
||
|
"""
|
||
|
if sleep is None:
|
||
|
sleep = self.sleep
|
||
|
if token is None:
|
||
|
token = uuid.uuid1().hex.encode()
|
||
|
else:
|
||
|
encoder = self.redis.get_encoder()
|
||
|
token = encoder.encode(token)
|
||
|
if blocking is None:
|
||
|
blocking = self.blocking
|
||
|
if blocking_timeout is None:
|
||
|
blocking_timeout = self.blocking_timeout
|
||
|
stop_trying_at = None
|
||
|
if blocking_timeout is not None:
|
||
|
stop_trying_at = mod_time.monotonic() + blocking_timeout
|
||
|
while True:
|
||
|
if self.do_acquire(token):
|
||
|
self.local.token = token
|
||
|
return True
|
||
|
if not blocking:
|
||
|
return False
|
||
|
next_try_at = mod_time.monotonic() + sleep
|
||
|
if stop_trying_at is not None and next_try_at > stop_trying_at:
|
||
|
return False
|
||
|
mod_time.sleep(sleep)
|
||
|
|
||
|
def do_acquire(self, token: str) -> bool:
|
||
|
if self.timeout:
|
||
|
# convert to milliseconds
|
||
|
timeout = int(self.timeout * 1000)
|
||
|
else:
|
||
|
timeout = None
|
||
|
if self.redis.set(self.name, token, nx=True, px=timeout):
|
||
|
return True
|
||
|
return False
|
||
|
|
||
|
def locked(self) -> bool:
|
||
|
"""
|
||
|
Returns True if this key is locked by any process, otherwise False.
|
||
|
"""
|
||
|
return self.redis.get(self.name) is not None
|
||
|
|
||
|
def owned(self) -> bool:
|
||
|
"""
|
||
|
Returns True if this key is locked by this lock, otherwise False.
|
||
|
"""
|
||
|
stored_token = self.redis.get(self.name)
|
||
|
# need to always compare bytes to bytes
|
||
|
# TODO: this can be simplified when the context manager is finished
|
||
|
if stored_token and not isinstance(stored_token, bytes):
|
||
|
encoder = self.redis.get_encoder()
|
||
|
stored_token = encoder.encode(stored_token)
|
||
|
return self.local.token is not None and stored_token == self.local.token
|
||
|
|
||
|
def release(self) -> None:
|
||
|
"""
|
||
|
Releases the already acquired lock
|
||
|
"""
|
||
|
expected_token = self.local.token
|
||
|
if expected_token is None:
|
||
|
raise LockError("Cannot release an unlocked lock", lock_name=self.name)
|
||
|
self.local.token = None
|
||
|
self.do_release(expected_token)
|
||
|
|
||
|
def do_release(self, expected_token: str) -> None:
|
||
|
if not bool(
|
||
|
self.lua_release(keys=[self.name], args=[expected_token], client=self.redis)
|
||
|
):
|
||
|
raise LockNotOwnedError(
|
||
|
"Cannot release a lock that's no longer owned",
|
||
|
lock_name=self.name,
|
||
|
)
|
||
|
|
||
|
def extend(self, additional_time: int, replace_ttl: bool = False) -> bool:
|
||
|
"""
|
||
|
Adds more time to an already acquired lock.
|
||
|
|
||
|
``additional_time`` can be specified as an integer or a float, both
|
||
|
representing the number of seconds to add.
|
||
|
|
||
|
``replace_ttl`` if False (the default), add `additional_time` to
|
||
|
the lock's existing ttl. If True, replace the lock's ttl with
|
||
|
`additional_time`.
|
||
|
"""
|
||
|
if self.local.token is None:
|
||
|
raise LockError("Cannot extend an unlocked lock", lock_name=self.name)
|
||
|
if self.timeout is None:
|
||
|
raise LockError("Cannot extend a lock with no timeout", lock_name=self.name)
|
||
|
return self.do_extend(additional_time, replace_ttl)
|
||
|
|
||
|
def do_extend(self, additional_time: int, replace_ttl: bool) -> bool:
|
||
|
additional_time = int(additional_time * 1000)
|
||
|
if not bool(
|
||
|
self.lua_extend(
|
||
|
keys=[self.name],
|
||
|
args=[self.local.token, additional_time, "1" if replace_ttl else "0"],
|
||
|
client=self.redis,
|
||
|
)
|
||
|
):
|
||
|
raise LockNotOwnedError(
|
||
|
"Cannot extend a lock that's no longer owned",
|
||
|
lock_name=self.name,
|
||
|
)
|
||
|
return True
|
||
|
|
||
|
def reacquire(self) -> bool:
|
||
|
"""
|
||
|
Resets a TTL of an already acquired lock back to a timeout value.
|
||
|
"""
|
||
|
if self.local.token is None:
|
||
|
raise LockError("Cannot reacquire an unlocked lock", lock_name=self.name)
|
||
|
if self.timeout is None:
|
||
|
raise LockError(
|
||
|
"Cannot reacquire a lock with no timeout",
|
||
|
lock_name=self.name,
|
||
|
)
|
||
|
return self.do_reacquire()
|
||
|
|
||
|
def do_reacquire(self) -> bool:
|
||
|
timeout = int(self.timeout * 1000)
|
||
|
if not bool(
|
||
|
self.lua_reacquire(
|
||
|
keys=[self.name], args=[self.local.token, timeout], client=self.redis
|
||
|
)
|
||
|
):
|
||
|
raise LockNotOwnedError(
|
||
|
"Cannot reacquire a lock that's no longer owned",
|
||
|
lock_name=self.name,
|
||
|
)
|
||
|
return True
|