new: [backend] Added debounce function to throttle exercise checks

This commit is contained in:
Sami Mokaddem 2024-07-02 11:41:17 +02:00
parent 29ef580dad
commit bbfba0d6e4
3 changed files with 70 additions and 7 deletions

View file

@ -1,10 +1,14 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import functools
import time
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
import json import json
import re import re
from typing import Union from typing import Union
import jq
import db import db
from inject_evaluator import eval_data_filtering, eval_query_comparison from inject_evaluator import eval_data_filtering, eval_query_comparison
import misp_api import misp_api
@ -12,6 +16,26 @@ import config
ACTIVE_EXERCISES_DIR = "active_exercises" ACTIVE_EXERCISES_DIR = "active_exercises"
def debounce_check_active_tasks(debounce_seconds: int = 1):
func_last_execution_time = {}
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
user_id = args[0]
now = time.time()
key = f"{user_id}_{func.__name__}"
if key not in func_last_execution_time:
func_last_execution_time[key] = now
return func(*args, **kwargs)
elif now >= func_last_execution_time[key] + debounce_seconds:
func_last_execution_time[key] = now
return func(*args, **kwargs)
else:
print(f">> Debounced for `{user_id}`")
return None
return wrapper
return decorator
def load_exercises() -> bool: def load_exercises() -> bool:
db.ALL_EXERCISES = read_exercise_dir() db.ALL_EXERCISES = read_exercise_dir()
@ -54,6 +78,17 @@ def is_validate_exercises(exercises: list) -> bool:
return False return False
tasks_uuid.add(t_uuid) tasks_uuid.add(t_uuid)
task_by_uuid[t_uuid] = inject task_by_uuid[t_uuid] = inject
for inject_evaluation in inject.get('inject_evaluation', []):
if inject_evaluation.get('evaluation_strategy', None) == 'data_filtering':
for evaluation in inject_evaluation.get('parameters', []):
jq_path = list(evaluation.keys())[0]
try:
jq.compile(jq_path)
except ValueError as e:
print(f"[{t_uuid} :: {inject['name']}] Could not compile jq path `{jq_path}`\n", e)
return False
return True return True
@ -382,15 +417,15 @@ def fetch_data_for_query_comparison(user_id: int, inject_evaluation: dict, perfo
return data return data
@debounce_check_active_tasks(debounce_seconds=5)
def check_active_tasks(user_id: int, data: dict, context: dict) -> bool: def check_active_tasks(user_id: int, data: dict, context: dict) -> bool:
succeeded_once = False succeeded_once = False
available_tasks = get_available_tasks_for_user(user_id) available_tasks = get_available_tasks_for_user(user_id)
for task_uuid in available_tasks: for task_uuid in available_tasks:
inject = db.INJECT_BY_UUID[task_uuid] inject = db.INJECT_BY_UUID[task_uuid]
if inject['exercise_uuid'] not in db.SELECTED_EXERCISES: if inject['exercise_uuid'] not in db.SELECTED_EXERCISES:
print(f"exercise not active for this inject {inject['name']}")
continue continue
print(f"checking: {inject['name']}") print(f"[{task_uuid}] :: checking: {inject['name']}")
completed = check_inject(user_id, inject, data, context) completed = check_inject(user_id, inject, data, context)
if completed: if completed:
succeeded_once = True succeeded_once = True

View file

@ -22,7 +22,7 @@ def get(url, data={}, api_key=misp_apikey):
try: try:
response = requests.get(full_url, data=data, headers=headers, verify=not misp_skipssl) response = requests.get(full_url, data=data, headers=headers, verify=not misp_skipssl)
except requests.exceptions.ConnectionError as e: except requests.exceptions.ConnectionError as e:
print(e) print('Could not perform request on MISP.', e)
return None return None
return response.json() if response.headers['content-type'].startswith('application/json') else response.text return response.json() if response.headers['content-type'].startswith('application/json') else response.text
@ -38,7 +38,7 @@ def post(url, data={}, api_key=misp_apikey):
try: try:
response = requests.post(full_url, data=json.dumps(data), headers=headers, verify=not misp_skipssl) response = requests.post(full_url, data=json.dumps(data), headers=headers, verify=not misp_skipssl)
except requests.exceptions.ConnectionError as e: except requests.exceptions.ConnectionError as e:
print(e) print('Could not perform request on MISP.', e)
return None return None
return response.json() if response.headers['content-type'].startswith('application/json') else response.text return response.json() if response.headers['content-type'].startswith('application/json') else response.text

View file

@ -1,7 +1,9 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import functools
import json import json
import sys import sys
import time
import zmq import zmq
import socketio import socketio
import eventlet import eventlet
@ -18,6 +20,26 @@ import misp_api
ZMQ_MESSAGE_COUNT = 0 ZMQ_MESSAGE_COUNT = 0
def debounce(debounce_seconds: int = 1):
func_last_execution_time = {}
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
now = time.time()
key = func.__name__
if key not in func_last_execution_time:
func_last_execution_time[key] = now
return func(*args, **kwargs)
elif now >= func_last_execution_time[key] + debounce_seconds:
func_last_execution_time[key] = now
return func(*args, **kwargs)
else:
return None
return wrapper
return decorator
# Initialize ZeroMQ context and subscriber socket # Initialize ZeroMQ context and subscriber socket
context = gzmq.Context() context = gzmq.Context()
zsocket = context.socket(gzmq.SUB) zsocket = context.socket(gzmq.SUB)
@ -117,7 +139,12 @@ def handleMessage(topic, s, message):
context = get_context(data) context = get_context(data)
succeeded_once = exercise_model.check_active_tasks(user_id, data, context) succeeded_once = exercise_model.check_active_tasks(user_id, data, context)
if succeeded_once: if succeeded_once:
sio.emit('refresh_score') sendRefreshScore()
@debounce(debounce_seconds=1)
def sendRefreshScore():
sio.emit('refresh_score')
def get_context(data: dict) -> dict: def get_context(data: dict) -> dict:
@ -153,11 +180,12 @@ def forward_zmq_to_socketio():
while True: while True:
message = zsocket.recv_string() message = zsocket.recv_string()
topic, s, m = message.partition(" ") topic, s, m = message.partition(" ")
handleMessage(topic, s, m)
try: try:
ZMQ_MESSAGE_COUNT += 1 ZMQ_MESSAGE_COUNT += 1
handleMessage(topic, s, m) # handleMessage(topic, s, m)
except Exception as e: except Exception as e:
print(e) print('Error handling message', e)
if __name__ == "__main__": if __name__ == "__main__":