diff --git a/exercise.py b/exercise.py index 25f4696..0971f97 100644 --- a/exercise.py +++ b/exercise.py @@ -10,7 +10,7 @@ from typing import Union import jq import db -from inject_evaluator import eval_data_filtering, eval_query_mirror +from inject_evaluator import eval_data_filtering, eval_query_mirror, eval_query_search import misp_api import config from config import logger @@ -351,11 +351,13 @@ def inject_checker_router(user_id: int, inject_evaluation: dict, data: dict, con return False if inject_evaluation['evaluation_strategy'] == 'data_filtering': - return eval_data_filtering(user_id, inject_evaluation, data_to_validate) + return eval_data_filtering(user_id, inject_evaluation, data_to_validate, context) elif inject_evaluation['evaluation_strategy'] == 'query_mirror': expected_data = data_to_validate['expected_data'] data_to_validate = data_to_validate['data_to_validate'] - return eval_query_mirror(user_id, expected_data, data_to_validate) + return eval_query_mirror(user_id, expected_data, data_to_validate, context) + elif inject_evaluation['evaluation_strategy'] == 'query_search': + return eval_query_search(user_id, inject_evaluation, data_to_validate, context) return False @@ -367,6 +369,8 @@ def get_data_to_validate(user_id: int, inject_evaluation: dict, data: dict) -> U elif inject_evaluation['evaluation_strategy'] == 'query_mirror': perfomed_query = parse_performed_query_from_log(data) data_to_validate = fetch_data_for_query_mirror(user_id, inject_evaluation, perfomed_query) + elif inject_evaluation['evaluation_strategy'] == 'query_search': + data_to_validate = fetch_data_for_query_search(user_id, inject_evaluation) return data_to_validate @@ -438,6 +442,18 @@ def fetch_data_for_query_mirror(user_id: int, inject_evaluation: dict, perfomed_ return data +def fetch_data_for_query_search(user_id: int, inject_evaluation: dict) -> Union[None, dict]: + authkey = db.USER_ID_TO_AUTHKEY_MAPPING[user_id] + if 'evaluation_context' not in inject_evaluation and 'query_context' not in inject_evaluation['evaluation_context']: + return None + query_context = inject_evaluation['evaluation_context']['query_context'] + search_method = query_context['request_method'] + search_url = query_context['url'] + search_payload = inject_evaluation['payload'] + search_data = misp_api.doRestQuery(authkey, search_method, search_url, search_payload) + return search_data + + @debounce_check_active_tasks(debounce_seconds=2) def check_active_tasks(user_id: int, data: dict, context: dict) -> bool: succeeded_once = False diff --git a/exercises/basic-event-creation.json b/exercises/basic-event-creation.json index 15ef1b6..6def47b 100644 --- a/exercises/basic-event-creation.json +++ b/exercises/basic-event-creation.json @@ -137,6 +137,14 @@ "inject_evaluation": [ { "parameters": [ + { + ".Event.user_id": { + "comparison": "equals", + "values": [ + "{{user_id}}" + ] + } + }, { ".Event.info": { "comparison": "contains", @@ -148,9 +156,17 @@ } ], "result": "MISP Event created", - "evaluation_strategy": "data_filtering", + "evaluation_strategy": "query_search", "evaluation_context": { - "request_is_rest": true + "request_is_rest": true, + "query_context": { + "url": "/events/restSearch", + "request_method": "POST", + "payload": { + "timestamp": "10d", + "eventinfo": "%API%" + } + } }, "score_range": [ 0, diff --git a/inject_evaluator.py b/inject_evaluator.py index 53cc1a1..fb66403 100644 --- a/inject_evaluator.py +++ b/inject_evaluator.py @@ -6,7 +6,6 @@ import operator from config import logger -# .Event.Attribute[] | select(.value == "evil.exe") | .Tag def jq_extract(path: str, data: dict, extract_type='first'): query = jq.compile(path).input_value(data) try: @@ -15,28 +14,40 @@ def jq_extract(path: str, data: dict, extract_type='first'): return None +# Replace the substring `{{variable}}` by context[variable] in the provided string +def apply_replacement_from_context(string: str, context: dict) -> str: + replacement_regex = r"{{(\w+)}}" + matches = re.fullmatch(replacement_regex, string, re.MULTILINE) + if not matches: + return string + subst_str = matches.groups()[0] + subst = str(context.get(subst_str, '')) + return re.sub(replacement_regex, subst, string) + + ## ## Data Filtering ## -def condition_satisfied(evaluation_config: dict, data_to_validate: Union[dict, list, str]) -> bool: +def condition_satisfied(evaluation_config: dict, data_to_validate: Union[dict, list, str], context: dict) -> bool: if type(data_to_validate) is bool: data_to_validate = "1" if data_to_validate else "0" if type(data_to_validate) is str: - return eval_condition_str(evaluation_config, data_to_validate) + return eval_condition_str(evaluation_config, data_to_validate, context) elif type(data_to_validate) is list: - return eval_condition_list(evaluation_config, data_to_validate) + return eval_condition_list(evaluation_config, data_to_validate, context) elif type(data_to_validate) is dict: # Not sure how we could have condition on this - return eval_condition_dict(evaluation_config, data_to_validate) + return eval_condition_dict(evaluation_config, data_to_validate, context) return False -def eval_condition_str(evaluation_config: dict, data_to_validate: str) -> bool: +def eval_condition_str(evaluation_config: dict, data_to_validate: str, context: dict) -> bool: comparison_type = evaluation_config['comparison'] values = evaluation_config['values'] if len(values) == 0: return False + values = [apply_replacement_from_context(v, context) for v in values] if comparison_type == 'contains': values = [v.lower() for v in values] @@ -56,7 +67,7 @@ def eval_condition_str(evaluation_config: dict, data_to_validate: str) -> bool: return False -def eval_condition_list(evaluation_config: dict, data_to_validate: str) -> bool: +def eval_condition_list(evaluation_config: dict, data_to_validate: str, context: dict) -> bool: comparison_type = evaluation_config['comparison'] values = evaluation_config['values'] comparators = { @@ -69,7 +80,7 @@ def eval_condition_list(evaluation_config: dict, data_to_validate: str) -> bool: if len(values) == 0: return False - + values = [apply_replacement_from_context(v, context) for v in values] if comparison_type == 'contains' or comparison_type == 'equals': data_to_validate_set = set(data_to_validate) @@ -102,7 +113,7 @@ def eval_condition_list(evaluation_config: dict, data_to_validate: str) -> bool: return False -def eval_condition_dict(evaluation_config: dict, data_to_validate: str) -> bool: +def eval_condition_dict(evaluation_config: dict, data_to_validate: str, context: dict) -> bool: comparison_type = evaluation_config['comparison'] values = evaluation_config['values'] comparators = { @@ -113,6 +124,10 @@ def eval_condition_dict(evaluation_config: dict, data_to_validate: str) -> bool: '=': operator.eq, } + if len(values) == 0: + return False + values = [apply_replacement_from_context(v, context) for v in values] + comparison_type = evaluation_config['comparison'] if comparison_type == 'contains': pass @@ -129,21 +144,31 @@ def eval_condition_dict(evaluation_config: dict, data_to_validate: str) -> bool: return False -def eval_data_filtering(user_id: int, inject_evaluation: dict, data: dict) -> bool: +def eval_data_filtering(user_id: int, inject_evaluation: dict, data: dict, context: dict) -> bool: for evaluation_params in inject_evaluation['parameters']: for evaluation_path, evaluation_config in evaluation_params.items(): + evaluation_path = apply_replacement_from_context(evaluation_path, context) data_to_validate = jq_extract(evaluation_path, data, evaluation_config.get('extract_type', 'first')) if data_to_validate is None: logger.debug('Could not extract data') return False - if not condition_satisfied(evaluation_config, data_to_validate): + if not condition_satisfied(evaluation_config, data_to_validate, context): return False return True ## -## Query comparison +## Query mirror ## -def eval_query_mirror(user_id: int, expected_data, data_to_validate) -> bool: +def eval_query_mirror(user_id: int, expected_data, data_to_validate, context: dict) -> bool: return expected_data == data_to_validate + + + +## +## Query search +## + +def eval_query_search(user_id: int, inject_evaluation: dict, data: dict, context: dict) -> bool: + return eval_data_filtering(user_id, inject_evaluation, data, context) \ No newline at end of file diff --git a/server.py b/server.py index 5c5278d..e6e3210 100755 --- a/server.py +++ b/server.py @@ -51,7 +51,7 @@ zsocket.setsockopt_string(zmq.SUBSCRIBE, '') # Initialize Socket.IO server # sio = socketio.Server(cors_allowed_origins='*', async_mode='eventlet') -sio = socketio.AsyncServer(cors_allowed_origins='*', async_mode='aiohttp') +sio = socketio.AsyncServer(cors_allowed_origins='*', async_mode='aiohttp', logger=True, engineio_logger=True) app = web.Application() sio.attach(app) @@ -146,7 +146,7 @@ async def handleMessage(topic, s, message): user_id = notification_model.get_user_id(data) if user_id is not None: if exercise_model.is_accepted_query(data): - context = get_context(data) + context = get_context(user_id, data) succeeded_once = exercise_model.check_active_tasks(user_id, data, context) if succeeded_once: await sendRefreshScore() @@ -157,8 +157,10 @@ async def sendRefreshScore(): await sio.emit('refresh_score') -def get_context(data: dict) -> dict: - context = {} +def get_context(user_id: int, data: dict) -> dict: + context = { + 'user_id': user_id, + } if 'Log' in data: if 'request_is_rest' in data['Log']: context['request_is_rest'] = data['Log']['request_is_rest'] @@ -200,11 +202,13 @@ async def forward_zmq_to_socketio(): while True: message = await zsocket.recv_string() topic, s, m = message.partition(" ") + await handleMessage(topic, s, m) try: ZMQ_MESSAGE_COUNT += 1 ZMQ_LAST_TIME = time.time() - await handleMessage(topic, s, m) + # await handleMessage(topic, s, m) except Exception as e: + print(e) logger.error('Error handling message %s', e)